use crate::core::constraint::BoxConstraints;
use crate::core::math::{
AddDiagonalVectorInPlace, BoxAffineScaling, Dot, GramMatrix, LinearSolveSpd, MatTransposeVec,
MaxDiagonal, NegInPlace, NormSquared, ScaledAdd,
};
use crate::core::problem::{Jacobian, Problem, Residual};
use crate::core::solver::Solver;
use crate::core::state::BasicState;
use crate::core::termination::TerminationReason;
pub struct Trf<V, M> {
tol_grad: f64,
tau: f64,
rstep: f64,
theta: f64,
max_inner_attempts: u32,
mu: Option<f64>,
nu: f64,
r_cache: Option<V>,
j_cache: Option<M>,
}
impl<V, M> Default for Trf<V, M> {
fn default() -> Self {
Self::new()
}
}
impl<V, M> Trf<V, M> {
pub fn new() -> Self {
Self {
tol_grad: 1e-8,
tau: 1e-3,
rstep: 1e-10,
theta: 0.99995,
max_inner_attempts: 50,
mu: None,
nu: 2.0,
r_cache: None,
j_cache: None,
}
}
pub fn tol_grad(mut self, tol: f64) -> Self {
assert!(tol >= 0.0, "tol_grad must be ≥ 0");
self.tol_grad = tol;
self
}
pub fn tau(mut self, tau: f64) -> Self {
assert!(tau > 0.0, "tau must be > 0");
self.tau = tau;
self
}
pub fn rstep(mut self, rstep: f64) -> Self {
assert!(rstep > 0.0, "rstep must be > 0");
self.rstep = rstep;
self
}
pub fn theta(mut self, theta: f64) -> Self {
assert!(
theta > 0.0 && theta < 1.0,
"theta must be in (0, 1), got {theta}"
);
self.theta = theta;
self
}
pub fn max_inner_attempts(mut self, n: u32) -> Self {
assert!(n > 0, "max_inner_attempts must be > 0");
self.max_inner_attempts = n;
self
}
}
impl<P, V, M> Solver<P, BasicState<V>> for Trf<V, M>
where
P: Residual<Param = V, Output = V> + Jacobian<Jacobian = M> + BoxConstraints<Param = V>,
V: ScaledAdd<f64> + NormSquared + NegInPlace + Dot + BoxAffineScaling + Clone,
M: GramMatrix
+ MatTransposeVec<V>
+ LinearSolveSpd<V>
+ AddDiagonalVectorInPlace<V>
+ MaxDiagonal
+ Clone,
{
type Error = <P as Residual>::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<BasicState<V>, Self::Error> {
state.param.project_strictly_inside(
problem.inner().lower(),
problem.inner().upper(),
self.rstep,
);
let (r, j) = problem.residual_and_jacobian(&state.param)?;
state.cost = Some(0.5 * r.norm_squared());
let g = j.mat_transpose_vec(&r);
let mut d_sq = state.param.clone();
let mut c_diag = state.param.clone();
state.param.compute_cl_scaling(
&g,
problem.inner().lower(),
problem.inner().upper(),
&mut d_sq,
&mut c_diag,
);
let mut a = j.gram();
a.add_diagonal_vector_in_place(&c_diag);
let max_diag = a.max_diagonal().max(1.0);
self.mu = Some(self.tau * max_diag);
self.nu = 2.0;
self.r_cache = Some(r);
self.j_cache = Some(j);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<(BasicState<V>, Option<TerminationReason>), Self::Error> {
let r = match self.r_cache.take() {
Some(r) => r,
None => problem.residual(&state.param)?,
};
let j = match self.j_cache.take() {
Some(j) => j,
None => problem.jacobian(&state.param)?,
};
let g = j.mat_transpose_vec(&r);
let mut d_sq = state.param.clone();
let mut c_diag = state.param.clone();
state.param.compute_cl_scaling(
&g,
problem.inner().lower(),
problem.inner().upper(),
&mut d_sq,
&mut c_diag,
);
if self.tol_grad > 0.0 && g.cl_kkt_inf_norm(&d_sq) <= self.tol_grad {
self.r_cache = Some(r);
self.j_cache = Some(j);
return Ok((state, Some(TerminationReason::SolverConverged)));
}
let mut neg_g = g.clone();
neg_g.neg_in_place();
let m = j.gram();
let mut mu = self
.mu
.expect("mu not set: Solver::init must run before next_iter");
let mut nu = self.nu;
let h;
let mut attempts: u32 = 0;
loop {
let mut a_damped = m.clone();
let mut damping_vec = c_diag.clone();
damping_vec.scaled_add(mu, &d_sq);
a_damped.add_diagonal_vector_in_place(&damping_vec);
match a_damped.solve_spd(&neg_g) {
Ok(step) => {
h = step;
break;
}
Err(_) => {
attempts += 1;
if attempts >= self.max_inner_attempts || !mu.is_finite() {
self.mu = Some(mu);
self.nu = nu;
self.r_cache = Some(r);
self.j_cache = Some(j);
return Ok((state, Some(TerminationReason::SolverFailed)));
}
mu *= nu;
nu *= 2.0;
}
}
}
let tau_max =
state
.param
.max_feasible_step(&h, problem.inner().lower(), problem.inner().upper());
let alpha = if tau_max >= 1.0 {
1.0
} else {
self.theta * tau_max
};
let mut x_trial = state.param.clone();
x_trial.scaled_add(alpha, &h);
let r_trial = problem.residual(&x_trial)?;
let f_trial = 0.5 * r_trial.norm_squared();
let prev_cost = state
.cost
.expect("cost not set: Solver::init must run before next_iter");
let h_t_g = h.dot(&g);
let dh_norm_sq = h.weighted_norm_squared(&d_sq);
let predicted =
-alpha * (1.0 - 0.5 * alpha) * h_t_g + 0.5 * alpha * alpha * mu * dh_norm_sq;
let half_s_t_c_s = 0.5 * alpha * alpha * h.weighted_norm_squared(&c_diag);
let actual = prev_cost - f_trial - half_s_t_c_s;
let rho = if predicted > 0.0 {
actual / predicted
} else {
0.0
};
if rho > 0.0 {
state.param = x_trial;
state.cost = Some(f_trial);
let factor = 1.0 - (2.0 * rho - 1.0).powi(3);
mu *= factor.max(1.0 / 3.0);
nu = 2.0;
self.r_cache = Some(r_trial);
self.j_cache = None;
} else {
mu *= nu;
nu *= 2.0;
self.r_cache = Some(r);
self.j_cache = Some(j);
}
self.mu = Some(mu);
self.nu = nu;
Ok((state, None))
}
}