use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::collections::VecDeque;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct LBFGS<A: Float + ScalarOperand + Debug> {
learning_rate: A,
history_size: usize,
tolerance_grad: A,
#[allow(dead_code)]
c1: A,
#[allow(dead_code)]
c2: A,
#[allow(dead_code)]
max_ls: usize,
old_dirs: VecDeque<Array1<A>>,
old_stps: VecDeque<Array1<A>>,
ro: VecDeque<A>,
prev_grad: Option<Array1<A>>,
h_diag: A,
n_iter: usize,
alpha: Vec<A>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> LBFGS<A> {
pub fn new(learning_rate: A) -> Self {
Self::new_with_config(
learning_rate,
100, A::from(1e-7).expect("unwrap failed"), A::from(1e-4).expect("unwrap failed"), A::from(0.9).expect("unwrap failed"), 25, )
}
pub fn new_with_config(
learning_rate: A,
history_size: usize,
tolerance_grad: A,
c1: A,
c2: A,
max_ls: usize,
) -> Self {
Self {
learning_rate,
history_size,
tolerance_grad,
c1,
c2,
max_ls,
old_dirs: VecDeque::with_capacity(history_size),
old_stps: VecDeque::with_capacity(history_size),
ro: VecDeque::with_capacity(history_size),
prev_grad: None,
h_diag: A::one(),
n_iter: 0,
alpha: vec![A::zero(); history_size],
}
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.old_dirs.clear();
self.old_stps.clear();
self.ro.clear();
self.prev_grad = None;
self.h_diag = A::one();
self.n_iter = 0;
self.alpha.fill(A::zero());
}
fn compute_direction(&mut self, gradient: &Array1<A>) -> Array1<A> {
if self.n_iter == 0 {
return gradient.mapv(|x| -x);
}
let num_old = self.old_dirs.len();
let mut q = gradient.mapv(|x| -x);
for i in (0..num_old).rev() {
self.alpha[i] = self.old_stps[i].dot(&q) * self.ro[i];
q = &q - &self.old_dirs[i] * self.alpha[i];
}
let mut r = q * self.h_diag;
for i in 0..num_old {
let beta = self.old_dirs[i].dot(&r) * self.ro[i];
r = &r + &self.old_stps[i] * (self.alpha[i] - beta);
}
r
}
fn update_history(&mut self, y: Array1<A>, s: Array1<A>) {
let ys = y.dot(&s);
if ys > A::from(1e-10).expect("unwrap failed") {
if self.old_dirs.len() >= self.history_size {
self.old_dirs.pop_front();
self.old_stps.pop_front();
self.ro.pop_front();
}
self.old_dirs.push_back(y.clone());
self.old_stps.push_back(s);
self.ro.push_back(A::one() / ys);
let yy = y.dot(&y);
if yy > A::zero() {
self.h_diag = ys / yy;
}
}
}
}
impl<A, D> Optimizer<A, D> for LBFGS<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_flat = params
.to_owned()
.into_shape_with_order(params.len())
.expect("unwrap failed");
let gradients_flat = gradients
.to_owned()
.into_shape_with_order(gradients.len())
.expect("unwrap failed");
let grad_norm = gradients_flat.dot(&gradients_flat).sqrt();
if grad_norm <= self.tolerance_grad {
return Ok(params.clone());
}
if let Some(prev_grad) = self.prev_grad.clone() {
let y = &gradients_flat - &prev_grad;
if self.n_iter > 0 {
let direction = self.compute_direction(&prev_grad);
let step_size = if self.n_iter == 1 {
self.learning_rate / (A::one() + grad_norm)
} else {
self.learning_rate
};
let s = direction * step_size;
self.update_history(y, s);
}
}
let direction = self.compute_direction(&gradients_flat);
let step_size = if self.n_iter == 0 {
self.learning_rate / (A::one() + grad_norm)
} else {
self.learning_rate
};
let new_params_flat = ¶ms_flat + &(&direction * step_size);
self.prev_grad = Some(gradients_flat.clone());
self.n_iter += 1;
let new_params = new_params_flat
.into_shape_with_order(params.raw_dim())
.expect("unwrap failed");
Ok(new_params)
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_lbfgs_basic_creation() {
let optimizer: LBFGS<f64> = LBFGS::new(1.0);
assert_abs_diff_eq!(optimizer.learning_rate(), 1.0);
assert_eq!(optimizer.history_size, 100);
assert_abs_diff_eq!(optimizer.tolerance_grad, 1e-7);
}
#[test]
fn test_lbfgs_convergence() {
let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
let mut params = Array1::from_vec(vec![10.0]);
for _ in 0..50 {
let gradients = Array1::from_vec(vec![2.0 * params[0]]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
assert!(params[0].abs() < 0.1);
}
#[test]
fn test_lbfgs_2d() {
let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
let mut params = Array1::from_vec(vec![5.0, 3.0]);
for _ in 0..50 {
let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
assert!(params[0].abs() < 0.1);
assert!(params[1].abs() < 0.1);
}
#[test]
fn test_lbfgs_reset() {
let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
let mut params = Array1::from_vec(vec![1.0]);
let gradients = Array1::from_vec(vec![2.0]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let gradients2 = Array1::from_vec(vec![1.5]);
params = optimizer.step(¶ms, &gradients2).expect("unwrap failed");
let gradients3 = Array1::from_vec(vec![1.0]);
let _ = optimizer.step(¶ms, &gradients3).expect("unwrap failed");
assert!(!optimizer.old_dirs.is_empty());
assert!(optimizer.n_iter > 0);
optimizer.reset();
assert!(optimizer.old_dirs.is_empty());
assert!(optimizer.old_stps.is_empty());
assert!(optimizer.ro.is_empty());
assert!(optimizer.prev_grad.is_none());
assert_eq!(optimizer.n_iter, 0);
}
}