optimization_solvers/quasi_newton/
bfgs.rs1use super::*;
2
3#[derive(derive_getters::Getters)]
4pub struct BFGS {
5 approx_inv_hessian: DMatrix<Floating>,
6 x: DVector<Floating>,
7 k: usize,
8 tol: Floating,
9 s_norm: Option<Floating>,
10 y_norm: Option<Floating>,
11 identity: DMatrix<Floating>,
12}
13
14impl BFGS {
15 pub fn next_iterate_too_close(&self) -> bool {
16 match self.s_norm() {
17 Some(s) => s < &self.tol,
18 None => false,
19 }
20 }
21 pub fn gradient_next_iterate_too_close(&self) -> bool {
22 match self.y_norm() {
23 Some(y) => y < &self.tol,
24 None => false,
25 }
26 }
27 pub fn new(tol: Floating, x0: DVector<Floating>) -> Self {
28 let n = x0.len();
29 let identity = DMatrix::identity(n, n);
30 BFGS {
31 approx_inv_hessian: identity.clone(),
32 x: x0,
33 k: 0,
34 tol,
35 s_norm: None,
36 y_norm: None,
37 identity,
38 }
39 }
40}
41
42impl ComputeDirection for BFGS {
43 fn compute_direction(
44 &mut self,
45 eval: &FuncEvalMultivariate,
46 ) -> Result<DVector<Floating>, SolverError> {
47 Ok(-&self.approx_inv_hessian * eval.g())
48 }
49}
50
51impl LineSearchSolver for BFGS {
52 fn k(&self) -> &usize {
53 &self.k
54 }
55 fn xk(&self) -> &DVector<Floating> {
56 &self.x
57 }
58 fn xk_mut(&mut self) -> &mut DVector<Floating> {
59 &mut self.x
60 }
61 fn k_mut(&mut self) -> &mut usize {
62 &mut self.k
63 }
64 fn has_converged(&self, eval: &FuncEvalMultivariate) -> bool {
65 if self.next_iterate_too_close() {
68 warn!(target: "bfgs","Minimization completed: next iterate too close");
69 true
70 } else if self.gradient_next_iterate_too_close() {
71 warn!(target: "bfgs","Minimization completed: gradient next iterate too close");
72 true
73 } else {
74 eval.g().norm() < self.tol
75 }
76 }
77
78 fn update_next_iterate<LS: LineSearch>(
79 &mut self,
80 line_search: &mut LS,
81 eval_x_k: &FuncEvalMultivariate,
82 oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
83 direction: &DVector<Floating>,
84 max_iter_line_search: usize,
85 ) -> Result<(), SolverError> {
86 let step = line_search.compute_step_len(
87 self.xk(),
88 eval_x_k,
89 direction,
90 oracle,
91 max_iter_line_search,
92 );
93
94 let next_iterate = self.xk() + step * direction;
95
96 let s = &next_iterate - &self.x;
97 self.s_norm = Some(s.norm());
98 let y = oracle(&next_iterate).g() - eval_x_k.g();
99 self.y_norm = Some(y.norm());
100
101 *self.xk_mut() = next_iterate;
103
104 if self.next_iterate_too_close() {
107 return Ok(());
108 }
109
110 if self.gradient_next_iterate_too_close() {
111 return Ok(());
112 }
113
114 let ys = &y.dot(&s);
116 let rho = 1.0 / ys;
117 let w_a = &s * &y.transpose();
118 let w_b = w_a.transpose();
120 let innovation = &s * &s.transpose();
121 let left_term = self.identity() - (w_a * rho);
122 let right_term = self.identity() - (w_b * rho);
123 self.approx_inv_hessian =
124 (left_term * &self.approx_inv_hessian * right_term) + innovation * rho;
125
126 Ok(())
127 }
128}
129
130#[cfg(test)]
131mod test_bfgs {
132 use super::*;
133 #[test]
134 fn test_outer() {
135 let a = DVector::from_vec(vec![1.0, 2.0]);
136 let b = DVector::from_vec(vec![3.0, 4.0]);
137 let c = a * b.transpose();
138 println!("{:?}", c);
139 }
140
141 #[test]
142 pub fn bfgs_morethuente() {
143 std::env::set_var("RUST_LOG", "info");
144
145 let _ = Tracer::default()
146 .with_stdout_layer(Some(LogFormat::Normal))
147 .build();
148 let gamma = 1.;
149 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
150 let f = 0.5 * ((x[0] + 1.).powi(2) + gamma * (x[1] - 1.).powi(2));
151 let g = DVector::from(vec![x[0] + 1., gamma * (x[1] - 1.)]);
152 (f, g).into()
153 };
154
155 let mut ls = MoreThuente::default();
158
159 let tol = 1e-12;
161 let x_0 = DVector::from(vec![180.0, 152.0]);
162 let mut gd = BFGS::new(tol, x_0);
163
164 let max_iter_solver = 1000;
166 let max_iter_line_search = 100000;
167
168 gd.minimize(
169 &mut ls,
170 f_and_g,
171 max_iter_solver,
172 max_iter_line_search,
173 None,
174 )
175 .unwrap();
176
177 println!("Iterate: {:?}", gd.xk());
178
179 let eval = f_and_g(gd.xk());
180 println!("Function eval: {:?}", eval);
181 println!("Gradient norm: {:?}", eval.g().norm());
182 println!("tol: {:?}", tol);
183
184 let convergence = gd.has_converged(&eval);
185 println!("Convergence: {:?}", convergence);
186
187 assert!((eval.f() - 0.0).abs() < 1e-6);
188 }
189
190 #[test]
191 pub fn bfgs_backtracking() {
192 std::env::set_var("RUST_LOG", "info");
193
194 let _ = Tracer::default()
195 .with_stdout_layer(Some(LogFormat::Normal))
196 .build();
197 let gamma = 1.;
198 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
199 let f = 0.5 * ((x[0] + 1.).powi(2) + gamma * (x[1] - 1.).powi(2));
200 let g = DVector::from(vec![x[0] + 1., gamma * (x[1] - 1.)]);
201 (f, g).into()
202 };
203
204 let alpha = 1e-4;
206 let beta = 0.5; let mut ls = BackTracking::new(alpha, beta);
209
210 let tol = 1e-12;
212 let x_0 = DVector::from(vec![180.0, 152.0]);
213 let mut gd = BFGS::new(tol, x_0);
214
215 let max_iter_solver = 1000;
217 let max_iter_line_search = 100000;
218
219 gd.minimize(
220 &mut ls,
221 f_and_g,
222 max_iter_solver,
223 max_iter_line_search,
224 None,
225 )
226 .unwrap();
227
228 println!("Iterate: {:?}", gd.xk());
229
230 let eval = f_and_g(gd.xk());
231 println!("Function eval: {:?}", eval);
232 println!("Gradient norm: {:?}", eval.g().norm());
233 println!("tol: {:?}", tol);
234
235 let convergence = gd.has_converged(&eval);
236 println!("Convergence: {:?}", convergence);
237
238 assert!((eval.f() - 0.0).abs() < 1e-6);
239 }
240}