use crate::DType;
use numr::autograd::Var;
use numr::error::Result as NumrResult;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::traits::newton_cg::{NewtonCGOptions, NewtonCGResult};
use super::helpers::{gradient_from_fn, hvp_from_fn};
use super::utils::tensor_norm;
pub fn newton_cg_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &NewtonCGOptions,
) -> OptimizeResult<NewtonCGResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let n = x0.numel();
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "newton_cg: empty initial guess".to_string(),
});
}
let device = x0.device();
let dtype = x0.dtype();
let max_cg_iter = options.max_cg_iter.unwrap_or_else(|| n.min(200));
let mut x = x0.clone();
let mut nfev = 0;
let mut ngrad = 0;
let mut nhvp = 0;
let (mut fx, mut grad) = evaluate_with_gradient(client, &f, &x)?;
nfev += 1;
ngrad += 1;
for iter in 0..options.max_iter {
let grad_norm = tensor_norm(client, &grad).map_err(|e| OptimizeError::NumericalError {
message: format!("newton_cg: grad norm - {}", e),
})?;
if grad_norm < options.g_tol {
return Ok(NewtonCGResult {
x,
fun: fx,
iterations: iter + 1,
nfev,
ngrad,
nhvp,
converged: true,
grad_norm,
});
}
let cg_tol = options.cg_tol * grad_norm;
let (p, cg_hvp_count) =
cg_solve_hvp(client, &f, &x, &grad, max_cg_iter, cg_tol, device, dtype)?;
nhvp += cg_hvp_count;
let (x_new, fx_new, ls_evals) = backtracking_line_search(client, &f, &x, &p, fx, &grad)?;
nfev += ls_evals;
let step_norm = tensor_norm(
client,
&client
.sub(&x_new, &x)
.map_err(|e| OptimizeError::NumericalError {
message: format!("newton_cg: step diff - {}", e),
})?,
)
.map_err(|e| OptimizeError::NumericalError {
message: format!("newton_cg: step norm - {}", e),
})?;
let f_decrease = (fx - fx_new).abs();
if step_norm < options.x_tol || f_decrease < options.f_tol {
let (_, final_grad) = evaluate_with_gradient(client, &f, &x_new)?;
ngrad += 1;
let final_grad_norm =
tensor_norm(client, &final_grad).map_err(|e| OptimizeError::NumericalError {
message: format!("newton_cg: final grad norm - {}", e),
})?;
return Ok(NewtonCGResult {
x: x_new,
fun: fx_new,
iterations: iter + 1,
nfev,
ngrad,
nhvp,
converged: true,
grad_norm: final_grad_norm,
});
}
x = x_new;
fx = fx_new;
let (_, new_grad) = evaluate_with_gradient(client, &f, &x)?;
ngrad += 1;
grad = new_grad;
}
let final_grad_norm =
tensor_norm(client, &grad).map_err(|e| OptimizeError::NumericalError {
message: format!("newton_cg: final grad norm - {}", e),
})?;
Ok(NewtonCGResult {
x,
fun: fx,
iterations: options.max_iter,
nfev,
ngrad,
nhvp,
converged: false,
grad_norm: final_grad_norm,
})
}
fn evaluate_with_gradient<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
) -> OptimizeResult<(f64, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
gradient_from_fn(client, f, x)
}
fn compute_hvp<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
v: &Tensor<R>,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let (_fx, hvp) = hvp_from_fn(client, f, x, v)?;
Ok(hvp)
}
#[allow(clippy::too_many_arguments)] fn cg_solve_hvp<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
g: &Tensor<R>,
max_iter: usize,
tol: f64,
device: &R::Device,
dtype: numr::dtype::DType,
) -> OptimizeResult<(Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let _n = g.numel();
let mut p = Tensor::<R>::zeros(g.shape(), dtype, device);
let neg_g = client
.mul_scalar(g, -1.0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: neg_g - {}", e),
})?;
let mut r = neg_g.clone();
let mut d = r.clone();
let mut r_dot_r = tensor_dot(client, &r, &r)?;
let mut hvp_count = 0;
for _ in 0..max_iter {
let r_norm = r_dot_r.sqrt();
if r_norm < tol {
break;
}
let hd = compute_hvp(client, f, x, &d)?;
hvp_count += 1;
let d_dot_hd = tensor_dot(client, &d, &hd)?;
if d_dot_hd <= 0.0 {
if hvp_count == 1 {
return Ok((neg_g, hvp_count));
}
break;
}
let alpha = r_dot_r / d_dot_hd;
let alpha_d = client
.mul_scalar(&d, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: alpha_d - {}", e),
})?;
p = client
.add(&p, &alpha_d)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: p update - {}", e),
})?;
let alpha_hd =
client
.mul_scalar(&hd, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: alpha_hd - {}", e),
})?;
r = client
.sub(&r, &alpha_hd)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: r update - {}", e),
})?;
let r_dot_r_new = tensor_dot(client, &r, &r)?;
let beta = r_dot_r_new / r_dot_r;
r_dot_r = r_dot_r_new;
let beta_d = client
.mul_scalar(&d, beta)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: beta_d - {}", e),
})?;
d = client
.add(&r, &beta_d)
.map_err(|e| OptimizeError::NumericalError {
message: format!("cg: d update - {}", e),
})?;
}
Ok((p, hvp_count))
}
fn tensor_dot<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> OptimizeResult<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let prod = client
.mul(a, b)
.map_err(|e| OptimizeError::NumericalError {
message: format!("tensor_dot: mul - {}", e),
})?;
let sum = client
.sum(&prod, &[0], false)
.map_err(|e| OptimizeError::NumericalError {
message: format!("tensor_dot: sum - {}", e),
})?;
sum.item::<f64>()
.map_err(|e| OptimizeError::NumericalError {
message: format!("tensor_dot: scalar extraction - {}", e),
})
}
fn backtracking_line_search<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
p: &Tensor<R>,
fx: f64,
grad: &Tensor<R>,
) -> OptimizeResult<(Tensor<R>, f64, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let c = 1e-4; let rho = 0.5; let max_iter = 20;
let grad_dot_p = tensor_dot(client, grad, p)?;
let mut alpha = 1.0;
let mut evals = 0;
for _ in 0..max_iter {
let alpha_p = client
.mul_scalar(p, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: alpha_p - {}", e),
})?;
let x_new = client
.add(x, &alpha_p)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: x_new - {}", e),
})?;
let x_new_var = Var::new(x_new.clone(), false);
let loss_new = f(&x_new_var, client).map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: f eval - {}", e),
})?;
let fx_new: f64 =
loss_new
.tensor()
.item::<f64>()
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: scalar extraction - {}", e),
})?;
evals += 1;
if fx_new <= fx + c * alpha * grad_dot_p {
return Ok((x_new, fx_new, evals));
}
alpha *= rho;
}
let alpha_p = client
.mul_scalar(p, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: final alpha_p - {}", e),
})?;
let x_new = client
.add(x, &alpha_p)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: final x_new - {}", e),
})?;
let x_new_var = Var::new(x_new.clone(), false);
let loss_new = f(&x_new_var, client).map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: final f eval - {}", e),
})?;
let fx_new: f64 =
loss_new
.tensor()
.item::<f64>()
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: scalar extraction - {}", e),
})?;
Ok((x_new, fx_new, evals + 1))
}
#[cfg(test)]
mod tests {
use super::*;
use numr::autograd::{var_mul, var_sum};
use numr::runtime::Runtime;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_newton_cg_quadratic() {
let (device, client) = setup();
let x0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let result = newton_cg_impl(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c)
},
&x0,
&NewtonCGOptions::default(),
)
.unwrap();
assert!(result.converged);
assert!(result.fun < 1e-10);
assert!(result.grad_norm < 1e-6);
}
#[test]
fn test_newton_cg_shifted_quadratic() {
let (device, client) = setup();
let x0 = Tensor::<CpuRuntime>::from_slice(&[0.0f64, 0.0, 0.0], &[3], &device);
let result = newton_cg_impl(
&client,
|x_var, c| {
let one = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0, 1.0], &[3], &device),
false,
);
let diff = numr::autograd::var_sub(x_var, &one, c)?;
let diff_sq = var_mul(&diff, &diff, c)?;
var_sum(&diff_sq, &[0], false, c)
},
&x0,
&NewtonCGOptions::default(),
)
.unwrap();
assert!(result.converged);
assert!(result.fun < 1e-10);
let x_final: Vec<f64> = result.x.to_vec();
for xi in x_final {
assert!((xi - 1.0).abs() < 1e-5);
}
}
}