use crate::error::{OptimError, Result};
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray_ext::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, Zero};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewtonCG<T: Float> {
learning_rate: T,
cg_tolerance: T,
cg_max_iters: usize,
hessian_reg: T,
step_count: usize,
}
impl<T: Float + ScalarOperand> Default for NewtonCG<T> {
fn default() -> Self {
Self::new(
T::from(1.0).expect("unwrap failed"), T::from(1e-6).expect("unwrap failed"), 100, T::from(1e-6).expect("unwrap failed"), )
.expect("unwrap failed")
}
}
impl<T: Float + ScalarOperand> NewtonCG<T> {
pub fn new(
learning_rate: T,
cg_tolerance: T,
cg_max_iters: usize,
hessian_reg: T,
) -> Result<Self> {
if learning_rate.to_f64().expect("unwrap failed") <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"learning_rate must be positive, got {}",
learning_rate.to_f64().expect("unwrap failed")
)));
}
if cg_tolerance.to_f64().expect("unwrap failed") <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"cg_tolerance must be positive, got {}",
cg_tolerance.to_f64().expect("unwrap failed")
)));
}
if cg_max_iters == 0 {
return Err(OptimError::InvalidParameter(
"cg_max_iters must be positive".to_string(),
));
}
if hessian_reg.to_f64().expect("unwrap failed") < 0.0 {
return Err(OptimError::InvalidParameter(format!(
"hessian_reg must be non-negative, got {}",
hessian_reg.to_f64().expect("unwrap failed")
)));
}
Ok(Self {
learning_rate,
cg_tolerance,
cg_max_iters,
hessian_reg,
step_count: 0,
})
}
pub fn step<F>(
&mut self,
params: ArrayView1<T>,
grads: ArrayView1<T>,
hvp_fn: F,
) -> Result<Array1<T>>
where
F: Fn(&[T]) -> Vec<T>,
{
let n = params.len();
if grads.len() != n {
return Err(OptimError::DimensionMismatch(format!(
"Expected gradient size {}, got {}",
n,
grads.len()
)));
}
self.step_count += 1;
let direction = self.conjugate_gradient(&grads, hvp_fn)?;
Ok(params.to_owned() + &(direction * self.learning_rate))
}
fn conjugate_gradient<F>(&self, grads: &ArrayView1<T>, hvp_fn: F) -> Result<Array1<T>>
where
F: Fn(&[T]) -> Vec<T>,
{
let n = grads.len();
let mut d = Array1::zeros(n);
let mut r = grads.mapv(|x| -x); let mut p = r.clone();
let mut r_norm_sq = r.iter().map(|&x| x * x).fold(T::zero(), |acc, x| acc + x);
let initial_r_norm_sq = r_norm_sq;
for cg_iter in 0..self.cg_max_iters {
if r_norm_sq < self.cg_tolerance * initial_r_norm_sq {
break;
}
let p_vec: Vec<T> = p.iter().copied().collect();
let ap_vec = hvp_fn(&p_vec);
if ap_vec.len() != n {
return Err(OptimError::DimensionMismatch(format!(
"Hessian-vector product returned wrong size: expected {}, got {}",
n,
ap_vec.len()
)));
}
let ap = Array1::from_vec(ap_vec);
let ap_reg = ap + &(p.mapv(|x| x * self.hessian_reg));
let p_dot_ap = p
.iter()
.zip(ap_reg.iter())
.map(|(&pi, &api)| pi * api)
.fold(T::zero(), |acc, x| acc + x);
if p_dot_ap.abs() < T::from(1e-12).expect("unwrap failed") {
break;
}
let alpha = r_norm_sq / p_dot_ap;
for i in 0..n {
d[i] = d[i] + alpha * p[i];
}
for i in 0..n {
r[i] = r[i] - alpha * ap_reg[i];
}
let r_norm_sq_new = r.iter().map(|&x| x * x).fold(T::zero(), |acc, x| acc + x);
let beta = r_norm_sq_new / r_norm_sq;
r_norm_sq = r_norm_sq_new;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
if p_dot_ap < T::zero() {
break;
}
}
Ok(d)
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset(&mut self) {
self.step_count = 0;
}
pub fn get_learning_rate(&self) -> T {
self.learning_rate
}
pub fn set_learning_rate(&mut self, learning_rate: T) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_newton_cg_creation() {
let optimizer = NewtonCG::<f32>::default();
assert_eq!(optimizer.step_count(), 0);
}
#[test]
fn test_newton_cg_custom_creation() {
let optimizer = NewtonCG::<f32>::new(0.5, 1e-8, 50, 1e-5).expect("unwrap failed");
assert_eq!(optimizer.step_count(), 0);
assert_relative_eq!(optimizer.get_learning_rate(), 0.5);
}
#[test]
fn test_newton_cg_invalid_params() {
assert!(NewtonCG::<f32>::new(-0.1, 1e-6, 100, 1e-6).is_err());
assert!(NewtonCG::<f32>::new(1.0, -1e-6, 100, 1e-6).is_err());
assert!(NewtonCG::<f32>::new(1.0, 1e-6, 0, 1e-6).is_err());
assert!(NewtonCG::<f32>::new(1.0, 1e-6, 100, -1e-6).is_err());
}
#[test]
fn test_newton_cg_quadratic_function() {
let mut optimizer = NewtonCG::<f64>::new(1.0, 1e-8, 50, 0.0).expect("unwrap failed");
let mut params = array![2.0, 2.0];
let b = array![1.0, 1.0];
let hvp_fn = |v: &[f64]| -> Vec<f64> { vec![2.0 * v[0], 2.0 * v[1]] };
let grads = array![
2.0 * params[0] - b[0], 2.0 * params[1] - b[1] ];
params = optimizer
.step(params.view(), grads.view(), hvp_fn)
.expect("unwrap failed");
assert_relative_eq!(params[0], 0.5, epsilon = 0.1);
assert_relative_eq!(params[1], 0.5, epsilon = 0.1);
}
#[test]
fn test_newton_cg_convergence() {
let mut optimizer = NewtonCG::<f64>::new(1.0, 1e-8, 100, 0.0).expect("unwrap failed");
let mut params = array![5.0, 5.0];
let hvp_fn = |v: &[f64]| -> Vec<f64> { vec![2.0 * v[0], 2.0 * v[1]] };
for _ in 0..10 {
let grads = array![2.0 * params[0], 2.0 * params[1]];
params = optimizer
.step(params.view(), grads.view(), hvp_fn)
.expect("unwrap failed");
}
assert!(
params[0].abs() < 0.01,
"Failed to converge, got x = {}",
params[0]
);
assert!(
params[1].abs() < 0.01,
"Failed to converge, got y = {}",
params[1]
);
}
#[test]
fn test_newton_cg_reset() {
let mut optimizer = NewtonCG::<f32>::default();
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let hvp_fn = |v: &[f32]| -> Vec<f32> { v.to_vec() };
optimizer
.step(params.view(), grads.view(), hvp_fn)
.expect("unwrap failed");
assert_eq!(optimizer.step_count(), 1);
optimizer.reset();
assert_eq!(optimizer.step_count(), 0);
}
#[test]
fn test_newton_cg_learning_rate() {
let mut optimizer = NewtonCG::<f32>::default();
assert_relative_eq!(optimizer.get_learning_rate(), 1.0);
optimizer.set_learning_rate(0.5);
assert_relative_eq!(optimizer.get_learning_rate(), 0.5);
}
}