optimization_solvers/line_search/
morethuente.rs

1use super::*;
2
3// Implementation from https://www.ii.uib.no/~lennart/drgrad/More1994.pdf (More, Thuente 1994) and https://bayanbox.ir/view/1460469776013846613/Sun-Yuan-Optimization-theory.pdf (Sun, Yuan 2006)
4
5#[derive(Debug, Clone, derive_getters::Getters)]
6pub struct MoreThuente {
7    c1: Floating, //mu (armijo sensitivity)
8    c2: Floating, //eta (curvature sensitivity)
9    t_min: Floating,
10    t_max: Floating,
11    delta_min: Floating,
12    delta: Floating,
13    delta_max: Floating,
14}
15
16impl Default for MoreThuente {
17    fn default() -> Self {
18        MoreThuente {
19            c1: 1e-4,
20            c2: 0.9,
21            t_min: 0.0,
22            t_max: Floating::INFINITY,
23            delta_min: 0.58333333,
24            delta: 0.66,
25            delta_max: 1.1,
26        }
27    }
28}
29
30impl MoreThuente {
31    pub fn with_deltas(
32        mut self,
33        delta_min: Floating,
34        delta: Floating,
35        delta_max: Floating,
36    ) -> Self {
37        self.delta_min = delta_min;
38        self.delta = delta;
39        self.delta_max = delta_max;
40        self
41    }
42    pub fn with_t_min(mut self, t_min: Floating) -> Self {
43        self.t_min = t_min;
44        self
45    }
46    pub fn with_t_max(mut self, t_max: Floating) -> Self {
47        self.t_max = t_max;
48        self
49    }
50    pub fn with_c1(mut self, c1: Floating) -> Self {
51        assert!(c1 > 0.0, "c1 must be positive");
52        assert!(c1 < self.c2, "c1 must be less than c2");
53        self.c1 = c1;
54        self
55    }
56    pub fn with_c2(mut self, c2: Floating) -> Self {
57        assert!(c2 > 0.0, "c2 must be positive");
58        assert!(c2 < 1.0, "c2 must be less than 1");
59        assert!(c2 > self.c1, "c2 must be greater than c1");
60        self.c2 = c2;
61        self
62    }
63
64    pub fn update_interval(
65        f_tl: &Floating,
66        f_t: &Floating,
67        g_t: &Floating,
68        tl: &mut Floating,
69        t: Floating,
70        tu: &mut Floating,
71    ) -> bool {
72        // case U1 in Update Algorithm and Case a in Modified Update Algorithm
73        if f_t > f_tl {
74            *tu = t;
75            false
76        }
77        // case U2 in Update Algorithm and Case b in Modified Update Algorithm
78        else if g_t * (*tl - t) > 0. {
79            *tl = t;
80            false
81        }
82        // case U3 in Update Algorithm and Case c in Modified Update Algorithm
83        else if g_t * (*tl - t) < 0. {
84            *tu = *tl;
85            *tl = t;
86            false
87        } else {
88            //interval converged to a point
89            true
90        }
91    }
92
93    pub fn cubic_minimizer(
94        ta: &Floating,
95        tb: &Floating,
96        f_ta: &Floating,
97        f_tb: &Floating,
98        g_ta: &Floating,
99        g_tb: &Floating,
100    ) -> Floating {
101        // Equation 2.4.51 [Sun, Yuan 2006]
102
103        let s = 3. * (f_tb - f_ta) / (tb - ta);
104        let z = s - g_ta - g_tb;
105        let w = (z.powi(2) - g_ta * g_tb).sqrt();
106        // Equation 2.4.56 [Sun, Yuan 2006]
107        ta + ((tb - ta) * ((w - g_ta - z) / (g_tb - g_ta + 2. * w)))
108    }
109
110    pub fn quadratic_minimzer_1(
111        ta: &Floating,
112        tb: &Floating,
113        f_ta: &Floating,
114        f_tb: &Floating,
115        g_ta: &Floating,
116    ) -> Floating {
117        // Equation 2.4.2 [Sun, Yuan 2006]
118        let lin_int = (f_ta - f_tb) / (ta - tb);
119
120        ta - 0.5 * ((ta - tb) * g_ta / (g_ta - lin_int))
121    }
122
123    pub fn quadratic_minimizer_2(
124        ta: &Floating,
125        tb: &Floating,
126        g_ta: &Floating,
127        g_tb: &Floating,
128    ) -> Floating {
129        // Equation 2.4.5 [Sun, Yuan 2006]
130        trace!(target: "morethuente line search", "Quadratic minimizer 2: ta: {}, tb: {}, g_ta: {}, g_tb: {}", ta, tb, g_ta, g_tb);
131        ta - g_ta * ((ta - tb) / (g_ta - g_tb))
132    }
133
134    pub fn phi(eval: &FuncEvalMultivariate, direction_k: &DVector<Floating>) -> FuncEvalUnivariate {
135        // recall that phi(t) = f(x + t * direction). Thus via chain rule nabla phi(t) = <nabla f(x + t * direction), direction> (i.e. the directional derivative of f at x + t * direction in the direction of direction)
136        let image = eval.f();
137        let derivative = eval.g().dot(direction_k);
138        FuncEvalUnivariate::new(*image, derivative)
139    }
140    pub fn psi(
141        &self,
142        phi_0: &FuncEvalUnivariate,
143        phi_t: &FuncEvalUnivariate,
144        t: &Floating,
145    ) -> FuncEvalUnivariate {
146        let image = phi_t.f() - phi_0.f() - self.c1 * t * phi_0.g();
147        let derivative = phi_t.g() - self.c1 * phi_0.g();
148        FuncEvalUnivariate::new(image, derivative)
149    }
150}
151
152impl SufficientDecreaseCondition for MoreThuente {
153    fn c1(&self) -> Floating {
154        self.c1
155    }
156}
157
158impl CurvatureCondition for MoreThuente {
159    fn c2(&self) -> Floating {
160        self.c2
161    }
162}
163
164impl LineSearch for MoreThuente {
165    fn compute_step_len(
166        &mut self,
167        x_k: &DVector<Floating>,         // current iterate
168        eval_x_k: &FuncEvalMultivariate, // function evaluation at x_k
169        direction_k: &DVector<Floating>, // direction of the ray along which we are going to search
170        oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate, // oracle
171        max_iter: usize, // maximum number of iterations during line search (if direction update is costly, set this high to perform more exact line search)
172    ) -> Floating {
173        let mut use_modified_updating = false;
174        let mut interval_converged = false;
175
176        let mut t = 1.0f64.max(self.t_min).min(self.t_max);
177        let mut tl = self.t_min;
178        let mut tu = self.t_max;
179        let eval_0 = eval_x_k;
180
181        for i in 0..max_iter {
182            let eval_t = oracle(&(x_k + t * direction_k));
183            // Check for convergence
184            if self.strong_wolfe_conditions_with_directional_derivative(
185                eval_0.f(),
186                eval_t.f(),
187                eval_0.g(),
188                eval_t.g(),
189                &t,
190                direction_k,
191            ) {
192                trace!("Strong Wolfe conditions satisfied at iteration {}", i);
193                return t;
194            } else if interval_converged {
195                trace!("Interval converged at iteration {}", i);
196                return t;
197            // } else if t == self.t_min {
198            } else if t == tl {
199                trace!("t is at the minimum value at iteration {}", i);
200                return t;
201            // } else if t == self.t_max {
202            } else if t == tu {
203                trace!("t is at the maximum value at iteration {}", i);
204                return t;
205            }
206
207            let phi_t = Self::phi(&eval_t, direction_k);
208            let phi_0 = Self::phi(eval_0, direction_k);
209
210            let psi_t = self.psi(&phi_0, &phi_t, &t);
211
212            if !use_modified_updating && psi_t.f() <= &0. && phi_t.g() > &0. {
213                //paper suggests that when the conidition is verified, you start using the modified updating and never go back
214                use_modified_updating = true;
215            }
216
217            let eval_tl = oracle(&(x_k + tl * direction_k));
218            let phi_tl = Self::phi(&eval_tl, direction_k);
219
220            // using auxiliary or modified evaluation according to the flag
221            let (f_tl, g_tl, f_t, g_t) = if use_modified_updating {
222                (*phi_tl.f(), *phi_tl.g(), phi_t.f(), phi_t.g())
223            } else {
224                let psi_tl = self.psi(&phi_0, &phi_tl, &tl);
225                (*psi_tl.f(), *psi_tl.g(), psi_t.f(), psi_t.g())
226            };
227
228            //Trial value selection (section 4 of the paper)
229            //case 1
230            if f_t > &f_tl {
231                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
232                let tq = Self::quadratic_minimzer_1(&tl, &t, &f_tl, f_t, &g_tl);
233
234                trace!(target: "morethuente line search", "Case 1: tc: {}, tq: {}", tc, tq);
235
236                if (tc - tl).abs() < (tq - tl).abs() {
237                    t = tc;
238                } else {
239                    t = 0.5 * (tq + tc); //midpoint
240                }
241            }
242            //case 2 (here f_t <= &f_tl)
243            else if g_t * g_tl < 0. {
244                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
245                let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
246
247                trace!(target: "morethuente line search", "Case 2: tc: {}, ts: {}", tc, ts);
248
249                if (tc - t).abs() >= (ts - t).abs() {
250                    t = tc;
251                } else {
252                    t = ts;
253                }
254            }
255            //case 3 (here f_t <= &f_tl, g_t * g_tl >= 0.)
256            else if g_t.abs() <= g_tl.abs() {
257                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
258                let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
259
260                trace!(target: "morethuente line search", "Case 3: tc: {}, ts: {}", tc, ts);
261
262                let t_plus = if (tc - t).abs() < (ts - t).abs() {
263                    tc
264                } else {
265                    ts
266                };
267                if t > tl {
268                    t = t_plus.min(t + self.delta * (tu - t));
269                } else {
270                    t = t_plus.max(t + self.delta * (tu - t));
271                }
272            }
273            // case 4 (here f_t <= &f_tl, g_t * g_tl >= 0., g_t.abs() > g_tl.abs())
274            else {
275                let (f_tu, g_tu) = {
276                    let eval_tu = oracle(&(x_k + tu * direction_k));
277                    let phi_tu = Self::phi(&eval_tu, direction_k);
278                    if use_modified_updating {
279                        (*phi_tu.f(), *phi_tu.g())
280                    } else {
281                        let psi_tu = self.psi(&phi_0, &phi_tu, &tu);
282                        (*psi_tu.f(), *psi_tu.g())
283                    }
284                };
285                trace!(target: "morethuente line search", "Case 4: f_tu: {}, g_tu: {}", f_tu, g_tu);
286                t = Self::cubic_minimizer(&tu, &t, f_t, &f_tu, g_t, &g_tu);
287            }
288
289            //clamping t to the max and min values
290            t = t.max(self.t_min).min(self.t_max);
291
292            //Updating algorithm (section 2 and 3 of the paper)
293            interval_converged = Self::update_interval(&f_tl, f_t, g_t, &mut tl, t, &mut tu)
294        }
295        trace!("Line search did not converge in {} iterations", max_iter);
296        t
297    }
298}
299
300#[cfg(test)]
301mod morethuente_test {
302    use super::*;
303    #[test]
304    pub fn test_phi() {
305        std::env::set_var("RUST_LOG", "debug");
306
307        // in this example the objecive function has constant hessian, thus its condition number doesn't change on different points.
308        // Recall that in gradient descent method, the upper bound of the log error is positive function of the upper bound of condition number of the hessian (ratio between max and min eigenvalue).
309        // This causes poor performance when the hessian is ill conditioned
310        let _ = Tracer::default()
311            .with_stdout_layer(Some(LogFormat::Normal))
312            .build();
313        let gamma = 90.0;
314        let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
315            let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
316            let g = DVector::from(vec![x[0], gamma * x[1]]);
317            (f, g).into()
318        };
319        let max_iter = 10000;
320        //here we define a rough gradient descent method that uses ls line search
321        let mut k = 1;
322        let mut iterate = DVector::from(vec![180.0, 152.0]);
323        let mut ls = MoreThuente::default();
324        // let ls = BackTracking::new(1e-4, 0.5);
325        let gradient_tol = 1e-12;
326
327        while max_iter > k {
328            trace!("Iterate: {:?}", iterate);
329            let eval = f_and_g(&iterate);
330            // we do a rough check on the squared norm of the gradient to verify convergence
331            if eval.g().dot(eval.g()) < gradient_tol {
332                trace!("Gradient norm is lower than tolerance. Convergence!.");
333                break;
334            }
335            let direction = -eval.g();
336            let t = <MoreThuente as LineSearch>::compute_step_len(
337                &mut ls,
338                &iterate,
339                &eval,
340                &direction,
341                &mut f_and_g,
342                max_iter,
343            );
344            //we perform the update
345            iterate += t * direction;
346            k += 1;
347        }
348        println!("Iterate: {:?}", iterate);
349        println!("Function eval: {:?}", f_and_g(&iterate));
350        assert!((iterate[0] - 0.0).abs() < 1e-6);
351        trace!("Test took {} iterations", k);
352    }
353}