use crate::error::OptimizeError;
use crate::unconstrained::utils::clip_step;
use crate::unconstrained::Bounds;
use scirs2_core::ndarray::{Array1, ArrayView1};
type ZoomSearchResult = ((f64, f64, Array1<f64>), usize, usize);
#[derive(Debug, Clone)]
pub struct StrongWolfeOptions {
pub c1: f64,
pub c2: f64,
pub initial_step: f64,
pub max_step: f64,
pub min_step: f64,
pub max_fev: usize,
pub tolerance: f64,
pub use_safeguarded_interpolation: bool,
pub use_extrapolation: bool,
}
impl Default for StrongWolfeOptions {
fn default() -> Self {
Self {
c1: 1e-4,
c2: 0.9,
initial_step: 1.0,
max_step: 1e10,
min_step: 1e-12,
max_fev: 100,
tolerance: 1e-10,
use_safeguarded_interpolation: true,
use_extrapolation: true,
}
}
}
#[derive(Debug, Clone)]
pub struct StrongWolfeResult {
pub alpha: f64,
pub f_new: f64,
pub g_new: Array1<f64>,
pub nfev: usize,
pub ngev: usize,
pub success: bool,
pub message: String,
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn strong_wolfe_line_search<F, G, S>(
fun: &mut F,
grad_fun: &mut G,
x: &ArrayView1<f64>,
f0: f64,
g0: &ArrayView1<f64>,
direction: &ArrayView1<f64>,
options: &StrongWolfeOptions,
bounds: Option<&Bounds>,
) -> Result<StrongWolfeResult, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> S,
G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
S: Into<f64>,
{
let derphi0 = g0.dot(direction);
if derphi0 >= 0.0 {
return Err(OptimizeError::ValueError(
"Search direction must be a descent direction".to_string(),
));
}
if options.c1 <= 0.0 || options.c1 >= options.c2 || options.c2 >= 1.0 {
return Err(OptimizeError::ValueError(
"Invalid Wolfe parameters: must have 0 < c1 < c2 < 1".to_string(),
));
}
let mut alpha = options.initial_step;
let mut nfev = 0;
let mut ngev = 0;
if let Some(bounds) = bounds {
alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
}
alpha = alpha.min(options.max_step).max(options.min_step);
let (interval_result, fev1, gev1) = find_interval(
fun, grad_fun, x, f0, derphi0, direction, alpha, options, bounds,
)?;
nfev += fev1;
ngev += gev1;
match interval_result {
IntervalResult::Found(alpha, f_alpha, g_alpha) => Ok(StrongWolfeResult {
alpha,
f_new: f_alpha,
g_new: g_alpha,
nfev,
ngev,
success: true,
message: "Strong Wolfe conditions satisfied in interval search".to_string(),
}),
IntervalResult::Bracket(alpha_lo, alpha_hi, f_lo, f_hi, g_lo) => {
let (zoom_result, fev2, gev2) = zoom_search(
fun, grad_fun, x, f0, derphi0, direction, alpha_lo, alpha_hi, f_lo, f_hi, g_lo,
options, bounds,
)?;
nfev += fev2;
ngev += gev2;
Ok(StrongWolfeResult {
alpha: zoom_result.0,
f_new: zoom_result.1,
g_new: zoom_result.2,
nfev,
ngev,
success: true,
message: "Strong Wolfe conditions satisfied in zoom phase".to_string(),
})
}
IntervalResult::Failed => Ok(StrongWolfeResult {
alpha: options.min_step,
f_new: f0,
g_new: g0.to_owned(),
nfev,
ngev,
success: false,
message: "Failed to find acceptable interval".to_string(),
}),
}
}
#[derive(Debug)]
enum IntervalResult {
Found(f64, f64, Array1<f64>), Bracket(f64, f64, f64, f64, f64), Failed,
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
fn find_interval<F, G, S>(
fun: &mut F,
grad_fun: &mut G,
x: &ArrayView1<f64>,
f0: f64,
derphi0: f64,
direction: &ArrayView1<f64>,
mut alpha: f64,
options: &StrongWolfeOptions,
bounds: Option<&Bounds>,
) -> Result<(IntervalResult, usize, usize), OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> S,
G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
S: Into<f64>,
{
let mut nfev = 0;
let mut ngev = 0;
let mut alpha_prev = 0.0;
let mut f_prev = f0;
let mut derphi_prev = derphi0;
for i in 0..options.max_fev {
if let Some(bounds) = bounds {
alpha = alpha.min(clip_step(x, direction, alpha, &bounds.lower, &bounds.upper));
}
alpha = alpha.min(options.max_step).max(options.min_step);
let x_alpha = x + alpha * direction;
let f_alpha = fun(&x_alpha.view()).into();
nfev += 1;
if f_alpha > f0 + options.c1 * alpha * derphi0 || (f_alpha >= f_prev && i > 0) {
return Ok((
IntervalResult::Bracket(alpha_prev, alpha, f_prev, f_alpha, derphi_prev),
nfev,
ngev,
));
}
let g_alpha = grad_fun(&x_alpha.view());
let derphi_alpha = g_alpha.dot(direction);
ngev += 1;
if derphi_alpha.abs() <= -options.c2 * derphi0 {
return Ok((IntervalResult::Found(alpha, f_alpha, g_alpha), nfev, ngev));
}
if derphi_alpha >= 0.0 {
return Ok((
IntervalResult::Bracket(alpha, alpha_prev, f_alpha, f_prev, derphi_alpha),
nfev,
ngev,
));
}
alpha_prev = alpha;
f_prev = f_alpha;
derphi_prev = derphi_alpha;
if options.use_extrapolation {
alpha = if i == 0 {
alpha * 2.0
} else {
alpha * (1.0 + 2.0 * derphi_alpha.abs() / derphi0.abs()).min(3.0)
};
} else {
alpha *= 2.0;
}
if alpha > options.max_step {
alpha = options.max_step;
}
if (alpha - alpha_prev).abs() < options.tolerance {
break;
}
}
Ok((IntervalResult::Failed, nfev, ngev))
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
fn zoom_search<F, G, S>(
fun: &mut F,
grad_fun: &mut G,
x: &ArrayView1<f64>,
f0: f64,
derphi0: f64,
direction: &ArrayView1<f64>,
mut alpha_lo: f64,
mut alpha_hi: f64,
mut f_lo: f64,
mut f_hi: f64,
mut derphi_lo: f64,
options: &StrongWolfeOptions,
bounds: Option<&Bounds>,
) -> Result<ZoomSearchResult, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> S,
G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
S: Into<f64>,
{
let mut nfev = 0;
let mut ngev = 0;
for _ in 0..options.max_fev {
let alpha = if options.use_safeguarded_interpolation {
safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0)
} else {
0.5 * (alpha_lo + alpha_hi)
};
let x_alpha = x + alpha * direction;
let f_alpha = fun(&x_alpha.view()).into();
nfev += 1;
if f_alpha > f0 + options.c1 * alpha * derphi0 || f_alpha >= f_lo {
alpha_hi = alpha;
f_hi = f_alpha;
} else {
let g_alpha = grad_fun(&x_alpha.view());
let derphi_alpha = g_alpha.dot(direction);
ngev += 1;
if derphi_alpha.abs() <= -options.c2 * derphi0 {
return Ok(((alpha, f_alpha, g_alpha), nfev, ngev));
}
if derphi_alpha * (alpha_hi - alpha_lo) >= 0.0 {
alpha_hi = alpha_lo;
f_hi = f_lo;
}
alpha_lo = alpha;
f_lo = f_alpha;
derphi_lo = derphi_alpha;
}
if (alpha_hi - alpha_lo).abs() < options.tolerance {
break;
}
}
let alpha = if f_lo < f_hi { alpha_lo } else { alpha_hi };
let x_alpha = x + alpha * direction;
let f_alpha = fun(&x_alpha.view()).into();
let g_alpha = grad_fun(&x_alpha.view());
nfev += 1;
ngev += 1;
Ok(((alpha, f_alpha, g_alpha), nfev, ngev))
}
#[allow(dead_code)]
fn safeguarded_interpolation(
alpha_lo: f64,
alpha_hi: f64,
f_lo: f64,
f_hi: f64,
derphi_lo: f64,
_derphi0: f64,
) -> f64 {
let delta = alpha_hi - alpha_lo;
let a = (f_hi - f_lo - derphi_lo * delta) / (delta * delta);
let b = derphi_lo;
if a.abs() > 1e-10 {
let discriminant = b * b - 3.0 * a * (f_lo - f_hi + derphi_lo * delta);
if discriminant >= 0.0 {
let alpha_c = alpha_lo + (-b + discriminant.sqrt()) / (3.0 * a);
let safeguard_lo = alpha_lo + 0.1 * delta;
let safeguard_hi = alpha_hi - 0.1 * delta;
if alpha_c >= safeguard_lo && alpha_c <= safeguard_hi {
return alpha_c;
}
}
}
if derphi_lo.abs() > 1e-10 {
let alpha_q =
alpha_lo - 0.5 * derphi_lo * delta * delta / (f_hi - f_lo - derphi_lo * delta);
let safeguard_lo = alpha_lo + 0.1 * delta;
let safeguard_hi = alpha_hi - 0.1 * delta;
if alpha_q >= safeguard_lo && alpha_q <= safeguard_hi {
return alpha_q;
}
}
0.5 * (alpha_lo + alpha_hi)
}
#[allow(dead_code)]
pub fn create_strong_wolfe_options_for_method(method: &str) -> StrongWolfeOptions {
match method.to_lowercase().as_str() {
"bfgs" | "lbfgs" | "sr1" | "dfp" => StrongWolfeOptions {
c1: 1e-4,
c2: 0.9,
initial_step: 1.0,
max_step: 1e4,
min_step: 1e-12,
max_fev: 50,
tolerance: 1e-10,
use_safeguarded_interpolation: true,
use_extrapolation: true,
},
"cg" | "conjugate_gradient" => StrongWolfeOptions {
c1: 1e-4,
c2: 0.1, initial_step: 1.0,
max_step: 1e4,
min_step: 1e-12,
max_fev: 50,
tolerance: 1e-10,
use_safeguarded_interpolation: true,
use_extrapolation: true,
},
"newton" => StrongWolfeOptions {
c1: 1e-4,
c2: 0.5, initial_step: 1.0,
max_step: 1e6,
min_step: 1e-15,
max_fev: 100,
tolerance: 1e-12,
use_safeguarded_interpolation: true,
use_extrapolation: false, },
_ => StrongWolfeOptions::default(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_strong_wolfe_quadratic() {
let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
let mut grad_quadratic =
|x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]]) };
let x = Array1::from_vec(vec![1.0, 1.0]);
let f0 = quadratic(&x.view());
let g0 = grad_quadratic(&x.view());
let direction = Array1::from_vec(vec![-1.0, -1.0]);
let options = StrongWolfeOptions::default();
let result = strong_wolfe_line_search(
&mut quadratic,
&mut grad_quadratic,
&x.view(),
f0,
&g0.view(),
&direction.view(),
&options,
None,
)
.expect("Operation failed");
assert!(result.success);
assert!(result.alpha > 0.0);
assert_abs_diff_eq!(result.alpha, 1.0, epsilon = 1e-6);
}
#[test]
fn test_strong_wolfe_rosenbrock() {
let mut rosenbrock = |x: &ArrayView1<f64>| -> f64 {
let a = 1.0;
let b = 100.0;
(a - x[0]).powi(2) + b * (x[1] - x[0].powi(2)).powi(2)
};
let mut grad_rosenbrock = |x: &ArrayView1<f64>| -> Array1<f64> {
let a = 1.0;
let b = 100.0;
let grad_x0 = -2.0 * (a - x[0]) - 4.0 * b * x[0] * (x[1] - x[0].powi(2));
let grad_x1 = 2.0 * b * (x[1] - x[0].powi(2));
Array1::from_vec(vec![grad_x0, grad_x1])
};
let x = Array1::from_vec(vec![0.0, 0.0]);
let f0 = rosenbrock(&x.view());
let g0 = grad_rosenbrock(&x.view());
let direction = -&g0;
let options = create_strong_wolfe_options_for_method("bfgs");
let result = strong_wolfe_line_search(
&mut rosenbrock,
&mut grad_rosenbrock,
&x.view(),
f0,
&g0.view(),
&direction.view(),
&options,
None,
)
.expect("Operation failed");
assert!(result.success);
assert!(result.alpha > 0.0);
assert!(result.f_new < f0); }
#[test]
fn test_safeguarded_interpolation() {
let alpha_lo = 0.0;
let alpha_hi = 1.0;
let f_lo = 1.0;
let f_hi = 0.5;
let derphi_lo = -1.0;
let derphi0 = -1.0;
let alpha = safeguarded_interpolation(alpha_lo, alpha_hi, f_lo, f_hi, derphi_lo, derphi0);
assert!(alpha >= alpha_lo + 0.1 * (alpha_hi - alpha_lo));
assert!(alpha <= alpha_hi - 0.1 * (alpha_hi - alpha_lo));
}
#[test]
fn test_method_specific_options() {
let bfgs_opts = create_strong_wolfe_options_for_method("bfgs");
assert_eq!(bfgs_opts.c2, 0.9);
let cg_opts = create_strong_wolfe_options_for_method("cg");
assert_eq!(cg_opts.c2, 0.1);
let newton_opts = create_strong_wolfe_options_for_method("newton");
assert_eq!(newton_opts.c2, 0.5);
assert!(!newton_opts.use_extrapolation);
}
}