argmin/solver/simulatedannealing/mod.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Simulated Annealing
9//!
10//! Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
11//! metallurgy. For details see [`SimulatedAnnealing`].
12//!
13//! ## References
14//!
15//! [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
16//!
17//! S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
18//! Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
19//! DOI: 10.1126/science.220.4598.671
20
21use crate::core::{
22 ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
23 TerminationStatus, KV,
24};
25use rand::prelude::*;
26use rand_xoshiro::Xoshiro256PlusPlus;
27#[cfg(feature = "serde1")]
28use serde::{Deserialize, Serialize};
29
30/// This trait handles the annealing of a parameter vector. Problems which are to be solved using
31/// [`SimulatedAnnealing`] must implement this trait.
32pub trait Anneal {
33 /// Type of the parameter vector
34 type Param;
35 /// Return type of the anneal function
36 type Output;
37 /// Precision of floats
38 type Float;
39
40 /// Anneal a parameter vector
41 fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error>;
42}
43
44/// Wraps a call to `anneal` defined in the `Anneal` trait and as such allows to call `anneal` on
45/// an instance of `Problem`. Internally, the number of evaluations of `anneal` is counted.
46impl<O: Anneal> Problem<O> {
47 /// Calls `anneal` defined in the `Anneal` trait and keeps track of the number of evaluations.
48 ///
49 /// # Example
50 ///
51 /// ```
52 /// # use argmin::core::{Problem, Error};
53 /// # use argmin::solver::simulatedannealing::Anneal;
54 /// #
55 /// # #[derive(Eq, PartialEq, Debug, Clone)]
56 /// # struct UserDefinedProblem {};
57 /// #
58 /// # impl Anneal for UserDefinedProblem {
59 /// # type Param = Vec<f64>;
60 /// # type Output = Vec<f64>;
61 /// # type Float = f64;
62 /// #
63 /// # fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error> {
64 /// # Ok(vec![1.0f64, 1.0f64])
65 /// # }
66 /// # }
67 /// // `UserDefinedProblem` implements `Anneal`.
68 /// let mut problem1 = Problem::new(UserDefinedProblem {});
69 ///
70 /// let param = vec![2.0f64, 1.0f64];
71 ///
72 /// let res = problem1.anneal(¶m, 1.0);
73 ///
74 /// assert_eq!(problem1.counts["anneal_count"], 1);
75 /// # assert_eq!(res.unwrap(), vec![1.0f64, 1.0f64]);
76 /// ```
77 pub fn anneal(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Output, Error> {
78 self.problem("anneal_count", |problem| problem.anneal(param, extent))
79 }
80}
81
82/// Temperature functions for Simulated Annealing.
83///
84/// Given the initial temperature `t_init` and the iteration number `i`, the current temperature
85/// `t_i` is given as follows:
86///
87/// * `SATempFunc::TemperatureFast`: `t_i = t_init / i`
88/// * `SATempFunc::Boltzmann`: `t_i = t_init / ln(i)`
89/// * `SATempFunc::Exponential`: `t_i = t_init * 0.95^i`
90#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
91#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
92pub enum SATempFunc<F> {
93 /// `t_i = t_init / i`
94 TemperatureFast,
95 /// `t_i = t_init / ln(i)`
96 #[default]
97 Boltzmann,
98 /// `t_i = t_init * x^i`
99 Exponential(F),
100 // /// User-provided temperature function. The first parameter must be the current temperature and
101 // /// the second parameter must be the iteration number.
102 // Custom(Box<dyn Fn(f64, u64) -> f64 + 'static>),
103}
104
105/// # Simulated Annealing
106///
107/// Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
108/// metallurgy. Parameter vectors are randomly modified in each iteration, where the degree of
109/// modification depends on the current temperature. The algorithm starts with a high temperature
110/// (a lot of modification and hence movement in parameter space) and continuously cools down as
111/// the iterations progress, hence narrowing down in the search. Under certain conditions,
112/// reannealing (increasing the temperature) can be performed. Solutions which are better than the
113/// previous one are always accepted and solutions which are worse are accepted with a probability
114/// proportional to the cost function value difference of previous to current parameter vector.
115/// These measures allow the algorithm to explore the parameter space in a large and a small scale
116/// and hence it is able to overcome local minima.
117///
118/// The initial temperature has to be provided by the user as well as the a initial parameter
119/// vector (via [`configure`](`crate::core::Executor::configure`) of
120/// [`Executor`](`crate::core::Executor`).
121///
122/// The cooling schedule can be set with [`SimulatedAnnealing::with_temp_func`]. For the available
123/// choices please see [`SATempFunc`].
124///
125/// Reannealing can be performed if no new best solution was found for `N` iterations
126/// ([`SimulatedAnnealing::with_reannealing_best`]), or if no new accepted solution was found for
127/// `N` iterations ([`SimulatedAnnealing::with_reannealing_accepted`]) or every `N` iterations
128/// without any other conditions ([`SimulatedAnnealing::with_reannealing_fixed`]).
129///
130/// The user-provided problem must implement [`Anneal`] which defines how parameter vectors are
131/// modified. Please see the Simulated Annealing example for one approach to do so for floating
132/// point parameters.
133///
134/// ## Requirements on the optimization problem
135///
136/// The optimization problem is required to implement [`CostFunction`].
137///
138/// ## References
139///
140/// [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
141///
142/// S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
143/// Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
144/// DOI: 10.1126/science.220.4598.671
145#[derive(Clone)]
146#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
147pub struct SimulatedAnnealing<F, R> {
148 /// Initial temperature
149 init_temp: F,
150 /// Temperature function used for decreasing the temperature
151 temp_func: SATempFunc<F>,
152 /// Number of iterations used for the calculation of temperature. Needed for reannealing
153 temp_iter: u64,
154 /// Number of iterations since the last accepted solution
155 stall_iter_accepted: u64,
156 /// Stop if `stall_iter_accepted` exceeds this number
157 stall_iter_accepted_limit: u64,
158 /// Number of iterations since the last best solution was found
159 stall_iter_best: u64,
160 /// Stop if `stall_iter_best` exceeds this number
161 stall_iter_best_limit: u64,
162 /// Reanneal after this number of iterations is reached
163 reanneal_fixed: u64,
164 /// Number of iterations since beginning or last reannealing
165 reanneal_iter_fixed: u64,
166 /// Reanneal after no accepted solution has been found for `reanneal_accepted` iterations
167 reanneal_accepted: u64,
168 /// Similar to `stall_iter_accepted`, but will be reset to 0 when reannealing is performed
169 reanneal_iter_accepted: u64,
170 /// Reanneal after no new best solution has been found for `reanneal_best` iterations
171 reanneal_best: u64,
172 /// Similar to `stall_iter_best`, but will be reset to 0 when reannealing is performed
173 reanneal_iter_best: u64,
174 /// current temperature
175 cur_temp: F,
176 /// random number generator
177 rng: R,
178}
179
180impl<F> SimulatedAnnealing<F, Xoshiro256PlusPlus>
181where
182 F: ArgminFloat,
183{
184 /// Construct a new instance of [`SimulatedAnnealing`]
185 ///
186 /// Takes the initial temperature as input, which must be >0.
187 ///
188 /// Uses the `Xoshiro256PlusPlus` RNG internally. For use of another RNG, consider using
189 /// [`SimulatedAnnealing::new_with_rng`].
190 ///
191 /// # Example
192 ///
193 /// ```
194 /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
195 /// # use argmin::core::Error;
196 /// # fn main() -> Result<(), Error> {
197 /// let sa = SimulatedAnnealing::new(100.0f64)?;
198 /// # Ok(())
199 /// # }
200 /// ```
201 pub fn new(initial_temperature: F) -> Result<Self, Error> {
202 SimulatedAnnealing::new_with_rng(initial_temperature, Xoshiro256PlusPlus::from_entropy())
203 }
204}
205
206impl<F, R> SimulatedAnnealing<F, R>
207where
208 F: ArgminFloat,
209{
210 /// Construct a new instance of [`SimulatedAnnealing`]
211 ///
212 /// Takes the initial temperature as input, which must be >0.
213 /// Requires a RNG which must implement `rand::Rng` (and `serde::Serialize` if the `serde1`
214 /// feature is enabled).
215 ///
216 /// # Example
217 ///
218 /// ```
219 /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
220 /// # use argmin::core::Error;
221 /// # fn main() -> Result<(), Error> {
222 /// # let my_rng = ();
223 /// let sa = SimulatedAnnealing::new_with_rng(100.0f64, my_rng)?;
224 /// # Ok(())
225 /// # }
226 /// ```
227 pub fn new_with_rng(init_temp: F, rng: R) -> Result<Self, Error> {
228 if init_temp <= float!(0.0) {
229 Err(argmin_error!(
230 InvalidParameter,
231 "`SimulatedAnnealing`: Initial temperature must be > 0."
232 ))
233 } else {
234 Ok(SimulatedAnnealing {
235 init_temp,
236 temp_func: SATempFunc::TemperatureFast,
237 temp_iter: 0,
238 stall_iter_accepted: 0,
239 stall_iter_accepted_limit: std::u64::MAX,
240 stall_iter_best: 0,
241 stall_iter_best_limit: std::u64::MAX,
242 reanneal_fixed: std::u64::MAX,
243 reanneal_iter_fixed: 0,
244 reanneal_accepted: std::u64::MAX,
245 reanneal_iter_accepted: 0,
246 reanneal_best: std::u64::MAX,
247 reanneal_iter_best: 0,
248 cur_temp: init_temp,
249 rng,
250 })
251 }
252 }
253
254 /// Set temperature function
255 ///
256 /// The temperature function defines how the temperature is decreased over the course of the
257 /// iterations.
258 /// See [`SATempFunc`] for the available options. Defaults to [`SATempFunc::TemperatureFast`].
259 ///
260 /// # Example
261 ///
262 /// ```
263 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
264 /// # use argmin::core::Error;
265 /// # fn main() -> Result<(), Error> {
266 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_temp_func(SATempFunc::Boltzmann);
267 /// # Ok(())
268 /// # }
269 /// ```
270 #[must_use]
271 pub fn with_temp_func(mut self, temperature_func: SATempFunc<F>) -> Self {
272 self.temp_func = temperature_func;
273 self
274 }
275
276 /// If there are no accepted solutions for `iter` iterations, the algorithm stops.
277 ///
278 /// Defaults to `std::u64::MAX`.
279 ///
280 /// # Example
281 ///
282 /// ```
283 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
284 /// # use argmin::core::Error;
285 /// # fn main() -> Result<(), Error> {
286 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_accepted(1000);
287 /// # Ok(())
288 /// # }
289 /// ```
290 #[must_use]
291 pub fn with_stall_accepted(mut self, iter: u64) -> Self {
292 self.stall_iter_accepted_limit = iter;
293 self
294 }
295
296 /// If there are no new best solutions for `iter` iterations, the algorithm stops.
297 ///
298 /// Defaults to `std::u64::MAX`.
299 ///
300 /// # Example
301 ///
302 /// ```
303 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
304 /// # use argmin::core::Error;
305 /// # fn main() -> Result<(), Error> {
306 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_best(2000);
307 /// # Ok(())
308 /// # }
309 /// ```
310 #[must_use]
311 pub fn with_stall_best(mut self, iter: u64) -> Self {
312 self.stall_iter_best_limit = iter;
313 self
314 }
315
316 /// Set number of iterations after which reannealing is performed
317 ///
318 /// Every `iter` iterations, reannealing (resetting temperature to its initial value) will be
319 /// performed. This may help in overcoming local minima.
320 ///
321 /// Defaults to `std::u64::MAX`.
322 ///
323 /// # Example
324 ///
325 /// ```
326 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
327 /// # use argmin::core::Error;
328 /// # fn main() -> Result<(), Error> {
329 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_fixed(5000);
330 /// # Ok(())
331 /// # }
332 /// ```
333 #[must_use]
334 pub fn with_reannealing_fixed(mut self, iter: u64) -> Self {
335 self.reanneal_fixed = iter;
336 self
337 }
338
339 /// Set the number of iterations that need to pass after the last accepted solution was found
340 /// for reannealing to be performed.
341 ///
342 /// If no new accepted solution is found for `iter` iterations, reannealing (resetting
343 /// temperature to its initial value) is performed. This may help in overcoming local minima.
344 ///
345 /// Defaults to `std::u64::MAX`.
346 ///
347 /// # Example
348 ///
349 /// ```
350 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
351 /// # use argmin::core::Error;
352 /// # fn main() -> Result<(), Error> {
353 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_accepted(5000);
354 /// # Ok(())
355 /// # }
356 /// ```
357 #[must_use]
358 pub fn with_reannealing_accepted(mut self, iter: u64) -> Self {
359 self.reanneal_accepted = iter;
360 self
361 }
362
363 /// Set the number of iterations that need to pass after the last best solution was found
364 /// for reannealing to be performed.
365 ///
366 /// If no new best solution is found for `iter` iterations, reannealing (resetting temperature
367 /// to its initial value) is performed. This may help in overcoming local minima.
368 ///
369 /// Defaults to `std::u64::MAX`.
370 ///
371 /// # Example
372 ///
373 /// ```
374 /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
375 /// # use argmin::core::Error;
376 /// # fn main() -> Result<(), Error> {
377 /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_best(5000);
378 /// # Ok(())
379 /// # }
380 /// ```
381 #[must_use]
382 pub fn with_reannealing_best(mut self, iter: u64) -> Self {
383 self.reanneal_best = iter;
384 self
385 }
386
387 /// Update the temperature based on the current iteration number.
388 ///
389 /// Updates are performed based on specific update functions. See `SATempFunc` for details.
390 fn update_temperature(&mut self) {
391 self.cur_temp = match self.temp_func {
392 SATempFunc::TemperatureFast => {
393 self.init_temp / F::from_u64(self.temp_iter + 1).unwrap()
394 }
395 SATempFunc::Boltzmann => self.init_temp / F::from_u64(self.temp_iter + 1).unwrap().ln(),
396 SATempFunc::Exponential(x) => {
397 self.init_temp * x.powf(F::from_u64(self.temp_iter + 1).unwrap())
398 }
399 };
400 }
401
402 /// Perform reannealing
403 fn reanneal(&mut self) -> (bool, bool, bool) {
404 let out = (
405 self.reanneal_iter_fixed >= self.reanneal_fixed,
406 self.reanneal_iter_accepted >= self.reanneal_accepted,
407 self.reanneal_iter_best >= self.reanneal_best,
408 );
409 if out.0 || out.1 || out.2 {
410 self.reanneal_iter_fixed = 0;
411 self.reanneal_iter_accepted = 0;
412 self.reanneal_iter_best = 0;
413 self.cur_temp = self.init_temp;
414 self.temp_iter = 0;
415 }
416 out
417 }
418
419 /// Update the stall iter variables
420 fn update_stall_and_reanneal_iter(&mut self, accepted: bool, new_best: bool) {
421 (self.stall_iter_accepted, self.reanneal_iter_accepted) = if accepted {
422 (0, 0)
423 } else {
424 (
425 self.stall_iter_accepted + 1,
426 self.reanneal_iter_accepted + 1,
427 )
428 };
429
430 (self.stall_iter_best, self.reanneal_iter_best) = if new_best {
431 (0, 0)
432 } else {
433 (self.stall_iter_best + 1, self.reanneal_iter_best + 1)
434 };
435 }
436}
437
438impl<O, P, F, R> Solver<O, IterState<P, (), (), (), (), F>> for SimulatedAnnealing<F, R>
439where
440 O: CostFunction<Param = P, Output = F> + Anneal<Param = P, Output = P, Float = F>,
441 P: Clone,
442 F: ArgminFloat,
443 R: Rng,
444{
445 const NAME: &'static str = "Simulated Annealing";
446 fn init(
447 &mut self,
448 problem: &mut Problem<O>,
449 mut state: IterState<P, (), (), (), (), F>,
450 ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
451 let param = state.take_param().ok_or_else(argmin_error_closure!(
452 NotInitialized,
453 concat!(
454 "`SimulatedAnnealing` requires an initial parameter vector. ",
455 "Please provide an initial guess via `Executor`s `configure` method."
456 )
457 ))?;
458
459 let cost = state.get_cost();
460 let cost = if cost.is_infinite() {
461 problem.cost(¶m)?
462 } else {
463 cost
464 };
465
466 Ok((
467 state.param(param).cost(cost),
468 Some(kv!(
469 "initial_temperature" => self.init_temp;
470 "stall_iter_accepted_limit" => self.stall_iter_accepted_limit;
471 "stall_iter_best_limit" => self.stall_iter_best_limit;
472 "reanneal_fixed" => self.reanneal_fixed;
473 "reanneal_accepted" => self.reanneal_accepted;
474 "reanneal_best" => self.reanneal_best;
475 )),
476 ))
477 }
478
479 /// Perform one iteration of SA algorithm
480 fn next_iter(
481 &mut self,
482 problem: &mut Problem<O>,
483 mut state: IterState<P, (), (), (), (), F>,
484 ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
485 // Careful: The order in here is *very* important, even if it may not seem so. Everything
486 // is linked to the iteration number, and getting things mixed up may lead to unexpected
487 // behavior.
488
489 let prev_param = state.take_param().ok_or_else(argmin_error_closure!(
490 PotentialBug,
491 "`SimulatedAnnealing`: Parameter vector in state not set."
492 ))?;
493 let prev_cost = state.get_cost();
494
495 // Make a move
496 let new_param = problem.anneal(&prev_param, self.cur_temp)?;
497
498 // Evaluate cost function with new parameter vector
499 let new_cost = problem.cost(&new_param)?;
500
501 // Acceptance function
502 //
503 // Decide whether new parameter vector should be accepted.
504 // If no, move on with old parameter vector.
505 //
506 // Any solution which satisfies `next_cost < prev_cost` will be accepted. Solutions worse
507 // than the previous one are accepted with a probability given as:
508 //
509 // `1 / (1 + exp((next_cost - prev_cost) / current_temperature))`,
510 //
511 // which will always be between 0 and 0.5.
512 let prob: f64 = self.rng.gen();
513 let prob = float!(prob);
514 let accepted = (new_cost < prev_cost)
515 || (float!(1.0) / (float!(1.0) + ((new_cost - prev_cost) / self.cur_temp).exp())
516 > prob);
517
518 let new_best_found = new_cost < state.best_cost;
519
520 // Update stall iter variables
521 self.update_stall_and_reanneal_iter(accepted, new_best_found);
522
523 let (r_fixed, r_accepted, r_best) = self.reanneal();
524
525 // Update temperature for next iteration.
526 self.temp_iter += 1;
527 // Actually not necessary as it does the same as `temp_iter`, but I'll leave it here for
528 // better readability.
529 self.reanneal_iter_fixed += 1;
530
531 self.update_temperature();
532
533 Ok((
534 if accepted {
535 state.param(new_param).cost(new_cost)
536 } else {
537 state.param(prev_param).cost(prev_cost)
538 },
539 Some(kv!(
540 "t" => self.cur_temp;
541 "new_be" => new_best_found;
542 "acc" => accepted;
543 "st_i_be" => self.stall_iter_best;
544 "st_i_ac" => self.stall_iter_accepted;
545 "ra_i_fi" => self.reanneal_iter_fixed;
546 "ra_i_be" => self.reanneal_iter_best;
547 "ra_i_ac" => self.reanneal_iter_accepted;
548 "ra_fi" => r_fixed;
549 "ra_be" => r_best;
550 "ra_ac" => r_accepted;
551 )),
552 ))
553 }
554
555 fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
556 if self.stall_iter_accepted > self.stall_iter_accepted_limit {
557 return TerminationStatus::Terminated(TerminationReason::SolverExit(
558 "AcceptedStallIterExceeded".to_string(),
559 ));
560 }
561 if self.stall_iter_best > self.stall_iter_best_limit {
562 return TerminationStatus::Terminated(TerminationReason::SolverExit(
563 "BestStallIterExceeded".to_string(),
564 ));
565 }
566 TerminationStatus::NotTerminated
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use crate::core::{test_utils::TestProblem, ArgminError, State};
574 use crate::test_trait_impl;
575 use approx::assert_relative_eq;
576
577 test_trait_impl!(sa, SimulatedAnnealing<f64, StdRng>);
578
579 #[test]
580 fn test_new() {
581 let sa: SimulatedAnnealing<f64, Xoshiro256PlusPlus> =
582 SimulatedAnnealing::new(100.0).unwrap();
583 let SimulatedAnnealing {
584 init_temp,
585 temp_func,
586 temp_iter,
587 stall_iter_accepted,
588 stall_iter_accepted_limit,
589 stall_iter_best,
590 stall_iter_best_limit,
591 reanneal_fixed,
592 reanneal_iter_fixed,
593 reanneal_accepted,
594 reanneal_iter_accepted,
595 reanneal_best,
596 reanneal_iter_best,
597 cur_temp,
598 rng: _rng,
599 } = sa;
600
601 assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
602 assert_eq!(temp_func, SATempFunc::TemperatureFast);
603 assert_eq!(temp_iter, 0);
604 assert_eq!(stall_iter_accepted, 0);
605 assert_eq!(stall_iter_accepted_limit, u64::MAX);
606 assert_eq!(stall_iter_best, 0);
607 assert_eq!(stall_iter_best_limit, u64::MAX);
608 assert_eq!(reanneal_fixed, u64::MAX);
609 assert_eq!(reanneal_iter_fixed, 0);
610 assert_eq!(reanneal_accepted, u64::MAX);
611 assert_eq!(reanneal_iter_accepted, 0);
612 assert_eq!(reanneal_best, u64::MAX);
613 assert_eq!(reanneal_iter_best, 0);
614 assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
615
616 for temp in [0.0, -1.0, -std::f64::EPSILON, -100.0] {
617 let res = SimulatedAnnealing::new(temp);
618 assert_error!(
619 res,
620 ArgminError,
621 "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
622 );
623 }
624 }
625
626 #[test]
627 fn test_new_with_rng() {
628 #[derive(Eq, PartialEq, Debug)]
629 struct MyRng {}
630
631 let sa: SimulatedAnnealing<f64, MyRng> =
632 SimulatedAnnealing::new_with_rng(100.0, MyRng {}).unwrap();
633 let SimulatedAnnealing {
634 init_temp,
635 temp_func,
636 temp_iter,
637 stall_iter_accepted,
638 stall_iter_accepted_limit,
639 stall_iter_best,
640 stall_iter_best_limit,
641 reanneal_fixed,
642 reanneal_iter_fixed,
643 reanneal_accepted,
644 reanneal_iter_accepted,
645 reanneal_best,
646 reanneal_iter_best,
647 cur_temp,
648 rng,
649 } = sa;
650
651 assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
652 assert_eq!(temp_func, SATempFunc::TemperatureFast);
653 assert_eq!(temp_iter, 0);
654 assert_eq!(stall_iter_accepted, 0);
655 assert_eq!(stall_iter_accepted_limit, u64::MAX);
656 assert_eq!(stall_iter_best, 0);
657 assert_eq!(stall_iter_best_limit, u64::MAX);
658 assert_eq!(reanneal_fixed, u64::MAX);
659 assert_eq!(reanneal_iter_fixed, 0);
660 assert_eq!(reanneal_accepted, u64::MAX);
661 assert_eq!(reanneal_iter_accepted, 0);
662 assert_eq!(reanneal_best, u64::MAX);
663 assert_eq!(reanneal_iter_best, 0);
664 assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
665 // important part
666 assert_eq!(rng, MyRng {});
667
668 for temp in [0.0, -1.0, -std::f64::EPSILON, -100.0] {
669 let res = SimulatedAnnealing::new_with_rng(temp, MyRng {});
670 assert_error!(
671 res,
672 ArgminError,
673 "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
674 );
675 }
676 }
677
678 #[test]
679 fn test_with_temp_func() {
680 for func in [
681 SATempFunc::TemperatureFast,
682 SATempFunc::Boltzmann,
683 SATempFunc::Exponential(2.0),
684 ] {
685 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
686 let sa = sa.with_temp_func(func);
687
688 assert_eq!(sa.temp_func, func);
689 }
690 }
691
692 #[test]
693 fn test_with_stall_accepted() {
694 for iter in [0, 1, 5, 10, 100, 100000] {
695 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
696 let sa = sa.with_stall_accepted(iter);
697
698 assert_eq!(sa.stall_iter_accepted_limit, iter);
699 }
700 }
701
702 #[test]
703 fn test_with_stall_best() {
704 for iter in [0, 1, 5, 10, 100, 100000] {
705 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
706 let sa = sa.with_stall_best(iter);
707
708 assert_eq!(sa.stall_iter_best_limit, iter);
709 }
710 }
711
712 #[test]
713 fn test_with_reannealing_fixed() {
714 for iter in [0, 1, 5, 10, 100, 100000] {
715 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
716 let sa = sa.with_reannealing_fixed(iter);
717
718 assert_eq!(sa.reanneal_fixed, iter);
719 }
720 }
721
722 #[test]
723 fn test_with_reannealing_accepted() {
724 for iter in [0, 1, 5, 10, 100, 100000] {
725 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
726 let sa = sa.with_reannealing_accepted(iter);
727
728 assert_eq!(sa.reanneal_accepted, iter);
729 }
730 }
731
732 #[test]
733 fn test_with_reannealing_best() {
734 for iter in [0, 1, 5, 10, 100, 100000] {
735 let sa = SimulatedAnnealing::new(100.0f64).unwrap();
736 let sa = sa.with_reannealing_best(iter);
737
738 assert_eq!(sa.reanneal_best, iter);
739 }
740 }
741
742 #[test]
743 fn test_update_temperature() {
744 for (func, val) in [
745 (SATempFunc::TemperatureFast, 100.0f64 / 2.0),
746 (SATempFunc::Boltzmann, 100.0f64 / 2.0f64.ln()),
747 (SATempFunc::Exponential(3.0), 100.0 * 3.0f64.powi(2)),
748 ] {
749 let mut sa = SimulatedAnnealing::new(100.0f64)
750 .unwrap()
751 .with_temp_func(func);
752 sa.temp_iter = 1;
753
754 sa.update_temperature();
755
756 assert_relative_eq!(sa.cur_temp, val, epsilon = f64::EPSILON);
757 }
758 }
759
760 #[test]
761 fn test_reanneal() {
762 let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
763
764 sa_t.reanneal_fixed = 10;
765 sa_t.reanneal_accepted = 20;
766 sa_t.reanneal_best = 30;
767 sa_t.temp_iter = 40;
768 sa_t.cur_temp = 50.0;
769
770 for ((f, a, b), expected) in [
771 ((0, 0, 0), (false, false, false)),
772 ((10, 0, 0), (true, false, false)),
773 ((11, 0, 0), (true, false, false)),
774 ((0, 20, 0), (false, true, false)),
775 ((0, 21, 0), (false, true, false)),
776 ((0, 0, 30), (false, false, true)),
777 ((0, 0, 31), (false, false, true)),
778 ((10, 20, 0), (true, true, false)),
779 ((10, 0, 30), (true, false, true)),
780 ((0, 20, 30), (false, true, true)),
781 ((10, 20, 30), (true, true, true)),
782 ] {
783 let mut sa = sa_t.clone();
784
785 sa.reanneal_iter_fixed = f;
786 sa.reanneal_iter_accepted = a;
787 sa.reanneal_iter_best = b;
788
789 assert_eq!(sa.reanneal(), expected);
790
791 if expected.0 || expected.1 || expected.2 {
792 assert_eq!(sa.reanneal_iter_fixed, 0);
793 assert_eq!(sa.reanneal_iter_accepted, 0);
794 assert_eq!(sa.reanneal_iter_best, 0);
795 assert_eq!(sa.temp_iter, 0);
796 assert_eq!(sa.cur_temp.to_ne_bytes(), sa.init_temp.to_ne_bytes());
797 }
798 }
799 }
800
801 #[test]
802 fn test_update_stall_and_reanneal_iter() {
803 let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
804
805 sa_t.stall_iter_accepted = 10;
806 sa_t.reanneal_iter_accepted = 20;
807 sa_t.stall_iter_best = 30;
808 sa_t.reanneal_iter_best = 40;
809
810 for ((a, b), (sia, ria, sib, rib)) in [
811 ((false, false), (11, 21, 31, 41)),
812 ((false, true), (11, 21, 0, 0)),
813 ((true, false), (0, 0, 31, 41)),
814 ((true, true), (0, 0, 0, 0)),
815 ] {
816 let mut sa = sa_t.clone();
817
818 sa.update_stall_and_reanneal_iter(a, b);
819
820 assert_eq!(sa.stall_iter_accepted, sia);
821 assert_eq!(sa.reanneal_iter_accepted, ria);
822 assert_eq!(sa.stall_iter_best, sib);
823 assert_eq!(sa.reanneal_iter_best, rib);
824 }
825 }
826
827 #[test]
828 fn test_init() {
829 let param: Vec<f64> = vec![-1.0, 1.0];
830
831 let stall_iter_accepted_limit = 10;
832 let stall_iter_best_limit = 20;
833 let reanneal_fixed = 30;
834 let reanneal_accepted = 40;
835 let reanneal_best = 50;
836
837 let mut sa = SimulatedAnnealing::new(100.0f64)
838 .unwrap()
839 .with_stall_accepted(stall_iter_accepted_limit)
840 .with_stall_best(stall_iter_best_limit)
841 .with_reannealing_fixed(reanneal_fixed)
842 .with_reannealing_accepted(reanneal_accepted)
843 .with_reannealing_best(reanneal_best);
844
845 // Forgot to initialize the parameter vector
846 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
847 let problem = TestProblem::new();
848 let res = sa.init(&mut Problem::new(problem), state);
849 assert_error!(
850 res,
851 ArgminError,
852 concat!(
853 "Not initialized: \"`SimulatedAnnealing` requires an initial parameter vector. ",
854 "Please provide an initial guess via `Executor`s `configure` method.\""
855 )
856 );
857
858 // All good.
859 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().param(param.clone());
860 let problem = TestProblem::new();
861 let (mut state_out, kv) = sa.init(&mut Problem::new(problem), state).unwrap();
862
863 let kv_expected = kv!(
864 "initial_temperature" => 100.0f64;
865 "stall_iter_accepted_limit" => stall_iter_accepted_limit;
866 "stall_iter_best_limit" => stall_iter_best_limit;
867 "reanneal_fixed" => reanneal_fixed;
868 "reanneal_accepted" => reanneal_accepted;
869 "reanneal_best" => reanneal_best;
870 );
871
872 assert_eq!(kv.unwrap(), kv_expected);
873
874 let s_param = state_out.take_param().unwrap();
875
876 for (s, p) in s_param.iter().zip(param.iter()) {
877 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
878 }
879
880 assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
881 }
882}