use crate::DType;
use numr::autograd::Var;
use numr::error::Result as NumrResult;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::traits::trust_region::{TrustRegionOptions, TrustRegionResult};
use super::trust_region_base::{
SubproblemResult, SubproblemSolver, compute_hvp_for_subproblem, compute_predicted_reduction,
trust_region_loop,
};
use super::utils::{tensor_dot, tensor_norm};
struct SteihaugCG;
impl<R, C, F> SubproblemSolver<R, C, F> for SteihaugCG
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
fn solve(
&self,
client: &C,
f: &F,
x: &Tensor<R>,
grad: &Tensor<R>,
trust_radius: f64,
) -> OptimizeResult<SubproblemResult<R>> {
steihaug_toint_cg(client, f, x, grad, trust_radius)
}
}
pub fn trust_ncg_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &TrustRegionOptions,
) -> OptimizeResult<TrustRegionResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
trust_region_loop(client, f, x0, options, &SteihaugCG)
}
fn steihaug_toint_cg<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
grad: &Tensor<R>,
trust_radius: f64,
) -> OptimizeResult<SubproblemResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let n = grad.numel();
let max_cg_iter = n.min(200);
let device = grad.device();
let dtype = grad.dtype();
let grad_norm = tensor_norm(client, grad).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: grad norm - {}", e),
})?;
let tol = 0.5_f64.min(grad_norm.sqrt()) * grad_norm;
let mut p = Tensor::<R>::zeros(grad.shape(), dtype, device);
let neg_g = client
.mul_scalar(grad, -1.0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: neg_g - {}", e),
})?;
let mut r = neg_g.clone();
let mut d = neg_g;
let mut r_dot_r = tensor_dot(client, &r, &r).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: dot - {}", e),
})?;
for _ in 0..max_cg_iter {
if r_dot_r.sqrt() < tol {
let hp = compute_hvp_for_subproblem(client, f, x, &p)?;
let pred = compute_predicted_reduction(client, grad, &p, &hp)?;
return Ok(SubproblemResult {
step: p,
hits_boundary: false,
predicted_reduction: pred,
});
}
let hd = compute_hvp_for_subproblem(client, f, x, &d)?;
let d_dot_hd = tensor_dot(client, &d, &hd).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: dot - {}", e),
})?;
if d_dot_hd <= 0.0 {
let p_boundary = move_to_boundary(client, &p, &d, trust_radius)?;
let hp = compute_hvp_for_subproblem(client, f, x, &p_boundary)?;
let pred = compute_predicted_reduction(client, grad, &p_boundary, &hp)?;
return Ok(SubproblemResult {
step: p_boundary,
hits_boundary: true,
predicted_reduction: pred,
});
}
let alpha = r_dot_r / d_dot_hd;
let alpha_d = client
.mul_scalar(&d, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: alpha_d - {}", e),
})?;
let p_new = client
.add(&p, &alpha_d)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: p_new - {}", e),
})?;
let p_new_norm =
tensor_norm(client, &p_new).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: p_new norm - {}", e),
})?;
if p_new_norm >= trust_radius {
let p_boundary = move_to_boundary(client, &p, &d, trust_radius)?;
let hp = compute_hvp_for_subproblem(client, f, x, &p_boundary)?;
let pred = compute_predicted_reduction(client, grad, &p_boundary, &hp)?;
return Ok(SubproblemResult {
step: p_boundary,
hits_boundary: true,
predicted_reduction: pred,
});
}
p = p_new;
let alpha_hd =
client
.mul_scalar(&hd, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: alpha_hd - {}", e),
})?;
r = client
.sub(&r, &alpha_hd)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: r update - {}", e),
})?;
let r_dot_r_new =
tensor_dot(client, &r, &r).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: dot - {}", e),
})?;
let beta = r_dot_r_new / r_dot_r;
r_dot_r = r_dot_r_new;
let beta_d = client
.mul_scalar(&d, beta)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: beta_d - {}", e),
})?;
d = client
.add(&r, &beta_d)
.map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: d update - {}", e),
})?;
}
let hp = compute_hvp_for_subproblem(client, f, x, &p)?;
let pred = compute_predicted_reduction(client, grad, &p, &hp)?;
let p_norm = tensor_norm(client, &p).map_err(|e| OptimizeError::NumericalError {
message: format!("steihaug: final p norm - {}", e),
})?;
Ok(SubproblemResult {
step: p,
hits_boundary: (p_norm - trust_radius).abs() / trust_radius < 0.01,
predicted_reduction: pred,
})
}
fn move_to_boundary<R, C>(
client: &C,
p: &Tensor<R>,
d: &Tensor<R>,
trust_radius: f64,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let p_dot_p = tensor_dot(client, p, p).map_err(|e| OptimizeError::NumericalError {
message: format!("move_to_boundary: dot - {}", e),
})?;
let p_dot_d = tensor_dot(client, p, d).map_err(|e| OptimizeError::NumericalError {
message: format!("move_to_boundary: dot - {}", e),
})?;
let d_dot_d = tensor_dot(client, d, d).map_err(|e| OptimizeError::NumericalError {
message: format!("move_to_boundary: dot - {}", e),
})?;
let delta_sq = trust_radius * trust_radius;
let discriminant = p_dot_d * p_dot_d - d_dot_d * (p_dot_p - delta_sq);
let tau = if discriminant < 0.0 {
0.0
} else {
(-p_dot_d + discriminant.sqrt()) / d_dot_d
};
let tau_d = client
.mul_scalar(d, tau)
.map_err(|e| OptimizeError::NumericalError {
message: format!("move_to_boundary: tau_d - {}", e),
})?;
client
.add(p, &tau_d)
.map_err(|e| OptimizeError::NumericalError {
message: format!("move_to_boundary: p + tau_d - {}", e),
})
}
#[cfg(test)]
mod tests {
use super::*;
use numr::autograd::{var_mul, var_sum};
use numr::runtime::Runtime;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_trust_ncg_quadratic() {
let (device, client) = setup();
let x0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let result = trust_ncg_impl(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c)
},
&x0,
&TrustRegionOptions::default(),
)
.unwrap();
assert!(result.converged);
assert!(result.fun < 1e-10);
}
#[test]
fn test_trust_ncg_shifted_quadratic() {
let (device, client) = setup();
let x0 = Tensor::<CpuRuntime>::from_slice(&[0.0f64, 0.0], &[2], &device);
let result = trust_ncg_impl(
&client,
|x_var, c| {
let one = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0], &[2], &device),
false,
);
let diff = numr::autograd::var_sub(x_var, &one, c)?;
let diff_sq = var_mul(&diff, &diff, c)?;
var_sum(&diff_sq, &[0], false, c)
},
&x0,
&TrustRegionOptions::default(),
)
.unwrap();
assert!(result.converged);
assert!(result.fun < 1e-10);
let sol: Vec<f64> = result.x.to_vec();
assert!((sol[0] - 1.0).abs() < 1e-4);
assert!((sol[1] - 1.0).abs() < 1e-4);
}
#[test]
fn test_trust_ncg_sphere() {
let (device, client) = setup();
let initial: Vec<f64> = (0..10).map(|i| (i as f64) - 5.0).collect();
let x0 = Tensor::<CpuRuntime>::from_slice(&initial, &[10], &device);
let result = trust_ncg_impl(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c)
},
&x0,
&TrustRegionOptions::default(),
)
.unwrap();
assert!(result.converged);
assert!(result.fun < 1e-10);
}
}