optimization_solvers/line_search/
backtracking.rs1use super::*;
3pub struct BackTracking {
4 c1: Floating, beta: Floating, }
7impl BackTracking {
8 pub fn new(c1: Floating, beta: Floating) -> Self {
9 BackTracking { c1, beta }
10 }
11}
12
13impl SufficientDecreaseCondition for BackTracking {
14 fn c1(&self) -> Floating {
15 self.c1
16 }
17}
18
19impl LineSearch for BackTracking {
20 fn compute_step_len(
21 &mut self,
22 x_k: &DVector<Floating>,
23 eval_x_k: &FuncEvalMultivariate,
24 direction_k: &DVector<Floating>,
25 oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
26 max_iter: usize,
27 ) -> Floating {
28 let mut t = 1.0;
29 let mut i = 0;
30
31 while max_iter > i {
32 let x_kp1 = x_k + t * direction_k;
33
34 let eval_kp1 = oracle(&x_kp1);
35
36 if eval_kp1.f().is_nan() || eval_kp1.f().is_infinite() {
38 trace!(target: "backtracking line search", "Step size too big: next iterate is out of domain. Decreasing step by beta ({:?})", x_kp1);
39 t *= self.beta;
40 continue;
41 }
42
43 if self.sufficient_decrease(eval_x_k.f(), eval_kp1.f(), eval_x_k.g(), &t, direction_k) {
45 trace!(target: "backtracking line search", "Sufficient decrease condition met. Exiting with step size: {:?}", t);
46 return t;
47 }
48
49 t *= self.beta;
51 i += 1;
52 }
53 trace!(target: "backtracking line search", "Max iter reached. Early stopping.");
54 t
55 }
59}
60
61#[cfg(test)]
62mod backtracking_tests {
63 use super::*;
64
65 #[test]
66 pub fn test_backtracking() {
67 std::env::set_var("RUST_LOG", "info");
68
69 let _ = Tracer::default()
73 .with_stdout_layer(Some(LogFormat::Normal))
74 .build();
75 let gamma = 90.0;
76 let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
77 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
78 let g = DVector::from(vec![x[0], gamma * x[1]]);
79 (f, g).into()
80 };
81 let max_iter = 1000;
82 let mut k = 1;
84 let mut iterate = DVector::from(vec![180.0, 152.0]);
85 let mut backtracking = BackTracking::new(1e-4, 0.5);
86 let gradient_tol = 1e-12;
87
88 while max_iter > k {
89 trace!("Iterate: {:?}", iterate);
90 let eval = f_and_g(&iterate);
91 if eval.g().dot(eval.g()) < gradient_tol {
93 trace!("Gradient norm is lower than tolerance. Convergence!.");
94 break;
95 }
96 let direction = -eval.g();
97 let t = <BackTracking as LineSearch>::compute_step_len(
98 &mut backtracking,
99 &iterate,
100 &eval,
101 &direction,
102 &mut f_and_g,
103 max_iter,
104 );
105 iterate += t * direction;
107 k += 1;
108 }
109 println!("Iterate: {:?}", iterate);
110 println!("Function eval: {:?}", f_and_g(&iterate));
111 assert!((iterate[0] - 0.0).abs() < 1e-6);
112 info!("Test took {} iterations", k);
113 }
114}