gosh_linesearch/lib.rs
1// [[file:../linesearch.note::*header][header:1]]
2//! Line search, also called one-dimensional search, refers to an optimization
3//! procedure for univariable functions.
4//!
5//! # Available algorithms
6//!
7//! * MoreThuente
8//! * BackTracking
9//! * BackTrackingArmijo
10//! * BackTrackingWolfe
11//! * BackTrackingStrongWolfe
12//!
13//! # References
14//!
15//! * Sun, W.; Yuan, Y. Optimization Theory and Methods: Nonlinear Programming, 1st
16//! ed.; Springer, 2006.
17//! * Nocedal, J.; Wright, S. Numerical Optimization; Springer Science & Business
18//! Media, 2006.
19//!
20//! # Examples
21//!
22//! ```ignore
23//! use line::linesearch;
24//!
25//! let mut step = 1.0;
26//! let count = linesearch()
27//! .with_initial_step(1.5) // the default is 1.0
28//! .with_algorithm("BackTracking") // the default is MoreThuente
29//! .find(5, |a: f64, out: &mut Output| {
30//! // restore position
31//! x.veccpy(&x_k);
32//! // update position with step along d
33//! x.vecadd(&d_k, a);
34//! // update value and gradient
35//! out.fx = f(x, &mut gx)?;
36//! // update line search gradient
37//! out.gx = gx.vecdot(d);
38//! // update optimal step size
39//! step = a;
40//! // return any user defined data
41//! Ok(())
42//! })?;
43//!
44//! let ls = linesearch()
45//! .with_max_iterations(5) // the default is 10
46//! .with_initial_step(1.5) // the default is 1.0
47//! .with_algorithm("BackTracking") // the default is MoreThuente
48//! .find_iter(|a: f64, out: &mut Output| {
49//! // restore position
50//! x.veccpy(&x_k);
51//! // update position with step along d
52//! x.vecadd(&d_k, a);
53//! // update value and gradient
54//! out.fx = f(x, &mut gx)?;
55//! // update line search gradient
56//! out.gx = gx.vecdot(d);
57//! // update optimal step size
58//! step = a;
59//! // return any user defined data
60//! Ok(())
61//! })?;
62//!
63//! for success in ls {
64//! if success {
65//! //
66//! } else {
67//! //
68//! }
69//! }
70//!```
71
72use crate::common::*;
73// header:1 ends here
74
75// [[file:../linesearch.note::*mods][mods:1]]
76mod backtracking;
77mod morethuente;
78// mods:1 ends here
79
80// [[file:../linesearch.note::*common][common:1]]
81pub(crate) mod common {
82 pub use gut::prelude::*;
83}
84// common:1 ends here
85
86// [[file:../linesearch.note::*algorithm][algorithm:1]]
87/// Line search algorithms.
88#[derive(Debug, Copy, Clone, PartialEq)]
89pub enum LineSearchAlgorithm {
90 /// MoreThuente method proposd by More and Thuente. This is the default for
91 /// regular LBFGS.
92 MoreThuente,
93
94 ///
95 /// BackTracking method with the Armijo condition.
96 ///
97 /// The backtracking method finds the step length such that it satisfies
98 /// the sufficient decrease (Armijo) condition,
99 /// - f(x + a * d) <= f(x) + ftol * a * g(x)^T d,
100 ///
101 /// where x is the current point, d is the current search direction, and
102 /// a is the step length.
103 ///
104 BackTrackingArmijo,
105
106 /// BackTracking method with strong Wolfe condition.
107 ///
108 /// The backtracking method finds the step length such that it satisfies
109 /// both the Armijo condition (BacktrackingArmijo)
110 /// and the following condition,
111 /// - |g(x + a * d)^T d| <= gtol * |g(x)^T d|,
112 ///
113 /// where x is the current point, d is the current search direction, and
114 /// a is the step length.
115 ///
116 BackTrackingStrongWolfe,
117
118 ///
119 /// BackTracking method with regular Wolfe condition.
120 ///
121 /// The backtracking method finds the step length such that it satisfies
122 /// both the Armijo condition (BacktrackingArmijo)
123 /// and the curvature condition,
124 /// - g(x + a * d)^T d >= gtol * g(x)^T d,
125 ///
126 /// where x is the current point, d is the current search direction, and a
127 /// is the step length.
128 ///
129 BackTrackingWolfe,
130}
131
132impl Default for LineSearchAlgorithm {
133 /// The default algorithm (MoreThuente method).
134 fn default() -> Self {
135 LineSearchAlgorithm::MoreThuente
136 }
137}
138// algorithm:1 ends here
139
140// [[file:../linesearch.note::*base][base:1]]
141#[derive(Clone, Debug, PartialEq)]
142pub enum LineSearchCondition {
143 /// The sufficient decrease condition.
144 Armijo,
145 Wolfe,
146 StrongWolfe,
147}
148
149/// A trait for performing line search
150pub(crate) trait LineSearchFind<E>
151where
152 E: Fn(f64) -> (f64, f64),
153{
154 /// Given initial step size and phi function, returns an satisfactory step
155 /// size.
156 ///
157 /// `step` is a positive scalar representing the step size along search
158 /// direction. phi is an univariable function of `step` for evaluating the
159 /// value and the gradient projected onto search direction.
160 fn find(&mut self, step: &mut f64, phi: E) -> Result<usize>;
161}
162// base:1 ends here
163
164// [[file:../linesearch.note::*builder][builder:1]]
165/// A unified interface to line search methods.
166///
167/// # Examples
168///
169/// ```ignore
170/// use line::linesearch;
171///
172/// let mut step = 1.0;
173/// let count = linesearch()
174/// .with_initial_step(1.5) // the default is 1.0
175/// .with_algorithm("BackTracking") // the default is MoreThuente
176/// .find(5, |a: f64, out: &mut Output| {
177/// // restore position
178/// x.veccpy(&x_k);
179/// // update position with step along d
180/// x.vecadd(&d_k, a);
181/// // update value and gradient
182/// out.fx = f(x, &mut gx)?;
183/// // update line search gradient
184/// out.gx = gx.vecdot(d);
185/// // update optimal step size
186/// step = a;
187/// // return any user defined data
188/// Ok(())
189/// })?;
190///
191/// let ls = linesearch()
192/// .with_max_iterations(5) // the default is 10
193/// .with_initial_step(1.5) // the default is 1.0
194/// .with_algorithm("BackTracking") // the default is MoreThuente
195/// .find_iter(|a: f64, out: &mut Output| {
196/// // restore position
197/// x.veccpy(&x_k);
198/// // update position with step along d
199/// x.vecadd(&d_k, a);
200/// // update value and gradient
201/// out.fx = f(x, &mut gx)?;
202/// // update line search gradient
203/// out.gx = gx.vecdot(d);
204/// // update optimal step size
205/// step = a;
206/// // return any user defined data
207/// Ok(())
208/// })?;
209///
210/// for success in ls {
211/// if success {
212/// //
213/// } else {
214/// //
215/// }
216/// }
217///```
218pub fn linesearch() -> LineSearch {
219 LineSearch::default()
220}
221
222pub struct LineSearch {
223 algorithm: LineSearchAlgorithm,
224 initial_step: f64,
225}
226
227impl Default for LineSearch {
228 fn default() -> Self {
229 LineSearch {
230 algorithm: LineSearchAlgorithm::default(),
231 initial_step: 1.0,
232 }
233 }
234}
235
236impl LineSearch {
237 /// Set initial step size when performing line search. The default is 1.0.
238 pub fn with_initial_step(mut self, stp: f64) -> Self {
239 assert!(
240 stp.is_sign_positive(),
241 "line search initial step should be a positive float!"
242 );
243
244 self.initial_step = stp;
245 self
246 }
247
248 /// Set line search algorithm. The default is MoreThuente algorithm.
249 pub fn with_algorithm(mut self, s: &str) -> Self {
250 self.algorithm = match s {
251 "MoreThuente" => LineSearchAlgorithm::MoreThuente,
252 "BackTracking" | "BackTrackingWolfe" => LineSearchAlgorithm::BackTrackingWolfe,
253 "BackTrackingStrongWolfe" => LineSearchAlgorithm::BackTrackingWolfe,
254 "BackTrackingArmijo" => LineSearchAlgorithm::BackTrackingArmijo,
255 _ => unimplemented!(),
256 };
257
258 self
259 }
260}
261// builder:1 ends here
262
263// [[file:../linesearch.note::57376052][57376052]]
264use std::fmt::Debug;
265
266// 定义线搜索函数计算核心
267pub trait LineSearchFindNext {
268 // 执行单步line search. 通过返回值可判断当前位置是否满足线搜索条件
269 fn find_next<E>(&self, stp: &mut f64, phi: E) -> Result<bool>
270 where
271 E: FnMut(f64) -> Result<(f64, f64)>;
272}
273
274// Input is initial step for performing line search
275pub type Input = f64;
276
277// 需要搜索步长对应的函数值及梯度
278pub struct Output {
279 /// The value of function at step `x`
280 pub fx: f64,
281 /// The gradient of function at step `x`
282 pub gx: f64,
283}
284
285// 给定NAN数据, 避免未处理output可能的副作用
286impl Default for Output {
287 fn default() -> Self {
288 use std::f64::NAN;
289 Self { fx: NAN, gx: NAN }
290 }
291}
292
293#[derive(Clone, Debug)]
294pub struct Progress<T>
295where
296 T: Debug + Clone,
297{
298 /// current step
299 pub step: f64,
300 /// Indicates line search done or not
301 pub done: bool,
302 /// The data returned from user defined closure for function evaluation
303 pub data: T,
304}
305
306/// T is user defined data
307pub struct LineSearchEval<E, T>
308where
309 E: FnMut(Input, &mut Output) -> Result<T>,
310 T: Debug + Clone,
311{
312 eval_fn: E,
313 user_data: Option<T>,
314}
315
316impl<E, T> LineSearchEval<E, T>
317where
318 E: FnMut(Input, &mut Output) -> Result<T>,
319 T: Debug + Clone,
320{
321 pub fn new(f: E) -> Self {
322 Self {
323 eval_fn: f,
324 user_data: None,
325 }
326 }
327
328 /// 调用回调函数, 同时保留用户自定义进度数据
329 pub fn call(&mut self, x: f64) -> Result<(f64, f64)> {
330 let mut out = Output::default();
331 self.user_data = (self.eval_fn)(x, &mut out)?.into();
332 Ok((out.fx, out.gx))
333 }
334}
335
336pub struct LineSearchIter<A, E, T>
337where
338 A: LineSearchFindNext,
339 E: FnMut(Input, &mut Output) -> Result<T>,
340 T: Debug + Clone,
341{
342 step: f64,
343 eval: LineSearchEval<E, T>,
344 algo: Option<A>,
345}
346
347impl<A, E, T> Iterator for LineSearchIter<A, E, T>
348where
349 A: LineSearchFindNext,
350 E: FnMut(Input, &mut Output) -> Result<T>,
351 T: Debug + Clone,
352{
353 type Item = Progress<T>;
354
355 /// Iterate over current line search step along searching direction. Return
356 /// user defined progress data.
357 fn next(&mut self) -> Option<Self::Item> {
358 let mut step = self.step;
359 let mut algo = self.algo.take();
360 let done = algo
361 .as_mut()
362 .unwrap()
363 .find_next(&mut step, |stp| {
364 let out = self.eval.call(stp)?;
365 Ok(out)
366 })
367 .ok()?;
368 self.step = step;
369 self.algo = algo;
370
371 Progress {
372 // 关键数据1: 当前步长
373 step,
374 // 关键数据2: 完成与否
375 done,
376 // 用户数据: 用户定义的重要数据
377 data: self.eval.user_data.take().expect("no user data"),
378 }
379 .into()
380 }
381}
382// 57376052 ends here
383
384// [[file:../linesearch.note::*api][api:1]]
385impl LineSearch {
386 /// Perform line search with a callback function `phi` to evaluate function
387 /// value and gradient projected onto search direction. This is the iterator
388 /// version of `find` method.
389 fn find_iter<E, T>(&self, phi: E) -> impl Iterator<Item = Progress<T>>
390 where
391 E: FnMut(Input, &mut Output) -> Result<T>,
392 T: Debug + Clone,
393 {
394 use self::LineSearchAlgorithm as lsa;
395
396 let mut bt_iter = None;
397 let mut mt_iter = None;
398 match self.algorithm {
399 lsa::MoreThuente => {
400 let iter = LineSearchIter {
401 step: self.initial_step,
402 eval: crate::LineSearchEval::new(phi),
403 algo: crate::morethuente::MoreThuente::default().into(),
404 };
405 mt_iter = iter.into();
406 }
407 other => {
408 let condition = match other {
409 lsa::BackTrackingWolfe => LineSearchCondition::Wolfe,
410 lsa::BackTrackingStrongWolfe => LineSearchCondition::StrongWolfe,
411 lsa::BackTrackingArmijo => LineSearchCondition::Armijo,
412 _ => todo!(),
413 };
414 let mut ls = LineSearchIter {
415 step: self.initial_step,
416 eval: crate::LineSearchEval::new(phi),
417 algo: crate::backtracking::BackTracking::default()
418 .set_condition(condition)
419 .into(),
420 };
421 bt_iter = ls.into();
422 }
423 }
424 bt_iter.into_iter().flatten().chain(mt_iter.into_iter().flatten())
425 }
426
427 /// Perform line search with a callback function `phi` to evaluate function
428 /// value and gradient projected onto search direction within `m` iterations
429 ///
430 /// # Return
431 ///
432 /// Return success or not within line search iteration.
433 ///
434 pub fn find<E, T>(&self, m: usize, phi: E) -> bool
435 where
436 E: FnMut(Input, &mut Output) -> Result<T>,
437 T: Debug + Clone,
438 {
439 for x in self.find_iter(phi).take(m) {
440 if x.done {
441 return true;
442 }
443 }
444 warn!("ls: optimal step not found!");
445 false
446 }
447}
448// api:1 ends here
449
450// [[file:../linesearch.note::*test][test:1]]
451#[test]
452fn test_ls_iter() -> Result<()> {
453 let mut step = 1.0;
454 let ls = linesearch()
455 .with_initial_step(1.5) // the default is 1.0
456 .with_algorithm("BackTracking") // the default is MoreThuente
457 .find_iter(|a: f64, out: &mut Output| {
458 // restore position
459 // x.veccpy(&x_k);
460 // update position with step along d
461 // x.vecadd(&d_k, a);
462 // update value and gradient
463 // out.fx = f(x, &mut gx)?;
464 out.fx = 0.1;
465 // update line search gradient
466 // out.gx = gx.vecdot(d);
467 out.gx = 0.1;
468 // update optimal step size
469 // step = a;
470 Ok(())
471 });
472
473 for x in ls.take(5) {
474 dbg!(x);
475 }
476
477 Ok(())
478}
479// test:1 ends here