use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::MinimizeOptions;
use super::helpers::{TensorMinimizeResult, backtracking_line_search_tensor};
use super::utils::{SINGULAR_THRESHOLD, finite_difference_gradient, tensor_norm};
pub fn bfgs_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &MinimizeOptions,
) -> OptimizeResult<TensorMinimizeResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let n = x0.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "bfgs: empty initial guess".to_string(),
});
}
let mut x = x0.clone();
let mut fx = f(&x).map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: initial evaluation - {}", e),
})?;
let mut nfev = 1;
let mut grad = finite_difference_gradient(client, &f, &x, fx, options.eps).map_err(|e| {
OptimizeError::NumericalError {
message: format!("bfgs: gradient - {}", e),
}
})?;
nfev += n;
let mut h_inv = create_identity_matrix::<R, C>(client, n)?;
for iter in 0..options.max_iter {
let grad_norm = tensor_norm(client, &grad).map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: grad norm - {}", e),
})?;
if grad_norm < options.g_tol {
return Ok(TensorMinimizeResult {
x,
fun: fx,
iterations: iter + 1,
nfev,
converged: true,
});
}
let grad_col = grad
.reshape(&[n, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: grad reshape - {}", e),
})?;
let h_grad =
client
.matmul(&h_inv, &grad_col)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: h_inv @ grad - {}", e),
})?;
let h_grad_flat = h_grad
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: h_grad reshape - {}", e),
})?;
let p =
client
.mul_scalar(&h_grad_flat, -1.0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: negate direction - {}", e),
})?;
let (x_new, fx_new, evals) =
backtracking_line_search_tensor(client, &f, &x, &p, fx, &grad)?;
nfev += evals;
let s = client
.sub(&x_new, &x)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s = x_new - x - {}", e),
})?;
let s_norm = tensor_norm(client, &s).map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s norm - {}", e),
})?;
if s_norm < options.x_tol || (fx - fx_new).abs() < options.f_tol {
return Ok(TensorMinimizeResult {
x: x_new,
fun: fx_new,
iterations: iter + 1,
nfev,
converged: true,
});
}
let grad_new = finite_difference_gradient(client, &f, &x_new, fx_new, options.eps)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: new gradient - {}", e),
})?;
nfev += n;
let y = client
.sub(&grad_new, &grad)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: y = grad_new - grad - {}", e),
})?;
let ys = tensor_inner_product(client, &y, &s)?;
if ys.abs() > SINGULAR_THRESHOLD {
let rho = 1.0 / ys;
let s_col = s
.reshape(&[n, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s reshape - {}", e),
})?;
let y_col = y
.reshape(&[n, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: y reshape - {}", e),
})?;
let s_row = s
.reshape(&[1, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s row reshape - {}", e),
})?;
let y_row = y
.reshape(&[1, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: y row reshape - {}", e),
})?;
let s_st =
client
.matmul(&s_col, &s_row)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s @ s.T - {}", e),
})?;
let h_y = client
.matmul(&h_inv, &y_col)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: H_inv @ y - {}", e),
})?;
let y_row_h = y_row.clone();
let yhy =
{
let yt_h = client.matmul(&y_row_h, &h_inv).map_err(|e| {
OptimizeError::NumericalError {
message: format!("bfgs: y.T @ H_inv - {}", e),
}
})?;
let yt_h_y = client.matmul(&yt_h, &y_col).map_err(|e| {
OptimizeError::NumericalError {
message: format!("bfgs: y.T @ H_inv @ y - {}", e),
}
})?;
let vals: Vec<f64> = yt_h_y.to_vec();
vals[0]
};
let h_y_row = h_y
.reshape(&[1, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: h_y row reshape - {}", e),
})?;
let s_hyt =
client
.matmul(&s_col, &h_y_row)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s @ h_y.T - {}", e),
})?;
let hy_st = client
.matmul(&h_y, &s_row)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: h_y @ s.T - {}", e),
})?;
let coeff1 = rho * (1.0 + rho * yhy);
let term1 =
client
.mul_scalar(&s_st, coeff1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: term1 - {}", e),
})?;
let sum_outer =
client
.add(&s_hyt, &hy_st)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: s_hyt + hy_st - {}", e),
})?;
let term2 =
client
.mul_scalar(&sum_outer, rho)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: term2 - {}", e),
})?;
let h_plus_term1 =
client
.add(&h_inv, &term1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: H + term1 - {}", e),
})?;
h_inv =
client
.sub(&h_plus_term1, &term2)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: H + term1 - term2 - {}", e),
})?;
}
x = x_new;
fx = fx_new;
grad = grad_new;
}
Ok(TensorMinimizeResult {
x,
fun: fx,
iterations: options.max_iter,
nfev,
converged: false,
})
}
fn create_identity_matrix<R, C>(client: &C, n: usize) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
client
.eye(n, None, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("bfgs: create identity - {}", e),
})
}
fn tensor_inner_product<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> OptimizeResult<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let n = a.shape()[0];
let a_row = a
.reshape(&[1, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("inner_product: a reshape - {}", e),
})?;
let b_col = b
.reshape(&[n, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("inner_product: b reshape - {}", e),
})?;
let result = client
.matmul(&a_row, &b_col)
.map_err(|e| OptimizeError::NumericalError {
message: format!("inner_product: matmul - {}", e),
})?;
let vals: Vec<f64> = result.to_vec();
Ok(vals[0])
}