optimization_solvers/line_search/
backtracking.rs

1// Inexact line search described in chapter 9.2 of Boyd's convex optimization book
2use super::*;
3pub struct BackTracking {
4    c1: Floating,   // recommended: [0.01, 0.3]
5    beta: Floating, // recommended: [0.1, 0.8]
6}
7impl BackTracking {
8    pub fn new(c1: Floating, beta: Floating) -> Self {
9        BackTracking { c1, beta }
10    }
11}
12
13impl SufficientDecreaseCondition for BackTracking {
14    fn c1(&self) -> Floating {
15        self.c1
16    }
17}
18
19impl LineSearch for BackTracking {
20    fn compute_step_len(
21        &mut self,
22        x_k: &DVector<Floating>,
23        eval_x_k: &FuncEvalMultivariate,
24        direction_k: &DVector<Floating>,
25        oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
26        max_iter: usize,
27    ) -> Floating {
28        let mut t = 1.0;
29        let mut i = 0;
30
31        while max_iter > i {
32            let x_kp1 = x_k + t * direction_k;
33
34            let eval_kp1 = oracle(&x_kp1);
35
36            // we check if we are out of domain
37            if eval_kp1.f().is_nan() || eval_kp1.f().is_infinite() {
38                trace!(target: "backtracking line search", "Step size too big: next iterate is out of domain. Decreasing step by beta ({:?})", x_kp1);
39                t *= self.beta;
40                continue;
41            }
42
43            // armijo condition
44            if self.sufficient_decrease(eval_x_k.f(), eval_kp1.f(), eval_x_k.g(), &t, direction_k) {
45                trace!(target: "backtracking line search", "Sufficient decrease condition met. Exiting with step size: {:?}", t);
46                return t;
47            }
48
49            //if we are here, it means that the we still didn't meet the exit condition, so we decrease the step size accordingly
50            t *= self.beta;
51            i += 1;
52        }
53        trace!(target: "backtracking line search", "Max iter reached. Early stopping.");
54        t
55        // worst case scenario: t=0 (or t>0 but t<1 because of early stopping).
56        // if t=0 we are not updating the iterate
57        // if early stop triggered, we benefit from some image reduction but it is not enough to be considered satisfactory
58    }
59}
60
61#[cfg(test)]
62mod backtracking_tests {
63    use super::*;
64
65    #[test]
66    pub fn test_backtracking() {
67        std::env::set_var("RUST_LOG", "info");
68
69        // in this example the objecive function has constant hessian, thus its condition number doesn't change on different points.
70        // Recall that in gradient descent method, the upper bound of the log error is positive function of the upper bound of condition number of the hessian (ratio between max and min eigenvalue).
71        // This causes poor performance when the hessian is ill conditioned
72        let _ = Tracer::default()
73            .with_stdout_layer(Some(LogFormat::Normal))
74            .build();
75        let gamma = 90.0;
76        let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
77            let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
78            let g = DVector::from(vec![x[0], gamma * x[1]]);
79            (f, g).into()
80        };
81        let max_iter = 1000;
82        //here we define a rough gradient descent method that uses backtracking line search
83        let mut k = 1;
84        let mut iterate = DVector::from(vec![180.0, 152.0]);
85        let mut backtracking = BackTracking::new(1e-4, 0.5);
86        let gradient_tol = 1e-12;
87
88        while max_iter > k {
89            trace!("Iterate: {:?}", iterate);
90            let eval = f_and_g(&iterate);
91            // we do a rough check on the squared norm of the gradient to verify convergence
92            if eval.g().dot(eval.g()) < gradient_tol {
93                trace!("Gradient norm is lower than tolerance. Convergence!.");
94                break;
95            }
96            let direction = -eval.g();
97            let t = <BackTracking as LineSearch>::compute_step_len(
98                &mut backtracking,
99                &iterate,
100                &eval,
101                &direction,
102                &mut f_and_g,
103                max_iter,
104            );
105            //we perform the update
106            iterate += t * direction;
107            k += 1;
108        }
109        println!("Iterate: {:?}", iterate);
110        println!("Function eval: {:?}", f_and_g(&iterate));
111        assert!((iterate[0] - 0.0).abs() < 1e-6);
112        info!("Test took {} iterations", k);
113    }
114}