use super::Optimizer;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
pub struct LBFGS {
learning_rate: f32,
max_iter: usize,
max_eval: usize,
tolerance_grad: f32,
tolerance_change: f32,
line_search_fn: LineSearchMethod,
step_count: usize,
memory: HashMap<usize, LBFGSMemory>,
prev_grad: HashMap<usize, Tensor<f32>>,
}
#[derive(Debug, Clone)]
pub enum LineSearchMethod {
StrongWolfe { c1: f32, c2: f32 },
Backtracking { c1: f32, rho: f32 },
None,
}
impl Default for LineSearchMethod {
fn default() -> Self {
LineSearchMethod::StrongWolfe { c1: 1e-4, c2: 0.9 }
}
}
#[derive(Debug, Clone)]
pub struct LBFGSMemory {
s_history: VecDeque<Tensor<f32>>, y_history: VecDeque<Tensor<f32>>, rho_history: VecDeque<f32>, max_size: usize,
}
impl LBFGSMemory {
pub fn new(max_size: usize) -> Self {
Self {
s_history: VecDeque::new(),
y_history: VecDeque::new(),
rho_history: VecDeque::new(),
max_size: max_size.max(1),
}
}
pub fn update(&mut self, s: Tensor<f32>, y: Tensor<f32>) -> RusTorchResult<()> {
let s_y_product = &s * &y;
let y_dot_s = s_y_product.sum();
if y_dot_s.abs() < 1e-12 {
return Err(RusTorchError::InvalidParameters {
operation: "L-BFGS memory update".to_string(),
message: "Insufficient curvature condition: y^T * s too small".to_string(),
});
}
let rho = 1.0 / y_dot_s;
self.s_history.push_back(s);
self.y_history.push_back(y);
self.rho_history.push_back(rho);
while self.s_history.len() > self.max_size {
self.s_history.pop_front();
self.y_history.pop_front();
self.rho_history.pop_front();
}
Ok(())
}
pub fn size(&self) -> usize {
self.s_history.len()
}
pub fn is_empty(&self) -> bool {
self.s_history.is_empty()
}
pub fn clear(&mut self) {
self.s_history.clear();
self.y_history.clear();
self.rho_history.clear();
}
pub fn compute_search_direction(&self, grad: &Tensor<f32>) -> Tensor<f32> {
if self.is_empty() {
return grad.clone() * (-1.0);
}
let mut q = grad.clone();
let mut alphas = Vec::with_capacity(self.size());
for i in (0..self.s_history.len()).rev() {
let rho = self.rho_history[i];
let s = &self.s_history[i];
let y = &self.y_history[i];
let s_q_product = s * &q;
let alpha = rho * s_q_product.sum();
alphas.push(alpha);
let y_term = y * alpha;
q = &q - &y_term;
}
let gamma =
if let (Some(s_last), Some(y_last)) = (self.s_history.back(), self.y_history.back()) {
let s_y_product = s_last * y_last;
let s_dot_y = s_y_product.sum();
let y_squared = y_last * y_last;
let y_dot_y = y_squared.sum();
if y_dot_y > 1e-12 {
(s_dot_y / y_dot_y).clamp(1e-8, 1e8) } else {
1.0
}
} else {
1.0
};
let mut r = &q * gamma;
alphas.reverse();
for (i, &alpha) in alphas.iter().enumerate() {
if i < self.s_history.len() {
let rho = self.rho_history[i];
let s = &self.s_history[i];
let y = &self.y_history[i];
let y_r_product = y * &r;
let beta = rho * y_r_product.sum();
let s_term = s * (alpha - beta);
r = &r + &s_term;
}
}
r * (-1.0)
}
}
impl LBFGS {
pub fn new(learning_rate: f32) -> RusTorchResult<Self> {
Self::with_params(
learning_rate,
20, 20, 1e-5, 1e-9, 10, LineSearchMethod::default(),
)
}
pub fn with_params(
learning_rate: f32,
max_iter: usize,
max_eval: usize,
tolerance_grad: f32,
tolerance_change: f32,
history_size: usize,
line_search_fn: LineSearchMethod,
) -> RusTorchResult<Self> {
if learning_rate <= 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "L-BFGS optimizer".to_string(),
message: "Learning rate must be positive".to_string(),
});
}
if tolerance_grad < 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "L-BFGS optimizer".to_string(),
message: "Gradient tolerance must be non-negative".to_string(),
});
}
if history_size == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "L-BFGS optimizer".to_string(),
message: "History size must be positive".to_string(),
});
}
Ok(Self {
learning_rate,
max_iter,
max_eval,
tolerance_grad,
tolerance_change,
line_search_fn,
step_count: 0,
memory: HashMap::new(),
prev_grad: HashMap::new(),
})
}
pub fn set_max_iter(&mut self, max_iter: usize) {
self.max_iter = max_iter;
}
pub fn set_tolerance_grad(&mut self, tolerance: f32) -> RusTorchResult<()> {
if tolerance < 0.0 {
return Err(RusTorchError::InvalidParameters {
operation: "L-BFGS optimizer".to_string(),
message: "Gradient tolerance must be non-negative".to_string(),
});
}
self.tolerance_grad = tolerance;
Ok(())
}
pub fn reset_memory(&mut self, param_id: Option<usize>) {
match param_id {
Some(id) => {
if let Some(memory) = self.memory.get_mut(&id) {
memory.clear();
}
self.prev_grad.remove(&id);
}
None => {
self.memory.clear();
self.prev_grad.clear();
self.step_count = 0;
}
}
}
fn compute_search_direction(&self, param_id: usize, grad: &Tensor<f32>) -> Tensor<f32> {
if let Some(memory) = self.memory.get(¶m_id) {
memory.compute_search_direction(grad)
} else {
grad.clone() * (-1.0)
}
}
fn backtracking_line_search(
&self,
param: &Tensor<f32>,
grad: &Tensor<f32>,
direction: &Tensor<f32>,
c1: f32,
rho: f32,
) -> f32 {
let max_iterations = 25; let mut alpha = self.learning_rate;
let grad_dir_product = grad * direction;
let directional_derivative = grad_dir_product.sum();
if directional_derivative >= -1e-12 {
return 1e-6; }
let grad_squared = grad * grad;
let f0 = grad_squared.sum();
for iteration in 0..max_iterations {
let reduction_estimate = alpha * c1 * directional_derivative;
let expected_improvement = reduction_estimate.abs();
if alpha > 1e-10 && expected_improvement > 1e-10 {
return alpha;
}
alpha *= rho;
if alpha < 1e-10 || iteration == max_iterations - 1 {
break;
}
}
alpha.clamp(1e-10, 1.0)
}
fn strong_wolfe_line_search(
&self,
param: &Tensor<f32>,
grad: &Tensor<f32>,
direction: &Tensor<f32>,
c1: f32,
c2: f32,
) -> f32 {
let max_iterations = 15;
let grad_dir_product = grad * direction;
let directional_derivative = grad_dir_product.sum();
if directional_derivative >= -1e-12 {
return 1e-6; }
let alpha_low = 0.0;
let mut alpha_high = f32::INFINITY;
let mut alpha = self.learning_rate.min(1.0);
for iteration in 0..max_iterations {
let armijo_threshold = alpha * c1 * directional_derivative;
let armijo_satisfied = armijo_threshold.abs() > 1e-12;
if armijo_satisfied {
if alpha > 1e-10 {
return alpha;
}
}
if alpha_high == f32::INFINITY {
if iteration < max_iterations / 2 {
alpha *= 1.6; } else {
alpha_high = alpha;
alpha = (alpha_low + alpha_high) * 0.5;
}
} else {
let new_alpha = (alpha_low + alpha_high) * 0.5;
if (alpha_high - alpha_low) < 1e-12 * (alpha_high + alpha_low) {
alpha = new_alpha;
break;
}
alpha = new_alpha;
}
if alpha < 1e-12 || alpha > 100.0 {
break;
}
}
alpha.clamp(1e-12, 10.0)
}
fn update_memory(
&mut self,
param_id: usize,
param_change: Tensor<f32>,
grad_change: Tensor<f32>,
history_size: usize,
) -> RusTorchResult<()> {
let memory = self
.memory
.entry(param_id)
.or_insert_with(|| LBFGSMemory::new(history_size));
memory.update(param_change, grad_change)
}
}
impl Optimizer for LBFGS {
fn step(&mut self, param: &Tensor<f32>, grad: &Tensor<f32>) {
let param_id = param.as_ptr() as usize;
self.step_count += 1;
let old_param = param.clone();
let grad_squared = grad * grad;
let grad_norm = grad_squared.sum().sqrt();
if grad_norm < self.tolerance_grad {
return; }
let search_direction = self.compute_search_direction(param_id, grad);
let alpha = match &self.line_search_fn {
LineSearchMethod::Backtracking { c1, rho } => {
self.backtracking_line_search(param, grad, &search_direction, *c1, *rho)
}
LineSearchMethod::StrongWolfe { c1, c2 } => {
self.strong_wolfe_line_search(param, grad, &search_direction, *c1, *c2)
}
LineSearchMethod::None => self.learning_rate,
};
let step = &search_direction * alpha;
let new_param = &old_param + &step;
param.copy_from(&new_param);
if let Some(prev_grad) = self.prev_grad.get(¶m_id) {
let param_change = &new_param - &old_param;
let grad_change = grad - prev_grad;
let history_size = 10; if let Err(e) = self.update_memory(param_id, param_change, grad_change, history_size) {
}
}
self.prev_grad.insert(param_id, grad.clone());
}
fn learning_rate(&self) -> f32 {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: f32) {
self.learning_rate = lr;
}
fn state_dict(&self) -> HashMap<String, f32> {
let mut state = HashMap::new();
state.insert("learning_rate".to_string(), self.learning_rate);
state.insert("max_iter".to_string(), self.max_iter as f32);
state.insert("max_eval".to_string(), self.max_eval as f32);
state.insert("tolerance_grad".to_string(), self.tolerance_grad);
state.insert("tolerance_change".to_string(), self.tolerance_change);
state.insert("step_count".to_string(), self.step_count as f32);
state
}
fn load_state_dict(&mut self, state: HashMap<String, f32>) {
if let Some(&lr) = state.get("learning_rate") {
self.learning_rate = lr;
}
if let Some(&max_iter) = state.get("max_iter") {
self.max_iter = max_iter as usize;
}
if let Some(&max_eval) = state.get("max_eval") {
self.max_eval = max_eval as usize;
}
if let Some(&tolerance_grad) = state.get("tolerance_grad") {
self.tolerance_grad = tolerance_grad;
}
if let Some(&tolerance_change) = state.get("tolerance_change") {
self.tolerance_change = tolerance_change;
}
if let Some(&step_count) = state.get("step_count") {
self.step_count = step_count as usize;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_lbfgs_creation() {
let optimizer = LBFGS::new(0.1).unwrap();
assert_eq!(optimizer.learning_rate(), 0.1);
assert_eq!(optimizer.step_count, 0);
}
#[test]
fn test_lbfgs_with_params() {
let optimizer = LBFGS::with_params(
0.01,
50,
100,
1e-6,
1e-10,
15,
LineSearchMethod::Backtracking { c1: 1e-4, rho: 0.5 },
)
.unwrap();
assert_eq!(optimizer.learning_rate(), 0.01);
assert_eq!(optimizer.max_iter, 50);
assert_eq!(optimizer.max_eval, 100);
assert_eq!(optimizer.tolerance_grad, 1e-6);
}
#[test]
fn test_lbfgs_parameter_validation() {
assert!(LBFGS::new(-0.1).is_err());
assert!(LBFGS::new(0.0).is_err());
let result = LBFGS::with_params(0.1, 10, 10, -1e-5, 1e-9, 5, LineSearchMethod::None);
assert!(result.is_err());
let result = LBFGS::with_params(0.1, 10, 10, 1e-5, 1e-9, 0, LineSearchMethod::None);
assert!(result.is_err());
}
#[test]
fn test_lbfgs_step() {
let mut optimizer = LBFGS::new(0.1).unwrap();
let param = Tensor::<f32>::ones(&[3, 3]);
let grad = Tensor::<f32>::ones(&[3, 3]) * 0.1;
let initial_param = param.clone();
optimizer.step(¶m, &grad);
assert_eq!(optimizer.step_count, 1);
let updated_data = param.data.as_slice().unwrap();
let initial_data = initial_param.data.as_slice().unwrap();
assert_ne!(updated_data[0], initial_data[0]);
let grad2 = Tensor::<f32>::ones(&[3, 3]) * 0.05;
optimizer.step(¶m, &grad2);
assert_eq!(optimizer.step_count, 2);
}
#[test]
fn test_lbfgs_memory_management() {
let mut optimizer = LBFGS::new(0.1).unwrap();
let param = Tensor::<f32>::ones(&[2, 2]);
let param_id = param.as_ptr() as usize;
for i in 0..5 {
let param_change = Tensor::<f32>::ones(&[2, 2]) * (i as f32 * 0.1 + 0.01); let grad_change = Tensor::<f32>::ones(&[2, 2]) * (i as f32 * 0.05 + 0.01); let _ = optimizer.update_memory(param_id, param_change, grad_change, 3);
}
if let Some(memory) = optimizer.memory.get(¶m_id) {
assert!(memory.size() <= 3);
}
}
#[test]
fn test_lbfgs_memory_reset() {
let mut optimizer = LBFGS::new(0.1).unwrap();
let param = Tensor::<f32>::ones(&[2, 2]);
let grad = Tensor::<f32>::ones(&[2, 2]) * 0.1;
optimizer.step(¶m, &grad);
assert_eq!(optimizer.step_count, 1);
optimizer.reset_memory(None);
assert_eq!(optimizer.step_count, 0);
assert!(optimizer.memory.is_empty());
assert!(optimizer.prev_grad.is_empty());
}
#[test]
fn test_lbfgs_search_direction() {
let optimizer = LBFGS::new(0.1).unwrap();
let grad = Tensor::<f32>::ones(&[2, 2]);
let param_id = 12345;
let direction = optimizer.compute_search_direction(param_id, &grad);
let expected = grad.clone() * (-1.0);
let dir_data = direction.data.as_slice().unwrap();
let exp_data = expected.data.as_slice().unwrap();
for (d, e) in dir_data.iter().zip(exp_data.iter()) {
assert!((d - e).abs() < 1e-6);
}
}
#[test]
fn test_lbfgs_state_dict() {
let optimizer =
LBFGS::with_params(0.05, 25, 50, 1e-4, 1e-8, 8, LineSearchMethod::None).unwrap();
let state = optimizer.state_dict();
assert_eq!(state["learning_rate"], 0.05);
assert_eq!(state["max_iter"], 25.0);
assert_eq!(state["tolerance_grad"], 1e-4);
}
#[test]
fn test_lbfgs_convergence_check() {
let mut optimizer = LBFGS::new(0.1).unwrap();
optimizer.set_tolerance_grad(1e-2).unwrap();
let param = Tensor::<f32>::ones(&[2, 2]);
let small_grad = Tensor::<f32>::ones(&[2, 2]) * 1e-3;
let initial_param = param.clone();
optimizer.step(¶m, &small_grad);
let updated_data = param.data.as_slice().unwrap();
let initial_data = initial_param.data.as_slice().unwrap();
for (u, i) in updated_data.iter().zip(initial_data.iter()) {
assert!((u - i).abs() < 1e-6);
}
}
#[test]
fn test_lbfgs_line_search_methods() {
let param = Tensor::<f32>::ones(&[2, 2]);
let grad = Tensor::<f32>::ones(&[2, 2]) * 0.1;
let direction = grad.clone() * (-1.0);
let optimizer1 = LBFGS::with_params(
0.1,
10,
10,
1e-5,
1e-9,
5,
LineSearchMethod::Backtracking { c1: 1e-4, rho: 0.5 },
)
.unwrap();
let alpha1 = optimizer1.backtracking_line_search(¶m, &grad, &direction, 1e-4, 0.5);
assert!(alpha1 > 0.0);
assert!(alpha1 <= 1.0);
let optimizer2 = LBFGS::with_params(
0.1,
10,
10,
1e-5,
1e-9,
5,
LineSearchMethod::StrongWolfe { c1: 1e-4, c2: 0.9 },
)
.unwrap();
let alpha2 = optimizer2.strong_wolfe_line_search(¶m, &grad, &direction, 1e-4, 0.9);
assert!(alpha2 > 0.0);
assert!(alpha2 <= 10.0);
}
}