optimization_solvers/
ls_solver.rs

1use super::*;
2
3pub trait ComputeDirection {
4    fn compute_direction(
5        &mut self,
6        eval_x_k: &FuncEvalMultivariate,
7    ) -> Result<DVector<Floating>, SolverError>;
8}
9
10#[derive(thiserror::Error, Debug)]
11pub enum SolverError {
12    #[error("Max iter reached")]
13    MaxIterReached,
14    #[error("Out of domain")]
15    OutOfDomain,
16    #[error("Error in input parameters")]
17    ErrorInputParams,
18    #[error("Abnormal termination")]
19    AbnormalTermination,
20}
21
22//Template pattern for solvers. Methods that are already implemented can be freely overriden.
23pub trait LineSearchSolver: ComputeDirection {
24    fn xk(&self) -> &DVector<Floating>;
25    fn xk_mut(&mut self) -> &mut DVector<Floating>;
26    fn k(&self) -> &usize;
27    fn k_mut(&mut self) -> &mut usize;
28    fn has_converged(&self, eval_x_k: &FuncEvalMultivariate) -> bool;
29
30    fn setup(&mut self) {}
31
32    fn evaluate_x_k(
33        &mut self,
34        oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
35    ) -> Result<FuncEvalMultivariate, SolverError> {
36        let eval_x_k = oracle(self.xk());
37        if eval_x_k.f().is_nan() || eval_x_k.f().is_infinite() {
38            error!(target: "solver","Minimization completed: next iterate is out of domain");
39            return Err(SolverError::OutOfDomain);
40        }
41        Ok(eval_x_k)
42    }
43
44    fn update_next_iterate<LS: LineSearch>(
45        &mut self,
46        line_search: &mut LS,
47        eval_x_k: &FuncEvalMultivariate, //eval_x_k: &FuncEvalMultivariate,
48        oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
49        direction: &DVector<Floating>,
50        max_iter_line_search: usize,
51    ) -> Result<(), SolverError> {
52        let step = line_search.compute_step_len(
53            self.xk(),
54            eval_x_k,
55            direction,
56            oracle,
57            max_iter_line_search,
58        );
59
60        let next_iterate = self.xk() + step * direction;
61        *self.xk_mut() = next_iterate;
62
63        Ok(())
64    }
65
66    fn minimize<LS: LineSearch>(
67        &mut self,
68        line_search: &mut LS,
69        mut oracle: impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
70        max_iter_solver: usize,
71        max_iter_line_search: usize,
72        mut callback: Option<&mut dyn FnMut(&Self)>,
73    ) -> Result<(), SolverError> {
74        *self.k_mut() = 0;
75
76        self.setup();
77
78        while &max_iter_solver > self.k() {
79            let eval_x_k = self.evaluate_x_k(&mut oracle)?;
80
81            if self.has_converged(&eval_x_k) {
82                info!(
83                    target: "solver",
84                    "Minimization completed: convergence in {} iterations",
85                    self.k()
86                );
87                return Ok(());
88            }
89
90            let direction = self.compute_direction(&eval_x_k)?;
91
92            debug!(target: "solver","Gradient: {:?}, Direction: {:?}", eval_x_k.g(), direction);
93            self.update_next_iterate(
94                line_search,
95                &eval_x_k,
96                &mut oracle,
97                &direction,
98                max_iter_line_search,
99            )?;
100
101            debug!(target: "solver","Iterate: {:?}", self.xk());
102            debug!(target: "solver","Function eval: {:?}", eval_x_k);
103
104            *self.k_mut() += 1;
105            if let Some(callback) = callback.as_mut() {
106                callback(self);
107            }
108        }
109        warn!(target: "solver","Minimization completed: max iter reached during minimization");
110        Err(SolverError::MaxIterReached)
111    }
112}
113
114pub trait HasBounds {
115    fn lower_bound(&self) -> &DVector<Floating>;
116    fn upper_bound(&self) -> &DVector<Floating>;
117    fn set_lower_bound(&mut self, lower_bound: DVector<Floating>);
118    fn set_upper_bound(&mut self, upper_bound: DVector<Floating>);
119}
120
121pub trait HasProjectedGradient: LineSearchSolver + HasBounds {
122    fn projected_gradient(&self, eval: &FuncEvalMultivariate) -> DVector<Floating> {
123        let mut proj_grad = eval.g().clone();
124        for (i, x) in self.xk().iter().enumerate() {
125            if (x == &self.lower_bound()[i] && proj_grad[i] > 0.0)
126                || (x == &self.upper_bound()[i] && proj_grad[i] < 0.0)
127            {
128                proj_grad[i] = 0.0;
129            }
130        }
131        proj_grad
132    }
133}
134
135//Blanket implementation for all optimization solvers that have bounds
136impl<T> HasProjectedGradient for T where T: LineSearchSolver + HasBounds {}