use num_traits::Float;
use crate::convergence::{dot, norm, ConvergenceParams};
use crate::line_search::{backtracking_armijo, ArmijoParams};
use crate::objective::Objective;
use crate::result::{LbfgsDiagnostics, OptimResult, SolverDiagnostics, TerminationReason};
#[derive(Debug, Clone)]
pub struct LbfgsConfig<F> {
pub memory: usize,
pub convergence: ConvergenceParams<F>,
pub line_search: ArmijoParams<F>,
}
impl Default for LbfgsConfig<f64> {
fn default() -> Self {
LbfgsConfig {
memory: 10,
convergence: ConvergenceParams::default(),
line_search: ArmijoParams::default(),
}
}
}
impl Default for LbfgsConfig<f32> {
fn default() -> Self {
LbfgsConfig {
memory: 10,
convergence: ConvergenceParams::default(),
line_search: ArmijoParams::default(),
}
}
}
pub fn lbfgs<F: Float, O: Objective<F>>(
obj: &mut O,
x0: &[F],
config: &LbfgsConfig<F>,
) -> OptimResult<F> {
let n = x0.len();
let mut diag = LbfgsDiagnostics::default();
if config.memory == 0 || config.convergence.max_iter == 0 {
return OptimResult {
x: x0.to_vec(),
value: F::nan(),
gradient: vec![F::nan(); n],
gradient_norm: F::nan(),
iterations: 0,
func_evals: 0,
termination: TerminationReason::NumericalError,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
let mut x = x0.to_vec();
let (mut f_val, mut grad) = obj.eval_grad(&x);
let mut func_evals = 1usize;
let mut grad_norm = norm(&grad);
if !grad_norm.is_finite() || !f_val.is_finite() {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: 0,
func_evals,
termination: TerminationReason::NumericalError,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
if grad_norm < config.convergence.grad_tol {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: 0,
func_evals,
termination: TerminationReason::GradientNorm,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
let m = config.memory;
let mut s_hist: Vec<Vec<F>> = Vec::with_capacity(m);
let mut y_hist: Vec<Vec<F>> = Vec::with_capacity(m);
let mut rho_hist: Vec<F> = Vec::with_capacity(m);
for iter in 0..config.convergence.max_iter {
let (d, gamma_clamped) = two_loop_recursion(&grad, &s_hist, &y_hist, &rho_hist);
if gamma_clamped {
diag.gamma_clamp_hits += 1;
}
let ls = match backtracking_armijo(obj, &x, &d, f_val, &grad, &config.line_search) {
Some(ls) => ls,
None => {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: iter,
func_evals,
termination: TerminationReason::LineSearchFailed,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
};
func_evals += ls.evals;
diag.line_search_backtracks += ls.evals.saturating_sub(1);
let mut s = vec![F::zero(); n];
let mut y = vec![F::zero(); n];
for i in 0..n {
s[i] = ls.alpha * d[i];
y[i] = ls.gradient[i] - grad[i];
x[i] = x[i] + s[i];
}
let f_prev = f_val;
f_val = ls.value;
grad = ls.gradient;
grad_norm = norm(&grad);
let sy = dot(&s, &y);
let ss = dot(&s, &s);
let yy = dot(&y, &y);
let cs_scale = (ss * yy).sqrt();
if sy > F::epsilon() * cs_scale {
if s_hist.len() == m {
s_hist.remove(0);
y_hist.remove(0);
rho_hist.remove(0);
diag.pairs_evicted_by_memory += 1;
}
rho_hist.push(F::one() / sy);
s_hist.push(s);
y_hist.push(y);
diag.pairs_accepted += 1;
} else {
diag.pairs_curvature_rejected += 1;
}
if !grad_norm.is_finite() || !f_val.is_finite() {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: iter + 1,
func_evals,
termination: TerminationReason::NumericalError,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
if grad_norm < config.convergence.grad_tol {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: iter + 1,
func_evals,
termination: TerminationReason::GradientNorm,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
let step_norm = norm_step(ls.alpha, &d);
if step_norm < config.convergence.step_tol {
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: iter + 1,
func_evals,
termination: TerminationReason::StepSize,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
if config.convergence.func_tol > F::zero()
&& (f_prev - f_val).abs() < config.convergence.func_tol * (F::one() + f_val.abs())
{
return OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: iter + 1,
func_evals,
termination: TerminationReason::FunctionChange,
diagnostics: SolverDiagnostics::Lbfgs(diag),
};
}
}
OptimResult {
x,
value: f_val,
gradient: grad,
gradient_norm: grad_norm,
iterations: config.convergence.max_iter,
func_evals,
termination: TerminationReason::MaxIterations,
diagnostics: SolverDiagnostics::Lbfgs(diag),
}
}
fn two_loop_recursion<F: Float>(
grad: &[F],
s_hist: &[Vec<F>],
y_hist: &[Vec<F>],
rho_hist: &[F],
) -> (Vec<F>, bool) {
let k = s_hist.len();
let n = grad.len();
let mut gamma_clamp_hit = false;
let mut q: Vec<F> = grad.to_vec();
let mut alpha = vec![F::zero(); k];
for i in (0..k).rev() {
alpha[i] = rho_hist[i] * dot(&s_hist[i], &q);
for j in 0..n {
q[j] = q[j] - alpha[i] * y_hist[i][j];
}
}
let mut r = q;
if k > 0 {
let sy = dot(&s_hist[k - 1], &y_hist[k - 1]);
let yy = dot(&y_hist[k - 1], &y_hist[k - 1]);
if yy > F::epsilon() {
let raw_gamma = sy / yy;
let lo = F::from(1e-3).unwrap();
let hi = F::from(1e3).unwrap();
let gamma = if raw_gamma.is_finite() {
if raw_gamma < lo || raw_gamma > hi {
gamma_clamp_hit = true;
}
raw_gamma.max(lo).min(hi)
} else {
gamma_clamp_hit = true;
F::one()
};
for v in r.iter_mut() {
*v = *v * gamma;
}
}
}
for i in 0..k {
let beta = rho_hist[i] * dot(&y_hist[i], &r);
for j in 0..n {
r[j] = r[j] + (alpha[i] - beta) * s_hist[i][j];
}
}
for v in r.iter_mut() {
*v = F::zero() - *v;
}
(r, gamma_clamp_hit)
}
fn norm_step<F: Float>(alpha: F, d: &[F]) -> F {
let mut s = F::zero();
for &di in d {
let step = alpha * di;
s = s + step * step;
}
s.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
struct Rosenbrock;
impl Objective<f64> for Rosenbrock {
fn dim(&self) -> usize {
2
}
fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
let a = 1.0 - x[0];
let b = x[1] - x[0] * x[0];
let f = a * a + 100.0 * b * b;
let g0 = -2.0 * a - 400.0 * x[0] * b;
let g1 = 200.0 * b;
(f, vec![g0, g1])
}
}
#[test]
fn lbfgs_rosenbrock() {
let mut obj = Rosenbrock;
let config = LbfgsConfig::default();
let result = lbfgs(&mut obj, &[0.0, 0.0], &config);
assert_eq!(result.termination, TerminationReason::GradientNorm);
assert!(
(result.x[0] - 1.0).abs() < 1e-6,
"x[0] = {}, expected 1.0",
result.x[0]
);
assert!(
(result.x[1] - 1.0).abs() < 1e-6,
"x[1] = {}, expected 1.0",
result.x[1]
);
assert!(result.gradient_norm < 1e-8);
}
#[test]
fn lbfgs_already_converged() {
let mut obj = Rosenbrock;
let config = LbfgsConfig::default();
let result = lbfgs(&mut obj, &[1.0, 1.0], &config);
assert_eq!(result.termination, TerminationReason::GradientNorm);
assert_eq!(result.iterations, 0);
}
}