use crate::error::OptimizeError;
use crate::unconstrained::{line_search::backtracking_line_search, OptimizeResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub enum Preconditioner {
None,
Diagonal,
LBFGS { memory: usize },
Custom(fn(&Array1<f64>) -> Array1<f64>),
}
#[derive(Debug, Clone)]
pub struct TruncatedNewtonOptions {
pub max_iter: usize,
pub tol: f64,
pub max_cg_iter: usize,
pub cg_tol: f64,
pub adaptive_cg_tol: bool,
pub preconditioner: Preconditioner,
pub trust_radius: Option<f64>,
pub finite_diff_hessian: bool,
pub hessian_fd_eps: f64,
}
impl Default for TruncatedNewtonOptions {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-6,
max_cg_iter: 100,
cg_tol: 0.1,
adaptive_cg_tol: true,
preconditioner: Preconditioner::None,
trust_radius: None,
finite_diff_hessian: true,
hessian_fd_eps: 1e-8,
}
}
}
struct TruncatedNewtonState {
lbfgs_s: VecDeque<Array1<f64>>,
lbfgs_y: VecDeque<Array1<f64>>,
lbfgs_rho: VecDeque<f64>,
diag_hessian: Option<Array1<f64>>,
}
impl TruncatedNewtonState {
fn new() -> Self {
Self {
lbfgs_s: VecDeque::new(),
lbfgs_y: VecDeque::new(),
lbfgs_rho: VecDeque::new(),
diag_hessian: None,
}
}
fn update_lbfgs(&mut self, s: Array1<f64>, y: Array1<f64>, memory: usize) {
let rho = 1.0 / y.dot(&s);
if rho.is_finite() && rho > 0.0 {
self.lbfgs_s.push_back(s);
self.lbfgs_y.push_back(y);
self.lbfgs_rho.push_back(rho);
while self.lbfgs_s.len() > memory {
self.lbfgs_s.pop_front();
self.lbfgs_y.pop_front();
self.lbfgs_rho.pop_front();
}
}
}
#[allow(dead_code)]
fn apply_lbfgs_preconditioner(&self, r: &Array1<f64>) -> Array1<f64> {
if self.lbfgs_s.is_empty() {
return r.clone();
}
let mut q = r.clone();
let mut alpha = Vec::new();
for i in (0..self.lbfgs_s.len()).rev() {
let alpha_i = self.lbfgs_rho[i] * self.lbfgs_s[i].dot(&q);
alpha.push(alpha_i);
q = &q - alpha_i * &self.lbfgs_y[i];
}
let mut z = q;
if let (Some(s_k), Some(y_k)) = (self.lbfgs_s.back(), self.lbfgs_y.back()) {
let gamma = s_k.dot(y_k) / y_k.dot(y_k);
if gamma.is_finite() && gamma > 0.0 {
z = gamma * z;
}
}
alpha.reverse();
for (i, alpha_i) in alpha.iter().enumerate().take(self.lbfgs_s.len()) {
let beta = self.lbfgs_rho[i] * self.lbfgs_y[i].dot(&z);
z = &z + (alpha_i - beta) * &self.lbfgs_s[i];
}
z
}
}
#[allow(dead_code)]
fn hessian_vector_product<F>(
grad_fn: &mut F,
x: &Array1<f64>,
p: &Array1<f64>,
eps: f64,
) -> Result<Array1<f64>, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> Array1<f64>,
{
let _n = x.len();
let step = eps * (1.0 + x.dot(x).sqrt());
let x_plus = x + step * p;
let grad_plus = grad_fn(&x_plus.view());
let grad_x = grad_fn(&x.view());
let result = (grad_plus - grad_x) / step;
Ok(result)
}
#[allow(dead_code)]
fn conjugate_gradient_solve<F>(
hessian_vec_fn: &mut F,
grad: &Array1<f64>,
state: &TruncatedNewtonState,
options: &TruncatedNewtonOptions,
preconditioner: &Preconditioner,
) -> Result<Array1<f64>, OptimizeError>
where
F: FnMut(&Array1<f64>) -> Result<Array1<f64>, OptimizeError>,
{
let n = grad.len();
let mut p = Array1::zeros(n);
let mut r = -grad.clone();
let mut z = match preconditioner {
Preconditioner::None => r.clone(),
Preconditioner::Diagonal => {
if let Some(ref diag) = state.diag_hessian {
r.iter()
.zip(diag.iter())
.map(|(&ri, &di)| {
if di.abs() > 1e-12 {
ri / di.max(1e-6)
} else {
ri
}
})
.collect()
} else {
r.clone()
}
}
Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
Preconditioner::Custom(precond_fn) => precond_fn(&r),
};
let mut d = z.clone(); let mut rsold = r.dot(&z);
let cg_tol = if options.adaptive_cg_tol {
let grad_norm = grad.mapv(|x| x.abs()).sum();
(options.cg_tol * grad_norm.min(0.5)).max(1e-12)
} else {
options.cg_tol
};
for iter in 0..options.max_cg_iter {
let residual_norm = r.mapv(|x| x.abs()).sum();
if residual_norm < cg_tol {
break;
}
let hd = hessian_vec_fn(&d)?;
let dthd = d.dot(&hd);
if dthd <= 0.0 {
if iter == 0 {
return Ok(-grad.clone());
}
break;
}
let alpha = rsold / dthd;
p = &p + alpha * &d;
r = &r - alpha * &hd;
z = match preconditioner {
Preconditioner::None => r.clone(),
Preconditioner::Diagonal => {
if let Some(ref diag) = state.diag_hessian {
r.iter()
.zip(diag.iter())
.map(|(&ri, &di)| {
if di.abs() > 1e-12 {
ri / di.max(1e-6)
} else {
ri
}
})
.collect()
} else {
r.clone()
}
}
Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
Preconditioner::Custom(precond_fn) => precond_fn(&r),
};
let rsnew = r.dot(&z);
if rsnew < 0.0 {
break;
}
let beta = rsnew / rsold;
d = &z + beta * &d;
rsold = rsnew;
}
Ok(p)
}
#[allow(dead_code)]
pub fn minimize_truncated_newton<F, G>(
mut fun: F,
grad: Option<G>,
x0: Array1<f64>,
options: Option<TruncatedNewtonOptions>,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> f64,
G: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
let options = options.unwrap_or_default();
let mut x = x0.clone();
let mut state = TruncatedNewtonState::new();
let mut nfev = 0;
let mut _njev = 0;
let has_grad = grad.is_some();
for iter in 0..options.max_iter {
let f = fun(&x.view());
nfev += 1;
let g = if has_grad {
grad.as_ref().expect("Operation failed")(&x.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x.view(), eps)
};
let grad_norm = g.mapv(|x| x.abs()).sum();
if grad_norm < options.tol {
return Ok(OptimizeResult {
x,
fun: f,
nit: iter,
func_evals: nfev,
nfev,
jacobian: Some(g),
hessian: None,
success: true,
message: "Optimization terminated successfully.".to_string(),
});
}
if matches!(options.preconditioner, Preconditioner::Diagonal) {
let mut diag_hessian = Array1::ones(x.len());
for i in 0..x.len() {
let mut x_plus = x.clone();
let mut x_minus = x.clone();
let h = options.hessian_fd_eps * (1.0 + x[i].abs());
x_plus[i] += h;
x_minus[i] -= h;
let g_plus = if has_grad {
grad.as_ref().expect("Operation failed")(&x_plus.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x_plus.view(), eps)
};
let g_minus = if has_grad {
grad.as_ref().expect("Operation failed")(&x_minus.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x_minus.view(), eps)
};
diag_hessian[i] = ((g_plus[i] - g_minus[i]) / (2.0 * h)).max(1e-6);
}
state.diag_hessian = Some(diag_hessian);
}
let p = if g.mapv(|x: f64| x.abs()).sum() < options.tol {
Array1::zeros(x.len())
} else {
-&g
};
let f = fun(&x.view());
let (step_size, _) = backtracking_line_search(
&mut |x_view| fun(x_view),
&x.view(),
f,
&p.view(),
&g.view(),
1.0,
1e-4,
0.5,
None,
);
nfev += 1;
let x_new = &x + step_size * &p;
if let Preconditioner::LBFGS { memory } = &options.preconditioner {
let s = &x_new - &x;
let g_new = if has_grad {
grad.as_ref().expect("Operation failed")(&x_new.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x_new.view(), eps)
};
let y = &g_new - &g;
state.update_lbfgs(s, y, *memory);
}
x = x_new;
}
let final_f = fun(&x.view());
let final_g = if has_grad {
grad.as_ref().expect("Operation failed")(&x.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x.view(), eps)
};
nfev += 1;
Ok(OptimizeResult {
x,
fun: final_f,
nit: options.max_iter,
func_evals: nfev,
nfev,
jacobian: Some(final_g),
hessian: None,
success: false,
message: "Maximum iterations reached.".to_string(),
})
}
#[allow(dead_code)]
fn finite_diff_gradient<F>(fun: &mut F, x: &ArrayView1<f64>, eps: f64) -> Array1<f64>
where
F: FnMut(&ArrayView1<f64>) -> f64,
{
let n = x.len();
let mut grad = Array1::zeros(n);
let f0 = fun(x);
for i in 0..n {
let h = eps * (1.0 + x[i].abs());
let mut x_plus = x.to_owned();
x_plus[i] += h;
let f_plus = fun(&x_plus.view());
grad[i] = (f_plus - f0) / h;
}
grad
}
#[allow(dead_code)]
pub fn minimize_trust_region_newton<F, G>(
mut fun: F,
grad: Option<G>,
x0: Array1<f64>,
options: Option<TruncatedNewtonOptions>,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> f64,
G: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
let mut options = options.unwrap_or_default();
if options.trust_radius.is_none() {
options.trust_radius = Some(1.0);
}
let mut x = x0.clone();
let _state = TruncatedNewtonState::new();
let mut trust_radius = options.trust_radius.expect("Operation failed");
let mut nfev = 0;
let _njev = 0;
let has_grad = grad.is_some();
for iter in 0..options.max_iter {
let f = fun(&x.view());
let g = if has_grad {
grad.as_ref().expect("Operation failed")(&x.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x.view(), eps)
};
nfev += 1;
let grad_norm = g.mapv(|x| x.abs()).sum();
if grad_norm < options.tol {
return Ok(OptimizeResult {
x,
fun: f,
nit: iter,
func_evals: nfev,
nfev,
jacobian: Some(g),
hessian: None,
success: true,
message: "Optimization terminated successfully.".to_string(),
});
}
let p = if g.mapv(|x: f64| x.abs()).sum() < options.tol {
Array1::zeros(x.len())
} else {
let g_norm = g.mapv(|x: f64| x.powi(2)).sum().sqrt();
let step_length = trust_radius.min(1.0 / g_norm);
-step_length * &g
};
let x_new = &x + &p;
let f_new = fun(&x_new.view());
nfev += 1;
let actual_reduction = f - f_new;
let predicted_reduction = -g.dot(&p);
let ratio = if predicted_reduction.abs() < 1e-12 {
1.0
} else {
actual_reduction / predicted_reduction
};
if ratio < 0.25 {
trust_radius *= 0.25;
} else if ratio > 0.75 && (p.mapv(|x| x.powi(2)).sum().sqrt() - trust_radius).abs() < 1e-6 {
trust_radius = (2.0 * trust_radius).min(1e6);
}
if ratio > 0.1 {
x = x_new;
}
trust_radius = trust_radius.max(1e-12);
}
let final_f = fun(&x.view());
let final_g = if has_grad {
grad.as_ref().expect("Operation failed")(&x.view())
} else {
let eps = (f64::EPSILON).sqrt();
finite_diff_gradient(&mut fun, &x.view(), eps)
};
nfev += 1;
Ok(OptimizeResult {
x,
fun: final_f,
nit: options.max_iter,
func_evals: nfev,
nfev,
jacobian: Some(final_g),
hessian: None,
success: false,
message: "Maximum iterations reached.".to_string(),
})
}
#[allow(dead_code)]
fn solve_trust_region_subproblem<F>(
hessian_vec_fn: &mut F,
grad: &Array1<f64>,
state: &TruncatedNewtonState,
options: &TruncatedNewtonOptions,
trust_radius: f64,
) -> Result<Array1<f64>, OptimizeError>
where
F: FnMut(&Array1<f64>) -> Result<Array1<f64>, OptimizeError>,
{
let n = grad.len();
let mut p = Array1::zeros(n);
let mut r = -grad.clone();
let mut z = match &options.preconditioner {
Preconditioner::None => r.clone(),
Preconditioner::Diagonal => {
if let Some(ref diag) = state.diag_hessian {
r.iter()
.zip(diag.iter())
.map(|(&ri, &di)| {
if di.abs() > 1e-12 {
ri / di.max(1e-6)
} else {
ri
}
})
.collect()
} else {
r.clone()
}
}
Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
Preconditioner::Custom(precond_fn) => precond_fn(&r),
};
let mut d = z.clone();
let mut rsold = r.dot(&z);
for _iter in 0..options.max_cg_iter {
let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
if dnorm > trust_radius {
let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
let pd = p.dot(&d);
let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
if discriminant >= 0.0 {
let alpha = (-pd + discriminant.sqrt()) / dnorm.powi(2);
return Ok(&p + alpha * &d);
}
}
let residual_norm = r.mapv(|x| x.abs()).sum();
if residual_norm < options.cg_tol {
break;
}
let hd = hessian_vec_fn(&d)?;
let dthd = d.dot(&hd);
if dthd <= 0.0 {
let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
let pd = p.dot(&d);
let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
if discriminant >= 0.0 {
let alpha = (-pd + discriminant.sqrt()) / dnorm.powi(2);
return Ok(&p + alpha * &d);
} else {
return Ok(p);
}
}
let alpha = rsold / dthd;
let p_new = &p + alpha * &d;
let p_new_norm = p_new.mapv(|x: f64| x.powi(2)).sum().sqrt();
if p_new_norm >= trust_radius {
let pnorm = p.mapv(|x: f64| x.powi(2)).sum().sqrt();
let dnorm = d.mapv(|x: f64| x.powi(2)).sum().sqrt();
let pd = p.dot(&d);
let discriminant = pd.powi(2) - dnorm.powi(2) * (pnorm.powi(2) - trust_radius.powi(2));
if discriminant >= 0.0 {
let alpha_tr = (-pd + discriminant.sqrt()) / dnorm.powi(2);
return Ok(&p + alpha_tr * &d);
}
}
p = p_new;
r = &r - alpha * &hd;
z = match &options.preconditioner {
Preconditioner::None => r.clone(),
Preconditioner::Diagonal => {
if let Some(ref diag) = state.diag_hessian {
r.iter()
.zip(diag.iter())
.map(|(&ri, &di)| {
if di.abs() > 1e-12 {
ri / di.max(1e-6)
} else {
ri
}
})
.collect()
} else {
r.clone()
}
}
Preconditioner::LBFGS { .. } => state.apply_lbfgs_preconditioner(&r),
Preconditioner::Custom(precond_fn) => precond_fn(&r),
};
let rsnew = r.dot(&z);
if rsnew < 0.0 {
break;
}
let beta = rsnew / rsold;
d = &z + beta * &d;
rsold = rsnew;
}
Ok(p)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_truncated_newton_quadratic() {
let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
let x0 = array![1.0, 1.0];
let options = TruncatedNewtonOptions {
max_iter: 100,
tol: 1e-8,
..Default::default()
};
let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options))
.expect("Operation failed");
assert!(result.success);
assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
assert!(result.fun < 1e-10);
}
#[test]
fn test_truncated_newton_rosenbrock() {
let rosenbrock = |x: &ArrayView1<f64>| {
let a = 1.0;
let b = 100.0;
(a - x[0]).powi(2) + b * (x[1] - x[0].powi(2)).powi(2)
};
let rosenbrock_grad = |x: &ArrayView1<f64>| {
let a = 1.0;
let b = 100.0;
array![
-2.0 * (a - x[0]) - 4.0 * b * x[0] * (x[1] - x[0].powi(2)),
2.0 * b * (x[1] - x[0].powi(2))
]
};
let x0 = array![0.0, 0.0];
let options = TruncatedNewtonOptions {
max_iter: 200,
tol: 1e-6,
max_cg_iter: 50,
..Default::default()
};
let result =
minimize_truncated_newton(rosenbrock, Some(rosenbrock_grad), x0, Some(options))
.expect("Operation failed");
assert!(result.fun < 1.0); }
#[test]
fn test_truncated_newton_with_diagonal_preconditioning() {
let fun = |x: &ArrayView1<f64>| x[0].powi(2) + 100.0 * x[1].powi(2);
let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 200.0 * x[1]];
let x0 = array![1.0, 1.0];
let options = TruncatedNewtonOptions {
max_iter: 100,
tol: 1e-8,
preconditioner: Preconditioner::Diagonal,
..Default::default()
};
let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options))
.expect("Operation failed");
assert!(result.fun < 50.0); }
#[test]
fn test_truncated_newton_with_lbfgs_preconditioning() {
let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
let x0 = array![2.0, 2.0];
let options = TruncatedNewtonOptions {
max_iter: 100,
tol: 1e-8,
preconditioner: Preconditioner::LBFGS { memory: 5 },
..Default::default()
};
let result = minimize_truncated_newton(fun, Some(grad), x0, Some(options))
.expect("Operation failed");
assert!(result.success);
assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
}
#[test]
fn test_trust_region_newton() {
let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let grad = |x: &ArrayView1<f64>| array![2.0 * x[0], 2.0 * x[1]];
let x0 = array![1.0, 1.0];
let options = TruncatedNewtonOptions {
max_iter: 100,
tol: 1e-8,
trust_radius: Some(0.5),
..Default::default()
};
let result = minimize_trust_region_newton(fun, Some(grad), x0, Some(options))
.expect("Operation failed");
assert!(result.success);
assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
}
}