argmin/solver/simulatedannealing/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Simulated Annealing
9//!
10//! Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
11//! metallurgy. For details see [`SimulatedAnnealing`].
12//!
13//! ## References
14//!
15//! [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
16//!
17//! S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
18//! Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
19//! DOI: 10.1126/science.220.4598.671
20
21use crate::core::{
22    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
23    TerminationStatus, KV,
24};
25use rand::prelude::*;
26use rand_xoshiro::Xoshiro256PlusPlus;
27#[cfg(feature = "serde1")]
28use serde::{Deserialize, Serialize};
29
30/// This trait handles the annealing of a parameter vector. Problems which are to be solved using
31/// [`SimulatedAnnealing`] must implement this trait.
32pub trait Anneal {
33    /// Type of the parameter vector
34    type Param;
35    /// Return type of the anneal function
36    type Output;
37    /// Precision of floats
38    type Float;
39
40    /// Anneal a parameter vector
41    fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error>;
42}
43
44/// Wraps a call to `anneal` defined in the `Anneal` trait and as such allows to call `anneal` on
45/// an instance of `Problem`. Internally, the number of evaluations of `anneal` is counted.
46impl<O: Anneal> Problem<O> {
47    /// Calls `anneal` defined in the `Anneal` trait and keeps track of the number of evaluations.
48    ///
49    /// # Example
50    ///
51    /// ```
52    /// # use argmin::core::{Problem, Error};
53    /// # use argmin::solver::simulatedannealing::Anneal;
54    /// #
55    /// # #[derive(Eq, PartialEq, Debug, Clone)]
56    /// # struct UserDefinedProblem {};
57    /// #
58    /// # impl Anneal for UserDefinedProblem {
59    /// #     type Param = Vec<f64>;
60    /// #     type Output = Vec<f64>;
61    /// #     type Float = f64;
62    /// #
63    /// #     fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error> {
64    /// #         Ok(vec![1.0f64, 1.0f64])
65    /// #     }
66    /// # }
67    /// // `UserDefinedProblem` implements `Anneal`.
68    /// let mut problem1 = Problem::new(UserDefinedProblem {});
69    ///
70    /// let param = vec![2.0f64, 1.0f64];
71    ///
72    /// let res = problem1.anneal(&param, 1.0);
73    ///
74    /// assert_eq!(problem1.counts["anneal_count"], 1);
75    /// # assert_eq!(res.unwrap(), vec![1.0f64, 1.0f64]);
76    /// ```
77    pub fn anneal(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Output, Error> {
78        self.problem("anneal_count", |problem| problem.anneal(param, extent))
79    }
80}
81
82/// Temperature functions for Simulated Annealing.
83///
84/// Given the initial temperature `t_init` and the iteration number `i`, the current temperature
85/// `t_i` is given as follows:
86///
87/// * `SATempFunc::TemperatureFast`: `t_i = t_init / i`
88/// * `SATempFunc::Boltzmann`: `t_i = t_init / ln(i)`
89/// * `SATempFunc::Exponential`: `t_i = t_init * 0.95^i`
90#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
91#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
92pub enum SATempFunc<F> {
93    /// `t_i = t_init / i`
94    TemperatureFast,
95    /// `t_i = t_init / ln(i)`
96    #[default]
97    Boltzmann,
98    /// `t_i = t_init * x^i`
99    Exponential(F),
100    // /// User-provided temperature function. The first parameter must be the current temperature and
101    // /// the second parameter must be the iteration number.
102    // Custom(Box<dyn Fn(f64, u64) -> f64 + 'static>),
103}
104
105/// # Simulated Annealing
106///
107/// Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
108/// metallurgy. Parameter vectors are randomly modified in each iteration, where the degree of
109/// modification depends on the current temperature. The algorithm starts with a high temperature
110/// (a lot of modification and hence movement in parameter space) and continuously cools down as
111/// the iterations progress, hence narrowing down in the search. Under certain conditions,
112/// reannealing (increasing the temperature) can be performed. Solutions which are better than the
113/// previous one are always accepted and solutions which are worse are accepted with a probability
114/// proportional to the cost function value difference of previous to current parameter vector.
115/// These measures allow the algorithm to explore the parameter space in a large and a small scale
116/// and hence it is able to overcome local minima.
117///
118/// The initial temperature has to be provided by the user as well as the a initial parameter
119/// vector (via [`configure`](`crate::core::Executor::configure`) of
120/// [`Executor`](`crate::core::Executor`).
121///
122/// The cooling schedule can be set with [`SimulatedAnnealing::with_temp_func`]. For the available
123/// choices please see [`SATempFunc`].
124///
125/// Reannealing can be performed if no new best solution was found for `N` iterations
126/// ([`SimulatedAnnealing::with_reannealing_best`]), or if no new accepted solution was found for
127/// `N` iterations ([`SimulatedAnnealing::with_reannealing_accepted`]) or every `N` iterations
128/// without any other conditions ([`SimulatedAnnealing::with_reannealing_fixed`]).
129///
130/// The user-provided problem must implement [`Anneal`] which defines how parameter vectors are
131/// modified. Please see the Simulated Annealing example for one approach to do so for floating
132/// point parameters.
133///
134/// ## Requirements on the optimization problem
135///
136/// The optimization problem is required to implement [`CostFunction`].
137///
138/// ## References
139///
140/// [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
141///
142/// S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
143/// Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
144/// DOI: 10.1126/science.220.4598.671
145#[derive(Clone)]
146#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
147pub struct SimulatedAnnealing<F, R> {
148    /// Initial temperature
149    init_temp: F,
150    /// Temperature function used for decreasing the temperature
151    temp_func: SATempFunc<F>,
152    /// Number of iterations used for the calculation of temperature. Needed for reannealing
153    temp_iter: u64,
154    /// Number of iterations since the last accepted solution
155    stall_iter_accepted: u64,
156    /// Stop if `stall_iter_accepted` exceeds this number
157    stall_iter_accepted_limit: u64,
158    /// Number of iterations since the last best solution was found
159    stall_iter_best: u64,
160    /// Stop if `stall_iter_best` exceeds this number
161    stall_iter_best_limit: u64,
162    /// Reanneal after this number of iterations is reached
163    reanneal_fixed: u64,
164    /// Number of iterations since beginning or last reannealing
165    reanneal_iter_fixed: u64,
166    /// Reanneal after no accepted solution has been found for `reanneal_accepted` iterations
167    reanneal_accepted: u64,
168    /// Similar to `stall_iter_accepted`, but will be reset to 0 when reannealing  is performed
169    reanneal_iter_accepted: u64,
170    /// Reanneal after no new best solution has been found for `reanneal_best` iterations
171    reanneal_best: u64,
172    /// Similar to `stall_iter_best`, but will be reset to 0 when reannealing is performed
173    reanneal_iter_best: u64,
174    /// current temperature
175    cur_temp: F,
176    /// random number generator
177    rng: R,
178}
179
180impl<F> SimulatedAnnealing<F, Xoshiro256PlusPlus>
181where
182    F: ArgminFloat,
183{
184    /// Construct a new instance of [`SimulatedAnnealing`]
185    ///
186    /// Takes the initial temperature as input, which must be >0.
187    ///
188    /// Uses the `Xoshiro256PlusPlus` RNG internally. For use of another RNG, consider using
189    /// [`SimulatedAnnealing::new_with_rng`].
190    ///
191    /// # Example
192    ///
193    /// ```
194    /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
195    /// # use argmin::core::Error;
196    /// # fn main() -> Result<(), Error> {
197    /// let sa = SimulatedAnnealing::new(100.0f64)?;
198    /// # Ok(())
199    /// # }
200    /// ```
201    pub fn new(initial_temperature: F) -> Result<Self, Error> {
202        SimulatedAnnealing::new_with_rng(initial_temperature, Xoshiro256PlusPlus::from_entropy())
203    }
204}
205
206impl<F, R> SimulatedAnnealing<F, R>
207where
208    F: ArgminFloat,
209{
210    /// Construct a new instance of [`SimulatedAnnealing`]
211    ///
212    /// Takes the initial temperature as input, which must be >0.
213    /// Requires a RNG which must implement `rand::Rng` (and `serde::Serialize` if the `serde1`
214    /// feature is enabled).
215    ///
216    /// # Example
217    ///
218    /// ```
219    /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
220    /// # use argmin::core::Error;
221    /// # fn main() -> Result<(), Error> {
222    /// # let my_rng = ();
223    /// let sa = SimulatedAnnealing::new_with_rng(100.0f64, my_rng)?;
224    /// # Ok(())
225    /// # }
226    /// ```
227    pub fn new_with_rng(init_temp: F, rng: R) -> Result<Self, Error> {
228        if init_temp <= float!(0.0) {
229            Err(argmin_error!(
230                InvalidParameter,
231                "`SimulatedAnnealing`: Initial temperature must be > 0."
232            ))
233        } else {
234            Ok(SimulatedAnnealing {
235                init_temp,
236                temp_func: SATempFunc::TemperatureFast,
237                temp_iter: 0,
238                stall_iter_accepted: 0,
239                stall_iter_accepted_limit: std::u64::MAX,
240                stall_iter_best: 0,
241                stall_iter_best_limit: std::u64::MAX,
242                reanneal_fixed: std::u64::MAX,
243                reanneal_iter_fixed: 0,
244                reanneal_accepted: std::u64::MAX,
245                reanneal_iter_accepted: 0,
246                reanneal_best: std::u64::MAX,
247                reanneal_iter_best: 0,
248                cur_temp: init_temp,
249                rng,
250            })
251        }
252    }
253
254    /// Set temperature function
255    ///
256    /// The temperature function defines how the temperature is decreased over the course of the
257    /// iterations.
258    /// See [`SATempFunc`] for the available options. Defaults to [`SATempFunc::TemperatureFast`].
259    ///
260    /// # Example
261    ///
262    /// ```
263    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
264    /// # use argmin::core::Error;
265    /// # fn main() -> Result<(), Error> {
266    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_temp_func(SATempFunc::Boltzmann);
267    /// # Ok(())
268    /// # }
269    /// ```
270    #[must_use]
271    pub fn with_temp_func(mut self, temperature_func: SATempFunc<F>) -> Self {
272        self.temp_func = temperature_func;
273        self
274    }
275
276    /// If there are no accepted solutions for `iter` iterations, the algorithm stops.
277    ///
278    /// Defaults to `std::u64::MAX`.
279    ///
280    /// # Example
281    ///
282    /// ```
283    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
284    /// # use argmin::core::Error;
285    /// # fn main() -> Result<(), Error> {
286    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_accepted(1000);
287    /// # Ok(())
288    /// # }
289    /// ```
290    #[must_use]
291    pub fn with_stall_accepted(mut self, iter: u64) -> Self {
292        self.stall_iter_accepted_limit = iter;
293        self
294    }
295
296    /// If there are no new best solutions for `iter` iterations, the algorithm stops.
297    ///
298    /// Defaults to `std::u64::MAX`.
299    ///
300    /// # Example
301    ///
302    /// ```
303    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
304    /// # use argmin::core::Error;
305    /// # fn main() -> Result<(), Error> {
306    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_best(2000);
307    /// # Ok(())
308    /// # }
309    /// ```
310    #[must_use]
311    pub fn with_stall_best(mut self, iter: u64) -> Self {
312        self.stall_iter_best_limit = iter;
313        self
314    }
315
316    /// Set number of iterations after which reannealing is performed
317    ///
318    /// Every `iter` iterations, reannealing (resetting temperature to its initial value) will be
319    /// performed. This may help in overcoming local minima.
320    ///
321    /// Defaults to `std::u64::MAX`.
322    ///
323    /// # Example
324    ///
325    /// ```
326    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
327    /// # use argmin::core::Error;
328    /// # fn main() -> Result<(), Error> {
329    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_fixed(5000);
330    /// # Ok(())
331    /// # }
332    /// ```
333    #[must_use]
334    pub fn with_reannealing_fixed(mut self, iter: u64) -> Self {
335        self.reanneal_fixed = iter;
336        self
337    }
338
339    /// Set the number of iterations that need to pass after the last accepted solution was found
340    /// for reannealing to be performed.
341    ///
342    /// If no new accepted solution is found for `iter` iterations, reannealing (resetting
343    /// temperature to its initial value) is performed. This may help in overcoming local minima.
344    ///
345    /// Defaults to `std::u64::MAX`.
346    ///
347    /// # Example
348    ///
349    /// ```
350    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
351    /// # use argmin::core::Error;
352    /// # fn main() -> Result<(), Error> {
353    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_accepted(5000);
354    /// # Ok(())
355    /// # }
356    /// ```
357    #[must_use]
358    pub fn with_reannealing_accepted(mut self, iter: u64) -> Self {
359        self.reanneal_accepted = iter;
360        self
361    }
362
363    /// Set the number of iterations that need to pass after the last best solution was found
364    /// for reannealing to be performed.
365    ///
366    /// If no new best solution is found for `iter` iterations, reannealing (resetting temperature
367    /// to its initial value) is performed. This may help in overcoming local minima.
368    ///
369    /// Defaults to `std::u64::MAX`.
370    ///
371    /// # Example
372    ///
373    /// ```
374    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
375    /// # use argmin::core::Error;
376    /// # fn main() -> Result<(), Error> {
377    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_best(5000);
378    /// # Ok(())
379    /// # }
380    /// ```
381    #[must_use]
382    pub fn with_reannealing_best(mut self, iter: u64) -> Self {
383        self.reanneal_best = iter;
384        self
385    }
386
387    /// Update the temperature based on the current iteration number.
388    ///
389    /// Updates are performed based on specific update functions. See `SATempFunc` for details.
390    fn update_temperature(&mut self) {
391        self.cur_temp = match self.temp_func {
392            SATempFunc::TemperatureFast => {
393                self.init_temp / F::from_u64(self.temp_iter + 1).unwrap()
394            }
395            SATempFunc::Boltzmann => self.init_temp / F::from_u64(self.temp_iter + 1).unwrap().ln(),
396            SATempFunc::Exponential(x) => {
397                self.init_temp * x.powf(F::from_u64(self.temp_iter + 1).unwrap())
398            }
399        };
400    }
401
402    /// Perform reannealing
403    fn reanneal(&mut self) -> (bool, bool, bool) {
404        let out = (
405            self.reanneal_iter_fixed >= self.reanneal_fixed,
406            self.reanneal_iter_accepted >= self.reanneal_accepted,
407            self.reanneal_iter_best >= self.reanneal_best,
408        );
409        if out.0 || out.1 || out.2 {
410            self.reanneal_iter_fixed = 0;
411            self.reanneal_iter_accepted = 0;
412            self.reanneal_iter_best = 0;
413            self.cur_temp = self.init_temp;
414            self.temp_iter = 0;
415        }
416        out
417    }
418
419    /// Update the stall iter variables
420    fn update_stall_and_reanneal_iter(&mut self, accepted: bool, new_best: bool) {
421        (self.stall_iter_accepted, self.reanneal_iter_accepted) = if accepted {
422            (0, 0)
423        } else {
424            (
425                self.stall_iter_accepted + 1,
426                self.reanneal_iter_accepted + 1,
427            )
428        };
429
430        (self.stall_iter_best, self.reanneal_iter_best) = if new_best {
431            (0, 0)
432        } else {
433            (self.stall_iter_best + 1, self.reanneal_iter_best + 1)
434        };
435    }
436}
437
438impl<O, P, F, R> Solver<O, IterState<P, (), (), (), (), F>> for SimulatedAnnealing<F, R>
439where
440    O: CostFunction<Param = P, Output = F> + Anneal<Param = P, Output = P, Float = F>,
441    P: Clone,
442    F: ArgminFloat,
443    R: Rng,
444{
445    const NAME: &'static str = "Simulated Annealing";
446    fn init(
447        &mut self,
448        problem: &mut Problem<O>,
449        mut state: IterState<P, (), (), (), (), F>,
450    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
451        let param = state.take_param().ok_or_else(argmin_error_closure!(
452            NotInitialized,
453            concat!(
454                "`SimulatedAnnealing` requires an initial parameter vector. ",
455                "Please provide an initial guess via `Executor`s `configure` method."
456            )
457        ))?;
458
459        let cost = state.get_cost();
460        let cost = if cost.is_infinite() {
461            problem.cost(&param)?
462        } else {
463            cost
464        };
465
466        Ok((
467            state.param(param).cost(cost),
468            Some(kv!(
469                "initial_temperature" => self.init_temp;
470                "stall_iter_accepted_limit" => self.stall_iter_accepted_limit;
471                "stall_iter_best_limit" => self.stall_iter_best_limit;
472                "reanneal_fixed" => self.reanneal_fixed;
473                "reanneal_accepted" => self.reanneal_accepted;
474                "reanneal_best" => self.reanneal_best;
475            )),
476        ))
477    }
478
479    /// Perform one iteration of SA algorithm
480    fn next_iter(
481        &mut self,
482        problem: &mut Problem<O>,
483        mut state: IterState<P, (), (), (), (), F>,
484    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
485        // Careful: The order in here is *very* important, even if it may not seem so. Everything
486        // is linked to the iteration number, and getting things mixed up may lead to unexpected
487        // behavior.
488
489        let prev_param = state.take_param().ok_or_else(argmin_error_closure!(
490            PotentialBug,
491            "`SimulatedAnnealing`: Parameter vector in state not set."
492        ))?;
493        let prev_cost = state.get_cost();
494
495        // Make a move
496        let new_param = problem.anneal(&prev_param, self.cur_temp)?;
497
498        // Evaluate cost function with new parameter vector
499        let new_cost = problem.cost(&new_param)?;
500
501        // Acceptance function
502        //
503        // Decide whether new parameter vector should be accepted.
504        // If no, move on with old parameter vector.
505        //
506        // Any solution which satisfies `next_cost < prev_cost` will be accepted. Solutions worse
507        // than the previous one are accepted with a probability given as:
508        //
509        // `1 / (1 + exp((next_cost - prev_cost) / current_temperature))`,
510        //
511        // which will always be between 0 and 0.5.
512        let prob: f64 = self.rng.gen();
513        let prob = float!(prob);
514        let accepted = (new_cost < prev_cost)
515            || (float!(1.0) / (float!(1.0) + ((new_cost - prev_cost) / self.cur_temp).exp())
516                > prob);
517
518        let new_best_found = new_cost < state.best_cost;
519
520        // Update stall iter variables
521        self.update_stall_and_reanneal_iter(accepted, new_best_found);
522
523        let (r_fixed, r_accepted, r_best) = self.reanneal();
524
525        // Update temperature for next iteration.
526        self.temp_iter += 1;
527        // Actually not necessary as it does the same as `temp_iter`, but I'll leave it here for
528        // better readability.
529        self.reanneal_iter_fixed += 1;
530
531        self.update_temperature();
532
533        Ok((
534            if accepted {
535                state.param(new_param).cost(new_cost)
536            } else {
537                state.param(prev_param).cost(prev_cost)
538            },
539            Some(kv!(
540                "t" => self.cur_temp;
541                "new_be" => new_best_found;
542                "acc" => accepted;
543                "st_i_be" => self.stall_iter_best;
544                "st_i_ac" => self.stall_iter_accepted;
545                "ra_i_fi" => self.reanneal_iter_fixed;
546                "ra_i_be" => self.reanneal_iter_best;
547                "ra_i_ac" => self.reanneal_iter_accepted;
548                "ra_fi" => r_fixed;
549                "ra_be" => r_best;
550                "ra_ac" => r_accepted;
551            )),
552        ))
553    }
554
555    fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
556        if self.stall_iter_accepted > self.stall_iter_accepted_limit {
557            return TerminationStatus::Terminated(TerminationReason::SolverExit(
558                "AcceptedStallIterExceeded".to_string(),
559            ));
560        }
561        if self.stall_iter_best > self.stall_iter_best_limit {
562            return TerminationStatus::Terminated(TerminationReason::SolverExit(
563                "BestStallIterExceeded".to_string(),
564            ));
565        }
566        TerminationStatus::NotTerminated
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use crate::core::{test_utils::TestProblem, ArgminError, State};
574    use crate::test_trait_impl;
575    use approx::assert_relative_eq;
576
577    test_trait_impl!(sa, SimulatedAnnealing<f64, StdRng>);
578
579    #[test]
580    fn test_new() {
581        let sa: SimulatedAnnealing<f64, Xoshiro256PlusPlus> =
582            SimulatedAnnealing::new(100.0).unwrap();
583        let SimulatedAnnealing {
584            init_temp,
585            temp_func,
586            temp_iter,
587            stall_iter_accepted,
588            stall_iter_accepted_limit,
589            stall_iter_best,
590            stall_iter_best_limit,
591            reanneal_fixed,
592            reanneal_iter_fixed,
593            reanneal_accepted,
594            reanneal_iter_accepted,
595            reanneal_best,
596            reanneal_iter_best,
597            cur_temp,
598            rng: _rng,
599        } = sa;
600
601        assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
602        assert_eq!(temp_func, SATempFunc::TemperatureFast);
603        assert_eq!(temp_iter, 0);
604        assert_eq!(stall_iter_accepted, 0);
605        assert_eq!(stall_iter_accepted_limit, u64::MAX);
606        assert_eq!(stall_iter_best, 0);
607        assert_eq!(stall_iter_best_limit, u64::MAX);
608        assert_eq!(reanneal_fixed, u64::MAX);
609        assert_eq!(reanneal_iter_fixed, 0);
610        assert_eq!(reanneal_accepted, u64::MAX);
611        assert_eq!(reanneal_iter_accepted, 0);
612        assert_eq!(reanneal_best, u64::MAX);
613        assert_eq!(reanneal_iter_best, 0);
614        assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
615
616        for temp in [0.0, -1.0, -std::f64::EPSILON, -100.0] {
617            let res = SimulatedAnnealing::new(temp);
618            assert_error!(
619                res,
620                ArgminError,
621                "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
622            );
623        }
624    }
625
626    #[test]
627    fn test_new_with_rng() {
628        #[derive(Eq, PartialEq, Debug)]
629        struct MyRng {}
630
631        let sa: SimulatedAnnealing<f64, MyRng> =
632            SimulatedAnnealing::new_with_rng(100.0, MyRng {}).unwrap();
633        let SimulatedAnnealing {
634            init_temp,
635            temp_func,
636            temp_iter,
637            stall_iter_accepted,
638            stall_iter_accepted_limit,
639            stall_iter_best,
640            stall_iter_best_limit,
641            reanneal_fixed,
642            reanneal_iter_fixed,
643            reanneal_accepted,
644            reanneal_iter_accepted,
645            reanneal_best,
646            reanneal_iter_best,
647            cur_temp,
648            rng,
649        } = sa;
650
651        assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
652        assert_eq!(temp_func, SATempFunc::TemperatureFast);
653        assert_eq!(temp_iter, 0);
654        assert_eq!(stall_iter_accepted, 0);
655        assert_eq!(stall_iter_accepted_limit, u64::MAX);
656        assert_eq!(stall_iter_best, 0);
657        assert_eq!(stall_iter_best_limit, u64::MAX);
658        assert_eq!(reanneal_fixed, u64::MAX);
659        assert_eq!(reanneal_iter_fixed, 0);
660        assert_eq!(reanneal_accepted, u64::MAX);
661        assert_eq!(reanneal_iter_accepted, 0);
662        assert_eq!(reanneal_best, u64::MAX);
663        assert_eq!(reanneal_iter_best, 0);
664        assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
665        // important part
666        assert_eq!(rng, MyRng {});
667
668        for temp in [0.0, -1.0, -std::f64::EPSILON, -100.0] {
669            let res = SimulatedAnnealing::new_with_rng(temp, MyRng {});
670            assert_error!(
671                res,
672                ArgminError,
673                "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
674            );
675        }
676    }
677
678    #[test]
679    fn test_with_temp_func() {
680        for func in [
681            SATempFunc::TemperatureFast,
682            SATempFunc::Boltzmann,
683            SATempFunc::Exponential(2.0),
684        ] {
685            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
686            let sa = sa.with_temp_func(func);
687
688            assert_eq!(sa.temp_func, func);
689        }
690    }
691
692    #[test]
693    fn test_with_stall_accepted() {
694        for iter in [0, 1, 5, 10, 100, 100000] {
695            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
696            let sa = sa.with_stall_accepted(iter);
697
698            assert_eq!(sa.stall_iter_accepted_limit, iter);
699        }
700    }
701
702    #[test]
703    fn test_with_stall_best() {
704        for iter in [0, 1, 5, 10, 100, 100000] {
705            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
706            let sa = sa.with_stall_best(iter);
707
708            assert_eq!(sa.stall_iter_best_limit, iter);
709        }
710    }
711
712    #[test]
713    fn test_with_reannealing_fixed() {
714        for iter in [0, 1, 5, 10, 100, 100000] {
715            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
716            let sa = sa.with_reannealing_fixed(iter);
717
718            assert_eq!(sa.reanneal_fixed, iter);
719        }
720    }
721
722    #[test]
723    fn test_with_reannealing_accepted() {
724        for iter in [0, 1, 5, 10, 100, 100000] {
725            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
726            let sa = sa.with_reannealing_accepted(iter);
727
728            assert_eq!(sa.reanneal_accepted, iter);
729        }
730    }
731
732    #[test]
733    fn test_with_reannealing_best() {
734        for iter in [0, 1, 5, 10, 100, 100000] {
735            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
736            let sa = sa.with_reannealing_best(iter);
737
738            assert_eq!(sa.reanneal_best, iter);
739        }
740    }
741
742    #[test]
743    fn test_update_temperature() {
744        for (func, val) in [
745            (SATempFunc::TemperatureFast, 100.0f64 / 2.0),
746            (SATempFunc::Boltzmann, 100.0f64 / 2.0f64.ln()),
747            (SATempFunc::Exponential(3.0), 100.0 * 3.0f64.powi(2)),
748        ] {
749            let mut sa = SimulatedAnnealing::new(100.0f64)
750                .unwrap()
751                .with_temp_func(func);
752            sa.temp_iter = 1;
753
754            sa.update_temperature();
755
756            assert_relative_eq!(sa.cur_temp, val, epsilon = f64::EPSILON);
757        }
758    }
759
760    #[test]
761    fn test_reanneal() {
762        let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
763
764        sa_t.reanneal_fixed = 10;
765        sa_t.reanneal_accepted = 20;
766        sa_t.reanneal_best = 30;
767        sa_t.temp_iter = 40;
768        sa_t.cur_temp = 50.0;
769
770        for ((f, a, b), expected) in [
771            ((0, 0, 0), (false, false, false)),
772            ((10, 0, 0), (true, false, false)),
773            ((11, 0, 0), (true, false, false)),
774            ((0, 20, 0), (false, true, false)),
775            ((0, 21, 0), (false, true, false)),
776            ((0, 0, 30), (false, false, true)),
777            ((0, 0, 31), (false, false, true)),
778            ((10, 20, 0), (true, true, false)),
779            ((10, 0, 30), (true, false, true)),
780            ((0, 20, 30), (false, true, true)),
781            ((10, 20, 30), (true, true, true)),
782        ] {
783            let mut sa = sa_t.clone();
784
785            sa.reanneal_iter_fixed = f;
786            sa.reanneal_iter_accepted = a;
787            sa.reanneal_iter_best = b;
788
789            assert_eq!(sa.reanneal(), expected);
790
791            if expected.0 || expected.1 || expected.2 {
792                assert_eq!(sa.reanneal_iter_fixed, 0);
793                assert_eq!(sa.reanneal_iter_accepted, 0);
794                assert_eq!(sa.reanneal_iter_best, 0);
795                assert_eq!(sa.temp_iter, 0);
796                assert_eq!(sa.cur_temp.to_ne_bytes(), sa.init_temp.to_ne_bytes());
797            }
798        }
799    }
800
801    #[test]
802    fn test_update_stall_and_reanneal_iter() {
803        let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
804
805        sa_t.stall_iter_accepted = 10;
806        sa_t.reanneal_iter_accepted = 20;
807        sa_t.stall_iter_best = 30;
808        sa_t.reanneal_iter_best = 40;
809
810        for ((a, b), (sia, ria, sib, rib)) in [
811            ((false, false), (11, 21, 31, 41)),
812            ((false, true), (11, 21, 0, 0)),
813            ((true, false), (0, 0, 31, 41)),
814            ((true, true), (0, 0, 0, 0)),
815        ] {
816            let mut sa = sa_t.clone();
817
818            sa.update_stall_and_reanneal_iter(a, b);
819
820            assert_eq!(sa.stall_iter_accepted, sia);
821            assert_eq!(sa.reanneal_iter_accepted, ria);
822            assert_eq!(sa.stall_iter_best, sib);
823            assert_eq!(sa.reanneal_iter_best, rib);
824        }
825    }
826
827    #[test]
828    fn test_init() {
829        let param: Vec<f64> = vec![-1.0, 1.0];
830
831        let stall_iter_accepted_limit = 10;
832        let stall_iter_best_limit = 20;
833        let reanneal_fixed = 30;
834        let reanneal_accepted = 40;
835        let reanneal_best = 50;
836
837        let mut sa = SimulatedAnnealing::new(100.0f64)
838            .unwrap()
839            .with_stall_accepted(stall_iter_accepted_limit)
840            .with_stall_best(stall_iter_best_limit)
841            .with_reannealing_fixed(reanneal_fixed)
842            .with_reannealing_accepted(reanneal_accepted)
843            .with_reannealing_best(reanneal_best);
844
845        // Forgot to initialize the parameter vector
846        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
847        let problem = TestProblem::new();
848        let res = sa.init(&mut Problem::new(problem), state);
849        assert_error!(
850            res,
851            ArgminError,
852            concat!(
853                "Not initialized: \"`SimulatedAnnealing` requires an initial parameter vector. ",
854                "Please provide an initial guess via `Executor`s `configure` method.\""
855            )
856        );
857
858        // All good.
859        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().param(param.clone());
860        let problem = TestProblem::new();
861        let (mut state_out, kv) = sa.init(&mut Problem::new(problem), state).unwrap();
862
863        let kv_expected = kv!(
864            "initial_temperature" => 100.0f64;
865            "stall_iter_accepted_limit" => stall_iter_accepted_limit;
866            "stall_iter_best_limit" => stall_iter_best_limit;
867            "reanneal_fixed" => reanneal_fixed;
868            "reanneal_accepted" => reanneal_accepted;
869            "reanneal_best" => reanneal_best;
870        );
871
872        assert_eq!(kv.unwrap(), kv_expected);
873
874        let s_param = state_out.take_param().unwrap();
875
876        for (s, p) in s_param.iter().zip(param.iter()) {
877            assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
878        }
879
880        assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
881    }
882}