1use super::*;
2
3#[derive(Debug, Clone, derive_getters::Getters)]
6pub struct MoreThuente {
7 c1: Floating, c2: Floating, t_min: Floating,
10 t_max: Floating,
11 delta_min: Floating,
12 delta: Floating,
13 delta_max: Floating,
14}
15
16impl Default for MoreThuente {
17 fn default() -> Self {
18 MoreThuente {
19 c1: 1e-4,
20 c2: 0.9,
21 t_min: 0.0,
22 t_max: Floating::INFINITY,
23 delta_min: 0.58333333,
24 delta: 0.66,
25 delta_max: 1.1,
26 }
27 }
28}
29
30impl MoreThuente {
31 pub fn with_deltas(
32 mut self,
33 delta_min: Floating,
34 delta: Floating,
35 delta_max: Floating,
36 ) -> Self {
37 self.delta_min = delta_min;
38 self.delta = delta;
39 self.delta_max = delta_max;
40 self
41 }
42 pub fn with_t_min(mut self, t_min: Floating) -> Self {
43 self.t_min = t_min;
44 self
45 }
46 pub fn with_t_max(mut self, t_max: Floating) -> Self {
47 self.t_max = t_max;
48 self
49 }
50 pub fn with_c1(mut self, c1: Floating) -> Self {
51 assert!(c1 > 0.0, "c1 must be positive");
52 assert!(c1 < self.c2, "c1 must be less than c2");
53 self.c1 = c1;
54 self
55 }
56 pub fn with_c2(mut self, c2: Floating) -> Self {
57 assert!(c2 > 0.0, "c2 must be positive");
58 assert!(c2 < 1.0, "c2 must be less than 1");
59 assert!(c2 > self.c1, "c2 must be greater than c1");
60 self.c2 = c2;
61 self
62 }
63
64 pub fn update_interval(
65 f_tl: &Floating,
66 f_t: &Floating,
67 g_t: &Floating,
68 tl: &mut Floating,
69 t: Floating,
70 tu: &mut Floating,
71 ) -> bool {
72 if f_t > f_tl {
74 *tu = t;
75 false
76 }
77 else if g_t * (*tl - t) > 0. {
79 *tl = t;
80 false
81 }
82 else if g_t * (*tl - t) < 0. {
84 *tu = *tl;
85 *tl = t;
86 false
87 } else {
88 true
90 }
91 }
92
93 pub fn cubic_minimizer(
94 ta: &Floating,
95 tb: &Floating,
96 f_ta: &Floating,
97 f_tb: &Floating,
98 g_ta: &Floating,
99 g_tb: &Floating,
100 ) -> Floating {
101 let s = 3. * (f_tb - f_ta) / (tb - ta);
104 let z = s - g_ta - g_tb;
105 let w = (z.powi(2) - g_ta * g_tb).sqrt();
106 ta + ((tb - ta) * ((w - g_ta - z) / (g_tb - g_ta + 2. * w)))
108 }
109
110 pub fn quadratic_minimzer_1(
111 ta: &Floating,
112 tb: &Floating,
113 f_ta: &Floating,
114 f_tb: &Floating,
115 g_ta: &Floating,
116 ) -> Floating {
117 let lin_int = (f_ta - f_tb) / (ta - tb);
119
120 ta - 0.5 * ((ta - tb) * g_ta / (g_ta - lin_int))
121 }
122
123 pub fn quadratic_minimizer_2(
124 ta: &Floating,
125 tb: &Floating,
126 g_ta: &Floating,
127 g_tb: &Floating,
128 ) -> Floating {
129 trace!(target: "morethuente line search", "Quadratic minimizer 2: ta: {}, tb: {}, g_ta: {}, g_tb: {}", ta, tb, g_ta, g_tb);
131 ta - g_ta * ((ta - tb) / (g_ta - g_tb))
132 }
133
134 pub fn phi(eval: &FuncEvalMultivariate, direction_k: &DVector<Floating>) -> FuncEvalUnivariate {
135 let image = eval.f();
137 let derivative = eval.g().dot(direction_k);
138 FuncEvalUnivariate::new(*image, derivative)
139 }
140 pub fn psi(
141 &self,
142 phi_0: &FuncEvalUnivariate,
143 phi_t: &FuncEvalUnivariate,
144 t: &Floating,
145 ) -> FuncEvalUnivariate {
146 let image = phi_t.f() - phi_0.f() - self.c1 * t * phi_0.g();
147 let derivative = phi_t.g() - self.c1 * phi_0.g();
148 FuncEvalUnivariate::new(image, derivative)
149 }
150}
151
152impl SufficientDecreaseCondition for MoreThuente {
153 fn c1(&self) -> Floating {
154 self.c1
155 }
156}
157
158impl CurvatureCondition for MoreThuente {
159 fn c2(&self) -> Floating {
160 self.c2
161 }
162}
163
164impl LineSearch for MoreThuente {
165 fn compute_step_len(
166 &mut self,
167 x_k: &DVector<Floating>, eval_x_k: &FuncEvalMultivariate, direction_k: &DVector<Floating>, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate, max_iter: usize, ) -> Floating {
173 let mut use_modified_updating = false;
174 let mut interval_converged = false;
175
176 let mut t = 1.0f64.max(self.t_min).min(self.t_max);
177 let mut tl = self.t_min;
178 let mut tu = self.t_max;
179 let eval_0 = eval_x_k;
180
181 for i in 0..max_iter {
182 let eval_t = oracle(&(x_k + t * direction_k));
183 if self.strong_wolfe_conditions_with_directional_derivative(
185 eval_0.f(),
186 eval_t.f(),
187 eval_0.g(),
188 eval_t.g(),
189 &t,
190 direction_k,
191 ) {
192 trace!("Strong Wolfe conditions satisfied at iteration {}", i);
193 return t;
194 } else if interval_converged {
195 trace!("Interval converged at iteration {}", i);
196 return t;
197 } else if t == tl {
199 trace!("t is at the minimum value at iteration {}", i);
200 return t;
201 } else if t == tu {
203 trace!("t is at the maximum value at iteration {}", i);
204 return t;
205 }
206
207 let phi_t = Self::phi(&eval_t, direction_k);
208 let phi_0 = Self::phi(eval_0, direction_k);
209
210 let psi_t = self.psi(&phi_0, &phi_t, &t);
211
212 if !use_modified_updating && psi_t.f() <= &0. && phi_t.g() > &0. {
213 use_modified_updating = true;
215 }
216
217 let eval_tl = oracle(&(x_k + tl * direction_k));
218 let phi_tl = Self::phi(&eval_tl, direction_k);
219
220 let (f_tl, g_tl, f_t, g_t) = if use_modified_updating {
222 (*phi_tl.f(), *phi_tl.g(), phi_t.f(), phi_t.g())
223 } else {
224 let psi_tl = self.psi(&phi_0, &phi_tl, &tl);
225 (*psi_tl.f(), *psi_tl.g(), psi_t.f(), psi_t.g())
226 };
227
228 if f_t > &f_tl {
231 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
232 let tq = Self::quadratic_minimzer_1(&tl, &t, &f_tl, f_t, &g_tl);
233
234 trace!(target: "morethuente line search", "Case 1: tc: {}, tq: {}", tc, tq);
235
236 if (tc - tl).abs() < (tq - tl).abs() {
237 t = tc;
238 } else {
239 t = 0.5 * (tq + tc); }
241 }
242 else if g_t * g_tl < 0. {
244 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
245 let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
246
247 trace!(target: "morethuente line search", "Case 2: tc: {}, ts: {}", tc, ts);
248
249 if (tc - t).abs() >= (ts - t).abs() {
250 t = tc;
251 } else {
252 t = ts;
253 }
254 }
255 else if g_t.abs() <= g_tl.abs() {
257 let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
258 let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
259
260 trace!(target: "morethuente line search", "Case 3: tc: {}, ts: {}", tc, ts);
261
262 let t_plus = if (tc - t).abs() < (ts - t).abs() {
263 tc
264 } else {
265 ts
266 };
267 if t > tl {
268 t = t_plus.min(t + self.delta * (tu - t));
269 } else {
270 t = t_plus.max(t + self.delta * (tu - t));
271 }
272 }
273 else {
275 let (f_tu, g_tu) = {
276 let eval_tu = oracle(&(x_k + tu * direction_k));
277 let phi_tu = Self::phi(&eval_tu, direction_k);
278 if use_modified_updating {
279 (*phi_tu.f(), *phi_tu.g())
280 } else {
281 let psi_tu = self.psi(&phi_0, &phi_tu, &tu);
282 (*psi_tu.f(), *psi_tu.g())
283 }
284 };
285 trace!(target: "morethuente line search", "Case 4: f_tu: {}, g_tu: {}", f_tu, g_tu);
286 t = Self::cubic_minimizer(&tu, &t, f_t, &f_tu, g_t, &g_tu);
287 }
288
289 t = t.max(self.t_min).min(self.t_max);
291
292 interval_converged = Self::update_interval(&f_tl, f_t, g_t, &mut tl, t, &mut tu)
294 }
295 trace!("Line search did not converge in {} iterations", max_iter);
296 t
297 }
298}
299
300#[cfg(test)]
301mod morethuente_test {
302 use super::*;
303 #[test]
304 pub fn test_phi() {
305 std::env::set_var("RUST_LOG", "debug");
306
307 let _ = Tracer::default()
311 .with_stdout_layer(Some(LogFormat::Normal))
312 .build();
313 let gamma = 90.0;
314 let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
315 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
316 let g = DVector::from(vec![x[0], gamma * x[1]]);
317 (f, g).into()
318 };
319 let max_iter = 10000;
320 let mut k = 1;
322 let mut iterate = DVector::from(vec![180.0, 152.0]);
323 let mut ls = MoreThuente::default();
324 let gradient_tol = 1e-12;
326
327 while max_iter > k {
328 trace!("Iterate: {:?}", iterate);
329 let eval = f_and_g(&iterate);
330 if eval.g().dot(eval.g()) < gradient_tol {
332 trace!("Gradient norm is lower than tolerance. Convergence!.");
333 break;
334 }
335 let direction = -eval.g();
336 let t = <MoreThuente as LineSearch>::compute_step_len(
337 &mut ls,
338 &iterate,
339 &eval,
340 &direction,
341 &mut f_and_g,
342 max_iter,
343 );
344 iterate += t * direction;
346 k += 1;
347 }
348 println!("Iterate: {:?}", iterate);
349 println!("Function eval: {:?}", f_and_g(&iterate));
350 assert!((iterate[0] - 0.0).abs() < 1e-6);
351 trace!("Test took {} iterations", k);
352 }
353}