use crate::DType;
use numr::autograd::Var;
use numr::error::Result as NumrResult;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::traits::trust_region::{TrustRegionOptions, TrustRegionResult};
use super::helpers::{gradient_from_fn, hvp_from_fn};
use super::utils::{tensor_dot, tensor_norm};
pub struct SubproblemResult<R: Runtime<DType = DType>> {
pub step: Tensor<R>,
pub hits_boundary: bool,
pub predicted_reduction: f64,
}
pub trait SubproblemSolver<R: Runtime<DType = DType>, C, F> {
fn solve(
&self,
client: &C,
f: &F,
x: &Tensor<R>,
grad: &Tensor<R>,
trust_radius: f64,
) -> OptimizeResult<SubproblemResult<R>>;
}
pub fn trust_region_loop<R, C, F, S>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &TrustRegionOptions,
solver: &S,
) -> OptimizeResult<TrustRegionResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
S: SubproblemSolver<R, C, F>,
{
if x0.numel() == 0 {
return Err(OptimizeError::InvalidInput {
context: "trust_region: empty initial guess".to_string(),
});
}
if options.initial_trust_radius <= 0.0 {
return Err(OptimizeError::InvalidInput {
context: "trust_region: initial_trust_radius must be positive".to_string(),
});
}
if options.max_trust_radius <= options.initial_trust_radius {
return Err(OptimizeError::InvalidInput {
context: "trust_region: max_trust_radius must exceed initial_trust_radius".to_string(),
});
}
if options.eta < 0.0 || options.eta >= 0.25 {
return Err(OptimizeError::InvalidInput {
context: "trust_region: eta must be in [0, 0.25)".to_string(),
});
}
let mut x = x0.clone();
let mut delta = options.initial_trust_radius;
let mut nfev = 0;
let (mut fx, mut grad) = gradient_from_fn(client, &f, &x)?;
nfev += 1;
for iter in 0..options.max_iter {
let grad_norm = tensor_norm(client, &grad).map_err(|e| OptimizeError::NumericalError {
message: format!("trust_region: grad norm - {}", e),
})?;
if grad_norm < options.gtol {
return Ok(TrustRegionResult {
x,
fun: fx,
grad,
iterations: iter,
converged: true,
trust_radius: delta,
nfev,
});
}
let sub_result = solver.solve(client, &f, &x, &grad, delta)?;
let step = &sub_result.step;
let x_new = client
.add(&x, step)
.map_err(|e| OptimizeError::NumericalError {
message: format!("trust_region: x + step - {}", e),
})?;
let (fx_new, grad_new) = gradient_from_fn(client, &f, &x_new)?;
nfev += 1;
let actual_reduction = fx - fx_new;
let predicted_reduction = sub_result.predicted_reduction;
if predicted_reduction <= 0.0 {
delta *= 0.25;
if delta < 1e-15 {
return Ok(TrustRegionResult {
x,
fun: fx,
grad,
iterations: iter + 1,
converged: false,
trust_radius: delta,
nfev,
});
}
continue;
}
let rho = if predicted_reduction.abs() < 1e-30 {
if actual_reduction.abs() < 1e-30 {
1.0
} else {
0.0
}
} else {
actual_reduction / predicted_reduction
};
if rho < 0.25 {
delta *= 0.25;
} else if rho > 0.75 && sub_result.hits_boundary {
delta = (2.0 * delta).min(options.max_trust_radius);
}
if rho > options.eta {
x = x_new;
fx = fx_new;
grad = grad_new;
}
if delta < 1e-15 {
return Ok(TrustRegionResult {
x,
fun: fx,
grad,
iterations: iter + 1,
converged: false,
trust_radius: delta,
nfev,
});
}
}
Ok(TrustRegionResult {
x,
fun: fx,
grad,
iterations: options.max_iter,
converged: false,
trust_radius: delta,
nfev,
})
}
pub fn compute_predicted_reduction<R, C>(
client: &C,
grad: &Tensor<R>,
step: &Tensor<R>,
h_step: &Tensor<R>,
) -> OptimizeResult<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let g_dot_p = tensor_dot(client, grad, step).map_err(|e| OptimizeError::NumericalError {
message: format!("trust_region: dot - {}", e),
})?;
let p_dot_hp = tensor_dot(client, step, h_step).map_err(|e| OptimizeError::NumericalError {
message: format!("trust_region: dot - {}", e),
})?;
Ok(-(g_dot_p + 0.5 * p_dot_hp))
}
pub fn secular_newton_update(lambda: f64, ratio: f64) -> f64 {
let updated = lambda + (ratio - 1.0) * lambda / ratio;
updated.max(1e-15)
}
pub fn compute_hvp_for_subproblem<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
v: &Tensor<R>,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> NumrResult<Var<R>>,
{
let (_fx, hvp) = hvp_from_fn(client, f, x, v)?;
Ok(hvp)
}