use cartan_core::{Connection, Manifold, Real, Retraction};
use crate::result::OptResult;
#[derive(Debug, Clone)]
pub struct RTRConfig {
pub max_iters: usize,
pub grad_tol: Real,
pub delta_init: Real,
pub delta_max: Real,
pub rho_min: Real,
pub max_cg_iters: usize,
pub cg_tol: Real,
}
impl Default for RTRConfig {
fn default() -> Self {
Self {
max_iters: 500,
grad_tol: 1e-6,
delta_init: 1.0,
delta_max: 8.0,
rho_min: 0.1,
max_cg_iters: 50,
cg_tol: 0.1,
}
}
}
fn solve_trs<M>(
manifold: &M,
x: &M::Point,
g: &M::Tangent,
hess: &dyn Fn(&M::Tangent) -> M::Tangent,
delta: Real,
max_cg: usize,
cg_tol: Real,
) -> M::Tangent
where
M: Manifold,
{
let g_norm = manifold.norm(x, g);
let tol = cg_tol * g_norm;
let mut eta = manifold.zero_tangent(x);
let mut r = g.clone();
let mut p = -g.clone();
for _ in 0..max_cg {
let hp = hess(&p);
let kappa = manifold.inner(x, &p, &hp);
if kappa <= 0.0 {
return boundary_step(manifold, x, &eta, &p, delta);
}
let r_sq = manifold.inner(x, &r, &r);
let alpha = r_sq / kappa;
let eta_new = eta.clone() + p.clone() * alpha;
if manifold.norm(x, &eta_new) >= delta {
return boundary_step(manifold, x, &eta, &p, delta);
}
eta = eta_new;
let r_new = r.clone() + hp * alpha;
if manifold.norm(x, &r_new) < tol {
return eta;
}
let r_sq_new = manifold.inner(x, &r_new, &r_new);
let beta = r_sq_new / r_sq;
p = -r_new.clone() + p * beta;
r = r_new;
}
eta
}
fn boundary_step<M>(
manifold: &M,
x: &M::Point,
eta: &M::Tangent,
p: &M::Tangent,
delta: Real,
) -> M::Tangent
where
M: Manifold,
{
let eta_sq = manifold.inner(x, eta, eta);
let ep = manifold.inner(x, eta, p);
let p_sq = manifold.inner(x, p, p);
if p_sq < 1e-30 {
return eta.clone(); }
let discriminant = ep * ep - p_sq * (eta_sq - delta * delta);
if discriminant < 0.0 {
return eta.clone();
}
let sqrt_disc = {
#[cfg(feature = "std")]
{
discriminant.sqrt()
}
#[cfg(not(feature = "std"))]
{
libm::sqrt(discriminant)
}
};
let tau = (-ep + sqrt_disc) / p_sq;
eta.clone() + p.clone() * tau
}
pub fn minimize_rtr<M, F, G, H>(
manifold: &M,
cost: F,
rgrad: G,
ehvp: H,
x0: M::Point,
config: &RTRConfig,
) -> OptResult<M::Point>
where
M: Manifold + Retraction + Connection,
F: Fn(&M::Point) -> Real,
G: Fn(&M::Point) -> M::Tangent,
H: Fn(&M::Point, &M::Tangent) -> M::Tangent,
{
let mut x = x0;
let mut f_x = cost(&x);
let mut g = rgrad(&x);
let mut g_norm = manifold.norm(&x, &g);
let mut delta = config.delta_init;
for iter in 0..config.max_iters {
if g_norm < config.grad_tol {
return OptResult {
point: x,
value: f_x,
grad_norm: g_norm,
iterations: iter,
converged: true,
};
}
let hess_riem = |v: &M::Tangent| -> M::Tangent {
manifold
.riemannian_hessian_vector_product(&x, &g, v, &|w| ehvp(&x, w))
.unwrap_or_else(|_| manifold.zero_tangent(&x))
};
let eta = solve_trs(
manifold,
&x,
&g,
&hess_riem,
delta,
config.max_cg_iters,
config.cg_tol,
);
let h_eta = hess_riem(&eta);
let model_decrease = -manifold.inner(&x, &g, &eta) - 0.5 * manifold.inner(&x, &h_eta, &eta);
let x_new = manifold.retract(&x, &eta);
let f_new = cost(&x_new);
let actual_decrease = f_x - f_new;
let rho = if model_decrease.abs() < 1e-30 {
1.0 } else {
actual_decrease / model_decrease
};
if rho > config.rho_min {
x = x_new;
f_x = f_new;
g = rgrad(&x);
g_norm = manifold.norm(&x, &g);
}
let eta_norm = manifold.norm(&x, &eta);
if rho < 0.25 {
delta *= 0.25;
} else if rho > 0.75 && (delta - eta_norm).abs() < 1e-10 * delta {
delta = (2.0 * delta).min(config.delta_max);
}
}
OptResult {
point: x,
value: f_x,
grad_norm: g_norm,
iterations: config.max_iters,
converged: false,
}
}