1use numra_core::Scalar;
13
14#[derive(Debug, Clone)]
16pub enum LineSearchError {
17 NotDescentDirection,
19 BracketCollapsed,
21 MaxIterations,
23}
24
25impl std::fmt::Display for LineSearchError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 Self::NotDescentDirection => write!(f, "search direction is not a descent direction"),
29 Self::BracketCollapsed => write!(f, "zoom bracket collapsed"),
30 Self::MaxIterations => write!(f, "max line search iterations reached"),
31 }
32 }
33}
34
35impl std::error::Error for LineSearchError {}
36
37#[derive(Debug, Clone)]
39pub struct WolfeOptions<S: Scalar> {
40 pub c1: S,
42 pub c2: S,
44 pub max_step: S,
46 pub max_iter: usize,
48}
49
50impl<S: Scalar> Default for WolfeOptions<S> {
51 fn default() -> Self {
52 Self {
53 c1: S::from_f64(1e-4),
54 c2: S::from_f64(0.9),
55 max_step: S::from_f64(1e20),
56 max_iter: 40,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct LineSearchResult<S: Scalar> {
64 pub step: S,
66 pub f_new: S,
68 pub n_eval: usize,
70}
71
72fn dot<S: Scalar>(a: &[S], b: &[S]) -> S {
74 a.iter()
75 .zip(b.iter())
76 .map(|(&ai, &bi)| ai * bi)
77 .fold(S::ZERO, |acc, x| acc + x)
78}
79
80const BRACKET_COLLAPSE_TOL: f64 = 1e-16;
82
83#[allow(clippy::too_many_arguments)]
87fn zoom<S, F, G>(
88 f: &F,
89 grad: &G,
90 x: &[S],
91 d: &[S],
92 f0: S,
93 dg0: S,
94 mut alpha_lo: S,
95 mut f_lo: S,
96 mut alpha_hi: S,
97 opts: &WolfeOptions<S>,
98 n_eval: &mut usize,
99) -> Result<LineSearchResult<S>, LineSearchError>
100where
101 S: Scalar,
102 F: Fn(&[S]) -> S,
103 G: Fn(&[S], &mut [S]),
104{
105 let n = x.len();
106 let mut x_trial = vec![S::ZERO; n];
107 let mut g_trial = vec![S::ZERO; n];
108
109 for _ in 0..opts.max_iter {
110 if (alpha_hi - alpha_lo).abs() < S::from_f64(BRACKET_COLLAPSE_TOL) {
111 return Err(LineSearchError::BracketCollapsed);
112 }
113
114 let alpha_j = (alpha_lo + alpha_hi) / S::TWO;
115
116 for i in 0..n {
118 x_trial[i] = x[i] + alpha_j * d[i];
119 }
120 let f_j = f(&x_trial);
121 *n_eval += 1;
122
123 if f_j > f0 + opts.c1 * alpha_j * dg0 || f_j >= f_lo {
124 alpha_hi = alpha_j;
126 } else {
127 grad(&x_trial, &mut g_trial);
129 let dg_j = dot(&g_trial, d);
130
131 if dg_j.abs() <= -opts.c2 * dg0 {
132 return Ok(LineSearchResult {
134 step: alpha_j,
135 f_new: f_j,
136 n_eval: *n_eval,
137 });
138 }
139
140 if dg_j * (alpha_hi - alpha_lo) >= S::ZERO {
141 alpha_hi = alpha_lo;
142 }
143
144 alpha_lo = alpha_j;
145 f_lo = f_j;
146 }
147 }
148
149 Err(LineSearchError::MaxIterations)
151}
152
153pub fn wolfe_line_search<S, F, G>(
168 f: F,
169 grad: G,
170 x: &[S],
171 d: &[S],
172 f0: S,
173 g0: &[S],
174 opts: &WolfeOptions<S>,
175) -> Result<LineSearchResult<S>, LineSearchError>
176where
177 S: Scalar,
178 F: Fn(&[S]) -> S,
179 G: Fn(&[S], &mut [S]),
180{
181 let dg0 = dot(g0, d);
182 if dg0 >= S::ZERO {
183 return Err(LineSearchError::NotDescentDirection);
184 }
185
186 let n = x.len();
187 let mut x_trial = vec![S::ZERO; n];
188 let mut g_trial = vec![S::ZERO; n];
189
190 let mut alpha_prev = S::ZERO;
191 let mut f_prev = f0;
192 let mut alpha = S::ONE;
193 let mut n_eval: usize = 0;
194
195 for i in 1..=opts.max_iter {
196 if alpha > opts.max_step {
198 alpha = opts.max_step;
199 }
200
201 for j in 0..n {
203 x_trial[j] = x[j] + alpha * d[j];
204 }
205 let f_alpha = f(&x_trial);
206 n_eval += 1;
207
208 if f_alpha > f0 + opts.c1 * alpha * dg0 || (i > 1 && f_alpha >= f_prev) {
210 return zoom(
211 &f,
212 &grad,
213 x,
214 d,
215 f0,
216 dg0,
217 alpha_prev,
218 f_prev,
219 alpha,
220 opts,
221 &mut n_eval,
222 );
223 }
224
225 grad(&x_trial, &mut g_trial);
227 let dg_alpha = dot(&g_trial, d);
228
229 if dg_alpha.abs() <= -opts.c2 * dg0 {
230 return Ok(LineSearchResult {
231 step: alpha,
232 f_new: f_alpha,
233 n_eval,
234 });
235 }
236
237 if dg_alpha >= S::ZERO {
239 return zoom(
240 &f,
241 &grad,
242 x,
243 d,
244 f0,
245 dg0,
246 alpha,
247 f_alpha,
248 alpha_prev,
249 opts,
250 &mut n_eval,
251 );
252 }
253
254 alpha_prev = alpha;
256 f_prev = f_alpha;
257 alpha *= S::TWO;
258 }
259
260 Err(LineSearchError::MaxIterations)
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_wolfe_quadratic() {
269 let f = |x: &[f64]| x[0] * x[0];
271 let grad = |x: &[f64], g: &mut [f64]| {
272 g[0] = 2.0 * x[0];
273 };
274
275 let x = [2.0];
276 let d = [-1.0];
277 let f0 = f(&x);
278 let g0 = [4.0]; let opts = WolfeOptions::default();
281 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
282
283 assert!(res.step > 0.0, "step must be positive");
284 assert!(
285 res.f_new < f0,
286 "function must decrease: f_new={} vs f0={}",
287 res.f_new,
288 f0
289 );
290 }
291
292 #[test]
293 fn test_wolfe_rosenbrock() {
294 let f = |x: &[f64]| {
296 let a = 1.0 - x[0];
297 let b = x[1] - x[0] * x[0];
298 a * a + 100.0 * b * b
299 };
300 let grad = |x: &[f64], g: &mut [f64]| {
301 g[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
302 g[1] = 200.0 * (x[1] - x[0] * x[0]);
303 };
304
305 let x = [-1.0, 1.0];
306 let f0 = f(&x);
307 let mut g0 = [0.0; 2];
308 grad(&x, &mut g0);
309
310 let d = [-g0[0], -g0[1]];
312
313 let opts = WolfeOptions::default();
314 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
315
316 assert!(res.step > 0.0, "step must be positive");
317 assert!(
318 res.f_new < f0,
319 "function must decrease: f_new={} vs f0={}",
320 res.f_new,
321 f0
322 );
323 }
324
325 #[test]
326 fn test_wolfe_not_descent() {
327 let f = |x: &[f64]| x[0] * x[0];
328 let grad = |x: &[f64], g: &mut [f64]| {
329 g[0] = 2.0 * x[0];
330 };
331
332 let x = [2.0];
333 let d = [1.0]; let f0 = f(&x);
335 let g0 = [4.0];
336
337 let opts = WolfeOptions::default();
338 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts);
339
340 assert!(res.is_err(), "must reject non-descent direction");
341 assert!(
342 matches!(res.unwrap_err(), LineSearchError::NotDescentDirection),
343 "error should be NotDescentDirection"
344 );
345 }
346
347 #[test]
348 fn test_wolfe_f32() {
349 let f = |x: &[f32]| x[0] * x[0];
351 let grad = |x: &[f32], g: &mut [f32]| {
352 g[0] = 2.0 * x[0];
353 };
354
355 let x = [2.0f32];
356 let d = [-1.0f32];
357 let f0 = f(&x);
358 let g0 = [4.0f32];
359
360 let opts = WolfeOptions::default();
361 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
362
363 assert!(res.step > 0.0, "step must be positive");
364 assert!(res.f_new < f0, "function must decrease");
365 }
366}