use crate::primitives::Vector;
use super::line_search::{LineSearch, WolfeLineSearch};
use super::{ConvergenceStatus, OptimizationResult, Optimizer};
#[derive(Debug, Clone)]
pub struct LBFGS {
pub(crate) max_iter: usize,
pub(crate) tol: f32,
pub(crate) m: usize,
line_search: WolfeLineSearch,
pub(crate) s_history: Vec<Vector<f32>>,
pub(crate) y_history: Vec<Vector<f32>>,
}
impl LBFGS {
#[must_use]
pub fn new(max_iter: usize, tol: f32, m: usize) -> Self {
Self {
max_iter,
tol,
m,
line_search: WolfeLineSearch::new(1e-4, 0.9, 50),
s_history: Vec::with_capacity(m),
y_history: Vec::with_capacity(m),
}
}
fn compute_direction(&self, grad: &Vector<f32>) -> Vector<f32> {
let n = grad.len();
let k = self.s_history.len();
if k == 0 {
let mut d = Vector::zeros(n);
for i in 0..n {
d[i] = -grad[i];
}
return d;
}
let mut q = Vector::zeros(n);
for i in 0..n {
q[i] = -grad[i];
}
let mut alpha = vec![0.0; k];
let mut rho = vec![0.0; k];
for i in (0..k).rev() {
let s = &self.s_history[i];
let y = &self.y_history[i];
let mut y_dot_s = 0.0;
for j in 0..n {
y_dot_s += y[j] * s[j];
}
rho[i] = 1.0 / y_dot_s;
let mut s_dot_q = 0.0;
for j in 0..n {
s_dot_q += s[j] * q[j];
}
alpha[i] = rho[i] * s_dot_q;
for j in 0..n {
q[j] -= alpha[i] * y[j];
}
}
let s_last = &self.s_history[k - 1];
let y_last = &self.y_history[k - 1];
let mut s_dot_y = 0.0;
let mut y_dot_y = 0.0;
for i in 0..n {
s_dot_y += s_last[i] * y_last[i];
y_dot_y += y_last[i] * y_last[i];
}
let gamma = s_dot_y / y_dot_y;
let mut r = Vector::zeros(n);
for i in 0..n {
r[i] = gamma * q[i];
}
for i in 0..k {
let s = &self.s_history[i];
let y = &self.y_history[i];
let mut y_dot_r = 0.0;
for j in 0..n {
y_dot_r += y[j] * r[j];
}
let beta = rho[i] * y_dot_r;
for j in 0..n {
r[j] += s[j] * (alpha[i] - beta);
}
}
r
}
fn norm(v: &Vector<f32>) -> f32 {
let mut sum = 0.0;
for i in 0..v.len() {
sum += v[i] * v[i];
}
sum.sqrt()
}
}
impl Optimizer for LBFGS {
#[provable_contracts_macros::contract("lbfgs-kernel-v1", equation = "two_loop_recursion")]
fn step(&mut self, _params: &mut Vector<f32>, _gradients: &Vector<f32>) {
panic!(
"L-BFGS does not support stochastic updates (step). Use minimize() for batch optimization."
)
}
fn minimize<F, G>(&mut self, objective: F, gradient: G, x0: Vector<f32>) -> OptimizationResult
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
{
let start_time = std::time::Instant::now();
let n = x0.len();
self.s_history.clear();
self.y_history.clear();
let mut x = x0;
let mut fx = objective(&x);
let mut grad = gradient(&x);
let mut grad_norm = Self::norm(&grad);
for iter in 0..self.max_iter {
if grad_norm < self.tol {
return OptimizationResult {
solution: x,
objective_value: fx,
iterations: iter,
status: ConvergenceStatus::Converged,
gradient_norm: grad_norm,
constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
};
}
let d = self.compute_direction(&grad);
let alpha = self.line_search.search(&objective, &gradient, &x, &d);
if alpha < 1e-12 {
return OptimizationResult {
solution: x,
objective_value: fx,
iterations: iter,
status: ConvergenceStatus::Stalled,
gradient_norm: grad_norm,
constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
};
}
let mut x_new = Vector::zeros(n);
for i in 0..n {
x_new[i] = x[i] + alpha * d[i];
}
let fx_new = objective(&x_new);
let grad_new = gradient(&x_new);
if fx_new.is_nan() || fx_new.is_infinite() {
return OptimizationResult {
solution: x,
objective_value: fx,
iterations: iter,
status: ConvergenceStatus::NumericalError,
gradient_norm: grad_norm,
constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
};
}
let mut s_k = Vector::zeros(n);
let mut y_k = Vector::zeros(n);
for i in 0..n {
s_k[i] = x_new[i] - x[i];
y_k[i] = grad_new[i] - grad[i];
}
let mut y_dot_s = 0.0;
for i in 0..n {
y_dot_s += y_k[i] * s_k[i];
}
if y_dot_s > 1e-10 {
if self.s_history.len() >= self.m {
self.s_history.remove(0);
self.y_history.remove(0);
}
self.s_history.push(s_k);
self.y_history.push(y_k);
}
x = x_new;
fx = fx_new;
grad = grad_new;
grad_norm = Self::norm(&grad);
}
OptimizationResult {
solution: x,
objective_value: fx,
iterations: self.max_iter,
status: ConvergenceStatus::MaxIterations,
gradient_norm: grad_norm,
constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
}
}
fn reset(&mut self) {
self.s_history.clear();
self.y_history.clear();
}
}
#[cfg(test)]
#[path = "lbfgs_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_lbfgs_contract.rs"]
mod tests_lbfgs_contract;