use core::fmt::Debug;
use crate::Error;
pub trait GammaStrategy: Send + Sync + Debug {
fn gamma(&self, n_trials: usize) -> f64;
fn clone_box(&self) -> Box<dyn GammaStrategy>;
}
impl Clone for Box<dyn GammaStrategy> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Debug, Clone, Copy)]
pub struct FixedGamma {
gamma: f64,
}
impl FixedGamma {
pub fn new(gamma: f64) -> crate::Result<Self> {
if gamma <= 0.0 || gamma >= 1.0 {
return Err(Error::InvalidGamma(gamma));
}
Ok(Self { gamma })
}
#[must_use]
pub fn value(&self) -> f64 {
self.gamma
}
}
impl Default for FixedGamma {
fn default() -> Self {
Self { gamma: 0.25 }
}
}
impl GammaStrategy for FixedGamma {
fn gamma(&self, _n_trials: usize) -> f64 {
self.gamma
}
fn clone_box(&self) -> Box<dyn GammaStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct LinearGamma {
gamma_min: f64,
gamma_max: f64,
n_trials_max: usize,
}
impl LinearGamma {
pub fn new(gamma_min: f64, gamma_max: f64, n_trials_max: usize) -> crate::Result<Self> {
if gamma_min <= 0.0 || gamma_min >= 1.0 {
return Err(Error::InvalidGamma(gamma_min));
}
if gamma_max <= 0.0 || gamma_max >= 1.0 {
return Err(Error::InvalidGamma(gamma_max));
}
if gamma_min > gamma_max {
return Err(Error::InvalidGamma(gamma_min));
}
Ok(Self {
gamma_min,
gamma_max,
n_trials_max,
})
}
#[must_use]
pub fn gamma_min(&self) -> f64 {
self.gamma_min
}
#[must_use]
pub fn gamma_max(&self) -> f64 {
self.gamma_max
}
#[must_use]
pub fn n_trials_max(&self) -> usize {
self.n_trials_max
}
}
impl Default for LinearGamma {
fn default() -> Self {
Self {
gamma_min: 0.10,
gamma_max: 0.25,
n_trials_max: 100,
}
}
}
impl GammaStrategy for LinearGamma {
#[allow(clippy::cast_precision_loss)]
fn gamma(&self, n_trials: usize) -> f64 {
if self.n_trials_max == 0 {
return self.gamma_max;
}
let t = (n_trials as f64 / self.n_trials_max as f64).min(1.0);
self.gamma_min + (self.gamma_max - self.gamma_min) * t
}
fn clone_box(&self) -> Box<dyn GammaStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SqrtGamma {
gamma_factor: f64,
gamma_max: f64,
}
impl SqrtGamma {
pub fn new(gamma_factor: f64, gamma_max: f64) -> crate::Result<Self> {
if gamma_factor <= 0.0 {
return Err(Error::InvalidGamma(gamma_factor));
}
if gamma_max <= 0.0 || gamma_max >= 1.0 {
return Err(Error::InvalidGamma(gamma_max));
}
Ok(Self {
gamma_factor,
gamma_max,
})
}
#[must_use]
pub fn gamma_factor(&self) -> f64 {
self.gamma_factor
}
#[must_use]
pub fn gamma_max(&self) -> f64 {
self.gamma_max
}
}
impl Default for SqrtGamma {
fn default() -> Self {
Self {
gamma_factor: 1.0,
gamma_max: 0.25,
}
}
}
impl GammaStrategy for SqrtGamma {
#[allow(clippy::cast_precision_loss)]
fn gamma(&self, n_trials: usize) -> f64 {
if n_trials == 0 {
return self.gamma_max;
}
let n_good = (self.gamma_factor / (n_trials as f64).sqrt()).max(1.0);
(n_good / n_trials as f64).min(self.gamma_max)
}
fn clone_box(&self) -> Box<dyn GammaStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct HyperoptGamma {
gamma_base: f64,
gamma_max: f64,
}
impl HyperoptGamma {
pub fn new(gamma_base: f64, gamma_max: f64) -> crate::Result<Self> {
if gamma_base < 0.0 {
return Err(Error::InvalidGamma(gamma_base));
}
if gamma_max <= 0.0 || gamma_max >= 1.0 {
return Err(Error::InvalidGamma(gamma_max));
}
Ok(Self {
gamma_base,
gamma_max,
})
}
#[must_use]
pub fn gamma_base(&self) -> f64 {
self.gamma_base
}
#[must_use]
pub fn gamma_max(&self) -> f64 {
self.gamma_max
}
}
impl Default for HyperoptGamma {
fn default() -> Self {
Self {
gamma_base: 24.0,
gamma_max: 0.25,
}
}
}
impl GammaStrategy for HyperoptGamma {
#[allow(clippy::cast_precision_loss)]
fn gamma(&self, n_trials: usize) -> f64 {
if n_trials == 0 {
return self.gamma_max;
}
((self.gamma_base + 1.0) / n_trials as f64).min(self.gamma_max)
}
fn clone_box(&self) -> Box<dyn GammaStrategy> {
Box::new(*self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampler::tpe::TpeSampler;
#[test]
fn test_fixed_gamma_default() {
let strategy = FixedGamma::default();
assert!((strategy.gamma(0) - 0.25).abs() < f64::EPSILON);
assert!((strategy.gamma(100) - 0.25).abs() < f64::EPSILON);
assert!((strategy.value() - 0.25).abs() < f64::EPSILON);
}
#[test]
fn test_fixed_gamma_custom() {
let strategy = FixedGamma::new(0.15).unwrap();
assert!((strategy.gamma(0) - 0.15).abs() < f64::EPSILON);
assert!((strategy.gamma(50) - 0.15).abs() < f64::EPSILON);
assert!((strategy.gamma(1000) - 0.15).abs() < f64::EPSILON);
}
#[test]
fn test_fixed_gamma_invalid() {
assert!(FixedGamma::new(0.0).is_err());
assert!(FixedGamma::new(1.0).is_err());
assert!(FixedGamma::new(-0.1).is_err());
assert!(FixedGamma::new(1.5).is_err());
}
#[test]
fn test_linear_gamma_default() {
let strategy = LinearGamma::default();
assert!((strategy.gamma(0) - 0.10).abs() < f64::EPSILON);
assert!((strategy.gamma(50) - 0.175).abs() < f64::EPSILON); assert!((strategy.gamma(100) - 0.25).abs() < f64::EPSILON);
assert!((strategy.gamma(200) - 0.25).abs() < f64::EPSILON); }
#[test]
fn test_linear_gamma_custom() {
let strategy = LinearGamma::new(0.1, 0.4, 100).unwrap();
assert!((strategy.gamma(0) - 0.1).abs() < f64::EPSILON);
assert!((strategy.gamma(50) - 0.25).abs() < f64::EPSILON);
assert!((strategy.gamma(100) - 0.4).abs() < f64::EPSILON);
assert!((strategy.gamma(200) - 0.4).abs() < f64::EPSILON);
}
#[test]
fn test_linear_gamma_invalid() {
assert!(LinearGamma::new(0.0, 0.5, 100).is_err());
assert!(LinearGamma::new(0.1, 1.0, 100).is_err());
assert!(LinearGamma::new(0.5, 0.2, 100).is_err()); }
#[test]
fn test_sqrt_gamma_default() {
let strategy = SqrtGamma::default();
assert!((strategy.gamma(0) - 0.25).abs() < f64::EPSILON);
let g10 = strategy.gamma(10);
let g100 = strategy.gamma(100);
assert!(g10 > g100);
}
#[test]
fn test_sqrt_gamma_custom() {
let strategy = SqrtGamma::new(2.0, 0.5).unwrap();
assert!((strategy.gamma(0) - 0.5).abs() < f64::EPSILON);
let g4 = strategy.gamma(4);
assert!((g4 - 0.25).abs() < f64::EPSILON);
}
#[test]
fn test_sqrt_gamma_invalid() {
assert!(SqrtGamma::new(0.0, 0.25).is_err()); assert!(SqrtGamma::new(-1.0, 0.25).is_err());
assert!(SqrtGamma::new(1.0, 0.0).is_err());
assert!(SqrtGamma::new(1.0, 1.0).is_err());
}
#[test]
fn test_hyperopt_gamma_default() {
let strategy = HyperoptGamma::default();
assert!((strategy.gamma(0) - 0.25).abs() < f64::EPSILON);
assert!((strategy.gamma(100) - 0.25).abs() < f64::EPSILON);
assert!((strategy.gamma(200) - 0.125).abs() < f64::EPSILON);
}
#[test]
fn test_hyperopt_gamma_custom() {
let strategy = HyperoptGamma::new(9.0, 0.5).unwrap();
assert!((strategy.gamma(20) - 0.5).abs() < f64::EPSILON);
assert!((strategy.gamma(100) - 0.1).abs() < f64::EPSILON);
}
#[test]
fn test_hyperopt_gamma_invalid() {
assert!(HyperoptGamma::new(-1.0, 0.25).is_err());
assert!(HyperoptGamma::new(24.0, 0.0).is_err());
assert!(HyperoptGamma::new(24.0, 1.0).is_err());
}
#[test]
fn test_gamma_strategy_clone_box() {
let fixed: Box<dyn GammaStrategy> = Box::new(FixedGamma::new(0.3).unwrap());
let cloned = fixed.clone();
assert!((cloned.gamma(0) - 0.3).abs() < f64::EPSILON);
let linear: Box<dyn GammaStrategy> = Box::new(LinearGamma::default());
let cloned = linear.clone();
assert!((cloned.gamma(0) - 0.10).abs() < f64::EPSILON);
}
#[test]
fn test_tpe_with_linear_gamma_strategy() {
let sampler = TpeSampler::builder()
.gamma_strategy(LinearGamma::new(0.1, 0.3, 50).unwrap())
.n_startup_trials(5)
.seed(42)
.build()
.unwrap();
let g = sampler.gamma_strategy().gamma(25);
assert!((g - 0.2).abs() < f64::EPSILON); }
#[test]
fn test_gamma_overrides_gamma_strategy() {
let sampler = TpeSampler::builder()
.gamma_strategy(SqrtGamma::default())
.gamma(0.15) .build()
.unwrap();
assert!((sampler.gamma_strategy().gamma(0) - 0.15).abs() < f64::EPSILON);
assert!((sampler.gamma_strategy().gamma(100) - 0.15).abs() < f64::EPSILON);
}
#[test]
fn test_gamma_strategy_overrides_gamma() {
let sampler = TpeSampler::builder()
.gamma(0.15)
.gamma_strategy(SqrtGamma::default()) .build()
.unwrap();
let g10 = sampler.gamma_strategy().gamma(10);
let g100 = sampler.gamma_strategy().gamma(100);
assert!(g10 > g100, "SqrtGamma should decrease with more trials");
}
#[test]
fn test_custom_gamma_strategy() {
#[derive(Debug, Clone)]
struct DoubleGamma;
impl GammaStrategy for DoubleGamma {
fn gamma(&self, n_trials: usize) -> f64 {
#[allow(clippy::cast_precision_loss)]
(0.01 * n_trials as f64).min(0.5)
}
fn clone_box(&self) -> Box<dyn GammaStrategy> {
Box::new(self.clone())
}
}
let sampler = TpeSampler::builder()
.gamma_strategy(DoubleGamma)
.build()
.unwrap();
assert!((sampler.gamma_strategy().gamma(10) - 0.1).abs() < f64::EPSILON);
assert!((sampler.gamma_strategy().gamma(50) - 0.5).abs() < f64::EPSILON);
assert!((sampler.gamma_strategy().gamma(100) - 0.5).abs() < f64::EPSILON);
}
}