1use super::*;
2
3#[derive(Debug, Clone, derive_getters::Getters)]
6pub struct MoreThuenteB {
7 c1: Floating, c2: Floating, t_min: Floating,
10 t_max: Floating,
11 delta_min: Floating,
12 delta: Floating,
13 delta_max: Floating,
14 lower_bound: DVector<Floating>,
15 upper_bound: DVector<Floating>,
16}
17
18impl MoreThuenteB {
19 pub fn new(n: usize) -> Self {
20 MoreThuenteB {
21 c1: 1e-4,
22 c2: 0.9,
23 t_min: 0.0,
24 t_max: Floating::INFINITY,
25 delta_min: 0.58333333,
26 delta: 0.66,
27 delta_max: 1.1,
28 lower_bound: DVector::from_element(n, -Floating::INFINITY),
29 upper_bound: DVector::from_element(n, Floating::INFINITY),
30 }
31 }
32 pub fn with_lower_bound(mut self, lower_bound: DVector<Floating>) -> Self {
33 self.lower_bound = lower_bound;
34 self
35 }
36 pub fn with_upper_bound(mut self, upper_bound: DVector<Floating>) -> Self {
37 self.upper_bound = upper_bound;
38 self
39 }
40 pub fn with_deltas(
41 mut self,
42 delta_min: Floating,
43 delta: Floating,
44 delta_max: Floating,
45 ) -> Self {
46 self.delta_min = delta_min;
47 self.delta = delta;
48 self.delta_max = delta_max;
49 self
50 }
51 pub fn with_t_min(mut self, t_min: Floating) -> Self {
52 self.t_min = t_min;
53 self
54 }
55 pub fn with_t_max(mut self, t_max: Floating) -> Self {
56 self.t_max = t_max;
57 self
58 }
59 pub fn with_c1(mut self, c1: Floating) -> Self {
60 assert!(c1 > 0.0, "c1 must be positive");
61 assert!(c1 < self.c2, "c1 must be less than c2");
62 self.c1 = c1;
63 self
64 }
65 pub fn with_c2(mut self, c2: Floating) -> Self {
66 assert!(c2 > 0.0, "c2 must be positive");
67 assert!(c2 < 1.0, "c2 must be less than 1");
68 assert!(c2 > self.c1, "c2 must be greater than c1");
69 self.c2 = c2;
70 self
71 }
72
73 pub fn update_interval(
74 f_tl: &Floating,
75 f_t: &Floating,
76 g_t: &Floating,
77 tl: &mut Floating,
78 t: Floating,
79 tu: &mut Floating,
80 ) -> bool {
81 if f_t > f_tl {
83 *tu = t;
84 false
85 }
86 else if g_t * (*tl - t) > 0. {
88 *tl = t;
89 false
90 }
91 else if g_t * (*tl - t) < 0. {
93 *tu = *tl;
94 *tl = t;
95 false
96 } else {
97 true
99 }
100 }
101
102 pub fn cubic_minimizer(
103 ta: &Floating,
104 tb: &Floating,
105 f_ta: &Floating,
106 f_tb: &Floating,
107 g_ta: &Floating,
108 g_tb: &Floating,
109 ) -> Floating {
110 let s = 3. * (f_tb - f_ta) / (tb - ta);
113 let z = s - g_ta - g_tb;
114 let w = (z.powi(2) - g_ta * g_tb).sqrt();
115 ta + ((tb - ta) * ((w - g_ta - z) / (g_tb - g_ta + 2. * w)))
117 }
118
119 pub fn quadratic_minimzer_1(
120 ta: &Floating,
121 tb: &Floating,
122 f_ta: &Floating,
123 f_tb: &Floating,
124 g_ta: &Floating,
125 ) -> Floating {
126 let lin_int = (f_ta - f_tb) / (ta - tb);
128
129 ta - 0.5 * ((ta - tb) * g_ta / (g_ta - lin_int))
130 }
131
132 pub fn quadratic_minimizer_2(
133 ta: &Floating,
134 tb: &Floating,
135 g_ta: &Floating,
136 g_tb: &Floating,
137 ) -> Floating {
138 trace!(target: "morethuente line search", "Quadratic minimizer 2: ta: {}, tb: {}, g_ta: {}, g_tb: {}", ta, tb, g_ta, g_tb);
140 ta - g_ta * ((ta - tb) / (g_ta - g_tb))
141 }
142
143 pub fn phi(eval: &FuncEvalMultivariate, direction_k: &DVector<Floating>) -> FuncEvalUnivariate {
144 let image = eval.f();
146 let derivative = eval.g().dot(direction_k);
147 FuncEvalUnivariate::new(*image, derivative)
148 }
149 pub fn psi(
150 &self,
151 phi_0: &FuncEvalUnivariate,
152 phi_t: &FuncEvalUnivariate,
153 t: &Floating,
154 ) -> FuncEvalUnivariate {
155 let image = phi_t.f() - phi_0.f() - self.c1 * t * phi_0.g();
156 let derivative = phi_t.g() - self.c1 * phi_0.g();
157 FuncEvalUnivariate::new(image, derivative)
158 }
159}
160
161impl SufficientDecreaseCondition for MoreThuenteB {
162 fn c1(&self) -> Floating {
163 self.c1
164 }
165}
166
167impl CurvatureCondition for MoreThuenteB {
168 fn c2(&self) -> Floating {
169 self.c2
170 }
171}
172
173impl LineSearch for MoreThuenteB {
174 fn compute_step_len(
175 &mut self,
176 x_k: &DVector<Floating>, eval_x_k: &FuncEvalMultivariate, direction_k: &DVector<Floating>, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate, max_iter: usize, ) -> Floating {
182 let mut use_modified_updating = false;
183 let mut interval_converged = false;
184
185 let t_max_candidate = direction_k
186 .iter()
187 .enumerate()
188 .map(|(i, d)| {
189 if *d > 0.0 {
190 (self.upper_bound[i] - x_k[i]) / d
191 } else if *d < 0.0 {
192 (self.lower_bound[i] - x_k[i]) / d
193 } else {
194 Floating::INFINITY
195 }
196 })
197 .fold(Floating::INFINITY, |acc, x| x.min(acc));
198
199 debug!(target: "morethuente line search", "t_max_candidate: {}",t_max_candidate);
200
201 self.t_max = self.t_max.min(t_max_candidate);
202
203 let mut t = 1.0f64.max(self.t_min).min(self.t_max);
204 let mut tl = self.t_min;
205 let mut tu = self.t_max;
206 let eval_0 = eval_x_k;
207
208 for i in 0..max_iter {
209 let eval_t = oracle(&(x_k + t * direction_k));
210 if self.strong_wolfe_conditions_with_directional_derivative(
212 eval_0.f(),
213 eval_t.f(),
214 eval_0.g(),
215 eval_t.g(),
216 &t,
217 direction_k,
218 ) {
219 trace!("Strong Wolfe conditions satisfied at iteration {}", i);
220 return t;
221 } else if interval_converged {
222 trace!("Interval converged at iteration {}", i);
223 return t;
224 } else if t == tl {
226 trace!("t is at the minimum value at iteration {}", i);
227 return t;
228 } else if t == tu {
230 trace!("t is at the maximum value at iteration {}", i);
231 return t;
232 }
233
234 let phi_t = Self::phi(&eval_t, direction_k);
235 let phi_0 = Self::phi(eval_0, direction_k);
236
237 let psi_t = self.psi(&phi_0, &phi_t, &t);
238
239 if !use_modified_updating && psi_t.f() <= &0. && phi_t.g() > &0. {
240 use_modified_updating = true;
242 }
243
244 let eval_tl = oracle(&(x_k + tl * direction_k));
245 let phi_tl = Self::phi(&eval_tl, direction_k);
246
247 let (f_tl, g_tl, f_t, g_t) = if use_modified_updating {
249 (*phi_tl.f(), *phi_tl.g(), phi_t.f(), phi_t.g())
250 } else {
251 let psi_tl = self.psi(&phi_0, &phi_tl, &tl);
252 (*psi_tl.f(), *psi_tl.g(), psi_t.f(), psi_t.g())
253 };
254
255 if f_t > &f_tl {
258 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
259 let tq = Self::quadratic_minimzer_1(&tl, &t, &f_tl, f_t, &g_tl);
260
261 trace!(target: "morethuente line search", "Case 1: tc: {}, tq: {}", tc, tq);
262
263 if (tc - tl).abs() < (tq - tl).abs() {
264 t = tc;
265 } else {
266 t = 0.5 * (tq + tc); }
268 }
269 else if g_t * g_tl < 0. {
271 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
272 let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
273
274 trace!(target: "morethuente line search", "Case 2: tc: {}, ts: {}", tc, ts);
275
276 if (tc - t).abs() >= (ts - t).abs() {
277 t = tc;
278 } else {
279 t = ts;
280 }
281 }
282 else if g_t.abs() <= g_tl.abs() {
284 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
285 let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
286
287 trace!(target: "morethuente line search", "Case 3: tc: {}, ts: {}", tc, ts);
288
289 let t_plus = if (tc - t).abs() < (ts - t).abs() {
290 tc
291 } else {
292 ts
293 };
294 if t > tl {
295 t = t_plus.min(t + self.delta * (tu - t));
296 } else {
297 t = t_plus.max(t + self.delta * (tu - t));
298 }
299 }
300 else {
302 let (f_tu, g_tu) = {
303 let eval_tu = oracle(&(x_k + tu * direction_k));
304 let phi_tu = Self::phi(&eval_tu, direction_k);
305 if use_modified_updating {
306 (*phi_tu.f(), *phi_tu.g())
307 } else {
308 let psi_tu = self.psi(&phi_0, &phi_tu, &tu);
309 (*psi_tu.f(), *psi_tu.g())
310 }
311 };
312 trace!(target: "morethuente line search", "Case 4: f_tu: {}, g_tu: {}", f_tu, g_tu);
313 t = Self::cubic_minimizer(&tu, &t, f_t, &f_tu, g_t, &g_tu);
314 }
315
316 t = t.max(self.t_min).min(self.t_max);
318
319 interval_converged = Self::update_interval(&f_tl, f_t, g_t, &mut tl, t, &mut tu)
321 }
322 trace!("Line search did not converge in {} iterations", max_iter);
323 t
324 }
325}
326
327#[cfg(test)]
328mod morethuente_test {
329 use super::*;
330 #[test]
331 pub fn test_phi() {
332 std::env::set_var("RUST_LOG", "debug");
333
334 let _ = Tracer::default()
338 .with_stdout_layer(Some(LogFormat::Normal))
339 .build();
340 let gamma = 90.0;
341 let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
342 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
343 let g = DVector::from(vec![x[0], gamma * x[1]]);
344 (f, g).into()
345 };
346 let max_iter = 10000;
347 let mut k = 1;
349 let mut iterate = DVector::from(vec![180.0, 152.0]);
350 let mut ls = MoreThuenteB::new(2);
351 let gradient_tol = 1e-12;
353
354 while max_iter > k {
355 trace!("Iterate: {:?}", iterate);
356 let eval = f_and_g(&iterate);
357 if eval.g().dot(eval.g()) < gradient_tol {
359 trace!("Gradient norm is lower than tolerance. Convergence!.");
360 break;
361 }
362 let direction = -eval.g();
363 let t = <MoreThuenteB as LineSearch>::compute_step_len(
364 &mut ls,
365 &iterate,
366 &eval,
367 &direction,
368 &mut f_and_g,
369 max_iter,
370 );
371 iterate += t * direction;
373 k += 1;
374 }
375 println!("Iterate: {:?}", iterate);
376 println!("Function eval: {:?}", f_and_g(&iterate));
377 assert!((iterate[0] - 0.0).abs() < 1e-6);
378 trace!("Test took {} iterations", k);
379 }
380}