use super::line_search::{backtracking_line_search, wolfe_line_search, LineSearchMethod};
use super::utilities::{dot_product, tensor_add, tensor_norm, tensor_scalar_mul, tensor_sub};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct BFGSParams {
pub initial_hessian_scale: f32,
pub max_iter: usize,
pub tolerance: f32,
pub line_search: LineSearchMethod,
}
impl Default for BFGSParams {
fn default() -> Self {
Self {
initial_hessian_scale: 1.0,
max_iter: 1000,
tolerance: 1e-6,
line_search: LineSearchMethod::Wolfe,
}
}
}
pub fn lbfgs_optimizer<F, G>(
objective: F,
gradient: G,
x0: &Tensor,
m: Option<usize>,
params: Option<BFGSParams>,
) -> TorshResult<(Tensor, Vec<f32>)>
where
F: Fn(&Tensor) -> TorshResult<f32>,
G: Fn(&Tensor) -> TorshResult<Tensor>,
{
let params = params.unwrap_or_default();
let m = m.unwrap_or(10);
let mut x = x0.clone();
let mut objective_values = Vec::new();
let mut s_history: Vec<Tensor> = Vec::with_capacity(m);
let mut y_history: Vec<Tensor> = Vec::with_capacity(m);
let mut rho_history: Vec<f32> = Vec::with_capacity(m);
let mut _grad_prev = gradient(&x)?;
for iter in 0..params.max_iter {
let f_val = objective(&x)?;
objective_values.push(f_val);
let grad = gradient(&x)?;
let grad_norm = tensor_norm(&grad)?;
if grad_norm < params.tolerance {
break;
}
let p = if iter == 0 {
tensor_scalar_mul(&grad, -1.0)?
} else {
lbfgs_two_loop_recursion(&grad, &s_history, &y_history, &rho_history)?
};
let alpha = match params.line_search {
LineSearchMethod::Wolfe => wolfe_line_search(&objective, &gradient, &x, &p, None)?,
LineSearchMethod::Backtracking => {
backtracking_line_search(&objective, &gradient, &x, &p, None)?
}
_ => {
return Err(TorshError::InvalidArgument(
"Unsupported line search method for L-BFGS".to_string(),
))
}
};
let x_new = tensor_add(&x, &tensor_scalar_mul(&p, alpha)?)?;
let grad_new = gradient(&x_new)?;
let s = tensor_scalar_mul(&p, alpha)?;
let y = tensor_sub(&grad_new, &grad)?;
let rho = 1.0 / dot_product(&y, &s)?;
if s_history.len() == m {
s_history.remove(0);
y_history.remove(0);
rho_history.remove(0);
}
s_history.push(s);
y_history.push(y);
rho_history.push(rho);
x = x_new;
_grad_prev = grad;
if iter % 100 == 0 {
println!(
"Iteration {}: f = {:.6e}, |∇f| = {:.6e}, α = {:.6e}",
iter, f_val, grad_norm, alpha
);
}
}
Ok((x, objective_values))
}
fn lbfgs_two_loop_recursion(
grad: &Tensor,
s_history: &[Tensor],
y_history: &[Tensor],
rho_history: &[f32],
) -> TorshResult<Tensor> {
let m = s_history.len();
let mut q = tensor_scalar_mul(grad, -1.0)?;
let mut alpha = vec![0.0; m];
for i in (0..m).rev() {
alpha[i] = rho_history[i] * dot_product(&s_history[i], &q)?;
q = tensor_sub(&q, &tensor_scalar_mul(&y_history[i], alpha[i])?)?;
}
let mut r = q;
if !s_history.is_empty() {
let last_idx = m - 1;
let sy = dot_product(&s_history[last_idx], &y_history[last_idx])?;
let yy = dot_product(&y_history[last_idx], &y_history[last_idx])?;
let gamma = sy / yy;
r = tensor_scalar_mul(&r, gamma)?;
}
for i in 0..m {
let beta = rho_history[i] * dot_product(&y_history[i], &r)?;
r = tensor_add(&r, &tensor_scalar_mul(&s_history[i], alpha[i] - beta)?)?;
}
Ok(r)
}