optimization_solvers/steepest_descent/
pnorm_descent.rs1use super::*;
2
3#[derive(derive_getters::Getters)]
12pub struct PnormDescent {
13 pub grad_tol: Floating,
14 pub x: DVector<Floating>,
15 pub k: usize,
16 pub inverse_p: DMatrix<Floating>,
17}
18
19impl PnormDescent {
20 pub fn new(grad_tol: Floating, x0: DVector<Floating>, inverse_p: DMatrix<Floating>) -> Self {
21 PnormDescent {
22 grad_tol,
23 x: x0,
24 k: 0,
25 inverse_p,
26 }
27 }
28}
29
30impl ComputeDirection for PnormDescent {
31 fn compute_direction(
32 &mut self,
33 eval: &FuncEvalMultivariate,
34 ) -> Result<DVector<Floating>, SolverError> {
35 Ok(-&self.inverse_p * eval.g())
36 }
37}
38
39impl LineSearchSolver for PnormDescent {
40 fn xk(&self) -> &DVector<Floating> {
41 &self.x
42 }
43 fn xk_mut(&mut self) -> &mut DVector<Floating> {
44 &mut self.x
45 }
46 fn k(&self) -> &usize {
47 &self.k
48 }
49 fn k_mut(&mut self) -> &mut usize {
50 &mut self.k
51 }
52 fn has_converged(&self, eval: &FuncEvalMultivariate) -> bool {
53 let grad = eval.g();
55 grad.iter()
57 .fold(Floating::NEG_INFINITY, |acc, x| x.abs().max(acc))
58 < self.grad_tol
59 }
60
61 fn update_next_iterate<LS: LineSearch>(
62 &mut self,
63 line_search: &mut LS,
64 eval_x_k: &FuncEvalMultivariate, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
66 direction: &DVector<Floating>,
67 max_iter_line_search: usize,
68 ) -> Result<(), SolverError> {
69 let step = line_search.compute_step_len(
70 self.xk(),
71 eval_x_k,
72 direction,
73 oracle,
74 max_iter_line_search,
75 );
76
77 debug!(target: "pnorm_descent", "ITERATE: {} + {} * {} = {}", self.xk(), step, direction, self.xk() + step * direction);
78
79 let next_iterate = self.xk() + step * direction;
80
81 *self.xk_mut() = next_iterate;
82
83 Ok(())
84 }
85}
86
87#[cfg(test)]
88mod gpnorm_descent_test {
89 use super::*;
90
91 #[test]
92 pub fn pnorm_morethuente() {
93 std::env::set_var("RUST_LOG", "info");
94
95 let _ = Tracer::default()
96 .with_stdout_layer(Some(LogFormat::Normal))
97 .build();
98 let gamma = 90.0;
99 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
100 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
101 let g = DVector::from(vec![x[0], gamma * x[1]]);
102 (f, g).into()
103 };
104 let inv_hessian = DMatrix::from_iterator(2, 2, vec![1.0, 0.0, 0.0, 1.0 / gamma]);
106 let mut ls = MoreThuente::default();
110
111 let tol = 1e-12;
113 let x_0 = DVector::from(vec![180.0, 152.0]);
114 let mut gd = PnormDescent::new(tol, x_0, inv_hessian);
115
116 let max_iter_solver = 1000;
118 let max_iter_line_search = 100;
119
120 gd.minimize(
121 &mut ls,
122 f_and_g,
123 max_iter_solver,
124 max_iter_line_search,
125 None,
126 )
127 .unwrap();
128
129 println!("Iterate: {:?}", gd.xk());
130
131 let eval = f_and_g(gd.xk());
132 println!("Function eval: {:?}", eval);
133 println!("Gradient norm: {:?}", eval.g().norm());
134 println!("tol: {:?}", tol);
135
136 let convergence = gd.has_converged(&eval);
137 println!("Convergence: {:?}", convergence);
138
139 assert!((eval.f() - 0.0).abs() < 1e-6);
140 }
141
142 #[test]
143 pub fn pnorm_backtracking() {
144 std::env::set_var("RUST_LOG", "info");
145
146 let _ = Tracer::default()
147 .with_stdout_layer(Some(LogFormat::Normal))
148 .build();
149 let gamma = 90.0;
150 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
151 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
152 let g = DVector::from(vec![x[0], gamma * x[1]]);
153 (f, g).into()
154 };
155 let inv_hessian = DMatrix::from_iterator(2, 2, vec![1.0, 0.0, 0.0, 1.0 / gamma]);
157 let alpha = 1e-4;
161 let beta = 0.5;
162 let mut ls = BackTracking::new(alpha, beta);
163
164 let tol = 1e-12;
166 let x_0 = DVector::from(vec![180.0, 152.0]);
167 let mut gd = PnormDescent::new(tol, x_0, inv_hessian);
168
169 let max_iter_solver = 1000;
171 let max_iter_line_search = 100;
172
173 gd.minimize(
174 &mut ls,
175 f_and_g,
176 max_iter_solver,
177 max_iter_line_search,
178 None,
179 )
180 .unwrap();
181
182 println!("Iterate: {:?}", gd.xk());
183
184 let eval = f_and_g(gd.xk());
185 println!("Function eval: {:?}", eval);
186 println!("Gradient norm: {:?}", eval.g().norm());
187 println!("tol: {:?}", tol);
188
189 let convergence = gd.has_converged(&eval);
190 println!("Convergence: {:?}", convergence);
191
192 assert!((eval.f() - 0.0).abs() < 1e-6);
193 }
194}