use core::fmt::Debug;
use core::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::distribution::Distribution;
use crate::error::{Error, Result};
use crate::param::ParamValue;
use crate::rng_util;
use crate::sampler::common;
use crate::sampler::tpe::gamma::{FixedGamma, GammaStrategy};
use crate::sampler::{CompletedTrial, Sampler};
use super::common as tpe_common;
pub struct TpeSampler {
gamma_strategy: Arc<dyn GammaStrategy>,
n_startup_trials: usize,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
seed: u64,
call_seq: AtomicU64,
}
impl TpeSampler {
#[must_use]
pub fn new() -> Self {
Self {
gamma_strategy: Arc::new(FixedGamma::default()),
n_startup_trials: 10,
n_ei_candidates: 24,
kde_bandwidth: None,
seed: fastrand::u64(..),
call_seq: AtomicU64::new(0),
}
}
#[must_use]
pub fn builder() -> TpeSamplerBuilder {
TpeSamplerBuilder::new()
}
pub fn with_config(
gamma: f64,
n_startup_trials: usize,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
seed: Option<u64>,
) -> Result<Self> {
let gamma_strategy = FixedGamma::new(gamma)?;
Self::with_strategy(
gamma_strategy,
n_startup_trials,
n_ei_candidates,
kde_bandwidth,
seed,
)
}
pub fn with_strategy<G: GammaStrategy + 'static>(
gamma_strategy: G,
n_startup_trials: usize,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
seed: Option<u64>,
) -> Result<Self> {
if let Some(bw) = kde_bandwidth
&& bw <= 0.0
{
return Err(Error::InvalidBandwidth(bw));
}
Ok(Self {
gamma_strategy: Arc::new(gamma_strategy),
n_startup_trials,
n_ei_candidates,
kde_bandwidth,
seed: seed.unwrap_or_else(|| fastrand::u64(..)),
call_seq: AtomicU64::new(0),
})
}
#[must_use]
pub fn gamma_strategy(&self) -> &dyn GammaStrategy {
self.gamma_strategy.as_ref()
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
#[must_use]
fn split_trials<'a>(
&self,
history: &'a [CompletedTrial],
) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) {
if history.is_empty() {
return (vec![], vec![]);
}
let gamma = self
.gamma_strategy
.gamma(history.len())
.clamp(f64::EPSILON, 1.0 - f64::EPSILON);
let n_good = ((history.len() as f64 * gamma).ceil() as usize)
.max(1)
.min(history.len() - 1);
let mut indices: Vec<usize> = (0..history.len()).collect();
if n_good > 0 {
indices.select_nth_unstable_by(n_good - 1, |&a, &b| {
history[a]
.value
.partial_cmp(&history[b].value)
.unwrap_or(core::cmp::Ordering::Equal)
});
}
let good: Vec<_> = indices[..n_good].iter().map(|&i| &history[i]).collect();
let bad: Vec<_> = indices[n_good..].iter().map(|&i| &history[i]).collect();
(good, bad)
}
}
impl Default for TpeSampler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TpeSamplerBuilder {
gamma_strategy: Box<dyn GammaStrategy>,
raw_gamma: Option<f64>,
n_startup_trials: usize,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
seed: Option<u64>,
}
impl TpeSamplerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
gamma_strategy: Box::new(FixedGamma::default()),
raw_gamma: None,
n_startup_trials: 10,
n_ei_candidates: 24,
kde_bandwidth: None,
seed: None,
}
}
#[must_use]
pub fn gamma(mut self, gamma: f64) -> Self {
self.raw_gamma = Some(gamma);
self
}
#[must_use]
pub fn gamma_strategy<G: GammaStrategy + 'static>(mut self, strategy: G) -> Self {
self.gamma_strategy = Box::new(strategy);
self.raw_gamma = None; self
}
#[must_use]
pub fn n_startup_trials(mut self, n: usize) -> Self {
self.n_startup_trials = n;
self
}
#[must_use]
pub fn n_ei_candidates(mut self, n: usize) -> Self {
self.n_ei_candidates = n;
self
}
#[must_use]
pub fn kde_bandwidth(mut self, bandwidth: f64) -> Self {
self.kde_bandwidth = Some(bandwidth);
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn build(self) -> Result<TpeSampler> {
let gamma_strategy: Arc<dyn GammaStrategy> = if let Some(raw) = self.raw_gamma {
Arc::new(FixedGamma::new(raw)?)
} else {
Arc::from(self.gamma_strategy)
};
if let Some(bw) = self.kde_bandwidth
&& bw <= 0.0
{
return Err(Error::InvalidBandwidth(bw));
}
Ok(TpeSampler {
gamma_strategy,
n_startup_trials: self.n_startup_trials,
n_ei_candidates: self.n_ei_candidates,
kde_bandwidth: self.kde_bandwidth,
seed: self.seed.unwrap_or_else(|| fastrand::u64(..)),
call_seq: AtomicU64::new(0),
})
}
}
impl Default for TpeSamplerBuilder {
fn default() -> Self {
Self::new()
}
}
fn find_matching_value<'t>(
t: &'t CompletedTrial,
target_dist: &Distribution,
) -> Option<&'t ParamValue> {
t.distributions
.iter()
.filter(|(_, dist)| *dist == target_dist)
.min_by_key(|(id, _)| *id)
.and_then(|(id, _)| t.params.get(id))
}
impl TpeSampler {
fn sample_float(
&self,
d: &crate::distribution::FloatDistribution,
good_trials: &[&CompletedTrial],
bad_trials: &[&CompletedTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Float(d.clone());
let good_values: Vec<f64> = good_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Float(f) => Some(*f),
_ => None,
})
.collect();
let bad_values: Vec<f64> = bad_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Float(f) => Some(*f),
_ => None,
})
.collect();
if good_values.is_empty() || bad_values.is_empty() {
return ParamValue::Float(rng_util::f64_range(rng, d.low, d.high));
}
let value = tpe_common::sample_tpe_float(
d,
good_values,
bad_values,
self.n_ei_candidates,
self.kde_bandwidth,
rng,
);
ParamValue::Float(value)
}
fn sample_int(
&self,
d: &crate::distribution::IntDistribution,
good_trials: &[&CompletedTrial],
bad_trials: &[&CompletedTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Int(d.clone());
let good_values: Vec<i64> = good_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Int(i) => Some(*i),
_ => None,
})
.collect();
let bad_values: Vec<i64> = bad_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Int(i) => Some(*i),
_ => None,
})
.collect();
if good_values.is_empty() || bad_values.is_empty() {
return common::sample_random(rng, &Distribution::Int(d.clone()));
}
let value = tpe_common::sample_tpe_int(
d,
good_values,
bad_values,
self.n_ei_candidates,
self.kde_bandwidth,
rng,
);
ParamValue::Int(value)
}
#[allow(clippy::unused_self)]
fn sample_categorical(
&self,
d: &crate::distribution::CategoricalDistribution,
good_trials: &[&CompletedTrial],
bad_trials: &[&CompletedTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Categorical(d.clone());
let good_indices: Vec<usize> = good_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Categorical(i) => Some(*i),
_ => None,
})
.collect();
let bad_indices: Vec<usize> = bad_trials
.iter()
.filter_map(|t| match find_matching_value(t, &target_dist)? {
ParamValue::Categorical(i) => Some(*i),
_ => None,
})
.collect();
if good_indices.is_empty() || bad_indices.is_empty() {
return common::sample_random(rng, &Distribution::Categorical(d.clone()));
}
let index =
tpe_common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng);
ParamValue::Categorical(index)
}
}
impl Sampler for TpeSampler {
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
history: &[CompletedTrial],
) -> ParamValue {
let seq = self.call_seq.fetch_add(1, Ordering::Relaxed);
let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed(
self.seed,
trial_id,
rng_util::distribution_fingerprint(distribution).wrapping_add(seq),
));
if history.len() < self.n_startup_trials {
return common::sample_random(&mut rng, distribution);
}
let (good_trials, bad_trials) = self.split_trials(history);
if good_trials.is_empty() || bad_trials.is_empty() {
return common::sample_random(&mut rng, distribution);
}
match distribution {
Distribution::Float(d) => self.sample_float(d, &good_trials, &bad_trials, &mut rng),
Distribution::Int(d) => self.sample_int(d, &good_trials, &bad_trials, &mut rng),
Distribution::Categorical(d) => {
self.sample_categorical(d, &good_trials, &bad_trials, &mut rng)
}
}
}
}
#[cfg(test)]
#[allow(
clippy::similar_names,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution};
use crate::parameter::ParamId;
fn create_trial(
id: u64,
value: f64,
params: Vec<(ParamId, ParamValue, Distribution)>,
) -> CompletedTrial {
let mut param_map = HashMap::new();
let mut dist_map = HashMap::new();
for (param_id, pv, dist) in params {
param_map.insert(param_id, pv);
dist_map.insert(param_id, dist);
}
CompletedTrial::new(id, param_map, dist_map, HashMap::new(), value)
}
#[test]
fn test_tpe_sampler_new() {
let sampler = TpeSampler::new();
assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 10);
assert_eq!(sampler.n_ei_candidates, 24);
}
#[test]
fn test_tpe_sampler_with_config() {
let sampler = TpeSampler::with_config(0.15, 20, 32, None, Some(42)).unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 20);
assert_eq!(sampler.n_ei_candidates, 32);
}
#[test]
fn test_tpe_sampler_invalid_gamma_zero() {
let result = TpeSampler::with_config(0.0, 10, 24, None, None);
assert!(matches!(result, Err(Error::InvalidGamma(_))));
}
#[test]
fn test_tpe_sampler_invalid_gamma_one() {
let result = TpeSampler::with_config(1.0, 10, 24, None, None);
assert!(matches!(result, Err(Error::InvalidGamma(_))));
}
#[test]
fn test_tpe_startup_random_sampling() {
let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let history: Vec<CompletedTrial> = vec![];
for i in 0..100 {
let value = sampler.sample(&dist, i, &history);
if let ParamValue::Float(v) = value {
assert!((0.0..=1.0).contains(&v));
} else {
panic!("Expected Float value");
}
}
}
#[test]
fn test_tpe_split_trials() {
let sampler = TpeSampler::with_config(0.25, 10, 24, None, Some(42)).unwrap();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20)
.map(|i| {
create_trial(
i as u64,
f64::from(i),
vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
)
})
.collect();
let (good, bad) = sampler.split_trials(&history);
assert_eq!(good.len(), 5);
assert_eq!(bad.len(), 15);
for trial in &good {
assert!(trial.value < 5.0);
}
}
#[test]
fn test_tpe_samples_float_with_history() {
let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20)
.map(|i| {
let x = f64::from(i) / 20.0;
let value = (x - 0.2).powi(2);
create_trial(
i as u64,
value,
vec![(x_id, ParamValue::Float(x), dist.clone())],
)
})
.collect();
let mut samples = vec![];
for i in 0..100 {
let value = sampler.sample(&dist, 100 + i, &history);
if let ParamValue::Float(v) = value {
samples.push(v);
}
}
let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
assert!(
mean < 0.5,
"Mean {mean} should be less than 0.5 (biased toward good region near 0.2)"
);
}
#[test]
fn test_tpe_categorical_sampling() {
let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 4 });
let cat_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20)
.map(|i| {
let category = i % 4;
let value = if category == 1 { 0.0 } else { 1.0 };
create_trial(
i as u64,
value,
vec![(
cat_id,
ParamValue::Categorical(category as usize),
dist.clone(),
)],
)
})
.collect();
let mut counts = vec![0usize; 4];
for i in 0..100 {
let value = sampler.sample(&dist, 100 + i, &history);
if let ParamValue::Categorical(idx) = value {
counts[idx] += 1;
}
}
assert!(
counts[1] > counts[0] && counts[1] > counts[2] && counts[1] > counts[3],
"Category 1 should be most common: {counts:?}"
);
}
#[test]
fn test_tpe_int_sampling() {
let sampler = TpeSampler::with_config(0.25, 5, 24, None, Some(42)).unwrap();
let dist = Distribution::Int(IntDistribution {
low: 0,
high: 100,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20)
.map(|i| {
let x = i * 5; let value = ((x as f64) - 30.0).powi(2);
create_trial(
i as u64,
value,
vec![(x_id, ParamValue::Int(x), dist.clone())],
)
})
.collect();
for i in 0..50 {
let value = sampler.sample(&dist, 100 + i, &history);
if let ParamValue::Int(v) = value {
assert!((0..=100).contains(&v), "Value {v} out of range");
} else {
panic!("Expected Int value");
}
}
}
#[test]
fn test_tpe_reproducibility() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20)
.map(|i| {
create_trial(
i as u64,
f64::from(i),
vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
)
})
.collect();
let sampler1 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
let sampler2 = TpeSampler::with_config(0.25, 5, 24, None, Some(12345)).unwrap();
for i in 0..10 {
let v1 = sampler1.sample(&dist, i, &history);
let v2 = sampler2.sample(&dist, i, &history);
assert_eq!(v1, v2, "Samples should be identical with same seed");
}
}
#[test]
fn test_tpe_sampler_builder_default() {
let builder = TpeSamplerBuilder::new();
let sampler = builder.build().unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.25).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 10);
assert_eq!(sampler.n_ei_candidates, 24);
}
#[test]
fn test_tpe_sampler_builder_custom() {
let sampler = TpeSamplerBuilder::new()
.gamma(0.15)
.n_startup_trials(20)
.n_ei_candidates(32)
.seed(42)
.build()
.unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 20);
assert_eq!(sampler.n_ei_candidates, 32);
}
#[test]
fn test_tpe_sampler_builder_via_sampler() {
let sampler = TpeSampler::builder()
.gamma(0.10)
.n_startup_trials(15)
.n_ei_candidates(48)
.build()
.unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.10).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 15);
assert_eq!(sampler.n_ei_candidates, 48);
}
#[test]
fn test_tpe_sampler_builder_partial() {
let sampler = TpeSamplerBuilder::new().gamma(0.20).build().unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.20).abs() < f64::EPSILON);
assert_eq!(sampler.n_startup_trials, 10); assert_eq!(sampler.n_ei_candidates, 24); }
#[test]
fn test_tpe_sampler_builder_invalid_gamma() {
let result = TpeSamplerBuilder::new().gamma(1.5).build();
assert!(matches!(result, Err(Error::InvalidGamma(_))));
}
#[test]
fn test_tpe_sampler_builder_reproducibility() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let history: Vec<CompletedTrial> = (0..20u32)
.map(|i| {
create_trial(
u64::from(i),
f64::from(i),
vec![(x_id, ParamValue::Float(f64::from(i) / 20.0), dist.clone())],
)
})
.collect();
let sampler1 = TpeSampler::builder()
.seed(99999)
.n_startup_trials(5)
.build()
.unwrap();
let sampler2 = TpeSampler::builder()
.seed(99999)
.n_startup_trials(5)
.build()
.unwrap();
for i in 0..10 {
let v1 = sampler1.sample(&dist, i, &history);
let v2 = sampler2.sample(&dist, i, &history);
assert_eq!(
v1, v2,
"Builder-created samplers with same seed should be identical"
);
}
}
}