optimization_solvers/line_search/
morethuente_b.rs

1use super::*;
2
3// Implementation from https://www.ii.uib.no/~lennart/drgrad/More1994.pdf (More, Thuente 1994) and https://bayanbox.ir/view/1460469776013846613/Sun-Yuan-Optimization-theory.pdf (Sun, Yuan 2006)
4// Like the unconstrained case, but here the max step length is adjsted to remain within the bounds
5#[derive(Debug, Clone, derive_getters::Getters)]
6pub struct MoreThuenteB {
7    c1: Floating, //mu (armijo sensitivity)
8    c2: Floating, //eta (curvature sensitivity)
9    t_min: Floating,
10    t_max: Floating,
11    delta_min: Floating,
12    delta: Floating,
13    delta_max: Floating,
14    lower_bound: DVector<Floating>,
15    upper_bound: DVector<Floating>,
16}
17
18impl MoreThuenteB {
19    pub fn new(n: usize) -> Self {
20        MoreThuenteB {
21            c1: 1e-4,
22            c2: 0.9,
23            t_min: 0.0,
24            t_max: Floating::INFINITY,
25            delta_min: 0.58333333,
26            delta: 0.66,
27            delta_max: 1.1,
28            lower_bound: DVector::from_element(n, -Floating::INFINITY),
29            upper_bound: DVector::from_element(n, Floating::INFINITY),
30        }
31    }
32    pub fn with_lower_bound(mut self, lower_bound: DVector<Floating>) -> Self {
33        self.lower_bound = lower_bound;
34        self
35    }
36    pub fn with_upper_bound(mut self, upper_bound: DVector<Floating>) -> Self {
37        self.upper_bound = upper_bound;
38        self
39    }
40    pub fn with_deltas(
41        mut self,
42        delta_min: Floating,
43        delta: Floating,
44        delta_max: Floating,
45    ) -> Self {
46        self.delta_min = delta_min;
47        self.delta = delta;
48        self.delta_max = delta_max;
49        self
50    }
51    pub fn with_t_min(mut self, t_min: Floating) -> Self {
52        self.t_min = t_min;
53        self
54    }
55    pub fn with_t_max(mut self, t_max: Floating) -> Self {
56        self.t_max = t_max;
57        self
58    }
59    pub fn with_c1(mut self, c1: Floating) -> Self {
60        assert!(c1 > 0.0, "c1 must be positive");
61        assert!(c1 < self.c2, "c1 must be less than c2");
62        self.c1 = c1;
63        self
64    }
65    pub fn with_c2(mut self, c2: Floating) -> Self {
66        assert!(c2 > 0.0, "c2 must be positive");
67        assert!(c2 < 1.0, "c2 must be less than 1");
68        assert!(c2 > self.c1, "c2 must be greater than c1");
69        self.c2 = c2;
70        self
71    }
72
73    pub fn update_interval(
74        f_tl: &Floating,
75        f_t: &Floating,
76        g_t: &Floating,
77        tl: &mut Floating,
78        t: Floating,
79        tu: &mut Floating,
80    ) -> bool {
81        // case U1 in Update Algorithm and Case a in Modified Update Algorithm
82        if f_t > f_tl {
83            *tu = t;
84            false
85        }
86        // case U2 in Update Algorithm and Case b in Modified Update Algorithm
87        else if g_t * (*tl - t) > 0. {
88            *tl = t;
89            false
90        }
91        // case U3 in Update Algorithm and Case c in Modified Update Algorithm
92        else if g_t * (*tl - t) < 0. {
93            *tu = *tl;
94            *tl = t;
95            false
96        } else {
97            //interval converged to a point
98            true
99        }
100    }
101
102    pub fn cubic_minimizer(
103        ta: &Floating,
104        tb: &Floating,
105        f_ta: &Floating,
106        f_tb: &Floating,
107        g_ta: &Floating,
108        g_tb: &Floating,
109    ) -> Floating {
110        // Equation 2.4.51 [Sun, Yuan 2006]
111
112        let s = 3. * (f_tb - f_ta) / (tb - ta);
113        let z = s - g_ta - g_tb;
114        let w = (z.powi(2) - g_ta * g_tb).sqrt();
115        // Equation 2.4.56 [Sun, Yuan 2006]
116        ta + ((tb - ta) * ((w - g_ta - z) / (g_tb - g_ta + 2. * w)))
117    }
118
119    pub fn quadratic_minimzer_1(
120        ta: &Floating,
121        tb: &Floating,
122        f_ta: &Floating,
123        f_tb: &Floating,
124        g_ta: &Floating,
125    ) -> Floating {
126        // Equation 2.4.2 [Sun, Yuan 2006]
127        let lin_int = (f_ta - f_tb) / (ta - tb);
128
129        ta - 0.5 * ((ta - tb) * g_ta / (g_ta - lin_int))
130    }
131
132    pub fn quadratic_minimizer_2(
133        ta: &Floating,
134        tb: &Floating,
135        g_ta: &Floating,
136        g_tb: &Floating,
137    ) -> Floating {
138        // Equation 2.4.5 [Sun, Yuan 2006]
139        trace!(target: "morethuente line search", "Quadratic minimizer 2: ta: {}, tb: {}, g_ta: {}, g_tb: {}", ta, tb, g_ta, g_tb);
140        ta - g_ta * ((ta - tb) / (g_ta - g_tb))
141    }
142
143    pub fn phi(eval: &FuncEvalMultivariate, direction_k: &DVector<Floating>) -> FuncEvalUnivariate {
144        // recall that phi(t) = f(x + t * direction). Thus via chain rule nabla phi(t) = <nabla f(x + t * direction), direction> (i.e. the directional derivative of f at x + t * direction in the direction of direction)
145        let image = eval.f();
146        let derivative = eval.g().dot(direction_k);
147        FuncEvalUnivariate::new(*image, derivative)
148    }
149    pub fn psi(
150        &self,
151        phi_0: &FuncEvalUnivariate,
152        phi_t: &FuncEvalUnivariate,
153        t: &Floating,
154    ) -> FuncEvalUnivariate {
155        let image = phi_t.f() - phi_0.f() - self.c1 * t * phi_0.g();
156        let derivative = phi_t.g() - self.c1 * phi_0.g();
157        FuncEvalUnivariate::new(image, derivative)
158    }
159}
160
161impl SufficientDecreaseCondition for MoreThuenteB {
162    fn c1(&self) -> Floating {
163        self.c1
164    }
165}
166
167impl CurvatureCondition for MoreThuenteB {
168    fn c2(&self) -> Floating {
169        self.c2
170    }
171}
172
173impl LineSearch for MoreThuenteB {
174    fn compute_step_len(
175        &mut self,
176        x_k: &DVector<Floating>,         // current iterate
177        eval_x_k: &FuncEvalMultivariate, // function evaluation at x_k
178        direction_k: &DVector<Floating>, // direction of the ray along which we are going to search
179        oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate, // oracle
180        max_iter: usize, // maximum number of iterations during line search (if direction update is costly, set this high to perform more exact line search)
181    ) -> Floating {
182        let mut use_modified_updating = false;
183        let mut interval_converged = false;
184
185        let t_max_candidate = direction_k
186            .iter()
187            .enumerate()
188            .map(|(i, d)| {
189                if *d > 0.0 {
190                    (self.upper_bound[i] - x_k[i]) / d
191                } else if *d < 0.0 {
192                    (self.lower_bound[i] - x_k[i]) / d
193                } else {
194                    Floating::INFINITY
195                }
196            })
197            .fold(Floating::INFINITY, |acc, x| x.min(acc));
198
199        debug!(target: "morethuente line search", "t_max_candidate: {}",t_max_candidate);
200
201        self.t_max = self.t_max.min(t_max_candidate);
202
203        let mut t = 1.0f64.max(self.t_min).min(self.t_max);
204        let mut tl = self.t_min;
205        let mut tu = self.t_max;
206        let eval_0 = eval_x_k;
207
208        for i in 0..max_iter {
209            let eval_t = oracle(&(x_k + t * direction_k));
210            // Check for convergence
211            if self.strong_wolfe_conditions_with_directional_derivative(
212                eval_0.f(),
213                eval_t.f(),
214                eval_0.g(),
215                eval_t.g(),
216                &t,
217                direction_k,
218            ) {
219                trace!("Strong Wolfe conditions satisfied at iteration {}", i);
220                return t;
221            } else if interval_converged {
222                trace!("Interval converged at iteration {}", i);
223                return t;
224            // } else if t == self.t_min {
225            } else if t == tl {
226                trace!("t is at the minimum value at iteration {}", i);
227                return t;
228            // } else if t == self.t_max {
229            } else if t == tu {
230                trace!("t is at the maximum value at iteration {}", i);
231                return t;
232            }
233
234            let phi_t = Self::phi(&eval_t, direction_k);
235            let phi_0 = Self::phi(eval_0, direction_k);
236
237            let psi_t = self.psi(&phi_0, &phi_t, &t);
238
239            if !use_modified_updating && psi_t.f() <= &0. && phi_t.g() > &0. {
240                //paper suggests that when the conidition is verified, you start using the modified updating and never go back
241                use_modified_updating = true;
242            }
243
244            let eval_tl = oracle(&(x_k + tl * direction_k));
245            let phi_tl = Self::phi(&eval_tl, direction_k);
246
247            // using auxiliary or modified evaluation according to the flag
248            let (f_tl, g_tl, f_t, g_t) = if use_modified_updating {
249                (*phi_tl.f(), *phi_tl.g(), phi_t.f(), phi_t.g())
250            } else {
251                let psi_tl = self.psi(&phi_0, &phi_tl, &tl);
252                (*psi_tl.f(), *psi_tl.g(), psi_t.f(), psi_t.g())
253            };
254
255            //Trial value selection (section 4 of the paper)
256            //case 1
257            if f_t > &f_tl {
258                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
259                let tq = Self::quadratic_minimzer_1(&tl, &t, &f_tl, f_t, &g_tl);
260
261                trace!(target: "morethuente line search", "Case 1: tc: {}, tq: {}", tc, tq);
262
263                if (tc - tl).abs() < (tq - tl).abs() {
264                    t = tc;
265                } else {
266                    t = 0.5 * (tq + tc); //midpoint
267                }
268            }
269            //case 2 (here f_t <= &f_tl)
270            else if g_t * g_tl < 0. {
271                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
272                let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
273
274                trace!(target: "morethuente line search", "Case 2: tc: {}, ts: {}", tc, ts);
275
276                if (tc - t).abs() >= (ts - t).abs() {
277                    t = tc;
278                } else {
279                    t = ts;
280                }
281            }
282            //case 3 (here f_t <= &f_tl, g_t * g_tl >= 0.)
283            else if g_t.abs() <= g_tl.abs() {
284                let tc = Self::cubic_minimizer(&tl, &t, &f_tl, f_t, &g_tl, g_t);
285                let ts = Self::quadratic_minimizer_2(&tl, &t, &g_tl, g_t);
286
287                trace!(target: "morethuente line search", "Case 3: tc: {}, ts: {}", tc, ts);
288
289                let t_plus = if (tc - t).abs() < (ts - t).abs() {
290                    tc
291                } else {
292                    ts
293                };
294                if t > tl {
295                    t = t_plus.min(t + self.delta * (tu - t));
296                } else {
297                    t = t_plus.max(t + self.delta * (tu - t));
298                }
299            }
300            // case 4 (here f_t <= &f_tl, g_t * g_tl >= 0., g_t.abs() > g_tl.abs())
301            else {
302                let (f_tu, g_tu) = {
303                    let eval_tu = oracle(&(x_k + tu * direction_k));
304                    let phi_tu = Self::phi(&eval_tu, direction_k);
305                    if use_modified_updating {
306                        (*phi_tu.f(), *phi_tu.g())
307                    } else {
308                        let psi_tu = self.psi(&phi_0, &phi_tu, &tu);
309                        (*psi_tu.f(), *psi_tu.g())
310                    }
311                };
312                trace!(target: "morethuente line search", "Case 4: f_tu: {}, g_tu: {}", f_tu, g_tu);
313                t = Self::cubic_minimizer(&tu, &t, f_t, &f_tu, g_t, &g_tu);
314            }
315
316            //clamping t to the max and min values
317            t = t.max(self.t_min).min(self.t_max);
318
319            //Updating algorithm (section 2 and 3 of the paper)
320            interval_converged = Self::update_interval(&f_tl, f_t, g_t, &mut tl, t, &mut tu)
321        }
322        trace!("Line search did not converge in {} iterations", max_iter);
323        t
324    }
325}
326
327#[cfg(test)]
328mod morethuente_test {
329    use super::*;
330    #[test]
331    pub fn test_phi() {
332        std::env::set_var("RUST_LOG", "debug");
333
334        // in this example the objecive function has constant hessian, thus its condition number doesn't change on different points.
335        // Recall that in gradient descent method, the upper bound of the log error is positive function of the upper bound of condition number of the hessian (ratio between max and min eigenvalue).
336        // This causes poor performance when the hessian is ill conditioned
337        let _ = Tracer::default()
338            .with_stdout_layer(Some(LogFormat::Normal))
339            .build();
340        let gamma = 90.0;
341        let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
342            let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
343            let g = DVector::from(vec![x[0], gamma * x[1]]);
344            (f, g).into()
345        };
346        let max_iter = 10000;
347        //here we define a rough gradient descent method that uses ls line search
348        let mut k = 1;
349        let mut iterate = DVector::from(vec![180.0, 152.0]);
350        let mut ls = MoreThuenteB::new(2);
351        // let ls = BackTracking::new(1e-4, 0.5);
352        let gradient_tol = 1e-12;
353
354        while max_iter > k {
355            trace!("Iterate: {:?}", iterate);
356            let eval = f_and_g(&iterate);
357            // we do a rough check on the squared norm of the gradient to verify convergence
358            if eval.g().dot(eval.g()) < gradient_tol {
359                trace!("Gradient norm is lower than tolerance. Convergence!.");
360                break;
361            }
362            let direction = -eval.g();
363            let t = <MoreThuenteB as LineSearch>::compute_step_len(
364                &mut ls,
365                &iterate,
366                &eval,
367                &direction,
368                &mut f_and_g,
369                max_iter,
370            );
371            //we perform the update
372            iterate += t * direction;
373            k += 1;
374        }
375        println!("Iterate: {:?}", iterate);
376        println!("Function eval: {:?}", f_and_g(&iterate));
377        assert!((iterate[0] - 0.0).abs() < 1e-6);
378        trace!("Test took {} iterations", k);
379    }
380}