optimization_solvers/
ls_solver.rs1use super::*;
2
3pub trait ComputeDirection {
4 fn compute_direction(
5 &mut self,
6 eval_x_k: &FuncEvalMultivariate,
7 ) -> Result<DVector<Floating>, SolverError>;
8}
9
10#[derive(thiserror::Error, Debug)]
11pub enum SolverError {
12 #[error("Max iter reached")]
13 MaxIterReached,
14 #[error("Out of domain")]
15 OutOfDomain,
16 #[error("Error in input parameters")]
17 ErrorInputParams,
18 #[error("Abnormal termination")]
19 AbnormalTermination,
20}
21
22pub trait LineSearchSolver: ComputeDirection {
24 fn xk(&self) -> &DVector<Floating>;
25 fn xk_mut(&mut self) -> &mut DVector<Floating>;
26 fn k(&self) -> &usize;
27 fn k_mut(&mut self) -> &mut usize;
28 fn has_converged(&self, eval_x_k: &FuncEvalMultivariate) -> bool;
29
30 fn setup(&mut self) {}
31
32 fn evaluate_x_k(
33 &mut self,
34 oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
35 ) -> Result<FuncEvalMultivariate, SolverError> {
36 let eval_x_k = oracle(self.xk());
37 if eval_x_k.f().is_nan() || eval_x_k.f().is_infinite() {
38 error!(target: "solver","Minimization completed: next iterate is out of domain");
39 return Err(SolverError::OutOfDomain);
40 }
41 Ok(eval_x_k)
42 }
43
44 fn update_next_iterate<LS: LineSearch>(
45 &mut self,
46 line_search: &mut LS,
47 eval_x_k: &FuncEvalMultivariate, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
49 direction: &DVector<Floating>,
50 max_iter_line_search: usize,
51 ) -> Result<(), SolverError> {
52 let step = line_search.compute_step_len(
53 self.xk(),
54 eval_x_k,
55 direction,
56 oracle,
57 max_iter_line_search,
58 );
59
60 let next_iterate = self.xk() + step * direction;
61 *self.xk_mut() = next_iterate;
62
63 Ok(())
64 }
65
66 fn minimize<LS: LineSearch>(
67 &mut self,
68 line_search: &mut LS,
69 mut oracle: impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
70 max_iter_solver: usize,
71 max_iter_line_search: usize,
72 mut callback: Option<&mut dyn FnMut(&Self)>,
73 ) -> Result<(), SolverError> {
74 *self.k_mut() = 0;
75
76 self.setup();
77
78 while &max_iter_solver > self.k() {
79 let eval_x_k = self.evaluate_x_k(&mut oracle)?;
80
81 if self.has_converged(&eval_x_k) {
82 info!(
83 target: "solver",
84 "Minimization completed: convergence in {} iterations",
85 self.k()
86 );
87 return Ok(());
88 }
89
90 let direction = self.compute_direction(&eval_x_k)?;
91
92 debug!(target: "solver","Gradient: {:?}, Direction: {:?}", eval_x_k.g(), direction);
93 self.update_next_iterate(
94 line_search,
95 &eval_x_k,
96 &mut oracle,
97 &direction,
98 max_iter_line_search,
99 )?;
100
101 debug!(target: "solver","Iterate: {:?}", self.xk());
102 debug!(target: "solver","Function eval: {:?}", eval_x_k);
103
104 *self.k_mut() += 1;
105 if let Some(callback) = callback.as_mut() {
106 callback(self);
107 }
108 }
109 warn!(target: "solver","Minimization completed: max iter reached during minimization");
110 Err(SolverError::MaxIterReached)
111 }
112}
113
114pub trait HasBounds {
115 fn lower_bound(&self) -> &DVector<Floating>;
116 fn upper_bound(&self) -> &DVector<Floating>;
117 fn set_lower_bound(&mut self, lower_bound: DVector<Floating>);
118 fn set_upper_bound(&mut self, upper_bound: DVector<Floating>);
119}
120
121pub trait HasProjectedGradient: LineSearchSolver + HasBounds {
122 fn projected_gradient(&self, eval: &FuncEvalMultivariate) -> DVector<Floating> {
123 let mut proj_grad = eval.g().clone();
124 for (i, x) in self.xk().iter().enumerate() {
125 if (x == &self.lower_bound()[i] && proj_grad[i] > 0.0)
126 || (x == &self.upper_bound()[i] && proj_grad[i] < 0.0)
127 {
128 proj_grad[i] = 0.0;
129 }
130 }
131 proj_grad
132 }
133}
134
135impl<T> HasProjectedGradient for T where T: LineSearchSolver + HasBounds {}