echidna_optim/
line_search.rs1use num_traits::Float;
2
3use crate::convergence::dot;
4use crate::objective::Objective;
5
6#[derive(Debug, Clone)]
8pub struct ArmijoParams<F> {
9 pub c: F,
11 pub rho: F,
13 pub alpha_init: F,
15 pub alpha_min: F,
17}
18
19impl Default for ArmijoParams<f64> {
20 fn default() -> Self {
21 ArmijoParams {
22 c: 1e-4,
23 rho: 0.5,
24 alpha_init: 1.0,
25 alpha_min: 1e-16,
26 }
27 }
28}
29
30impl Default for ArmijoParams<f32> {
31 fn default() -> Self {
32 ArmijoParams {
33 c: 1e-4,
34 rho: 0.5,
35 alpha_init: 1.0,
36 alpha_min: 1e-8,
37 }
38 }
39}
40
41#[derive(Debug)]
43pub struct LineSearchResult<F> {
44 pub alpha: F,
46 pub value: F,
48 pub gradient: Vec<F>,
50 pub evals: usize,
52}
53
54pub fn backtracking_armijo<F: Float, O: Objective<F>>(
60 obj: &mut O,
61 x: &[F],
62 d: &[F],
63 f_x: F,
64 grad_x: &[F],
65 params: &ArmijoParams<F>,
66) -> Option<LineSearchResult<F>> {
67 let n = x.len();
68 let dg = dot(grad_x, d);
69
70 if dg >= F::zero() {
72 return None;
73 }
74
75 let mut alpha = params.alpha_init;
76 let mut x_new = vec![F::zero(); n];
77 let mut evals = 0;
78
79 loop {
80 if alpha < params.alpha_min {
81 return None;
82 }
83
84 for i in 0..n {
85 x_new[i] = x[i] + alpha * d[i];
86 }
87
88 let (f_new, g_new) = obj.eval_grad(&x_new);
89 evals += 1;
90
91 if f_new <= f_x + params.c * alpha * dg {
93 return Some(LineSearchResult {
94 alpha,
95 value: f_new,
96 gradient: g_new,
97 evals,
98 });
99 }
100
101 alpha = alpha * params.rho;
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 struct Quadratic;
111
112 impl Objective<f64> for Quadratic {
113 fn dim(&self) -> usize {
114 2
115 }
116
117 fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
118 let f = 0.5 * (x[0] * x[0] + x[1] * x[1]);
119 let g = vec![x[0], x[1]];
120 (f, g)
121 }
122 }
123
124 #[test]
125 fn armijo_quadratic_descent() {
126 let mut obj = Quadratic;
127 let x = vec![2.0, 3.0];
128 let (f_x, grad) = obj.eval_grad(&x);
129 let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
131
132 let result =
133 backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
134
135 assert!(result.alpha > 0.0);
136 assert!(result.value < f_x, "line search should decrease objective");
137 }
138
139 #[test]
140 fn armijo_full_step_on_quadratic() {
141 let mut obj = Quadratic;
142 let x = vec![2.0, 3.0];
143 let (f_x, grad) = obj.eval_grad(&x);
144 let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
145
146 let result =
147 backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
148
149 assert!(
151 (result.alpha - 1.0).abs() < 1e-12,
152 "full step should be accepted on quadratic, got alpha={}",
153 result.alpha
154 );
155 }
156
157 #[test]
158 fn armijo_non_descent_returns_none() {
159 let mut obj = Quadratic;
160 let x = vec![2.0, 3.0];
161 let (f_x, grad) = obj.eval_grad(&x);
162 let d = grad.clone();
164
165 let result = backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default());
166 assert!(result.is_none());
167 }
168}