use super::{Step, Subproblem, model_decrease, tau_to_boundary};
use crate::core::math::{
Dot, MatVec, NegInPlace, NormSquared, Scalar, ScaleInPlace, ScaledAdd, VectorLen,
};
#[derive(Debug, Clone, Copy)]
pub struct Steihaug {
max_iter: Option<usize>,
}
impl Steihaug {
pub fn new() -> Self {
Self { max_iter: None }
}
pub fn with_max_iter(mut self, n: usize) -> Self {
assert!(n >= 1, "max_iter must be ≥ 1");
self.max_iter = Some(n);
self
}
}
impl Default for Steihaug {
fn default() -> Self {
Self::new()
}
}
impl<V, M, F> Subproblem<V, M, F> for Steihaug
where
F: Scalar,
V: Clone + Dot<F> + NormSquared<F> + ScaledAdd<F> + ScaleInPlace<F> + NegInPlace + VectorLen,
M: MatVec<V>,
{
fn solve(&self, g: &V, b: &M, radius: F) -> Step<V, F> {
let n = g.vec_len();
let max_iter = self.max_iter.unwrap_or(n.max(1));
let mut z = g.clone();
z.scale_in_place(F::zero());
let mut r = g.clone();
let mut r_dot = r.dot(&r);
let g_norm = r_dot.sqrt();
let half = F::from_f64(0.5).unwrap();
let s = g_norm.sqrt();
let tol = (if s < half { s } else { half }) * g_norm;
if g_norm <= tol {
let predicted_reduction = model_decrease(g, b, &z);
return Step {
d: z,
predicted_reduction,
hit_boundary: false,
};
}
let mut d = r.clone();
d.neg_in_place();
for _ in 0..max_iter {
let bd = b.matvec(&d);
let dbd = d.dot(&bd);
if dbd <= F::zero() {
let tau = tau_to_boundary(&z, &d, radius);
z.scaled_add(tau, &d);
let predicted_reduction = model_decrease(g, b, &z);
return Step {
d: z,
predicted_reduction,
hit_boundary: true,
};
}
let alpha = r_dot / dbd;
let mut z_next = z.clone();
z_next.scaled_add(alpha, &d);
if z_next.norm_squared().sqrt() >= radius {
let tau = tau_to_boundary(&z, &d, radius);
z.scaled_add(tau, &d);
let predicted_reduction = model_decrease(g, b, &z);
return Step {
d: z,
predicted_reduction,
hit_boundary: true,
};
}
z = z_next;
r.scaled_add(alpha, &bd);
let r_dot_next = r.dot(&r);
if r_dot_next.sqrt() < tol {
let predicted_reduction = model_decrease(g, b, &z);
return Step {
d: z,
predicted_reduction,
hit_boundary: false,
};
}
let beta = r_dot_next / r_dot;
let mut d_next = r.clone();
d_next.neg_in_place();
d_next.scaled_add(beta, &d);
d = d_next;
r_dot = r_dot_next;
}
let predicted_reduction = model_decrease(g, b, &z);
Step {
d: z,
predicted_reduction,
hit_boundary: false,
}
}
}