1use numra_core::{NumraError, Scalar};
13
14#[derive(Debug, Clone)]
16pub enum LineSearchError {
17 NotDescentDirection,
19 BracketCollapsed,
21 MaxIterations,
23}
24
25impl std::fmt::Display for LineSearchError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 Self::NotDescentDirection => write!(f, "search direction is not a descent direction"),
29 Self::BracketCollapsed => write!(f, "zoom bracket collapsed"),
30 Self::MaxIterations => write!(f, "max line search iterations reached"),
31 }
32 }
33}
34
35impl std::error::Error for LineSearchError {}
36
37impl From<LineSearchError> for NumraError {
38 fn from(e: LineSearchError) -> Self {
39 NumraError::LineSearch(e.to_string())
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct WolfeOptions<S: Scalar> {
46 pub c1: S,
48 pub c2: S,
50 pub max_step: S,
52 pub max_iter: usize,
54}
55
56impl<S: Scalar> Default for WolfeOptions<S> {
57 fn default() -> Self {
58 Self {
59 c1: S::from_f64(1e-4),
60 c2: S::from_f64(0.9),
61 max_step: S::from_f64(1e20),
62 max_iter: 40,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct LineSearchResult<S: Scalar> {
70 pub step: S,
72 pub f_new: S,
74 pub n_eval: usize,
76}
77
78fn dot<S: Scalar>(a: &[S], b: &[S]) -> S {
80 a.iter()
81 .zip(b.iter())
82 .map(|(&ai, &bi)| ai * bi)
83 .fold(S::ZERO, |acc, x| acc + x)
84}
85
86const BRACKET_COLLAPSE_TOL: f64 = 1e-16;
88
89#[allow(clippy::too_many_arguments)]
93fn zoom<S, F, G>(
94 f: &F,
95 grad: &G,
96 x: &[S],
97 d: &[S],
98 f0: S,
99 dg0: S,
100 mut alpha_lo: S,
101 mut f_lo: S,
102 mut alpha_hi: S,
103 opts: &WolfeOptions<S>,
104 n_eval: &mut usize,
105) -> Result<LineSearchResult<S>, LineSearchError>
106where
107 S: Scalar,
108 F: Fn(&[S]) -> S,
109 G: Fn(&[S], &mut [S]),
110{
111 let n = x.len();
112 let mut x_trial = vec![S::ZERO; n];
113 let mut g_trial = vec![S::ZERO; n];
114
115 for _ in 0..opts.max_iter {
116 if (alpha_hi - alpha_lo).abs() < S::from_f64(BRACKET_COLLAPSE_TOL) {
117 return Err(LineSearchError::BracketCollapsed);
118 }
119
120 let alpha_j = (alpha_lo + alpha_hi) / S::TWO;
121
122 for i in 0..n {
124 x_trial[i] = x[i] + alpha_j * d[i];
125 }
126 let f_j = f(&x_trial);
127 *n_eval += 1;
128
129 if f_j > f0 + opts.c1 * alpha_j * dg0 || f_j >= f_lo {
130 alpha_hi = alpha_j;
132 } else {
133 grad(&x_trial, &mut g_trial);
135 let dg_j = dot(&g_trial, d);
136
137 if dg_j.abs() <= -opts.c2 * dg0 {
138 return Ok(LineSearchResult {
140 step: alpha_j,
141 f_new: f_j,
142 n_eval: *n_eval,
143 });
144 }
145
146 if dg_j * (alpha_hi - alpha_lo) >= S::ZERO {
147 alpha_hi = alpha_lo;
148 }
149
150 alpha_lo = alpha_j;
151 f_lo = f_j;
152 }
153 }
154
155 Err(LineSearchError::MaxIterations)
157}
158
159pub fn wolfe_line_search<S, F, G>(
174 f: F,
175 grad: G,
176 x: &[S],
177 d: &[S],
178 f0: S,
179 g0: &[S],
180 opts: &WolfeOptions<S>,
181) -> Result<LineSearchResult<S>, LineSearchError>
182where
183 S: Scalar,
184 F: Fn(&[S]) -> S,
185 G: Fn(&[S], &mut [S]),
186{
187 let dg0 = dot(g0, d);
188 if dg0 >= S::ZERO {
189 return Err(LineSearchError::NotDescentDirection);
190 }
191
192 let n = x.len();
193 let mut x_trial = vec![S::ZERO; n];
194 let mut g_trial = vec![S::ZERO; n];
195
196 let mut alpha_prev = S::ZERO;
197 let mut f_prev = f0;
198 let mut alpha = S::ONE;
199 let mut n_eval: usize = 0;
200
201 for i in 1..=opts.max_iter {
202 if alpha > opts.max_step {
204 alpha = opts.max_step;
205 }
206
207 for j in 0..n {
209 x_trial[j] = x[j] + alpha * d[j];
210 }
211 let f_alpha = f(&x_trial);
212 n_eval += 1;
213
214 if f_alpha > f0 + opts.c1 * alpha * dg0 || (i > 1 && f_alpha >= f_prev) {
216 return zoom(
217 &f,
218 &grad,
219 x,
220 d,
221 f0,
222 dg0,
223 alpha_prev,
224 f_prev,
225 alpha,
226 opts,
227 &mut n_eval,
228 );
229 }
230
231 grad(&x_trial, &mut g_trial);
233 let dg_alpha = dot(&g_trial, d);
234
235 if dg_alpha.abs() <= -opts.c2 * dg0 {
236 return Ok(LineSearchResult {
237 step: alpha,
238 f_new: f_alpha,
239 n_eval,
240 });
241 }
242
243 if dg_alpha >= S::ZERO {
245 return zoom(
246 &f,
247 &grad,
248 x,
249 d,
250 f0,
251 dg0,
252 alpha,
253 f_alpha,
254 alpha_prev,
255 opts,
256 &mut n_eval,
257 );
258 }
259
260 alpha_prev = alpha;
262 f_prev = f_alpha;
263 alpha *= S::TWO;
264 }
265
266 Err(LineSearchError::MaxIterations)
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_wolfe_quadratic() {
275 let f = |x: &[f64]| x[0] * x[0];
277 let grad = |x: &[f64], g: &mut [f64]| {
278 g[0] = 2.0 * x[0];
279 };
280
281 let x = [2.0];
282 let d = [-1.0];
283 let f0 = f(&x);
284 let g0 = [4.0]; let opts = WolfeOptions::default();
287 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
288
289 assert!(res.step > 0.0, "step must be positive");
290 assert!(
291 res.f_new < f0,
292 "function must decrease: f_new={} vs f0={}",
293 res.f_new,
294 f0
295 );
296 }
297
298 #[test]
299 fn test_wolfe_rosenbrock() {
300 let f = |x: &[f64]| {
302 let a = 1.0 - x[0];
303 let b = x[1] - x[0] * x[0];
304 a * a + 100.0 * b * b
305 };
306 let grad = |x: &[f64], g: &mut [f64]| {
307 g[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
308 g[1] = 200.0 * (x[1] - x[0] * x[0]);
309 };
310
311 let x = [-1.0, 1.0];
312 let f0 = f(&x);
313 let mut g0 = [0.0; 2];
314 grad(&x, &mut g0);
315
316 let d = [-g0[0], -g0[1]];
318
319 let opts = WolfeOptions::default();
320 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
321
322 assert!(res.step > 0.0, "step must be positive");
323 assert!(
324 res.f_new < f0,
325 "function must decrease: f_new={} vs f0={}",
326 res.f_new,
327 f0
328 );
329 }
330
331 #[test]
332 fn test_wolfe_not_descent() {
333 let f = |x: &[f64]| x[0] * x[0];
334 let grad = |x: &[f64], g: &mut [f64]| {
335 g[0] = 2.0 * x[0];
336 };
337
338 let x = [2.0];
339 let d = [1.0]; let f0 = f(&x);
341 let g0 = [4.0];
342
343 let opts = WolfeOptions::default();
344 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts);
345
346 assert!(res.is_err(), "must reject non-descent direction");
347 assert!(
348 matches!(res.unwrap_err(), LineSearchError::NotDescentDirection),
349 "error should be NotDescentDirection"
350 );
351 }
352
353 #[test]
354 fn test_wolfe_f32() {
355 let f = |x: &[f32]| x[0] * x[0];
357 let grad = |x: &[f32], g: &mut [f32]| {
358 g[0] = 2.0 * x[0];
359 };
360
361 let x = [2.0f32];
362 let d = [-1.0f32];
363 let f0 = f(&x);
364 let g0 = [4.0f32];
365
366 let opts = WolfeOptions::default();
367 let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
368
369 assert!(res.step > 0.0, "step must be positive");
370 assert!(res.f_new < f0, "function must decrease");
371 }
372}