use crate::core::{
ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
TerminationStatus, KV,
};
use rand::prelude::*;
use rand_xoshiro::Xoshiro256PlusPlus;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
pub trait Anneal {
type Param;
type Output;
type Float;
fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error>;
}
impl<O: Anneal> Problem<O> {
pub fn anneal(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Output, Error> {
self.problem("anneal_count", |problem| problem.anneal(param, extent))
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum SATempFunc<F> {
TemperatureFast,
#[default]
Boltzmann,
Exponential(F),
}
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct SimulatedAnnealing<F, R> {
init_temp: F,
temp_func: SATempFunc<F>,
temp_iter: u64,
stall_iter_accepted: u64,
stall_iter_accepted_limit: u64,
stall_iter_best: u64,
stall_iter_best_limit: u64,
reanneal_fixed: u64,
reanneal_iter_fixed: u64,
reanneal_accepted: u64,
reanneal_iter_accepted: u64,
reanneal_best: u64,
reanneal_iter_best: u64,
cur_temp: F,
rng: R,
}
impl<F> SimulatedAnnealing<F, Xoshiro256PlusPlus>
where
F: ArgminFloat,
{
pub fn new(initial_temperature: F) -> Result<Self, Error> {
SimulatedAnnealing::new_with_rng(
initial_temperature,
Xoshiro256PlusPlus::try_from_os_rng()?,
)
}
}
impl<F, R> SimulatedAnnealing<F, R>
where
F: ArgminFloat,
{
pub fn new_with_rng(init_temp: F, rng: R) -> Result<Self, Error> {
if init_temp <= float!(0.0) {
Err(argmin_error!(
InvalidParameter,
"`SimulatedAnnealing`: Initial temperature must be > 0."
))
} else {
Ok(SimulatedAnnealing {
init_temp,
temp_func: SATempFunc::TemperatureFast,
temp_iter: 0,
stall_iter_accepted: 0,
stall_iter_accepted_limit: u64::MAX,
stall_iter_best: 0,
stall_iter_best_limit: u64::MAX,
reanneal_fixed: u64::MAX,
reanneal_iter_fixed: 0,
reanneal_accepted: u64::MAX,
reanneal_iter_accepted: 0,
reanneal_best: u64::MAX,
reanneal_iter_best: 0,
cur_temp: init_temp,
rng,
})
}
}
#[must_use]
pub fn with_temp_func(mut self, temperature_func: SATempFunc<F>) -> Self {
self.temp_func = temperature_func;
self
}
#[must_use]
pub fn with_stall_accepted(mut self, iter: u64) -> Self {
self.stall_iter_accepted_limit = iter;
self
}
#[must_use]
pub fn with_stall_best(mut self, iter: u64) -> Self {
self.stall_iter_best_limit = iter;
self
}
#[must_use]
pub fn with_reannealing_fixed(mut self, iter: u64) -> Self {
self.reanneal_fixed = iter;
self
}
#[must_use]
pub fn with_reannealing_accepted(mut self, iter: u64) -> Self {
self.reanneal_accepted = iter;
self
}
#[must_use]
pub fn with_reannealing_best(mut self, iter: u64) -> Self {
self.reanneal_best = iter;
self
}
fn update_temperature(&mut self) {
self.cur_temp = match self.temp_func {
SATempFunc::TemperatureFast => {
self.init_temp / F::from_u64(self.temp_iter + 1).unwrap()
}
SATempFunc::Boltzmann => self.init_temp / F::from_u64(self.temp_iter + 1).unwrap().ln(),
SATempFunc::Exponential(x) => {
self.init_temp * x.powf(F::from_u64(self.temp_iter + 1).unwrap())
}
};
}
fn reanneal(&mut self) -> (bool, bool, bool) {
let out = (
self.reanneal_iter_fixed >= self.reanneal_fixed,
self.reanneal_iter_accepted >= self.reanneal_accepted,
self.reanneal_iter_best >= self.reanneal_best,
);
if out.0 || out.1 || out.2 {
self.reanneal_iter_fixed = 0;
self.reanneal_iter_accepted = 0;
self.reanneal_iter_best = 0;
self.cur_temp = self.init_temp;
self.temp_iter = 0;
}
out
}
fn update_stall_and_reanneal_iter(&mut self, accepted: bool, new_best: bool) {
(self.stall_iter_accepted, self.reanneal_iter_accepted) = if accepted {
(0, 0)
} else {
(
self.stall_iter_accepted + 1,
self.reanneal_iter_accepted + 1,
)
};
(self.stall_iter_best, self.reanneal_iter_best) = if new_best {
(0, 0)
} else {
(self.stall_iter_best + 1, self.reanneal_iter_best + 1)
};
}
}
impl<O, P, F, R> Solver<O, IterState<P, (), (), (), (), F>> for SimulatedAnnealing<F, R>
where
O: CostFunction<Param = P, Output = F> + Anneal<Param = P, Output = P, Float = F>,
P: Clone,
F: ArgminFloat,
R: Rng,
{
fn name(&self) -> &str {
"Simulated Annealing"
}
fn init(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
let param = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SimulatedAnnealing` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let cost = state.get_cost();
let cost = if cost.is_infinite() {
problem.cost(¶m)?
} else {
cost
};
Ok((
state.param(param).cost(cost),
Some(kv!(
"initial_temperature" => self.init_temp;
"stall_iter_accepted_limit" => self.stall_iter_accepted_limit;
"stall_iter_best_limit" => self.stall_iter_best_limit;
"reanneal_fixed" => self.reanneal_fixed;
"reanneal_accepted" => self.reanneal_accepted;
"reanneal_best" => self.reanneal_best;
)),
))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
let prev_param = state.take_param().ok_or_else(argmin_error_closure!(
PotentialBug,
"`SimulatedAnnealing`: Parameter vector in state not set."
))?;
let prev_cost = state.get_cost();
let new_param = problem.anneal(&prev_param, self.cur_temp)?;
let new_cost = problem.cost(&new_param)?;
let prob: f64 = self.rng.random();
let prob = float!(prob);
let accepted = (new_cost < prev_cost)
|| (float!(1.0) / (float!(1.0) + ((new_cost - prev_cost) / self.cur_temp).exp())
> prob);
let new_best_found = new_cost < state.best_cost;
self.update_stall_and_reanneal_iter(accepted, new_best_found);
let (r_fixed, r_accepted, r_best) = self.reanneal();
self.temp_iter += 1;
self.reanneal_iter_fixed += 1;
self.update_temperature();
Ok((
if accepted {
state.param(new_param).cost(new_cost)
} else {
state.param(prev_param).cost(prev_cost)
},
Some(kv!(
"t" => self.cur_temp;
"new_be" => new_best_found;
"acc" => accepted;
"st_i_be" => self.stall_iter_best;
"st_i_ac" => self.stall_iter_accepted;
"ra_i_fi" => self.reanneal_iter_fixed;
"ra_i_be" => self.reanneal_iter_best;
"ra_i_ac" => self.reanneal_iter_accepted;
"ra_fi" => r_fixed;
"ra_be" => r_best;
"ra_ac" => r_accepted;
)),
))
}
fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
if self.stall_iter_accepted > self.stall_iter_accepted_limit {
return TerminationStatus::Terminated(TerminationReason::SolverExit(
"AcceptedStallIterExceeded".to_string(),
));
}
if self.stall_iter_best > self.stall_iter_best_limit {
return TerminationStatus::Terminated(TerminationReason::SolverExit(
"BestStallIterExceeded".to_string(),
));
}
TerminationStatus::NotTerminated
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{test_utils::TestProblem, ArgminError, State};
use approx::assert_relative_eq;
test_trait_impl!(sa, SimulatedAnnealing<f64, StdRng>);
#[test]
fn test_new() {
let sa: SimulatedAnnealing<f64, Xoshiro256PlusPlus> =
SimulatedAnnealing::new(100.0).unwrap();
let SimulatedAnnealing {
init_temp,
temp_func,
temp_iter,
stall_iter_accepted,
stall_iter_accepted_limit,
stall_iter_best,
stall_iter_best_limit,
reanneal_fixed,
reanneal_iter_fixed,
reanneal_accepted,
reanneal_iter_accepted,
reanneal_best,
reanneal_iter_best,
cur_temp,
rng: _rng,
} = sa;
assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
assert_eq!(temp_func, SATempFunc::TemperatureFast);
assert_eq!(temp_iter, 0);
assert_eq!(stall_iter_accepted, 0);
assert_eq!(stall_iter_accepted_limit, u64::MAX);
assert_eq!(stall_iter_best, 0);
assert_eq!(stall_iter_best_limit, u64::MAX);
assert_eq!(reanneal_fixed, u64::MAX);
assert_eq!(reanneal_iter_fixed, 0);
assert_eq!(reanneal_accepted, u64::MAX);
assert_eq!(reanneal_iter_accepted, 0);
assert_eq!(reanneal_best, u64::MAX);
assert_eq!(reanneal_iter_best, 0);
assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
for temp in [0.0, -1.0, -f64::EPSILON, -100.0] {
let res = SimulatedAnnealing::new(temp);
assert_error!(
res,
ArgminError,
"Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
);
}
}
#[test]
fn test_new_with_rng() {
#[derive(Eq, PartialEq, Debug)]
struct MyRng {}
let sa: SimulatedAnnealing<f64, MyRng> =
SimulatedAnnealing::new_with_rng(100.0, MyRng {}).unwrap();
let SimulatedAnnealing {
init_temp,
temp_func,
temp_iter,
stall_iter_accepted,
stall_iter_accepted_limit,
stall_iter_best,
stall_iter_best_limit,
reanneal_fixed,
reanneal_iter_fixed,
reanneal_accepted,
reanneal_iter_accepted,
reanneal_best,
reanneal_iter_best,
cur_temp,
rng,
} = sa;
assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
assert_eq!(temp_func, SATempFunc::TemperatureFast);
assert_eq!(temp_iter, 0);
assert_eq!(stall_iter_accepted, 0);
assert_eq!(stall_iter_accepted_limit, u64::MAX);
assert_eq!(stall_iter_best, 0);
assert_eq!(stall_iter_best_limit, u64::MAX);
assert_eq!(reanneal_fixed, u64::MAX);
assert_eq!(reanneal_iter_fixed, 0);
assert_eq!(reanneal_accepted, u64::MAX);
assert_eq!(reanneal_iter_accepted, 0);
assert_eq!(reanneal_best, u64::MAX);
assert_eq!(reanneal_iter_best, 0);
assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
assert_eq!(rng, MyRng {});
for temp in [0.0, -1.0, -f64::EPSILON, -100.0] {
let res = SimulatedAnnealing::new_with_rng(temp, MyRng {});
assert_error!(
res,
ArgminError,
"Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
);
}
}
#[test]
fn test_with_temp_func() {
for func in [
SATempFunc::TemperatureFast,
SATempFunc::Boltzmann,
SATempFunc::Exponential(2.0),
] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_temp_func(func);
assert_eq!(sa.temp_func, func);
}
}
#[test]
fn test_with_stall_accepted() {
for iter in [0, 1, 5, 10, 100, 100000] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_stall_accepted(iter);
assert_eq!(sa.stall_iter_accepted_limit, iter);
}
}
#[test]
fn test_with_stall_best() {
for iter in [0, 1, 5, 10, 100, 100000] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_stall_best(iter);
assert_eq!(sa.stall_iter_best_limit, iter);
}
}
#[test]
fn test_with_reannealing_fixed() {
for iter in [0, 1, 5, 10, 100, 100000] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_reannealing_fixed(iter);
assert_eq!(sa.reanneal_fixed, iter);
}
}
#[test]
fn test_with_reannealing_accepted() {
for iter in [0, 1, 5, 10, 100, 100000] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_reannealing_accepted(iter);
assert_eq!(sa.reanneal_accepted, iter);
}
}
#[test]
fn test_with_reannealing_best() {
for iter in [0, 1, 5, 10, 100, 100000] {
let sa = SimulatedAnnealing::new(100.0f64).unwrap();
let sa = sa.with_reannealing_best(iter);
assert_eq!(sa.reanneal_best, iter);
}
}
#[test]
fn test_update_temperature() {
for (func, val) in [
(SATempFunc::TemperatureFast, 100.0f64 / 2.0),
(SATempFunc::Boltzmann, 100.0f64 / 2.0f64.ln()),
(SATempFunc::Exponential(3.0), 100.0 * 3.0f64.powi(2)),
] {
let mut sa = SimulatedAnnealing::new(100.0f64)
.unwrap()
.with_temp_func(func);
sa.temp_iter = 1;
sa.update_temperature();
assert_relative_eq!(sa.cur_temp, val, epsilon = f64::EPSILON);
}
}
#[test]
fn test_reanneal() {
let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
sa_t.reanneal_fixed = 10;
sa_t.reanneal_accepted = 20;
sa_t.reanneal_best = 30;
sa_t.temp_iter = 40;
sa_t.cur_temp = 50.0;
for ((f, a, b), expected) in [
((0, 0, 0), (false, false, false)),
((10, 0, 0), (true, false, false)),
((11, 0, 0), (true, false, false)),
((0, 20, 0), (false, true, false)),
((0, 21, 0), (false, true, false)),
((0, 0, 30), (false, false, true)),
((0, 0, 31), (false, false, true)),
((10, 20, 0), (true, true, false)),
((10, 0, 30), (true, false, true)),
((0, 20, 30), (false, true, true)),
((10, 20, 30), (true, true, true)),
] {
let mut sa = sa_t.clone();
sa.reanneal_iter_fixed = f;
sa.reanneal_iter_accepted = a;
sa.reanneal_iter_best = b;
assert_eq!(sa.reanneal(), expected);
if expected.0 || expected.1 || expected.2 {
assert_eq!(sa.reanneal_iter_fixed, 0);
assert_eq!(sa.reanneal_iter_accepted, 0);
assert_eq!(sa.reanneal_iter_best, 0);
assert_eq!(sa.temp_iter, 0);
assert_eq!(sa.cur_temp.to_ne_bytes(), sa.init_temp.to_ne_bytes());
}
}
}
#[test]
fn test_update_stall_and_reanneal_iter() {
let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
sa_t.stall_iter_accepted = 10;
sa_t.reanneal_iter_accepted = 20;
sa_t.stall_iter_best = 30;
sa_t.reanneal_iter_best = 40;
for ((a, b), (sia, ria, sib, rib)) in [
((false, false), (11, 21, 31, 41)),
((false, true), (11, 21, 0, 0)),
((true, false), (0, 0, 31, 41)),
((true, true), (0, 0, 0, 0)),
] {
let mut sa = sa_t.clone();
sa.update_stall_and_reanneal_iter(a, b);
assert_eq!(sa.stall_iter_accepted, sia);
assert_eq!(sa.reanneal_iter_accepted, ria);
assert_eq!(sa.stall_iter_best, sib);
assert_eq!(sa.reanneal_iter_best, rib);
}
}
#[test]
fn test_init() {
let param: Vec<f64> = vec![-1.0, 1.0];
let stall_iter_accepted_limit = 10;
let stall_iter_best_limit = 20;
let reanneal_fixed = 30;
let reanneal_accepted = 40;
let reanneal_best = 50;
let mut sa = SimulatedAnnealing::new(100.0f64)
.unwrap()
.with_stall_accepted(stall_iter_accepted_limit)
.with_stall_best(stall_iter_best_limit)
.with_reannealing_fixed(reanneal_fixed)
.with_reannealing_accepted(reanneal_accepted)
.with_reannealing_best(reanneal_best);
let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
let problem = TestProblem::new();
let res = sa.init(&mut Problem::new(problem), state);
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`SimulatedAnnealing` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().param(param.clone());
let problem = TestProblem::new();
let (mut state_out, kv) = sa.init(&mut Problem::new(problem), state).unwrap();
let kv_expected = kv!(
"initial_temperature" => 100.0f64;
"stall_iter_accepted_limit" => stall_iter_accepted_limit;
"stall_iter_best_limit" => stall_iter_best_limit;
"reanneal_fixed" => reanneal_fixed;
"reanneal_accepted" => reanneal_accepted;
"reanneal_best" => reanneal_best;
);
assert_eq!(kv.unwrap(), kv_expected);
let s_param = state_out.take_param().unwrap();
for (s, p) in s_param.iter().zip(param.iter()) {
assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
}
assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
}
}