#![allow(deprecated)]
use irithyll_core::rng::{standard_normal, xorshift64, xorshift64_f64};
#[deprecated(
since = "10.0.0",
note = "use the typed `SearchSpace` / `ParamMap` API in `irithyll::automl::space` instead"
)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum HyperParam {
Float {
name: &'static str,
low: f64,
high: f64,
log_scale: bool,
},
Int {
name: &'static str,
low: i64,
high: i64,
},
Categorical {
name: &'static str,
n_choices: usize,
},
}
impl HyperParam {
pub fn name(&self) -> &str {
match self {
HyperParam::Float { name, .. } => name,
HyperParam::Int { name, .. } => name,
HyperParam::Categorical { name, .. } => name,
}
}
}
#[deprecated(
since = "10.0.0",
note = "use `irithyll::automl::SearchSpace` (typed, named-access) instead"
)]
#[derive(Debug, Clone)]
pub struct ConfigSpace {
params: Vec<HyperParam>,
}
impl ConfigSpace {
pub fn new() -> Self {
Self { params: Vec::new() }
}
pub fn push(mut self, param: HyperParam) -> Self {
self.params.push(param);
self
}
pub fn params(&self) -> &[HyperParam] {
&self.params
}
pub fn n_params(&self) -> usize {
self.params.len()
}
pub fn dim(&self) -> usize {
self.params.len()
}
pub fn set_range(&mut self, name: &str, low: f64, high: f64) {
let param = self
.params
.iter_mut()
.find(|p| p.name() == name)
.unwrap_or_else(|| panic!("ConfigSpace::set_range: unknown parameter '{name}'"));
match param {
HyperParam::Float {
low: ref mut lo,
high: ref mut hi,
..
} => {
*lo = low;
*hi = high;
}
HyperParam::Int {
low: ref mut lo,
high: ref mut hi,
..
} => {
*lo = low.round() as i64;
*hi = high.round() as i64;
}
HyperParam::Categorical { name, .. } => {
panic!(
"ConfigSpace::set_range: cannot set range on categorical parameter '{name}'"
);
}
}
}
}
impl Default for ConfigSpace {
fn default() -> Self {
Self::new()
}
}
#[deprecated(
since = "10.0.0",
note = "use `irithyll::automl::ParamMap` (typed, named-access) instead"
)]
#[derive(Debug, Clone)]
pub struct HyperConfig {
values: Vec<f64>,
}
impl HyperConfig {
pub fn new(values: Vec<f64>) -> Self {
Self { values }
}
pub fn values(&self) -> &[f64] {
&self.values
}
pub fn get(&self, index: usize) -> f64 {
self.values[index]
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
#[deprecated(
since = "10.0.0",
note = "use `SearchSpace::sample(...)` (typed, named-access) instead"
)]
pub struct ConfigSampler {
space: ConfigSpace,
rng_state: u64,
}
impl ConfigSampler {
pub fn new(space: ConfigSpace, seed: u64) -> Self {
assert!(seed != 0, "ConfigSampler seed must be non-zero");
Self {
space,
rng_state: seed,
}
}
pub fn space(&self) -> &ConfigSpace {
&self.space
}
pub fn random(&mut self) -> HyperConfig {
let rng = &mut self.rng_state;
let values = self
.space
.params
.iter()
.map(|p| {
let u = xorshift64_f64(rng);
Self::map_unit_to_param(u, p)
})
.collect();
HyperConfig { values }
}
pub fn latin_hypercube(&mut self, n: usize) -> Vec<HyperConfig> {
if n == 0 {
return Vec::new();
}
let d = self.space.n_params();
let mut stratified: Vec<Vec<f64>> = Vec::with_capacity(d);
for _ in 0..d {
let mut column: Vec<f64> = (0..n)
.map(|i| {
let lo = i as f64 / n as f64;
let hi = (i + 1) as f64 / n as f64;
let u = xorshift64_f64(&mut self.rng_state);
lo + u * (hi - lo)
})
.collect();
for i in (1..n).rev() {
let j = (xorshift64(&mut self.rng_state) as usize) % (i + 1);
column.swap(i, j);
}
stratified.push(column);
}
(0..n)
.map(|i| {
let values = self
.space
.params
.iter()
.enumerate()
.map(|(dim, param)| Self::map_unit_to_param(stratified[dim][i], param))
.collect();
HyperConfig { values }
})
.collect()
}
pub fn perturb(&mut self, config: &HyperConfig, sigma: f64) -> HyperConfig {
let rng = &mut self.rng_state;
let values = self
.space
.params
.iter()
.enumerate()
.map(|(i, param)| {
let current = config.values()[i];
match param {
HyperParam::Float {
low,
high,
log_scale,
..
} => {
if *log_scale {
let ln_low = low.ln();
let ln_high = high.ln();
let ln_current = current.max(*low).ln(); let noise = standard_normal(rng) * sigma * (ln_high - ln_low);
(ln_current + noise).exp().clamp(*low, *high)
} else {
let noise = standard_normal(rng) * sigma * (high - low);
(current + noise).clamp(*low, *high)
}
}
HyperParam::Int { low, high, .. } => {
let range = (*high - *low) as f64;
let noise = standard_normal(rng) * sigma * range;
let perturbed = (current + noise).round();
perturbed.clamp(*low as f64, *high as f64)
}
HyperParam::Categorical { n_choices, .. } => {
let p = xorshift64_f64(rng);
if p < sigma.min(1.0) && *n_choices > 1 {
let new_choice = (xorshift64(rng) as usize) % (*n_choices - 1);
let current_idx = current as usize;
if new_choice >= current_idx {
(new_choice + 1) as f64
} else {
new_choice as f64
}
} else {
current
}
}
}
})
.collect();
HyperConfig { values }
}
fn map_unit_to_param(u: f64, param: &HyperParam) -> f64 {
match param {
HyperParam::Float {
low,
high,
log_scale,
..
} => {
if *log_scale {
let ln_low = low.ln();
let ln_high = high.ln();
(ln_low + u * (ln_high - ln_low)).exp()
} else {
low + u * (high - low)
}
}
HyperParam::Int { low, high, .. } => {
let range = (*high - *low + 1) as f64;
let v = *low as f64 + (u * range).floor();
v.min(*high as f64).max(*low as f64)
}
HyperParam::Categorical { n_choices, .. } => {
let v = (u * *n_choices as f64).floor();
v.min((*n_choices - 1) as f64)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_space_builder() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "learning_rate",
low: 0.001,
high: 1.0,
log_scale: true,
})
.push(HyperParam::Int {
name: "n_trees",
low: 10,
high: 200,
})
.push(HyperParam::Categorical {
name: "loss",
n_choices: 3,
});
assert_eq!(
space.n_params(),
3,
"expected 3 params, got {}",
space.n_params()
);
assert_eq!(space.dim(), space.n_params(), "dim should alias n_params");
assert_eq!(space.params()[0].name(), "learning_rate");
assert_eq!(space.params()[1].name(), "n_trees");
assert_eq!(space.params()[2].name(), "loss");
}
#[test]
fn random_sample_in_bounds() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "lr",
low: 0.001,
high: 1.0,
log_scale: false,
})
.push(HyperParam::Int {
name: "depth",
low: 1,
high: 20,
})
.push(HyperParam::Categorical {
name: "act",
n_choices: 5,
});
let mut sampler = ConfigSampler::new(space, 42);
for i in 0..100 {
let cfg = sampler.random();
assert_eq!(cfg.len(), 3, "config should have 3 values");
let lr = cfg.get(0);
assert!(
(0.001..=1.0).contains(&lr),
"sample {i}: lr={lr} out of [0.001, 1.0]"
);
let depth = cfg.get(1);
assert!(
(1.0..=20.0).contains(&depth),
"sample {i}: depth={depth} out of [1, 20]"
);
assert_eq!(
depth,
depth.floor(),
"sample {i}: depth={depth} should be integer"
);
let act = cfg.get(2);
assert!(
(0.0..=4.0).contains(&act),
"sample {i}: act={act} out of [0, 4]"
);
assert_eq!(
act,
act.floor(),
"sample {i}: act={act} should be integer index"
);
}
}
#[test]
fn log_scale_sampling() {
let space = ConfigSpace::new().push(HyperParam::Float {
name: "lr",
low: 1e-5,
high: 1.0,
log_scale: true,
});
let mut sampler = ConfigSampler::new(space, 77);
for i in 0..200 {
let cfg = sampler.random();
let v = cfg.get(0);
assert!(
(1e-5..=1.0).contains(&v),
"sample {i}: log-scale value {v} out of [1e-5, 1.0]"
);
}
}
#[test]
fn latin_hypercube_coverage() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "x",
low: 0.0,
high: 1.0,
log_scale: false,
})
.push(HyperParam::Float {
name: "y",
low: 0.0,
high: 1.0,
log_scale: false,
});
let n = 10;
let mut sampler = ConfigSampler::new(space, 99);
let configs = sampler.latin_hypercube(n);
assert_eq!(configs.len(), n, "expected {n} configs from LHS");
for dim in 0..2 {
let mut strata = vec![false; n];
for cfg in &configs {
let v = cfg.get(dim);
let stratum = (v * n as f64).floor() as usize;
let stratum = stratum.min(n - 1);
assert!(
!strata[stratum],
"dim {dim}: stratum {stratum} hit twice (value={v})"
);
strata[stratum] = true;
}
assert!(
strata.iter().all(|&hit| hit),
"dim {dim}: not all strata covered"
);
}
}
#[test]
fn latin_hypercube_in_bounds() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "lr",
low: 0.01,
high: 0.5,
log_scale: true,
})
.push(HyperParam::Int {
name: "k",
low: 5,
high: 50,
})
.push(HyperParam::Categorical {
name: "opt",
n_choices: 4,
});
let n = 20;
let mut sampler = ConfigSampler::new(space, 1337);
let configs = sampler.latin_hypercube(n);
assert_eq!(configs.len(), n, "expected {n} LHS configs");
for (i, cfg) in configs.iter().enumerate() {
let lr = cfg.get(0);
assert!(
(0.01..=0.5).contains(&lr),
"LHS {i}: lr={lr} out of [0.01, 0.5]"
);
let k = cfg.get(1);
assert!((5.0..=50.0).contains(&k), "LHS {i}: k={k} out of [5, 50]");
assert_eq!(k, k.floor(), "LHS {i}: k={k} should be integer");
let opt = cfg.get(2);
assert!(
(0.0..=3.0).contains(&opt),
"LHS {i}: opt={opt} out of [0, 3]"
);
assert_eq!(
opt,
opt.floor(),
"LHS {i}: opt={opt} should be integer index"
);
}
}
#[test]
fn hyper_config_access() {
let cfg = HyperConfig::new(vec![1.0, 2.0, 3.0]);
assert_eq!(cfg.len(), 3, "expected 3 values");
assert!(!cfg.is_empty(), "config with 3 values should not be empty");
assert_eq!(cfg.get(0), 1.0, "expected first value to be 1.0");
assert_eq!(cfg.get(1), 2.0, "expected second value to be 2.0");
assert_eq!(cfg.get(2), 3.0, "expected third value to be 3.0");
assert_eq!(cfg.values(), &[1.0, 2.0, 3.0], "values slice mismatch");
let empty = HyperConfig::new(vec![]);
assert!(empty.is_empty(), "empty config should be empty");
assert_eq!(empty.len(), 0, "empty config len should be 0");
}
#[test]
fn categorical_sampling() {
let space = ConfigSpace::new().push(HyperParam::Categorical {
name: "color",
n_choices: 7,
});
let mut sampler = ConfigSampler::new(space, 55);
for i in 0..200 {
let cfg = sampler.random();
let v = cfg.get(0);
assert!(
(0.0..=6.0).contains(&v),
"sample {i}: categorical={v} out of [0, 6]"
);
assert_eq!(
v,
v.floor(),
"sample {i}: categorical={v} should be integer index"
);
}
}
#[test]
fn empty_config_space() {
let space = ConfigSpace::new();
assert_eq!(space.n_params(), 0, "new space should have 0 params");
assert_eq!(space.dim(), 0, "new space dim should be 0");
assert!(
space.params().is_empty(),
"new space params should be empty"
);
}
#[test]
fn perturb_stays_in_bounds() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "lr",
low: 0.001,
high: 1.0,
log_scale: false,
})
.push(HyperParam::Int {
name: "depth",
low: 1,
high: 20,
})
.push(HyperParam::Categorical {
name: "act",
n_choices: 5,
});
let mut sampler = ConfigSampler::new(space, 42);
let base = sampler.random();
for i in 0..200 {
let perturbed = sampler.perturb(&base, 0.3);
assert_eq!(perturbed.len(), 3, "perturbed should have 3 values");
let lr = perturbed.get(0);
assert!(
(0.001..=1.0).contains(&lr),
"perturb {i}: lr={lr} out of [0.001, 1.0]"
);
let depth = perturbed.get(1);
assert!(
(1.0..=20.0).contains(&depth),
"perturb {i}: depth={depth} out of [1, 20]"
);
let act = perturbed.get(2);
assert!(
(0.0..=4.0).contains(&act),
"perturb {i}: act={act} out of [0, 4]"
);
assert_eq!(act, act.floor(), "perturb {i}: act={act} should be integer");
}
}
#[test]
fn perturb_zero_sigma_preserves_float_and_int() {
let space = ConfigSpace::new()
.push(HyperParam::Float {
name: "lr",
low: 0.001,
high: 1.0,
log_scale: false,
})
.push(HyperParam::Int {
name: "depth",
low: 1,
high: 20,
});
let mut sampler = ConfigSampler::new(space, 42);
let base = HyperConfig::new(vec![0.5, 10.0]);
for _ in 0..100 {
let perturbed = sampler.perturb(&base, 0.0);
assert!(
(perturbed.get(0) - 0.5).abs() < 1e-10,
"sigma=0 should not change float, got {}",
perturbed.get(0)
);
assert!(
(perturbed.get(1) - 10.0).abs() < 1e-10,
"sigma=0 should not change int, got {}",
perturbed.get(1)
);
}
}
#[test]
fn perturb_log_scale_in_bounds() {
let space = ConfigSpace::new().push(HyperParam::Float {
name: "lr",
low: 1e-5,
high: 1.0,
log_scale: true,
});
let mut sampler = ConfigSampler::new(space, 77);
let base = HyperConfig::new(vec![0.001]);
for i in 0..200 {
let perturbed = sampler.perturb(&base, 0.5);
let v = perturbed.get(0);
assert!(
(1e-5..=1.0).contains(&v),
"perturb {i}: log-scale value {v} out of [1e-5, 1.0]"
);
}
}
#[test]
fn perturb_categorical_changes() {
let space = ConfigSpace::new().push(HyperParam::Categorical {
name: "color",
n_choices: 5,
});
let mut sampler = ConfigSampler::new(space, 99);
let base = HyperConfig::new(vec![2.0]);
let mut saw_different = false;
for _ in 0..100 {
let perturbed = sampler.perturb(&base, 1.0); let v = perturbed.get(0);
assert!(
(0.0..=4.0).contains(&v),
"categorical should be in [0, 4], got {v}"
);
if (v - 2.0).abs() > 0.5 {
saw_different = true;
}
}
assert!(
saw_different,
"with sigma=1.0, categorical should sometimes change from index 2"
);
}
#[test]
fn deterministic_with_seed() {
let make_space = || {
ConfigSpace::new()
.push(HyperParam::Float {
name: "lr",
low: 0.001,
high: 1.0,
log_scale: true,
})
.push(HyperParam::Int {
name: "n",
low: 1,
high: 100,
})
.push(HyperParam::Categorical {
name: "opt",
n_choices: 4,
})
};
let seed = 123456;
let mut s1 = ConfigSampler::new(make_space(), seed);
let mut s2 = ConfigSampler::new(make_space(), seed);
for i in 0..50 {
let c1 = s1.random();
let c2 = s2.random();
assert_eq!(
c1.values(),
c2.values(),
"random sample {i} differs between same-seed samplers"
);
}
let mut s3 = ConfigSampler::new(make_space(), seed);
let mut s4 = ConfigSampler::new(make_space(), seed);
let lhs3 = s3.latin_hypercube(15);
let lhs4 = s4.latin_hypercube(15);
for (i, (c3, c4)) in lhs3.iter().zip(lhs4.iter()).enumerate() {
assert_eq!(
c3.values(),
c4.values(),
"LHS config {i} differs between same-seed samplers"
);
}
}
}