1use crate::{
2 core::{utils::SampleFloat, Callbacks, Point, SimulatedAnnealingSummary},
3 error::{GaneshError, GaneshResult},
4 traits::{
5 Algorithm, GenericCostFunction, ProgressStatus, Status, StatusMessage, SupportsTransform,
6 Terminator, Transform,
7 },
8 Float,
9};
10use serde::{Deserialize, Serialize};
11use std::ops::ControlFlow;
12
13#[derive(Copy, Clone)]
15pub struct SimulatedAnnealingTerminator {
16 pub min_temperature: Float,
18}
19impl Default for SimulatedAnnealingTerminator {
20 fn default() -> Self {
21 Self {
22 min_temperature: 1e-3,
23 }
24 }
25}
26impl<P, U, E, I>
27 Terminator<SimulatedAnnealing, P, SimulatedAnnealingStatus<I>, U, E, SimulatedAnnealingConfig>
28 for SimulatedAnnealingTerminator
29where
30 P: SimulatedAnnealingGenerator<U, E, Input = I>,
31 I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
32{
33 fn check_for_termination(
34 &mut self,
35 _current_step: usize,
36 _algorithm: &mut SimulatedAnnealing,
37 _problem: &P,
38 status: &mut SimulatedAnnealingStatus<I>,
39 _args: &U,
40 _config: &SimulatedAnnealingConfig,
41 ) -> ControlFlow<()> {
42 if status.temperature < self.min_temperature {
43 return ControlFlow::Break(());
44 }
45 ControlFlow::Continue(())
46 }
47}
48
49pub trait SimulatedAnnealingGenerator<U, E>: GenericCostFunction<U, E> {
51 fn initial(
53 &self,
54 transform: &Option<Box<dyn Transform>>,
55 status: &mut SimulatedAnnealingStatus<Self::Input>,
56 args: &U,
57 ) -> Self::Input;
58 fn generate(
60 &self,
61 transform: &Option<Box<dyn Transform>>,
62 status: &mut SimulatedAnnealingStatus<Self::Input>,
63 args: &U,
64 ) -> Self::Input;
65}
66
67pub struct SimulatedAnnealingConfig {
69 transform: Option<Box<dyn Transform>>,
70 pub initial_temperature: Float,
72 pub cooling_rate: Float,
74}
75impl Default for SimulatedAnnealingConfig {
76 fn default() -> Self {
77 Self {
78 transform: None,
79 initial_temperature: 1.0,
80 cooling_rate: 0.999,
81 }
82 }
83}
84impl SimulatedAnnealingConfig {
85 pub fn new(initial_temperature: Float, cooling_rate: Float) -> GaneshResult<Self> {
92 if initial_temperature <= 0.0 {
93 return Err(GaneshError::ConfigError(
94 "Initial temperature must be greater than 0".to_string(),
95 ));
96 }
97 if cooling_rate <= 0.0 || cooling_rate >= 1.0 {
98 return Err(GaneshError::ConfigError(
99 "Cooling rate must be in (0, 1)".to_string(),
100 ));
101 }
102 Ok(Self {
103 transform: None,
104 initial_temperature,
105 cooling_rate,
106 })
107 }
108}
109impl SupportsTransform for SimulatedAnnealingConfig {
110 fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
111 &mut self.transform
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, Default)]
117pub struct SimulatedAnnealingStatus<I> {
118 pub temperature: Float,
120 pub initial: Point<I>,
122 pub best: Point<I>,
124 pub current: Point<I>,
126 pub message: StatusMessage,
128 pub n_f_evals: usize,
130}
131
132impl<I> Status for SimulatedAnnealingStatus<I>
133where
134 I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
135{
136 fn reset(&mut self) {
137 self.temperature = Default::default();
138 self.best = Default::default();
139 self.current = Default::default();
140 self.message = Default::default();
141 self.n_f_evals = Default::default();
142 }
143
144 fn message(&self) -> &StatusMessage {
145 &self.message
146 }
147
148 fn set_message(&mut self) -> &mut StatusMessage {
149 &mut self.message
150 }
151}
152
153impl<I> ProgressStatus for SimulatedAnnealingStatus<I>
154where
155 I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
156{
157 fn write_progress(&self, out: &mut String) -> std::fmt::Result {
158 use std::fmt::Write;
159 write!(
160 out,
161 "status={} temperature={} best_fx={} current_fx={}",
162 self.message,
163 self.temperature,
164 self.best.fx.unwrap_or(Float::NAN),
165 self.current.fx.unwrap_or(Float::NAN)
166 )
167 }
168}
169
170pub struct SimulatedAnnealing {
172 rng: fastrand::Rng,
173}
174
175impl Default for SimulatedAnnealing {
176 fn default() -> Self {
177 Self::new(Some(0))
178 }
179}
180
181impl SimulatedAnnealing {
182 pub fn new(seed: Option<u64>) -> Self {
184 Self {
185 rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
186 }
187 }
188}
189
190impl<P, U, E, I> Algorithm<P, SimulatedAnnealingStatus<I>, U, E> for SimulatedAnnealing
191where
192 P: SimulatedAnnealingGenerator<U, E, Input = I>,
193 I: Serialize + for<'a> Deserialize<'a> + Clone + Default,
194{
195 type Summary = SimulatedAnnealingSummary<I>;
196 type Config = SimulatedAnnealingConfig;
197 type Init = ();
198
199 #[allow(clippy::expect_used)]
200 fn initialize(
201 &mut self,
202 problem: &P,
203 status: &mut SimulatedAnnealingStatus<I>,
204 args: &U,
205 _init: &Self::Init,
206 config: &Self::Config,
207 ) -> Result<(), E> {
208 let x0 = problem.initial(&config.transform, status, args);
209 let fx0 = problem.evaluate_generic(&x0, args)?;
210 status.temperature = config.initial_temperature;
211 status.current = Point {
212 x: x0,
213 fx: Some(fx0),
214 };
215 status.initial = status.current.clone();
216 status.best = status.current.clone();
217 status.set_message().initialize();
218 Ok(())
219 }
220
221 fn step(
222 &mut self,
223 _current_step: usize,
224 problem: &P,
225 status: &mut SimulatedAnnealingStatus<I>,
226 args: &U,
227 config: &Self::Config,
228 ) -> Result<(), E> {
229 let x = problem.generate(&config.transform, status, args);
230 let fx = problem.evaluate_generic(&x, args)?;
231 status.n_f_evals += 1;
232
233 status.temperature *= config.cooling_rate;
234
235 if fx < status.best.fx_checked() {
236 status.current = Point { x, fx: Some(fx) };
237 status.best = status.current.clone();
238 return Ok(());
239 }
240
241 let d_fx = fx - status.current.fx_checked();
242 let acceptance_probability = (-d_fx / status.temperature).exp();
243
244 if acceptance_probability > self.rng.float() {
245 status.current = Point { x, fx: Some(fx) };
246 }
247 Ok(())
248 }
249
250 fn summarize(
251 &self,
252 _current_step: usize,
253 _problem: &P,
254 status: &SimulatedAnnealingStatus<I>,
255 _args: &U,
256 _init: &Self::Init,
257 _config: &Self::Config,
258 ) -> Result<Self::Summary, E> {
259 Ok(SimulatedAnnealingSummary {
260 bounds: None,
261 message: status.message.clone(),
262 x0: status.initial.x.clone(),
263 x: status.best.x.clone(),
264 fx: status.best.fx_checked(),
265 n_f_evals: status.n_f_evals,
266 n_g_evals: 0,
267 n_h_evals: 0,
268 })
269 }
270
271 fn default_callbacks() -> Callbacks<Self, P, SimulatedAnnealingStatus<I>, U, E, Self::Config>
272 where
273 Self: Sized,
274 {
275 Callbacks::empty().with_terminator(SimulatedAnnealingTerminator::default())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::{
283 core::{Bounds, Callbacks, MaxSteps},
284 test_functions::Rosenbrock,
285 traits::cost_function::GenericGradient,
286 DVector,
287 };
288 use approx::assert_relative_eq;
289 use nalgebra::DMatrix;
290 use std::{cell::RefCell, convert::Infallible, fmt::Debug};
291
292 pub struct GradientAnnealingProblem<U, E>(
293 Box<dyn GenericGradient<U, E, Input = DVector<Float>>>,
294 DVector<Float>,
295 );
296 impl<U, E> GradientAnnealingProblem<U, E> {
297 pub fn new<P>(problem: P, x0: &[Float]) -> Self
298 where
299 P: GenericGradient<U, E, Input = DVector<Float>> + 'static,
300 {
301 Self(Box::new(problem), DVector::from_row_slice(x0))
302 }
303 }
304 impl<U, E> GenericCostFunction<U, E> for GradientAnnealingProblem<U, E> {
305 type Input = DVector<Float>;
306
307 fn evaluate_generic(&self, x: &Self::Input, args: &U) -> Result<Float, E> {
308 self.0.evaluate_generic(x, args)
309 }
310 }
311 impl<U, E> GenericGradient<U, E> for GradientAnnealingProblem<U, E> {
312 fn gradient_generic(&self, x: &Self::Input, args: &U) -> Result<DVector<Float>, E> {
313 self.0.gradient_generic(x, args)
314 }
315
316 fn hessian_generic(&self, x: &Self::Input, args: &U) -> Result<DMatrix<Float>, E> {
317 self.0.hessian_generic(x, args)
318 }
319 }
320 impl<U, E: Debug> SimulatedAnnealingGenerator<U, E> for GradientAnnealingProblem<U, E>
321 where
322 Self: GenericGradient<U, E, Input = DVector<Float>>,
323 {
324 fn generate(
325 &self,
326 transform: &Option<Box<dyn Transform>>,
327 status: &mut SimulatedAnnealingStatus<Self::Input>,
328 args: &U,
329 ) -> Self::Input {
330 let x_int = transform.to_owned_internal(&status.current.x);
331 #[allow(clippy::expect_used)]
332 let g_ext = self
333 .gradient_generic(&status.current.x, args)
334 .expect("This should never fail");
335 let g_int = transform.pullback_gradient(&x_int, &g_ext);
336 let x_int_new = x_int - &(status.temperature * 1e-4 * g_int);
337 transform.to_owned_external(&x_int_new)
338 }
339
340 fn initial(
341 &self,
342 _transform: &Option<Box<dyn Transform>>,
343 _status: &mut SimulatedAnnealingStatus<Self::Input>,
344 _args: &U,
345 ) -> Self::Input {
346 self.1.clone()
347 }
348 }
349
350 #[test]
351 fn test_simulated_annealing() {
352 let mut solver = SimulatedAnnealing::default();
353 let problem = GradientAnnealingProblem::new(Rosenbrock { n: 2 }, &[0.0, 0.0]);
354 let result = solver
355 .process(
356 &problem,
357 &(),
358 (),
359 SimulatedAnnealingConfig::new(1.0, 0.999)
360 .unwrap()
361 .with_transform(&Bounds::from([(-5.0, 5.0), (-5.0, 5.0)])),
362 SimulatedAnnealing::default_callbacks(),
363 )
364 .unwrap();
365 assert_relative_eq!(result.fx, 0.0, epsilon = 0.5);
366 }
367
368 struct SequenceAnnealingProblem {
369 initial: DVector<Float>,
370 proposals: RefCell<Vec<DVector<Float>>>,
371 }
372 impl SequenceAnnealingProblem {
373 fn new(initial: &[Float], proposals: Vec<&[Float]>) -> Self {
374 Self {
375 initial: DVector::from_row_slice(initial),
376 proposals: RefCell::new(
377 proposals
378 .into_iter()
379 .map(DVector::from_row_slice)
380 .collect::<Vec<_>>(),
381 ),
382 }
383 }
384 }
385 impl GenericCostFunction<(), Infallible> for SequenceAnnealingProblem {
386 type Input = DVector<Float>;
387
388 fn evaluate_generic(&self, x: &Self::Input, _: &()) -> Result<Float, Infallible> {
389 Ok(x[0])
390 }
391 }
392 impl SimulatedAnnealingGenerator<(), Infallible> for SequenceAnnealingProblem {
393 fn initial(
394 &self,
395 _: &Option<Box<dyn Transform>>,
396 _: &mut SimulatedAnnealingStatus<Self::Input>,
397 _: &(),
398 ) -> Self::Input {
399 self.initial.clone()
400 }
401
402 fn generate(
403 &self,
404 _: &Option<Box<dyn Transform>>,
405 _: &mut SimulatedAnnealingStatus<Self::Input>,
406 _: &(),
407 ) -> Self::Input {
408 self.proposals.borrow_mut().remove(0)
409 }
410 }
411
412 #[test]
413 fn accepts_improving_proposal_even_if_not_new_best() {
414 let mut solver = SimulatedAnnealing::default();
415 let problem = SequenceAnnealingProblem::new(&[2.0], vec![&[1.0]]);
416 let config = SimulatedAnnealingConfig::new(0.01, 0.9).unwrap();
417 let mut status = SimulatedAnnealingStatus::default();
418
419 solver
420 .initialize(&problem, &mut status, &(), &(), &config)
421 .unwrap();
422 status.best = Point {
423 x: DVector::from_row_slice(&[0.0]),
424 fx: Some(0.0),
425 };
426 status.current = Point {
427 x: DVector::from_row_slice(&[2.0]),
428 fx: Some(2.0),
429 };
430
431 solver.step(0, &problem, &mut status, &(), &config).unwrap();
432
433 assert_relative_eq!(status.current.x[0], 1.0);
434 assert_relative_eq!(status.current.fx_checked(), 1.0);
435 assert_relative_eq!(status.best.x[0], 0.0);
436 assert_relative_eq!(status.best.fx_checked(), 0.0);
437 }
438
439 #[test]
440 fn rejected_proposal_does_not_advance_current() {
441 let mut solver = SimulatedAnnealing::default();
442 let problem = SequenceAnnealingProblem::new(&[0.0], vec![&[1.0]]);
443 let config = SimulatedAnnealingConfig::new(1e-6, 0.9).unwrap();
444 let mut status = SimulatedAnnealingStatus::default();
445
446 solver
447 .initialize(&problem, &mut status, &(), &(), &config)
448 .unwrap();
449 let current_before = status.current.clone();
450 let best_before = status.best.clone();
451
452 solver.step(0, &problem, &mut status, &(), &config).unwrap();
453
454 assert_eq!(status.current.x, current_before.x);
455 assert_eq!(status.current.fx, current_before.fx);
456 assert_eq!(status.best.x, best_before.x);
457 assert_eq!(status.best.fx, best_before.fx);
458 }
459
460 #[test]
461 fn summary_reports_nonzero_evals_and_terminal_message() {
462 let mut solver = SimulatedAnnealing::default();
463 let problem = GradientAnnealingProblem::new(Rosenbrock { n: 2 }, &[0.0, 0.0]);
464 let result = solver
465 .process(
466 &problem,
467 &(),
468 (),
469 SimulatedAnnealingConfig::new(1.0, 0.999).unwrap(),
470 Callbacks::empty().with_terminator(MaxSteps(2)),
471 )
472 .unwrap();
473
474 assert!(result.n_f_evals > 0);
475 assert!(result
476 .message
477 .to_string()
478 .contains("Maximum number of steps reached"));
479 }
480}