use crate::StrError;
const DEFAULT_C1: f64 = 1e-4;
const DEFAULT_RHO: f64 = 0.5;
const DEFAULT_MIN_ALPHA: f64 = 1e-20;
const DEFAULT_MAX_ITERATIONS: usize = 20;
#[derive(Clone, Copy, Debug)]
pub struct LineSearcher {
pub c1: f64,
pub rho: f64,
pub min_alpha: f64,
pub max_iterations: usize,
}
impl LineSearcher {
pub fn new() -> Self {
LineSearcher {
c1: DEFAULT_C1,
rho: DEFAULT_RHO,
min_alpha: DEFAULT_MIN_ALPHA,
max_iterations: DEFAULT_MAX_ITERATIONS,
}
}
fn validate_params(&self) -> Result<(), StrError> {
if self.c1 <= 0.0 || self.c1 >= 1.0 {
return Err("c1 must satisfy 0 < c1 < 1");
}
if self.rho <= 0.0 || self.rho >= 1.0 {
return Err("rho must satisfy 0 < rho < 1");
}
if self.min_alpha <= 0.0 {
return Err("min_alpha must be > 0");
}
if self.min_alpha >= 1.0 {
return Err("min_alpha must be < 1");
}
if self.max_iterations == 0 {
return Err("max_iterations must be ≥ 1");
}
Ok(())
}
pub fn search<F, A>(
&self,
x: f64,
p: f64,
fx: f64,
slope: f64,
args: &mut A,
mut f: F,
) -> Result<(f64, usize), StrError>
where
F: FnMut(f64, &mut A) -> Result<f64, StrError>,
{
self.validate_params()?;
if slope >= 0.0 {
return Err("direction must be a descent direction (slope < 0)");
}
let mut alpha = 1.0;
for n_iter in 0..self.max_iterations {
let target = fx + self.c1 * alpha * slope;
let x_new = x + alpha * p;
let f_new = f(x_new, args)?;
if f_new <= target {
return Ok((alpha, n_iter + 1));
}
alpha *= self.rho;
if alpha < self.min_alpha {
return Err("step size too small");
}
}
Err("line search failed to converge")
}
}
pub fn line_search<F, A>(x: f64, p: f64, fx: f64, slope: f64, args: &mut A, f: F) -> Result<f64, StrError>
where
F: FnMut(f64, &mut A) -> Result<f64, StrError>,
{
let searcher = LineSearcher::new();
searcher.search(x, p, fx, slope, args, f).map(|(alpha, _)| alpha)
}
pub fn line_search_with_stats<F, A>(
x: f64,
p: f64,
fx: f64,
slope: f64,
args: &mut A,
f: F,
) -> Result<(f64, usize), StrError>
where
F: FnMut(f64, &mut A) -> Result<f64, StrError>,
{
let searcher = LineSearcher::new();
searcher.search(x, p, fx, slope, args, f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn line_search_non_quadratic() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| {
let d = x - 2.0;
Ok(d.powi(4) + d.powi(2))
};
let x = 0.0;
let fx = 20.0;
let p = 1.0;
let slope = -36.0;
let alpha = line_search(x, p, fx, slope, args, f).unwrap();
let x_new = x + alpha * p;
assert_eq!(alpha, 1.0);
assert!(x_new > 0.0 && x_new < 4.0);
}
#[test]
fn line_search_with_stats_works() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| {
let d = x - 3.0;
Ok(d.powi(4) + d.powi(2))
};
let x = 0.0;
let fx = 90.0;
let p = 1.0;
let slope = -114.0;
let (alpha, n_evals) = line_search_with_stats(x, p, fx, slope, args, f).unwrap();
assert_eq!(n_evals, 1);
let x_new = x + alpha * p;
assert!(x_new > 0.0 && x_new < 5.0);
}
#[test]
fn line_search_captures_not_descent_direction() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| {
let d = x - 2.0;
Ok(d.powi(4) + d.powi(2))
};
let _ = f(0.0, args);
let x = 0.0;
let fx = 20.0;
let p = -1.0;
let slope = 36.0;
let result = line_search(x, p, fx, slope, args, f);
assert_eq!(result.err(), Some("direction must be a descent direction (slope < 0)"));
}
#[test]
fn line_search_custom_parameters() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| {
let d = x - 2.0;
Ok(d.powi(4) + d.powi(2))
};
let x = 0.0;
let fx = 20.0;
let p = 1.0;
let slope = -36.0;
let mut searcher = LineSearcher::new();
searcher.c1 = 1e-3;
searcher.rho = 0.7;
searcher.max_iterations = 30;
let (alpha, _) = searcher.search(x, p, fx, slope, args, f).unwrap();
let x_new = x + alpha * p;
assert!(x_new > 0.0 && x_new < 4.0);
}
#[test]
fn line_search_with_args() {
struct Args {
target: f64,
}
let args = &mut Args { target: 5.0 };
let f = |x: f64, a: &mut Args| {
let d = x - a.target;
Ok(d.powi(4) + d.powi(2))
};
let x = 0.0;
let fx = 625.0 + 25.0; let p = 1.0;
let slope = -510.0;
let alpha = line_search(x, p, fx, slope, args, f).unwrap();
let x_new = x + alpha * p;
assert!(x_new > 0.0 && x_new < 10.0);
}
#[test]
fn line_search_stops_too_small_alpha() {
struct Args {}
let args = &mut Args {};
let f = |_: f64, _: &mut Args| Ok(f64::MAX);
let x = 0.0;
let fx = 1.0;
let p = 1.0;
let slope = -1.0;
let mut searcher = LineSearcher::new();
searcher.min_alpha = 0.1;
searcher.rho = 0.5;
searcher.max_iterations = 10;
let result = searcher.search(x, p, fx, slope, args, f);
assert_eq!(result.err(), Some("step size too small"));
}
#[test]
fn line_search_convergence_limits() {
struct Args {}
let args = &mut Args {};
let f = |_: f64, _: &mut Args| Ok(f64::MAX);
let x = 0.0;
let fx = 1.0;
let p = 1.0;
let slope = -1.0;
let mut searcher = LineSearcher::new();
searcher.max_iterations = 1;
let result = searcher.search(x, p, fx, slope, args, f);
assert_eq!(result.err(), Some("line search failed to converge"));
}
#[test]
fn line_search_multiple_calls() {
struct Args {
count: usize,
}
fn f(x: f64, a: &mut Args) -> Result<f64, StrError> {
a.count += 1;
let d = x - 2.0;
Ok(d.powi(4) + d.powi(2))
}
let searcher = LineSearcher::new();
let args = &mut Args { count: 0 };
let (alpha1, n_evals1) = searcher.search(0.0, 1.0, 20.0, -36.0, args, f).unwrap();
assert_eq!(n_evals1, 1);
assert_eq!(args.count, 1);
assert!(alpha1 > 0.0 && alpha1 <= 1.0);
let (alpha2, n_evals2) = searcher.search(4.0, -1.0, 20.0, -36.0, args, f).unwrap();
assert_eq!(n_evals2, 1);
assert_eq!(args.count, 2);
assert!(alpha2 > 0.0 && alpha2 <= 1.0);
}
#[test]
fn line_search_captures_zero_slope() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| Ok(x.powi(4));
let _ = f(0.0, args);
let x = 0.0;
let fx = 0.0;
let p = 1.0;
let slope = 0.0;
let result = line_search(x, p, fx, slope, args, f);
assert_eq!(result.err(), Some("direction must be a descent direction (slope < 0)"));
}
#[test]
fn line_search_validate_params_works() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| Ok(x);
let _ = f(0.0, args);
let mut searcher = LineSearcher::new();
searcher.c1 = 0.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("c1 must satisfy 0 < c1 < 1")
);
searcher.c1 = 1.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("c1 must satisfy 0 < c1 < 1")
);
searcher.c1 = DEFAULT_C1;
searcher.rho = 0.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("rho must satisfy 0 < rho < 1")
);
searcher.rho = 1.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("rho must satisfy 0 < rho < 1")
);
searcher.rho = DEFAULT_RHO;
searcher.min_alpha = 0.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("min_alpha must be > 0")
);
searcher.min_alpha = 1.0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("min_alpha must be < 1")
);
searcher.min_alpha = DEFAULT_MIN_ALPHA;
searcher.max_iterations = 0;
assert_eq!(
searcher.search(0.0, 1.0, 1.0, -1.0, args, f).err(),
Some("max_iterations must be ≥ 1")
);
}
#[test]
fn line_search_exponential() {
struct Args {}
let args = &mut Args {};
let f = |x: f64, _: &mut Args| Ok(f64::exp(-x));
let x = 0.0;
let fx = 1.0;
let p = 1.0;
let slope = -1.0;
let alpha = line_search(x, p, fx, slope, args, f).unwrap();
let x_new = x + alpha * p;
assert!(x_new > 0.0);
assert!(x_new < 5.0);
}
#[test]
fn line_search_captures_f_error() {
struct Args {}
let args = &mut Args {};
let f = |_: f64, _: &mut Args| Err("f evaluation failed");
let result = line_search(0.0, 1.0, 1.0, -1.0, args, f);
assert_eq!(result.err(), Some("f evaluation failed"));
}
}