Skip to main content

ganesh/algorithms/gradient_free/
simulated_annealing.rs

1use crate::{
2    core::{utils::SampleFloat, Callbacks, Point, SimulatedAnnealingSummary},
3    error::{GaneshError, GaneshResult},
4    traits::{
5        Algorithm, GenericCostFunction, ProgressStatus, Status, StatusMessage, SupportsTransform,
6        Terminator, Transform,
7    },
8    Float,
9};
10use serde::{Deserialize, Serialize};
11use std::ops::ControlFlow;
12
13/// A temperature-activated terminator for [`SimulatedAnnealing`].
14#[derive(Copy, Clone)]
15pub struct SimulatedAnnealingTerminator {
16    /// The minimum temperature for the simulated annealing algorithm.
17    pub min_temperature: Float,
18}
19impl Default for SimulatedAnnealingTerminator {
20    fn default() -> Self {
21        Self {
22            min_temperature: 1e-3,
23        }
24    }
25}
26impl<P, U, E, I>
27    Terminator<SimulatedAnnealing, P, SimulatedAnnealingStatus<I>, U, E, SimulatedAnnealingConfig>
28    for SimulatedAnnealingTerminator
29where
30    P: SimulatedAnnealingGenerator<U, E, Input = I>,
31    I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
32{
33    fn check_for_termination(
34        &mut self,
35        _current_step: usize,
36        _algorithm: &mut SimulatedAnnealing,
37        _problem: &P,
38        status: &mut SimulatedAnnealingStatus<I>,
39        _args: &U,
40        _config: &SimulatedAnnealingConfig,
41    ) -> ControlFlow<()> {
42        if status.temperature < self.min_temperature {
43            return ControlFlow::Break(());
44        }
45        ControlFlow::Continue(())
46    }
47}
48
49/// A trait for generating new points in the simulated annealing algorithm.
50pub trait SimulatedAnnealingGenerator<U, E>: GenericCostFunction<U, E> {
51    /// Returns the initial state of the algorithm.
52    fn initial(
53        &self,
54        transform: &Option<Box<dyn Transform>>,
55        status: &mut SimulatedAnnealingStatus<Self::Input>,
56        args: &U,
57    ) -> Self::Input;
58    /// Generates a new state based on the current state, cost function and the status.
59    fn generate(
60        &self,
61        transform: &Option<Box<dyn Transform>>,
62        status: &mut SimulatedAnnealingStatus<Self::Input>,
63        args: &U,
64    ) -> Self::Input;
65}
66
67/// The internal configuration struct for the [`SimulatedAnnealing`] algorithm.
68pub struct SimulatedAnnealingConfig {
69    transform: Option<Box<dyn Transform>>,
70    /// The initial temperature for the simulated annealing algorithm.
71    pub initial_temperature: Float,
72    /// The cooling rate for the simulated annealing algorithm.
73    pub cooling_rate: Float,
74}
75impl Default for SimulatedAnnealingConfig {
76    fn default() -> Self {
77        Self {
78            transform: None,
79            initial_temperature: 1.0,
80            cooling_rate: 0.999,
81        }
82    }
83}
84impl SimulatedAnnealingConfig {
85    /// Create a new [`SimulatedAnnealingConfig`] with the given parameters.
86    ///
87    /// # Errors
88    ///
89    /// Returns a configuration error if `initial_temperature <= 0` or `cooling_rate` is not in
90    /// the interval `(0, 1)`.
91    pub fn new(initial_temperature: Float, cooling_rate: Float) -> GaneshResult<Self> {
92        if initial_temperature <= 0.0 {
93            return Err(GaneshError::ConfigError(
94                "Initial temperature must be greater than 0".to_string(),
95            ));
96        }
97        if cooling_rate <= 0.0 || cooling_rate >= 1.0 {
98            return Err(GaneshError::ConfigError(
99                "Cooling rate must be in (0, 1)".to_string(),
100            ));
101        }
102        Ok(Self {
103            transform: None,
104            initial_temperature,
105            cooling_rate,
106        })
107    }
108}
109impl SupportsTransform for SimulatedAnnealingConfig {
110    fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
111        &mut self.transform
112    }
113}
114
115/// A struct for the status of the simulated annealing algorithm.
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
117pub struct SimulatedAnnealingStatus<I> {
118    /// The current temperature of the simulated annealing algorithm.
119    pub temperature: Float,
120    /// The initial point in the simulated annealing algorithm.
121    pub initial: Point<I>,
122    /// The best point in the simulated annealing algorithm.
123    pub best: Point<I>,
124    /// The current point in the simulated annealing algorithm.
125    pub current: Point<I>,
126    /// The message to be displayed at the end of the algorithm.
127    pub message: StatusMessage,
128    /// The number of function evaluations.
129    pub n_f_evals: usize,
130}
131
132impl<I> Status for SimulatedAnnealingStatus<I>
133where
134    I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
135{
136    fn reset(&mut self) {
137        self.temperature = Default::default();
138        self.best = Default::default();
139        self.current = Default::default();
140        self.message = Default::default();
141        self.n_f_evals = Default::default();
142    }
143
144    fn message(&self) -> &StatusMessage {
145        &self.message
146    }
147
148    fn set_message(&mut self) -> &mut StatusMessage {
149        &mut self.message
150    }
151}
152
153impl<I> ProgressStatus for SimulatedAnnealingStatus<I>
154where
155    I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
156{
157    fn write_progress(&self, out: &mut String) -> std::fmt::Result {
158        use std::fmt::Write;
159        write!(
160            out,
161            "status={} temperature={} best_fx={} current_fx={}",
162            self.message,
163            self.temperature,
164            self.best.fx.unwrap_or(Float::NAN),
165            self.current.fx.unwrap_or(Float::NAN)
166        )
167    }
168}
169
170/// A struct for the simulated annealing algorithm.
171pub struct SimulatedAnnealing {
172    rng: fastrand::Rng,
173}
174
175impl Default for SimulatedAnnealing {
176    fn default() -> Self {
177        Self::new(Some(0))
178    }
179}
180
181impl SimulatedAnnealing {
182    /// Creates a new instance of the simulated annealing algorithm.
183    pub fn new(seed: Option<u64>) -> Self {
184        Self {
185            rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
186        }
187    }
188}
189
190impl<P, U, E, I> Algorithm<P, SimulatedAnnealingStatus<I>, U, E> for SimulatedAnnealing
191where
192    P: SimulatedAnnealingGenerator<U, E, Input = I>,
193    I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
194{
195    type Summary = SimulatedAnnealingSummary<I>;
196    type Config = SimulatedAnnealingConfig;
197    type Init = ();
198
199    #[allow(clippy::expect_used)]
200    fn initialize(
201        &mut self,
202        problem: &P,
203        status: &mut SimulatedAnnealingStatus<I>,
204        args: &U,
205        _init: &Self::Init,
206        config: &Self::Config,
207    ) -> Result<(), E> {
208        let x0 = problem.initial(&config.transform, status, args);
209        let fx0 = problem.evaluate_generic(&x0, args)?;
210        status.temperature = config.initial_temperature;
211        status.current = Point {
212            x: x0,
213            fx: Some(fx0),
214        };
215        status.initial = status.current.clone();
216        status.best = status.current.clone();
217        status.set_message().initialize();
218        Ok(())
219    }
220
221    fn step(
222        &mut self,
223        _current_step: usize,
224        problem: &P,
225        status: &mut SimulatedAnnealingStatus<I>,
226        args: &U,
227        config: &Self::Config,
228    ) -> Result<(), E> {
229        let x = problem.generate(&config.transform, status, args);
230        let fx = problem.evaluate_generic(&x, args)?;
231        status.n_f_evals += 1;
232
233        status.temperature *= config.cooling_rate;
234
235        if fx < status.best.fx_checked() {
236            status.current = Point { x, fx: Some(fx) };
237            status.best = status.current.clone();
238            return Ok(());
239        }
240
241        let d_fx = fx - status.current.fx_checked();
242        let acceptance_probability = (-d_fx / status.temperature).exp();
243
244        if acceptance_probability > self.rng.float() {
245            status.current = Point { x, fx: Some(fx) };
246        }
247        Ok(())
248    }
249
250    fn summarize(
251        &self,
252        _current_step: usize,
253        _problem: &P,
254        status: &SimulatedAnnealingStatus<I>,
255        _args: &U,
256        _init: &Self::Init,
257        _config: &Self::Config,
258    ) -> Result<Self::Summary, E> {
259        Ok(SimulatedAnnealingSummary {
260            bounds: None,
261            message: status.message.clone(),
262            x0: status.initial.x.clone(),
263            x: status.best.x.clone(),
264            fx: status.best.fx_checked(),
265            n_f_evals: status.n_f_evals,
266            n_g_evals: 0,
267            n_h_evals: 0,
268        })
269    }
270
271    fn default_callbacks() -> Callbacks<Self, P, SimulatedAnnealingStatus<I>, U, E, Self::Config>
272    where
273        Self: Sized,
274    {
275        Callbacks::empty().with_terminator(SimulatedAnnealingTerminator::default())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::{
283        core::{Bounds, Callbacks, MaxSteps},
284        test_functions::Rosenbrock,
285        traits::cost_function::GenericGradient,
286        DVector,
287    };
288    use approx::assert_relative_eq;
289    use nalgebra::DMatrix;
290    use std::{cell::RefCell, convert::Infallible, fmt::Debug};
291
292    pub struct GradientAnnealingProblem<U, E>(
293        Box<dyn GenericGradient<U, E, Input = DVector<Float>>>,
294        DVector<Float>,
295    );
296    impl<U, E> GradientAnnealingProblem<U, E> {
297        pub fn new<P>(problem: P, x0: &[Float]) -> Self
298        where
299            P: GenericGradient<U, E, Input = DVector<Float>> + 'static,
300        {
301            Self(Box::new(problem), DVector::from_row_slice(x0))
302        }
303    }
304    impl<U, E> GenericCostFunction<U, E> for GradientAnnealingProblem<U, E> {
305        type Input = DVector<Float>;
306
307        fn evaluate_generic(&self, x: &Self::Input, args: &U) -> Result<Float, E> {
308            self.0.evaluate_generic(x, args)
309        }
310    }
311    impl<U, E> GenericGradient<U, E> for GradientAnnealingProblem<U, E> {
312        fn gradient_generic(&self, x: &Self::Input, args: &U) -> Result<DVector<Float>, E> {
313            self.0.gradient_generic(x, args)
314        }
315
316        fn hessian_generic(&self, x: &Self::Input, args: &U) -> Result<DMatrix<Float>, E> {
317            self.0.hessian_generic(x, args)
318        }
319    }
320    impl<U, E: Debug> SimulatedAnnealingGenerator<U, E> for GradientAnnealingProblem<U, E>
321    where
322        Self: GenericGradient<U, E, Input = DVector<Float>>,
323    {
324        fn generate(
325            &self,
326            transform: &Option<Box<dyn Transform>>,
327            status: &mut SimulatedAnnealingStatus<Self::Input>,
328            args: &U,
329        ) -> Self::Input {
330            let x_int = transform.to_owned_internal(&status.current.x);
331            #[allow(clippy::expect_used)]
332            let g_ext = self
333                .gradient_generic(&status.current.x, args)
334                .expect("This should never fail");
335            let g_int = transform.pullback_gradient(&x_int, &g_ext);
336            let x_int_new = x_int - &(status.temperature * 1e-4 * g_int);
337            transform.to_owned_external(&x_int_new)
338        }
339
340        fn initial(
341            &self,
342            _transform: &Option<Box<dyn Transform>>,
343            _status: &mut SimulatedAnnealingStatus<Self::Input>,
344            _args: &U,
345        ) -> Self::Input {
346            self.1.clone()
347        }
348    }
349
350    #[test]
351    fn test_simulated_annealing() {
352        let mut solver = SimulatedAnnealing::default();
353        let problem = GradientAnnealingProblem::new(Rosenbrock { n: 2 }, &[0.0, 0.0]);
354        let result = solver
355            .process(
356                &problem,
357                &(),
358                (),
359                SimulatedAnnealingConfig::new(1.0, 0.999)
360                    .unwrap()
361                    .with_transform(&Bounds::from([(-5.0, 5.0), (-5.0, 5.0)])),
362                SimulatedAnnealing::default_callbacks(),
363            )
364            .unwrap();
365        assert_relative_eq!(result.fx, 0.0, epsilon = 0.5);
366    }
367
368    struct SequenceAnnealingProblem {
369        initial: DVector<Float>,
370        proposals: RefCell<Vec<DVector<Float>>>,
371    }
372    impl SequenceAnnealingProblem {
373        fn new(initial: &[Float], proposals: Vec<&[Float]>) -> Self {
374            Self {
375                initial: DVector::from_row_slice(initial),
376                proposals: RefCell::new(
377                    proposals
378                        .into_iter()
379                        .map(DVector::from_row_slice)
380                        .collect::<Vec<_>>(),
381                ),
382            }
383        }
384    }
385    impl GenericCostFunction<(), Infallible> for SequenceAnnealingProblem {
386        type Input = DVector<Float>;
387
388        fn evaluate_generic(&self, x: &Self::Input, _: &()) -> Result<Float, Infallible> {
389            Ok(x[0])
390        }
391    }
392    impl SimulatedAnnealingGenerator<(), Infallible> for SequenceAnnealingProblem {
393        fn initial(
394            &self,
395            _: &Option<Box<dyn Transform>>,
396            _: &mut SimulatedAnnealingStatus<Self::Input>,
397            _: &(),
398        ) -> Self::Input {
399            self.initial.clone()
400        }
401
402        fn generate(
403            &self,
404            _: &Option<Box<dyn Transform>>,
405            _: &mut SimulatedAnnealingStatus<Self::Input>,
406            _: &(),
407        ) -> Self::Input {
408            self.proposals.borrow_mut().remove(0)
409        }
410    }
411
412    #[test]
413    fn accepts_improving_proposal_even_if_not_new_best() {
414        let mut solver = SimulatedAnnealing::default();
415        let problem = SequenceAnnealingProblem::new(&[2.0], vec![&[1.0]]);
416        let config = SimulatedAnnealingConfig::new(0.01, 0.9).unwrap();
417        let mut status = SimulatedAnnealingStatus::default();
418
419        solver
420            .initialize(&problem, &mut status, &(), &(), &config)
421            .unwrap();
422        status.best = Point {
423            x: DVector::from_row_slice(&[0.0]),
424            fx: Some(0.0),
425        };
426        status.current = Point {
427            x: DVector::from_row_slice(&[2.0]),
428            fx: Some(2.0),
429        };
430
431        solver.step(0, &problem, &mut status, &(), &config).unwrap();
432
433        assert_relative_eq!(status.current.x[0], 1.0);
434        assert_relative_eq!(status.current.fx_checked(), 1.0);
435        assert_relative_eq!(status.best.x[0], 0.0);
436        assert_relative_eq!(status.best.fx_checked(), 0.0);
437    }
438
439    #[test]
440    fn rejected_proposal_does_not_advance_current() {
441        let mut solver = SimulatedAnnealing::default();
442        let problem = SequenceAnnealingProblem::new(&[0.0], vec![&[1.0]]);
443        let config = SimulatedAnnealingConfig::new(1e-6, 0.9).unwrap();
444        let mut status = SimulatedAnnealingStatus::default();
445
446        solver
447            .initialize(&problem, &mut status, &(), &(), &config)
448            .unwrap();
449        let current_before = status.current.clone();
450        let best_before = status.best.clone();
451
452        solver.step(0, &problem, &mut status, &(), &config).unwrap();
453
454        assert_eq!(status.current.x, current_before.x);
455        assert_eq!(status.current.fx, current_before.fx);
456        assert_eq!(status.best.x, best_before.x);
457        assert_eq!(status.best.fx, best_before.fx);
458    }
459
460    #[test]
461    fn summary_reports_nonzero_evals_and_terminal_message() {
462        let mut solver = SimulatedAnnealing::default();
463        let problem = GradientAnnealingProblem::new(Rosenbrock { n: 2 }, &[0.0, 0.0]);
464        let result = solver
465            .process(
466                &problem,
467                &(),
468                (),
469                SimulatedAnnealingConfig::new(1.0, 0.999).unwrap(),
470                Callbacks::empty().with_terminator(MaxSteps(2)),
471            )
472            .unwrap();
473
474        assert!(result.n_f_evals > 0);
475        assert!(result
476            .message
477            .to_string()
478            .contains("Maximum number of steps reached"));
479    }
480}