echidna_optim/
line_search.rs1use num_traits::Float;
2
3use crate::convergence::dot;
4use crate::objective::Objective;
5
6#[derive(Debug, Clone)]
8pub struct ArmijoParams<F> {
9 pub c: F,
11 pub rho: F,
13 pub alpha_init: F,
15 pub alpha_min: F,
17}
18
19impl Default for ArmijoParams<f64> {
20 fn default() -> Self {
21 ArmijoParams {
22 c: 1e-4,
23 rho: 0.5,
24 alpha_init: 1.0,
25 alpha_min: 1e-16,
26 }
27 }
28}
29
30impl Default for ArmijoParams<f32> {
31 fn default() -> Self {
32 ArmijoParams {
33 c: 1e-4,
34 rho: 0.5,
35 alpha_init: 1.0,
36 alpha_min: 1e-8,
37 }
38 }
39}
40
41#[derive(Debug)]
43pub struct LineSearchResult<F> {
44 pub alpha: F,
46 pub value: F,
48 pub gradient: Vec<F>,
50 pub evals: usize,
52}
53
54pub fn backtracking_armijo<F: Float, O: Objective<F>>(
62 obj: &mut O,
63 x: &[F],
64 d: &[F],
65 f_x: F,
66 grad_x: &[F],
67 params: &ArmijoParams<F>,
68) -> Option<LineSearchResult<F>> {
69 let n = x.len();
70 let dg = dot(grad_x, d);
71
72 if dg >= F::zero() {
74 return None;
75 }
76
77 let mut alpha = params.alpha_init;
78 let mut x_new = vec![F::zero(); n];
79 let mut evals = 0;
80
81 loop {
82 if alpha < params.alpha_min {
83 return None;
84 }
85
86 for i in 0..n {
87 x_new[i] = x[i] + alpha * d[i];
88 }
89
90 let (f_new, g_new) = obj.eval_grad(&x_new);
91 evals += 1;
92
93 if !f_new.is_finite() || !g_new.iter().all(|g| g.is_finite()) {
99 alpha = alpha * params.rho;
100 continue;
101 }
102
103 if f_new <= f_x + params.c * alpha * dg {
105 return Some(LineSearchResult {
106 alpha,
107 value: f_new,
108 gradient: g_new,
109 evals,
110 });
111 }
112
113 alpha = alpha * params.rho;
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 struct Quadratic;
123
124 impl Objective<f64> for Quadratic {
125 fn dim(&self) -> usize {
126 2
127 }
128
129 fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
130 let f = 0.5 * (x[0] * x[0] + x[1] * x[1]);
131 let g = vec![x[0], x[1]];
132 (f, g)
133 }
134 }
135
136 #[test]
137 fn armijo_quadratic_descent() {
138 let mut obj = Quadratic;
139 let x = vec![2.0, 3.0];
140 let (f_x, grad) = obj.eval_grad(&x);
141 let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
143
144 let result =
145 backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
146
147 assert!(result.alpha > 0.0);
148 assert!(result.value < f_x, "line search should decrease objective");
149 }
150
151 #[test]
152 fn armijo_full_step_on_quadratic() {
153 let mut obj = Quadratic;
154 let x = vec![2.0, 3.0];
155 let (f_x, grad) = obj.eval_grad(&x);
156 let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
157
158 let result =
159 backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
160
161 assert!(
163 (result.alpha - 1.0).abs() < 1e-12,
164 "full step should be accepted on quadratic, got alpha={}",
165 result.alpha
166 );
167 }
168
169 #[test]
170 fn armijo_non_descent_returns_none() {
171 let mut obj = Quadratic;
172 let x = vec![2.0, 3.0];
173 let (f_x, grad) = obj.eval_grad(&x);
174 let d = grad.clone();
176
177 let result = backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default());
178 assert!(result.is_none());
179 }
180}