use rand::prelude::*;
use serde::{Deserialize, Serialize};
use super::{Budget, OptimizationResult, PerturbativeMetaheuristic, SearchSpace};
use crate::metaheuristics::budget::ConvergenceTracker;
use crate::metaheuristics::traits::TerminationReason;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimulatedAnnealing {
pub initial_temp: f64,
pub final_temp: f64,
pub cooling_rate: f64,
pub step_scale: f64,
#[serde(default)]
seed: Option<u64>,
#[serde(skip)]
current: Vec<f64>,
#[serde(skip)]
current_val: f64,
#[serde(skip)]
best: Vec<f64>,
#[serde(skip)]
best_val: f64,
#[serde(skip)]
history: Vec<f64>,
}
impl Default for SimulatedAnnealing {
fn default() -> Self {
Self {
initial_temp: 100.0,
final_temp: 1e-8,
cooling_rate: 0.95,
step_scale: 0.1,
seed: None,
current: Vec::new(),
current_val: f64::INFINITY,
best: Vec::new(),
best_val: f64::INFINITY,
history: Vec::new(),
}
}
}
impl SimulatedAnnealing {
#[must_use]
pub fn with_initial_temp(mut self, t: f64) -> Self {
self.initial_temp = t;
self
}
#[must_use]
pub fn with_cooling_rate(mut self, alpha: f64) -> Self {
self.cooling_rate = alpha;
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
fn perturb(&self, x: &[f64], lower: &[f64], upper: &[f64], rng: &mut impl Rng) -> Vec<f64> {
x.iter()
.enumerate()
.map(|(i, &xi)| {
let range = upper[i] - lower[i];
let delta = rng.random_range(-1.0..=1.0) * range * self.step_scale;
(xi + delta).clamp(lower[i], upper[i])
})
.collect()
}
}
impl PerturbativeMetaheuristic for SimulatedAnnealing {
type Solution = Vec<f64>;
fn optimize<F>(
&mut self,
objective: &F,
space: &SearchSpace,
budget: Budget,
) -> OptimizationResult<Self::Solution>
where
F: Fn(&[f64]) -> f64,
{
let mut rng: Box<dyn RngCore> = match self.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
let (lower, upper, dim) = match space {
SearchSpace::Continuous { dim, lower, upper } => (lower.clone(), upper.clone(), *dim),
_ => panic!("SA requires continuous search space"),
};
self.current = (0..dim)
.map(|j| rng.random_range(lower[j]..=upper[j]))
.collect();
self.current_val = objective(&self.current);
self.best = self.current.clone();
self.best_val = self.current_val;
self.history.clear();
self.history.push(self.best_val);
let mut tracker = ConvergenceTracker::from_budget(&budget);
tracker.update(self.best_val, 1);
let mut temp = self.initial_temp;
let max_iter = budget.max_evaluations(1);
for _ in 0..max_iter {
if temp < self.final_temp {
break;
}
let candidate = self.perturb(&self.current, &lower, &upper, &mut rng);
let candidate_val = objective(&candidate);
let delta = candidate_val - self.current_val;
let accept = delta < 0.0 || rng.random::<f64>() < (-delta / temp).exp();
if accept {
self.current = candidate;
self.current_val = candidate_val;
if self.current_val < self.best_val {
self.best = self.current.clone();
self.best_val = self.current_val;
}
}
self.history.push(self.best_val);
temp *= self.cooling_rate;
if !tracker.update(self.best_val, 1) {
break;
}
}
let termination = if tracker.is_converged() {
TerminationReason::Converged
} else if tracker.is_exhausted() {
TerminationReason::BudgetExhausted
} else {
TerminationReason::MaxIterations
};
OptimizationResult::new(
self.best.clone(),
self.best_val,
tracker.evaluations(),
self.history.len(),
self.history.clone(),
termination,
)
}
fn best(&self) -> Option<&Self::Solution> {
if self.best.is_empty() {
None
} else {
Some(&self.best)
}
}
fn history(&self) -> &[f64] {
&self.history
}
fn reset(&mut self) {
self.current.clear();
self.current_val = f64::INFINITY;
self.best.clear();
self.best_val = f64::INFINITY;
self.history.clear();
}
}
#[cfg(test)]
#[path = "tests_sa_contract.rs"]
mod tests_sa_contract;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sa_sphere() {
let objective = |x: &[f64]| x.iter().map(|xi| xi * xi).sum();
let mut sa = SimulatedAnnealing::default().with_seed(42);
let space = SearchSpace::continuous(5, -5.0, 5.0);
let result = sa.optimize(&objective, &space, Budget::Evaluations(10000));
assert!(result.objective_value < 1.0);
}
#[test]
fn test_sa_builder() {
let sa = SimulatedAnnealing::default()
.with_initial_temp(500.0)
.with_cooling_rate(0.99)
.with_seed(123);
assert!((sa.initial_temp - 500.0).abs() < 1e-10);
assert!((sa.cooling_rate - 0.99).abs() < 1e-10);
}
#[test]
fn test_sa_reset() {
let objective = |x: &[f64]| x.iter().sum::<f64>();
let mut sa = SimulatedAnnealing::default().with_seed(42);
let space = SearchSpace::continuous(2, -1.0, 1.0);
let _ = sa.optimize(&objective, &space, Budget::Evaluations(100));
assert!(sa.best().is_some());
sa.reset();
assert!(sa.best().is_none());
}
#[test]
fn test_sa_temperature_exhaustion() {
let objective = |x: &[f64]| x.iter().map(|xi| xi * xi).sum();
let mut sa = SimulatedAnnealing::default()
.with_initial_temp(1.0)
.with_cooling_rate(0.01) .with_seed(99);
let space = SearchSpace::continuous(2, -5.0, 5.0);
let result = sa.optimize(&objective, &space, Budget::Evaluations(100_000));
assert!(
result.termination == TerminationReason::MaxIterations
|| result.termination == TerminationReason::Converged
);
assert!(result.iterations < 100_000);
}
#[test]
fn test_sa_budget_exhausted_termination() {
let objective = |x: &[f64]| x.iter().map(|xi| xi * xi).sum();
let mut sa = SimulatedAnnealing::default()
.with_initial_temp(1e10)
.with_cooling_rate(0.9999)
.with_seed(7);
let space = SearchSpace::continuous(2, -5.0, 5.0);
let result = sa.optimize(&objective, &space, Budget::Evaluations(10));
assert_eq!(result.termination, TerminationReason::BudgetExhausted);
}
#[test]
fn test_sa_convergence_termination() {
let objective = |_x: &[f64]| 0.0; let mut sa = SimulatedAnnealing::default()
.with_initial_temp(100.0)
.with_cooling_rate(0.95)
.with_seed(42);
let space = SearchSpace::continuous(2, -1.0, 1.0);
let budget = Budget::Convergence {
patience: 5,
min_delta: 1e-6,
max_evaluations: 100_000,
};
let result = sa.optimize(&objective, &space, budget);
assert_eq!(result.termination, TerminationReason::Converged);
}
#[test]
fn test_sa_best_before_optimize() {
let sa = SimulatedAnnealing::default();
assert!(sa.best().is_none());
assert!(sa.history().is_empty());
}
#[test]
fn test_sa_history_populated() {
let objective = |x: &[f64]| x.iter().map(|xi| xi * xi).sum();
let mut sa = SimulatedAnnealing::default().with_seed(42);
let space = SearchSpace::continuous(2, -5.0, 5.0);
let result = sa.optimize(&objective, &space, Budget::Evaluations(50));
assert!(!result.history.is_empty());
assert_eq!(sa.history().len(), result.history.len());
}
#[test]
fn test_sa_default_values() {
let sa = SimulatedAnnealing::default();
assert!((sa.initial_temp - 100.0).abs() < 1e-10);
assert!((sa.final_temp - 1e-8).abs() < 1e-15);
assert!((sa.cooling_rate - 0.95).abs() < 1e-10);
assert!((sa.step_scale - 0.1).abs() < 1e-10);
}
}