#[derive(Debug, Clone)]
pub struct Trial<P: ParamKey = GenericParam> {
pub values: HashMap<P, ParamValue>,
}
impl<P: ParamKey> Trial<P> {
#[must_use]
pub fn get(&self, key: &P) -> Option<&ParamValue> {
self.values.get(key)
}
#[must_use]
pub fn get_f64(&self, key: &P) -> Option<f64> {
self.values.get(key).and_then(ParamValue::as_f64)
}
#[must_use]
pub fn get_i64(&self, key: &P) -> Option<i64> {
self.values.get(key).and_then(ParamValue::as_i64)
}
#[must_use]
pub fn get_usize(&self, key: &P) -> Option<usize> {
self.values
.get(key)
.and_then(ParamValue::as_i64)
.map(|v| v as usize)
}
#[must_use]
pub fn get_bool(&self, key: &P) -> Option<bool> {
self.values.get(key).and_then(ParamValue::as_bool)
}
}
impl<P: ParamKey> std::fmt::Display for Trial<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let params: Vec<String> = self
.values
.iter()
.map(|(k, v)| format!("{}={}", k.name(), v))
.collect();
write!(f, "{{{}}}", params.join(", "))
}
}
#[derive(Debug, Clone)]
pub struct TrialResult<P: ParamKey = GenericParam> {
pub trial: Trial<P>,
pub score: f64,
pub metrics: HashMap<String, f64>,
}
pub trait SearchStrategy<P: ParamKey> {
fn suggest(&mut self, space: &SearchSpace<P>, n: usize) -> Vec<Trial<P>>;
fn update(&mut self, _results: &[TrialResult<P>]) {}
}
pub trait Rng {
fn gen_f64(&mut self) -> f64;
fn gen_f64_range(&mut self, low: f64, high: f64) -> f64 {
low + self.gen_f64() * (high - low)
}
fn gen_i64_range(&mut self, low: i64, high: i64) -> i64;
fn gen_usize(&mut self, len: usize) -> usize;
}
#[derive(Debug, Clone)]
pub(crate) struct XorShift64 {
state: u64,
}
impl XorShift64 {
#[must_use]
pub(crate) fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
}
impl Rng for XorShift64 {
fn gen_f64(&mut self) -> f64 {
(self.next_u64() as f64) / (u64::MAX as f64)
}
fn gen_i64_range(&mut self, low: i64, high: i64) -> i64 {
if low >= high {
return low;
}
let range = (high - low + 1) as u64;
low + (self.next_u64() % range) as i64
}
fn gen_usize(&mut self, len: usize) -> usize {
if len == 0 {
return 0;
}
(self.next_u64() as usize) % len
}
}
#[derive(Debug, Clone)]
pub struct RandomSearch {
pub n_iter: usize,
pub seed: u64,
rng: XorShift64,
trials_generated: usize,
}
impl RandomSearch {
#[must_use]
pub fn new(n_iter: usize) -> Self {
Self {
n_iter,
seed: 42,
rng: XorShift64::new(42),
trials_generated: 0,
}
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self.rng = XorShift64::new(seed);
self
}
#[must_use]
pub fn remaining(&self) -> usize {
self.n_iter.saturating_sub(self.trials_generated)
}
}
impl<P: ParamKey> SearchStrategy<P> for RandomSearch {
fn suggest(&mut self, space: &SearchSpace<P>, n: usize) -> Vec<Trial<P>> {
let n = n.min(self.remaining());
let trials: Vec<Trial<P>> = (0..n).map(|_| space.sample(&mut self.rng)).collect();
self.trials_generated += trials.len();
trials
}
}
#[derive(Debug, Clone)]
pub struct GridSearch {
pub points_per_param: usize,
position: usize,
}
impl GridSearch {
#[must_use]
pub fn new(points_per_param: usize) -> Self {
Self {
points_per_param: points_per_param.max(2),
position: 0,
}
}
}
impl<P: ParamKey> SearchStrategy<P> for GridSearch {
fn suggest(&mut self, space: &SearchSpace<P>, n: usize) -> Vec<Trial<P>> {
let grid = space.grid(self.points_per_param);
let remaining = grid.len().saturating_sub(self.position);
let n = n.min(remaining);
let trials = grid[self.position..self.position + n].to_vec();
self.position += n;
trials
}
}
#[derive(Debug, Clone)]
pub struct DESearch {
pub n_iter: usize,
pub population_size: usize,
pub seed: u64,
pub strategy: DEStrategy,
pub use_jade: bool,
population: Vec<Vec<f64>>,
fitness: Vec<f64>,
best_idx: usize,
param_order: Vec<String>,
param_bounds: Vec<(f64, f64, bool, bool)>,
trials_generated: usize,
initialized: bool,
mutation_factor: f64,
crossover_rate: f64,
}