argmin/core/solver.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, KV};
9
10/// The interface all solvers are required to implement.
11///
12/// Every solver needs to implement this trait in order to function with the `Executor`.
13/// It handles initialization ([`init`](`Solver::init`)), each iteration of the solver
14/// ([`next_iter`](`Solver::next_iter`)), and termination of the algorithm
15/// ([`terminate`](`Solver::terminate`) and [`terminate_internal`](`Solver::terminate_internal`)).
16/// Only `next_iter` is mandatory to implement, all others provide default implementations.
17///
18/// A `Solver` should be (de)serializable in order to work with checkpointing.
19///
20/// # Example
21///
22/// ```
23/// use argmin::core::{
24/// ArgminFloat, Solver, IterState, CostFunction, Error, KV, Problem, TerminationReason, TerminationStatus
25/// };
26///
27/// #[derive(Clone)]
28/// struct OptimizationAlgorithm {}
29///
30/// impl<O, P, G, J, H, R, F> Solver<O, IterState<P, G, J, H, R, F>> for OptimizationAlgorithm
31/// where
32/// O: CostFunction<Param = P, Output = F>,
33/// P: Clone,
34/// F: ArgminFloat
35/// {
36/// fn name(&self) -> &str { "OptimizationAlgorithm" }
37///
38/// fn init(
39/// &mut self,
40/// problem: &mut Problem<O>,
41/// state: IterState<P, G, J, H, R, F>,
42/// ) -> Result<(IterState<P, G, J, H, R, F>, Option<KV>), Error> {
43/// // Initialize algorithm, update `state`.
44/// // Implementing this method is optional.
45/// Ok((state, None))
46/// }
47///
48/// fn next_iter(
49/// &mut self,
50/// problem: &mut Problem<O>,
51/// state: IterState<P, G, J, H, R, F>,
52/// ) -> Result<(IterState<P, G, J, H, R, F>, Option<KV>), Error> {
53/// // Compute single iteration of algorithm, update `state`.
54/// // Implementing this method is required.
55/// Ok((state, None))
56/// }
57///
58/// fn terminate(&mut self, state: &IterState<P, G, J, H, R, F>) -> TerminationStatus {
59/// // Check if stopping criteria are met.
60/// // Implementing this method is optional.
61/// TerminationStatus::NotTerminated
62/// }
63/// }
64/// ```
65pub trait Solver<O, I: State> {
66 /// Name of the solver. Mainly used in [Observers](`crate::core::observers::Observe`).
67 fn name(&self) -> &str;
68
69 /// Initializes the algorithm.
70 ///
71 /// Executed before any iterations are performed and has access to the optimization problem
72 /// definition and the internal state of the solver.
73 /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
74 /// [Observers](`crate::core::observers::Observe`).
75 /// The default implementation returns the unaltered `state` and no `KV`.
76 fn init(&mut self, _problem: &mut Problem<O>, state: I) -> Result<(I, Option<KV>), Error> {
77 Ok((state, None))
78 }
79
80 /// Computes a single iteration of the algorithm and has access to the optimization problem
81 /// definition and the internal state of the solver.
82 /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
83 /// [Observers](`crate::core::observers::Observe`).
84 fn next_iter(&mut self, problem: &mut Problem<O>, state: I) -> Result<(I, Option<KV>), Error>;
85
86 /// Checks whether basic termination reasons apply.
87 ///
88 /// Terminate if
89 ///
90 /// 1) algorithm was terminated somewhere else in the Executor
91 /// 2) iteration count exceeds maximum number of iterations
92 /// 3) best cost is lower than or equal to the target cost
93 ///
94 /// This can be overwritten; however it is not advised. It is recommended to implement other
95 /// stopping criteria via ([`terminate`](`Solver::terminate`).
96 fn terminate_internal(&mut self, state: &I) -> TerminationStatus {
97 let solver_status = self.terminate(state);
98 if solver_status.terminated() {
99 return solver_status;
100 }
101 if state.get_iter() >= state.get_max_iters() {
102 return TerminationStatus::Terminated(TerminationReason::MaxItersReached);
103 }
104 if state.get_best_cost() <= state.get_target_cost() {
105 return TerminationStatus::Terminated(TerminationReason::TargetCostReached);
106 }
107 TerminationStatus::NotTerminated
108 }
109
110 /// Used to implement stopping criteria, in particular criteria which are not covered by
111 /// ([`terminate_internal`](`Solver::terminate_internal`).
112 ///
113 /// This method has access to the internal state and returns an `TerminationReason`.
114 fn terminate(&mut self, _state: &I) -> TerminationStatus {
115 TerminationStatus::NotTerminated
116 }
117}