use std::collections::HashMap;
use parking_lot::Mutex;
use crate::distribution::{
CategoricalDistribution, Distribution, FloatDistribution, IntDistribution,
};
use crate::param::ParamValue;
use crate::sampler::{CompletedTrial, Sampler};
#[must_use]
pub fn generate_int_grid_points(dist: &IntDistribution, n_points: usize) -> Vec<i64> {
let low = dist.low;
let high = dist.high;
if low > high {
return vec![];
}
if low == high {
return vec![low];
}
let points: Vec<i64> = if let Some(step) = dist.step {
if step <= 0 {
return vec![low];
}
let mut result = Vec::new();
let mut current = low;
while current <= high {
result.push(current);
current = current.saturating_add(step);
if result.last() == Some(¤t) {
break;
}
}
result
} else if dist.log_scale {
if low <= 0 {
generate_linear_int_points(low, high, n_points)
} else {
generate_log_int_points(low, high, n_points)
}
} else {
generate_linear_int_points(low, high, n_points)
};
let mut clamped: Vec<i64> = points.into_iter().map(|p| p.clamp(low, high)).collect();
clamped.sort_unstable();
clamped.dedup();
clamped
}
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
fn generate_linear_int_points(low: i64, high: i64, n_points: usize) -> Vec<i64> {
if n_points == 0 {
return vec![];
}
if n_points == 1 {
return vec![low];
}
let range = high - low;
let mut result = Vec::with_capacity(n_points);
for i in 0..n_points {
let fraction = i as f64 / (n_points - 1) as f64;
let value = low as f64 + fraction * range as f64;
result.push(value.round() as i64);
}
result
}
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
fn generate_log_int_points(low: i64, high: i64, n_points: usize) -> Vec<i64> {
debug_assert!(low > 0, "log scale requires positive low bound");
if n_points == 0 {
return vec![];
}
if n_points == 1 {
return vec![low];
}
let log_low = (low as f64).ln();
let log_high = (high as f64).ln();
let mut result = Vec::with_capacity(n_points);
for i in 0..n_points {
let fraction = i as f64 / (n_points - 1) as f64;
let log_value = log_low + fraction * (log_high - log_low);
let value = log_value.exp().round() as i64;
result.push(value);
}
result
}
#[must_use]
pub fn generate_float_grid_points(dist: &FloatDistribution, n_points: usize) -> Vec<f64> {
let low = dist.low;
let high = dist.high;
if low > high {
return vec![];
}
if (low - high).abs() < f64::EPSILON {
return vec![low];
}
let points: Vec<f64> = if let Some(step) = dist.step {
if step <= 0.0 {
return vec![low];
}
let mut result = Vec::new();
let mut current = low;
while current <= high + f64::EPSILON {
result.push(current.clamp(low, high));
current += step;
if result.len() > 1_000_000 {
break;
}
}
result
} else if dist.log_scale {
if low <= 0.0 {
generate_linear_float_points(low, high, n_points)
} else {
generate_log_float_points(low, high, n_points)
}
} else {
generate_linear_float_points(low, high, n_points)
};
points.into_iter().map(|p| p.clamp(low, high)).collect()
}
#[allow(clippy::cast_precision_loss)]
fn generate_linear_float_points(low: f64, high: f64, n_points: usize) -> Vec<f64> {
if n_points == 0 {
return vec![];
}
if n_points == 1 {
return vec![low];
}
let range = high - low;
let mut result = Vec::with_capacity(n_points);
for i in 0..n_points {
let fraction = i as f64 / (n_points - 1) as f64;
let value = low + fraction * range;
result.push(value);
}
result
}
#[allow(clippy::cast_precision_loss)]
fn generate_log_float_points(low: f64, high: f64, n_points: usize) -> Vec<f64> {
debug_assert!(low > 0.0, "log scale requires positive low bound");
if n_points == 0 {
return vec![];
}
if n_points == 1 {
return vec![low];
}
let log_low = low.ln();
let log_high = high.ln();
let mut result = Vec::with_capacity(n_points);
for i in 0..n_points {
let fraction = i as f64 / (n_points - 1) as f64;
let log_value = log_low + fraction * (log_high - log_low);
let value = log_value.exp();
result.push(value);
}
result
}
#[must_use]
pub fn generate_categorical_grid_points(dist: &CategoricalDistribution) -> Vec<usize> {
(0..dist.n_choices).collect()
}
#[derive(Debug, Clone)]
struct CachedGrid {
points: Vec<ParamValue>,
current_index: usize,
}
#[derive(Debug, Default)]
struct GridState {
grids: HashMap<String, CachedGrid>,
}
pub struct GridSampler {
n_points_per_param: usize,
state: Mutex<GridState>,
}
impl GridSampler {
#[must_use]
pub fn new() -> Self {
Self {
n_points_per_param: 10,
state: Mutex::new(GridState::default()),
}
}
#[must_use]
pub fn builder() -> GridSearchSamplerBuilder {
GridSearchSamplerBuilder::new()
}
}
impl Default for GridSampler {
fn default() -> Self {
Self::new()
}
}
impl GridSampler {
#[must_use]
pub fn is_exhausted(&self) -> bool {
let state = self.state.lock();
if state.grids.is_empty() {
return true;
}
state
.grids
.values()
.all(|grid| grid.current_index >= grid.points.len())
}
#[must_use]
pub fn grid_size(&self) -> usize {
let state = self.state.lock();
state.grids.values().map(|grid| grid.points.len()).sum()
}
}
#[derive(Debug, Clone)]
pub struct GridSearchSamplerBuilder {
n_points_per_param: usize,
}
impl GridSearchSamplerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
n_points_per_param: 10,
}
}
#[must_use]
pub fn n_points_per_param(mut self, n: usize) -> Self {
self.n_points_per_param = n;
self
}
#[must_use]
pub fn build(self) -> GridSampler {
GridSampler {
n_points_per_param: self.n_points_per_param,
state: Mutex::new(GridState::default()),
}
}
}
impl Default for GridSearchSamplerBuilder {
fn default() -> Self {
Self::new()
}
}
fn distribution_key(dist: &Distribution) -> String {
match dist {
Distribution::Float(d) => {
format!(
"float:{}:{}:{}:{}",
d.low,
d.high,
d.log_scale,
d.step.map_or("none".to_string(), |s| s.to_string())
)
}
Distribution::Int(d) => {
format!(
"int:{}:{}:{}:{}",
d.low,
d.high,
d.log_scale,
d.step.map_or("none".to_string(), |s| s.to_string())
)
}
Distribution::Categorical(d) => {
format!("cat:{}", d.n_choices)
}
}
}
impl Sampler for GridSampler {
fn sample(
&self,
distribution: &Distribution,
_trial_id: u64,
_history: &[CompletedTrial],
) -> ParamValue {
let mut state = self.state.lock();
let key = distribution_key(distribution);
let cached = state.grids.entry(key).or_insert_with(|| {
let points = match distribution {
Distribution::Float(d) => generate_float_grid_points(d, self.n_points_per_param)
.into_iter()
.map(ParamValue::Float)
.collect(),
Distribution::Int(d) => generate_int_grid_points(d, self.n_points_per_param)
.into_iter()
.map(ParamValue::Int)
.collect(),
Distribution::Categorical(d) => generate_categorical_grid_points(d)
.into_iter()
.map(ParamValue::Categorical)
.collect(),
};
CachedGrid {
points,
current_index: 0,
}
});
assert!(
cached.current_index < cached.points.len(),
"GridSampler: all grid points exhausted"
);
let value = cached.points[cached.current_index].clone();
cached.current_index += 1;
value
}
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
mod tests {
use super::*;
#[test]
fn test_int_grid_with_step() {
let dist = IntDistribution {
low: 0,
high: 10,
log_scale: false,
step: Some(2),
};
let points = generate_int_grid_points(&dist, 10);
assert_eq!(points, vec![0, 2, 4, 6, 8, 10]);
}
#[test]
fn test_int_grid_with_step_not_exact_multiple() {
let dist = IntDistribution {
low: 0,
high: 9,
log_scale: false,
step: Some(2),
};
let points = generate_int_grid_points(&dist, 10);
assert_eq!(points, vec![0, 2, 4, 6, 8]);
}
#[test]
fn test_int_grid_without_step_linear() {
let dist = IntDistribution {
low: 0,
high: 100,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 5);
assert_eq!(points, vec![0, 25, 50, 75, 100]);
}
#[test]
fn test_int_grid_without_step_linear_10_points() {
let dist = IntDistribution {
low: 0,
high: 9,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 10);
assert_eq!(points, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_int_grid_with_log_scale() {
let dist = IntDistribution {
low: 1,
high: 1000,
log_scale: true,
step: None,
};
let points = generate_int_grid_points(&dist, 4);
assert_eq!(points.len(), 4);
assert_eq!(points[0], 1);
assert_eq!(*points.last().unwrap(), 1000);
for p in &points {
assert!(*p >= 1 && *p <= 1000);
}
}
#[test]
fn test_int_grid_log_scale_non_positive_fallback() {
let dist = IntDistribution {
low: 0,
high: 100,
log_scale: true,
step: None,
};
let points = generate_int_grid_points(&dist, 5);
assert_eq!(points, vec![0, 25, 50, 75, 100]);
}
#[test]
fn test_int_grid_single_point() {
let dist = IntDistribution {
low: 5,
high: 5,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 10);
assert_eq!(points, vec![5]);
}
#[test]
fn test_int_grid_invalid_bounds() {
let dist = IntDistribution {
low: 10,
high: 5,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 10);
assert!(points.is_empty());
}
#[test]
fn test_float_grid_with_step() {
let dist = FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: Some(0.25),
};
let points = generate_float_grid_points(&dist, 10);
assert_eq!(points.len(), 5);
assert!((points[0] - 0.0).abs() < f64::EPSILON);
assert!((points[1] - 0.25).abs() < f64::EPSILON);
assert!((points[2] - 0.5).abs() < f64::EPSILON);
assert!((points[3] - 0.75).abs() < f64::EPSILON);
assert!((points[4] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_float_grid_step_overrides_log_scale() {
let dist = FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: true, step: Some(0.5),
};
let points = generate_float_grid_points(&dist, 10);
assert_eq!(points.len(), 3);
assert!((points[0] - 0.0).abs() < f64::EPSILON);
assert!((points[1] - 0.5).abs() < f64::EPSILON);
assert!((points[2] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_float_grid_without_step_linear() {
let dist = FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
};
let points = generate_float_grid_points(&dist, 5);
assert_eq!(points.len(), 5);
assert!((points[0] - 0.0).abs() < f64::EPSILON);
assert!((points[1] - 0.25).abs() < f64::EPSILON);
assert!((points[2] - 0.5).abs() < f64::EPSILON);
assert!((points[3] - 0.75).abs() < f64::EPSILON);
assert!((points[4] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_float_grid_with_log_scale() {
let dist = FloatDistribution {
low: 1e-4,
high: 1.0,
log_scale: true,
step: None,
};
let points = generate_float_grid_points(&dist, 5);
assert_eq!(points.len(), 5);
assert!((points[0] - 1e-4).abs() < 1e-10);
assert!((points[4] - 1.0).abs() < 1e-10);
for p in &points {
assert!(*p >= 1e-4 && *p <= 1.0);
}
let ratio1 = points[1] / points[0];
let ratio2 = points[2] / points[1];
assert!((ratio1 - ratio2).abs() / ratio1 < 0.01);
}
#[test]
fn test_float_grid_log_scale_non_positive_fallback() {
let dist = FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: true,
step: None,
};
let points = generate_float_grid_points(&dist, 5);
assert_eq!(points.len(), 5);
assert!((points[0] - 0.0).abs() < f64::EPSILON);
assert!((points[4] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_float_grid_single_point() {
let dist = FloatDistribution {
low: 0.5,
high: 0.5,
log_scale: false,
step: None,
};
let points = generate_float_grid_points(&dist, 10);
assert_eq!(points.len(), 1);
assert!((points[0] - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_float_grid_invalid_bounds() {
let dist = FloatDistribution {
low: 1.0,
high: 0.0,
log_scale: false,
step: None,
};
let points = generate_float_grid_points(&dist, 10);
assert!(points.is_empty());
}
#[test]
fn test_categorical_grid() {
let dist = CategoricalDistribution { n_choices: 5 };
let points = generate_categorical_grid_points(&dist);
assert_eq!(points, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_categorical_grid_single_choice() {
let dist = CategoricalDistribution { n_choices: 1 };
let points = generate_categorical_grid_points(&dist);
assert_eq!(points, vec![0]);
}
#[test]
fn test_categorical_grid_empty() {
let dist = CategoricalDistribution { n_choices: 0 };
let points = generate_categorical_grid_points(&dist);
assert!(points.is_empty());
}
#[test]
fn test_sampler_exhausts_after_expected_samples() {
let sampler = GridSampler::new();
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 3 });
for _ in 0..3 {
let _ = sampler.sample(&dist, 0, &[]);
}
assert!(sampler.is_exhausted());
}
#[test]
fn test_sampler_exhaustion_with_int_distribution() {
let sampler = GridSampler::builder().n_points_per_param(5).build();
let dist = Distribution::Int(IntDistribution {
low: 0,
high: 100,
log_scale: false,
step: None,
});
for _ in 0..5 {
let _ = sampler.sample(&dist, 0, &[]);
}
assert!(sampler.is_exhausted());
assert_eq!(sampler.grid_size(), 5);
}
#[test]
#[should_panic(expected = "GridSampler: all grid points exhausted")]
fn test_sampler_panics_after_exhaustion() {
let sampler = GridSampler::new();
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 2 });
sampler.sample(&dist, 0, &[]);
sampler.sample(&dist, 0, &[]);
sampler.sample(&dist, 0, &[]);
}
#[test]
fn test_is_exhausted_before_sampling() {
let sampler = GridSampler::new();
assert!(sampler.is_exhausted());
}
#[test]
fn test_is_exhausted_during_sampling() {
let sampler = GridSampler::new();
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 3 });
sampler.sample(&dist, 0, &[]);
assert!(!sampler.is_exhausted());
sampler.sample(&dist, 0, &[]);
assert!(!sampler.is_exhausted());
sampler.sample(&dist, 0, &[]);
assert!(sampler.is_exhausted());
}
#[test]
fn test_is_exhausted_multiple_distributions() {
let sampler = GridSampler::new();
let dist1 = Distribution::Categorical(CategoricalDistribution { n_choices: 2 });
let dist2 = Distribution::Categorical(CategoricalDistribution { n_choices: 3 });
sampler.sample(&dist1, 0, &[]);
sampler.sample(&dist1, 0, &[]);
sampler.sample(&dist2, 0, &[]);
assert!(!sampler.is_exhausted());
sampler.sample(&dist2, 0, &[]);
assert!(!sampler.is_exhausted());
sampler.sample(&dist2, 0, &[]);
assert!(sampler.is_exhausted());
}
#[test]
fn test_builder_default() {
let sampler = GridSampler::builder().build();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..10 {
let _ = sampler.sample(&dist, 0, &[]);
}
assert!(sampler.is_exhausted());
}
#[test]
fn test_builder_custom_n_points() {
let sampler = GridSampler::builder().n_points_per_param(3).build();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..3 {
let _ = sampler.sample(&dist, 0, &[]);
}
assert!(sampler.is_exhausted());
assert_eq!(sampler.grid_size(), 3);
}
#[test]
fn test_new_default() {
let sampler = GridSampler::new();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..10 {
let _ = sampler.sample(&dist, 0, &[]);
}
assert!(sampler.is_exhausted());
}
#[test]
fn test_reproducibility_same_grid_order() {
let sampler1 = GridSampler::builder().n_points_per_param(5).build();
let sampler2 = GridSampler::builder().n_points_per_param(5).build();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..5 {
let v1 = sampler1.sample(&dist, 0, &[]);
let v2 = sampler2.sample(&dist, 0, &[]);
assert_eq!(v1, v2);
}
}
#[test]
fn test_reproducibility_int_distribution() {
let sampler1 = GridSampler::new();
let sampler2 = GridSampler::new();
let dist = Distribution::Int(IntDistribution {
low: 0,
high: 10,
log_scale: false,
step: Some(2),
});
let expected = vec![0, 2, 4, 6, 8, 10];
for exp in &expected {
let v1 = sampler1.sample(&dist, 0, &[]);
let v2 = sampler2.sample(&dist, 0, &[]);
assert_eq!(v1, ParamValue::Int(*exp));
assert_eq!(v2, ParamValue::Int(*exp));
}
}
#[test]
fn test_reproducibility_categorical() {
let sampler1 = GridSampler::new();
let sampler2 = GridSampler::new();
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 4 });
for i in 0..4 {
let v1 = sampler1.sample(&dist, 0, &[]);
let v2 = sampler2.sample(&dist, 0, &[]);
assert_eq!(v1, ParamValue::Categorical(i));
assert_eq!(v2, ParamValue::Categorical(i));
}
}
#[test]
fn test_grid_size_empty() {
let sampler = GridSampler::new();
assert_eq!(sampler.grid_size(), 0);
}
#[test]
fn test_grid_size_single_distribution() {
let sampler = GridSampler::builder().n_points_per_param(5).build();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
assert_eq!(sampler.grid_size(), 0);
sampler.sample(&dist, 0, &[]);
assert_eq!(sampler.grid_size(), 5);
}
#[test]
fn test_grid_size_multiple_distributions() {
let sampler = GridSampler::builder().n_points_per_param(3).build();
let dist1 = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist2 = Distribution::Categorical(CategoricalDistribution { n_choices: 5 });
sampler.sample(&dist1, 0, &[]);
assert_eq!(sampler.grid_size(), 3);
sampler.sample(&dist2, 0, &[]);
assert_eq!(sampler.grid_size(), 3 + 5);
}
#[test]
fn test_int_step_larger_than_range() {
let dist = IntDistribution {
low: 0,
high: 5,
log_scale: false,
step: Some(10),
};
let points = generate_int_grid_points(&dist, 10);
assert_eq!(points, vec![0]);
}
#[test]
fn test_float_step_larger_than_range() {
let dist = FloatDistribution {
low: 0.0,
high: 0.5,
log_scale: false,
step: Some(1.0),
};
let points = generate_float_grid_points(&dist, 10);
assert_eq!(points.len(), 1);
assert!((points[0] - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_n_points_one() {
let dist = IntDistribution {
low: 0,
high: 100,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 1);
assert_eq!(points, vec![0]);
}
#[test]
fn test_n_points_zero() {
let dist = IntDistribution {
low: 0,
high: 100,
log_scale: false,
step: None,
};
let points = generate_int_grid_points(&dist, 0);
assert!(points.is_empty());
}
}