use std::collections::HashMap;
use parking_lot::Mutex;
use crate::distribution::Distribution;
use crate::param::ParamValue;
use crate::rng_util;
use crate::sampler::{CompletedTrial, Sampler};
use super::common::{from_internal, internal_bounds, sample_random};
#[derive(Clone, Copy, Debug, Default)]
pub enum DEStrategy {
#[default]
Rand1,
Best1,
CurrentToBest1,
}
pub struct DESampler {
state: Mutex<State>,
}
impl DESampler {
#[must_use]
pub fn new() -> Self {
Self {
state: Mutex::new(State::new(None, 0.8, 0.9, DEStrategy::Rand1, None)),
}
}
#[must_use]
pub fn with_seed(seed: u64) -> Self {
Self {
state: Mutex::new(State::new(None, 0.8, 0.9, DEStrategy::Rand1, Some(seed))),
}
}
#[must_use]
pub fn builder() -> DESamplerBuilder {
DESamplerBuilder::new()
}
}
impl Default for DESampler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DESamplerBuilder {
population_size: Option<usize>,
mutation_factor: f64,
crossover_rate: f64,
strategy: DEStrategy,
seed: Option<u64>,
}
impl Default for DESamplerBuilder {
fn default() -> Self {
Self::new()
}
}
impl DESamplerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
population_size: None,
mutation_factor: 0.8,
crossover_rate: 0.9,
strategy: DEStrategy::Rand1,
seed: None,
}
}
#[must_use]
pub fn population_size(mut self, size: usize) -> Self {
self.population_size = Some(size);
self
}
#[must_use]
pub fn mutation_factor(mut self, f: f64) -> Self {
self.mutation_factor = f;
self
}
#[must_use]
pub fn crossover_rate(mut self, cr: f64) -> Self {
self.crossover_rate = cr;
self
}
#[must_use]
pub fn strategy(mut self, strategy: DEStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn build(self) -> DESampler {
DESampler {
state: Mutex::new(State::new(
self.population_size,
self.mutation_factor,
self.crossover_rate,
self.strategy,
self.seed,
)),
}
}
}
#[derive(Clone, Debug)]
struct DimensionInfo {
distribution: Distribution,
is_continuous: bool,
bounds: Option<(f64, f64)>,
}
#[derive(Clone, Debug)]
struct Candidate {
x: Vec<f64>,
categorical_values: HashMap<usize, usize>,
target_idx: usize,
}
#[derive(Clone, Debug)]
struct TrialProgress {
candidate_idx: usize,
next_dim: usize,
}
enum Phase {
Discovery,
Active,
}
struct State {
rng: fastrand::Rng,
user_population_size: Option<usize>,
mutation_factor: f64,
crossover_rate: f64,
strategy: DEStrategy,
phase: Phase,
dimensions: Vec<DimensionInfo>,
discovery_trial_id: Option<u64>,
population: Vec<Vec<f64>>,
population_categorical: Vec<HashMap<usize, usize>>,
population_values: Vec<f64>,
best_idx: usize,
initialized: bool,
population_size: usize,
candidates: Vec<Candidate>,
trial_progress: HashMap<u64, TrialProgress>,
assigned_count: usize,
generation_trial_ids: Vec<u64>,
}
impl State {
fn new(
user_population_size: Option<usize>,
mutation_factor: f64,
crossover_rate: f64,
strategy: DEStrategy,
seed: Option<u64>,
) -> Self {
let rng = seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed);
Self {
rng,
user_population_size,
mutation_factor,
crossover_rate,
strategy,
phase: Phase::Discovery,
dimensions: Vec::new(),
discovery_trial_id: None,
population: Vec::new(),
population_categorical: Vec::new(),
population_values: Vec::new(),
best_idx: 0,
initialized: false,
population_size: 0,
candidates: Vec::new(),
trial_progress: HashMap::new(),
assigned_count: 0,
generation_trial_ids: Vec::new(),
}
}
}
fn sample_random_internal(rng: &mut fastrand::Rng, bounds: (f64, f64)) -> f64 {
rng_util::f64_range(rng, bounds.0, bounds.1)
}
fn clamp_to_bounds(value: f64, bounds: Option<(f64, f64)>) -> f64 {
if let Some((lo, hi)) = bounds {
value.clamp(lo, hi)
} else {
value
}
}
fn select_random_indices(
rng: &mut fastrand::Rng,
n: usize,
count: usize,
exclude: &[usize],
) -> Vec<usize> {
let mut selected = Vec::with_capacity(count);
while selected.len() < count {
let idx = rng.usize(0..n);
if !exclude.contains(&idx) && !selected.contains(&idx) {
selected.push(idx);
}
}
selected
}
fn generate_trial_vectors(state: &mut State) -> Vec<Candidate> {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
let pop_size = state.population_size;
let mut candidates = Vec::with_capacity(pop_size);
for i in 0..pop_size {
let mutant = create_mutant_with_rng(state, i, n_continuous);
let j_rand = state.rng.usize(0..n_continuous.max(1));
let trial_x: Vec<f64> = if n_continuous > 0 {
(0..n_continuous)
.map(|j| {
let use_mutant = j == j_rand || state.rng.f64() < state.crossover_rate;
let val = if use_mutant {
mutant[j]
} else {
state.population[i][j]
};
let dim_bounds = continuous_dim_bounds(&state.dimensions, j);
clamp_to_bounds(val, dim_bounds)
})
.collect()
} else {
Vec::new()
};
let mut categorical_values = HashMap::new();
for (dim_idx, dim) in state.dimensions.iter().enumerate() {
if !dim.is_continuous
&& let Distribution::Categorical(cat) = &dim.distribution
{
categorical_values.insert(dim_idx, state.rng.usize(0..cat.n_choices));
}
}
candidates.push(Candidate {
x: trial_x,
categorical_values,
target_idx: i,
});
}
candidates
}
fn create_mutant_with_rng(state: &mut State, target_idx: usize, n_continuous: usize) -> Vec<f64> {
if n_continuous == 0 {
return Vec::new();
}
let pop = &state.population;
let best_idx = state.best_idx;
let f = state.mutation_factor;
let pop_size = state.population_size;
match state.strategy {
DEStrategy::Rand1 => {
let indices = select_random_indices(&mut state.rng, pop_size, 3, &[target_idx]);
let (r1, r2, r3) = (indices[0], indices[1], indices[2]);
(0..n_continuous)
.map(|j| pop[r1][j] + f * (pop[r2][j] - pop[r3][j]))
.collect()
}
DEStrategy::Best1 => {
let indices = select_random_indices(&mut state.rng, pop_size, 2, &[target_idx]);
let (r1, r2) = (indices[0], indices[1]);
(0..n_continuous)
.map(|j| pop[best_idx][j] + f * (pop[r1][j] - pop[r2][j]))
.collect()
}
DEStrategy::CurrentToBest1 => {
let indices = select_random_indices(&mut state.rng, pop_size, 2, &[target_idx]);
let (r1, r2) = (indices[0], indices[1]);
(0..n_continuous)
.map(|j| {
pop[target_idx][j]
+ f * (pop[best_idx][j] - pop[target_idx][j])
+ f * (pop[r1][j] - pop[r2][j])
})
.collect()
}
}
}
fn continuous_dim_bounds(
dimensions: &[DimensionInfo],
continuous_idx: usize,
) -> Option<(f64, f64)> {
let mut ci = 0;
for dim in dimensions {
if dim.is_continuous {
if ci == continuous_idx {
return dim.bounds;
}
ci += 1;
}
}
None
}
fn generate_initial_population(state: &mut State) -> Vec<Candidate> {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
let mut candidates = Vec::with_capacity(state.population_size);
for i in 0..state.population_size {
let x: Vec<f64> = if n_continuous > 0 {
let mut v = Vec::with_capacity(n_continuous);
for dim in &state.dimensions {
if dim.is_continuous {
let val = if let Some(bounds) = dim.bounds {
sample_random_internal(&mut state.rng, bounds)
} else {
0.0
};
v.push(val);
}
}
v
} else {
Vec::new()
};
let mut categorical_values = HashMap::new();
for (dim_idx, dim) in state.dimensions.iter().enumerate() {
if !dim.is_continuous
&& let Distribution::Categorical(cat) = &dim.distribution
{
categorical_values.insert(dim_idx, state.rng.usize(0..cat.n_choices));
}
}
candidates.push(Candidate {
x,
categorical_values,
target_idx: i,
});
}
candidates
}
impl Sampler for DESampler {
#[allow(clippy::cast_precision_loss)]
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
history: &[CompletedTrial],
) -> ParamValue {
let mut state = self.state.lock();
match &state.phase {
Phase::Discovery => sample_discovery(&mut state, distribution, trial_id),
Phase::Active => sample_active(&mut state, distribution, trial_id, history),
}
}
}
fn sample_discovery(state: &mut State, distribution: &Distribution, trial_id: u64) -> ParamValue {
if let Some(prev_id) = state.discovery_trial_id
&& trial_id != prev_id
{
finalize_discovery(state);
return sample_active(state, distribution, trial_id, &[]);
}
state.discovery_trial_id = Some(trial_id);
let is_continuous = !matches!(distribution, Distribution::Categorical(_));
let bounds = internal_bounds(distribution);
state.dimensions.push(DimensionInfo {
distribution: distribution.clone(),
is_continuous,
bounds,
});
sample_random(&mut state.rng, distribution)
}
#[allow(clippy::cast_precision_loss)]
fn finalize_discovery(state: &mut State) {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
state.population_size = state
.user_population_size
.unwrap_or_else(|| (10 * n_continuous).max(15));
state.population_size = state.population_size.max(4);
state.candidates = generate_initial_population(state);
state.assigned_count = 0;
state.generation_trial_ids.clear();
state.trial_progress.clear();
state.phase = Phase::Active;
}
fn sample_active(
state: &mut State,
distribution: &Distribution,
trial_id: u64,
history: &[CompletedTrial],
) -> ParamValue {
maybe_update_generation(state, history);
if !state.trial_progress.contains_key(&trial_id) {
assign_candidate(state, trial_id);
}
let progress = state.trial_progress.get_mut(&trial_id).unwrap();
let dim_idx = progress.next_dim;
progress.next_dim += 1;
if dim_idx >= state.dimensions.len() {
return sample_random(&mut state.rng, distribution);
}
let candidate = &state.candidates[progress.candidate_idx];
let dim_info = &state.dimensions[dim_idx];
if dim_info.is_continuous {
let ci = state.dimensions[..dim_idx]
.iter()
.filter(|d| d.is_continuous)
.count();
if ci < candidate.x.len() {
from_internal(candidate.x[ci], &dim_info.distribution)
} else {
sample_random(&mut state.rng, distribution)
}
} else {
if let Some(&cat_idx) = candidate.categorical_values.get(&dim_idx) {
ParamValue::Categorical(cat_idx)
} else {
sample_random(&mut state.rng, distribution)
}
}
}
fn assign_candidate(state: &mut State, trial_id: u64) {
let candidate_idx = if state.assigned_count < state.candidates.len() {
let idx = state.assigned_count;
state.assigned_count += 1;
idx
} else {
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
let x: Vec<f64> = (0..n_continuous)
.map(|j| {
let bounds = continuous_dim_bounds(&state.dimensions, j);
if let Some(b) = bounds {
sample_random_internal(&mut state.rng, b)
} else {
0.0
}
})
.collect();
let mut categorical_values = HashMap::new();
for (dim_idx, dim) in state.dimensions.iter().enumerate() {
if !dim.is_continuous
&& let Distribution::Categorical(cat) = &dim.distribution
{
categorical_values.insert(dim_idx, state.rng.usize(0..cat.n_choices));
}
}
state.candidates.push(Candidate {
x,
categorical_values,
target_idx: 0, });
let idx = state.candidates.len() - 1;
state.assigned_count = state.candidates.len();
idx
};
state.trial_progress.insert(
trial_id,
TrialProgress {
candidate_idx,
next_dim: 0,
},
);
state.generation_trial_ids.push(trial_id);
}
fn maybe_update_generation(state: &mut State, history: &[CompletedTrial]) {
let pop_size = state.population_size;
if state.generation_trial_ids.len() < pop_size {
return;
}
let trial_ids: Vec<u64> = state
.generation_trial_ids
.iter()
.take(pop_size)
.copied()
.collect();
let history_map: HashMap<u64, f64> = history.iter().map(|t| (t.id, t.value)).collect();
let all_completed = trial_ids.iter().all(|id| history_map.contains_key(id));
if !all_completed {
return;
}
let n_continuous = state.dimensions.iter().filter(|d| d.is_continuous).count();
if state.initialized {
perform_selection(state, &trial_ids, &history_map);
} else {
initialize_population(state, &trial_ids, &history_map, n_continuous);
}
state.candidates = if state.initialized && n_continuous > 0 {
generate_trial_vectors(state)
} else {
generate_initial_population(state)
};
state.assigned_count = 0;
state.generation_trial_ids.clear();
state.trial_progress.clear();
}
fn initialize_population(
state: &mut State,
trial_ids: &[u64],
history_map: &HashMap<u64, f64>,
_n_continuous: usize,
) {
state.population.clear();
state.population_categorical.clear();
state.population_values.clear();
let mut best_value = f64::INFINITY;
let mut best_idx = 0;
for (i, &trial_id) in trial_ids.iter().enumerate() {
let progress = &state.trial_progress[&trial_id];
let candidate = &state.candidates[progress.candidate_idx];
let value = history_map[&trial_id];
state.population.push(candidate.x.clone());
state
.population_categorical
.push(candidate.categorical_values.clone());
state.population_values.push(value);
if value < best_value {
best_value = value;
best_idx = i;
}
}
state.best_idx = best_idx;
state.initialized = true;
}
fn perform_selection(state: &mut State, trial_ids: &[u64], history_map: &HashMap<u64, f64>) {
for &trial_id in trial_ids {
let progress = &state.trial_progress[&trial_id];
let candidate = &state.candidates[progress.candidate_idx];
let trial_value = history_map[&trial_id];
let target_idx = candidate.target_idx;
if target_idx < state.population_size && trial_value <= state.population_values[target_idx]
{
state.population[target_idx] = candidate.x.clone();
state.population_categorical[target_idx] = candidate.categorical_values.clone();
state.population_values[target_idx] = trial_value;
}
}
let mut best_value = f64::INFINITY;
let mut best_idx = 0;
for (i, &val) in state.population_values.iter().enumerate() {
if val < best_value {
best_value = val;
best_idx = i;
}
}
state.best_idx = best_idx;
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
mod tests {
use super::*;
use crate::distribution::FloatDistribution;
#[test]
fn test_de_sampler_basic_float() {
let sampler = DESampler::with_seed(42);
let dist = Distribution::Float(FloatDistribution {
low: -5.0,
high: 5.0,
log_scale: false,
step: None,
});
for i in 0..100 {
let value = sampler.sample(&dist, i, &[]);
if let ParamValue::Float(v) = value {
assert!(
(-5.0..=5.0).contains(&v),
"value {v} out of bounds at trial {i}"
);
} else {
panic!("Expected Float value");
}
}
}
#[test]
fn test_de_sampler_reproducibility() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let sample_values = |seed: u64| {
let sampler = DESampler::with_seed(seed);
(0..20)
.map(|i| sampler.sample(&dist, i, &[]))
.collect::<Vec<_>>()
};
let v1 = sample_values(42);
let v2 = sample_values(42);
assert_eq!(v1, v2, "same seed should produce same results");
let v3 = sample_values(99);
assert_ne!(v1, v3, "different seeds should produce different results");
}
#[test]
fn test_de_strategy_default() {
assert!(matches!(DEStrategy::default(), DEStrategy::Rand1));
}
#[test]
fn test_builder_defaults() {
let builder = DESamplerBuilder::new();
assert!(builder.population_size.is_none());
assert!((builder.mutation_factor - 0.8).abs() < f64::EPSILON);
assert!((builder.crossover_rate - 0.9).abs() < f64::EPSILON);
assert!(matches!(builder.strategy, DEStrategy::Rand1));
assert!(builder.seed.is_none());
}
}