use parking_lot::Mutex;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use crate::distribution::Distribution;
use crate::param::ParamValue;
use crate::sampler::{CompletedTrial, Sampler};
pub struct RandomSampler {
rng: Mutex<StdRng>,
}
impl RandomSampler {
pub fn new() -> Self {
Self {
rng: Mutex::new(StdRng::from_os_rng()),
}
}
pub fn with_seed(seed: u64) -> Self {
Self {
rng: Mutex::new(StdRng::seed_from_u64(seed)),
}
}
}
impl Default for RandomSampler {
fn default() -> Self {
Self::new()
}
}
impl Sampler for RandomSampler {
fn sample(
&self,
distribution: &Distribution,
_trial_id: u64,
_history: &[CompletedTrial],
) -> ParamValue {
let mut rng = self.rng.lock();
match distribution {
Distribution::Float(d) => {
let value = if d.log_scale {
let log_low = d.low.ln();
let log_high = d.high.ln();
let log_value = rng.random_range(log_low..=log_high);
log_value.exp()
} else if let Some(step) = d.step {
let n_steps = ((d.high - d.low) / step).floor() as i64;
let k = rng.random_range(0..=n_steps);
d.low + (k as f64) * step
} else {
rng.random_range(d.low..=d.high)
};
ParamValue::Float(value)
}
Distribution::Int(d) => {
let value = if d.log_scale {
let log_low = (d.low as f64).ln();
let log_high = (d.high as f64).ln();
let log_value = rng.random_range(log_low..=log_high);
let raw = log_value.exp().round() as i64;
raw.clamp(d.low, d.high)
} else if let Some(step) = d.step {
let n_steps = (d.high - d.low) / step;
let k = rng.random_range(0..=n_steps);
d.low + k * step
} else {
rng.random_range(d.low..=d.high)
};
ParamValue::Int(value)
}
Distribution::Categorical(d) => {
let index = rng.random_range(0..d.n_choices);
ParamValue::Categorical(index)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution};
#[test]
fn test_random_sampler_float() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Float(v) = value {
assert!((0.0..=1.0).contains(&v));
} else {
panic!("Expected Float value");
}
}
}
#[test]
fn test_random_sampler_float_log() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Float(FloatDistribution {
low: 1e-5,
high: 1.0,
log_scale: true,
step: None,
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Float(v) = value {
assert!((1e-5..=1.0).contains(&v));
} else {
panic!("Expected Float value");
}
}
}
#[test]
fn test_random_sampler_float_step() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: Some(0.25),
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Float(v) = value {
assert!((0.0..=1.0).contains(&v));
let k = ((v - 0.0) / 0.25).round() as i64;
let expected = 0.0 + k as f64 * 0.25;
assert!((v - expected).abs() < 1e-10);
} else {
panic!("Expected Float value");
}
}
}
#[test]
fn test_random_sampler_int() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Int(IntDistribution {
low: 0,
high: 10,
log_scale: false,
step: None,
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Int(v) = value {
assert!((0..=10).contains(&v));
} else {
panic!("Expected Int value");
}
}
}
#[test]
fn test_random_sampler_int_log() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Int(IntDistribution {
low: 1,
high: 1000,
log_scale: true,
step: None,
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Int(v) = value {
assert!((1..=1000).contains(&v));
} else {
panic!("Expected Int value");
}
}
}
#[test]
fn test_random_sampler_int_step() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Int(IntDistribution {
low: 0,
high: 10,
log_scale: false,
step: Some(2),
});
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Int(v) = value {
assert!((0..=10).contains(&v));
assert!(v % 2 == 0);
} else {
panic!("Expected Int value");
}
}
}
#[test]
fn test_random_sampler_categorical() {
let sampler = RandomSampler::with_seed(42);
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 5 });
for _ in 0..100 {
let value = sampler.sample(&dist, 0, &[]);
if let ParamValue::Categorical(idx) = value {
assert!(idx < 5);
} else {
panic!("Expected Categorical value");
}
}
}
#[test]
fn test_random_sampler_reproducibility() {
let sampler1 = RandomSampler::with_seed(42);
let sampler2 = RandomSampler::with_seed(42);
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
for _ in 0..10 {
let v1 = sampler1.sample(&dist, 0, &[]);
let v2 = sampler2.sample(&dist, 0, &[]);
assert_eq!(v1, v2);
}
}
}