Skip to main content

cartan_optim/
rtr.rs

1// ~/cartan/cartan-optim/src/rtr.rs
2
3//! Riemannian Trust Region (RTR) with Steihaug-Toint truncated CG subproblem.
4//!
5//! ## Algorithm (Absil-Baker-Gallivan 2007)
6//!
7//! At iterate x_k with gradient g_k = grad f(x_k) and trust radius Δ_k:
8//!
9//!   1. **Subproblem**: solve (approximately) for η_k ∈ T_{x_k}M:
10//!      min m_k(η) = f(x_k) + <g_k, η> + ½ <Hess f(x_k)\[η\], η>
11//!      subject to ||η||_{x_k} ≤ Δ_k
12//!      via Steihaug-Toint truncated CG.
13//!
14//!   2. **Ratio**: ρ_k = (f(x_k) − f(Retract(x_k, η_k))) / (m_k(0) − m_k(η_k))
15//!
16//!   3. **Accept/reject**:
17//!      if ρ_k > ρ_min: x_{k+1} = Retract(x_k, η_k)
18//!      else:           x_{k+1} = x_k
19//!
20//!   4. **Update Δ**:
21//!      ρ < 0.25 → Δ_{k+1} = Δ_k / 4
22//!      ρ > 0.75 and ||η_k|| ≈ Δ_k → Δ_{k+1} = min(2Δ_k, Δ_max)
23//!      else → Δ_{k+1} = Δ_k
24//!
25//! ## References
26//!
27//! - Absil, Baker, Gallivan. "Trust-Region Methods on Riemannian Manifolds."
28//!   Found. Comput. Math. 2007.
29//! - Absil, Mahony, Sepulchre. "Optimization Algorithms on Matrix Manifolds."
30//!   Chapter 7 (trust region).
31
32use cartan_core::{Connection, Manifold, Real, Retraction};
33
34use crate::result::OptResult;
35
36/// Configuration for Riemannian Trust Region.
37#[derive(Debug, Clone)]
38pub struct RTRConfig {
39    /// Maximum number of outer iterations.
40    pub max_iters: usize,
41    /// Stop when ||grad f(x)|| < grad_tol.
42    pub grad_tol: Real,
43    /// Initial trust radius.
44    pub delta_init: Real,
45    /// Maximum trust radius.
46    pub delta_max: Real,
47    /// Minimum acceptable ratio ρ for step acceptance.
48    pub rho_min: Real,
49    /// Maximum number of inner CG iterations per outer step.
50    pub max_cg_iters: usize,
51    /// CG convergence tolerance (relative to ||g||).
52    pub cg_tol: Real,
53}
54
55impl Default for RTRConfig {
56    fn default() -> Self {
57        Self {
58            max_iters: 500,
59            grad_tol: 1e-6,
60            delta_init: 1.0,
61            delta_max: 8.0,
62            rho_min: 0.1,
63            max_cg_iters: 50,
64            cg_tol: 0.1,
65        }
66    }
67}
68
69/// Solve the trust-region subproblem via Steihaug-Toint truncated CG.
70///
71/// Minimize m(η) = <g, η> + ½ <H[η], η>  s.t. ||η||_M ≤ Δ
72///
73/// Returns the step η and whether the boundary was hit.
74fn solve_trs<M>(
75    manifold: &M,
76    x: &M::Point,
77    g: &M::Tangent,
78    hess: &dyn Fn(&M::Tangent) -> M::Tangent,
79    delta: Real,
80    max_cg: usize,
81    cg_tol: Real,
82) -> M::Tangent
83where
84    M: Manifold,
85{
86    let g_norm = manifold.norm(x, g);
87    let tol = cg_tol * g_norm;
88
89    // η_0 = 0, r_0 = g, p_0 = -r_0 = -g
90    let mut eta = manifold.zero_tangent(x);
91    let mut r = g.clone();
92    let mut p = -g.clone();
93
94    for _ in 0..max_cg {
95        // κ = <p, H[p]>
96        let hp = hess(&p);
97        let kappa = manifold.inner(x, &p, &hp);
98
99        // Negative curvature: move to boundary in direction p.
100        if kappa <= 0.0 {
101            return boundary_step(manifold, x, &eta, &p, delta);
102        }
103
104        let r_sq = manifold.inner(x, &r, &r);
105        let alpha = r_sq / kappa;
106
107        // Proposed step: η_new = η + α p
108        let eta_new = eta.clone() + p.clone() * alpha;
109
110        // Boundary hit: step exceeds trust radius.
111        if manifold.norm(x, &eta_new) >= delta {
112            return boundary_step(manifold, x, &eta, &p, delta);
113        }
114
115        eta = eta_new;
116        let r_new = r.clone() + hp * alpha;
117
118        if manifold.norm(x, &r_new) < tol {
119            return eta;
120        }
121
122        let r_sq_new = manifold.inner(x, &r_new, &r_new);
123        let beta = r_sq_new / r_sq;
124        p = -r_new.clone() + p * beta;
125        r = r_new;
126    }
127
128    eta
129}
130
131/// Find τ ≥ 0 such that ||η + τ p|| = Δ, then return η + τ p.
132///
133/// This is the boundary intercept: solve ||η||² + 2τ<η,p> + τ²||p||² = Δ².
134fn boundary_step<M>(
135    manifold: &M,
136    x: &M::Point,
137    eta: &M::Tangent,
138    p: &M::Tangent,
139    delta: Real,
140) -> M::Tangent
141where
142    M: Manifold,
143{
144    let eta_sq = manifold.inner(x, eta, eta);
145    let ep = manifold.inner(x, eta, p);
146    let p_sq = manifold.inner(x, p, p);
147
148    if p_sq < 1e-30 {
149        return eta.clone(); // p ≈ 0, can't move to boundary
150    }
151
152    // Solve: p_sq τ² + 2 ep τ + (eta_sq − Δ²) = 0
153    let discriminant = ep * ep - p_sq * (eta_sq - delta * delta);
154    if discriminant < 0.0 {
155        return eta.clone();
156    }
157    let sqrt_disc = {
158        #[cfg(feature = "std")]
159        {
160            discriminant.sqrt()
161        }
162        #[cfg(not(feature = "std"))]
163        {
164            libm::sqrt(discriminant)
165        }
166    };
167    let tau = (-ep + sqrt_disc) / p_sq;
168    eta.clone() + p.clone() * tau
169}
170
171/// Run Riemannian Trust Region.
172///
173/// # Arguments
174///
175/// - `manifold`: Must implement `Retraction` and `Connection`.
176/// - `cost`: Cost function.
177/// - `rgrad`: Riemannian gradient.
178/// - `ehvp`: Euclidean Hessian-vector product. Given a tangent direction `v`,
179///   returns the ambient Euclidean HVP `D²f(x)[v]` at the current `x`.
180///   The `Connection` impl converts this to the Riemannian HVP.
181/// - `x0`: Initial point.
182/// - `config`: Solver parameters.
183pub fn minimize_rtr<M, F, G, H>(
184    manifold: &M,
185    cost: F,
186    rgrad: G,
187    ehvp: H,
188    x0: M::Point,
189    config: &RTRConfig,
190) -> OptResult<M::Point>
191where
192    M: Manifold + Retraction + Connection,
193    F: Fn(&M::Point) -> Real,
194    G: Fn(&M::Point) -> M::Tangent,
195    H: Fn(&M::Point, &M::Tangent) -> M::Tangent,
196{
197    let mut x = x0;
198    let mut f_x = cost(&x);
199    let mut g = rgrad(&x);
200    let mut g_norm = manifold.norm(&x, &g);
201    let mut delta = config.delta_init;
202
203    for iter in 0..config.max_iters {
204        if g_norm < config.grad_tol {
205            return OptResult {
206                point: x,
207                value: f_x,
208                grad_norm: g_norm,
209                iterations: iter,
210                converged: true,
211            };
212        }
213
214        // Build Riemannian HVP: H_riem[v] = Connection::riemannian_hessian_vector_product(...)
215        let hess_riem = |v: &M::Tangent| -> M::Tangent {
216            manifold
217                .riemannian_hessian_vector_product(&x, &g, v, &|w| ehvp(&x, w))
218                .unwrap_or_else(|_| manifold.zero_tangent(&x))
219        };
220
221        // Solve the trust-region subproblem.
222        let eta = solve_trs(
223            manifold,
224            &x,
225            &g,
226            &hess_riem,
227            delta,
228            config.max_cg_iters,
229            config.cg_tol,
230        );
231
232        // Model decrease: m(0) - m(eta) = -<g, eta> - ½<H[eta], eta>
233        let h_eta = hess_riem(&eta);
234        let model_decrease = -manifold.inner(&x, &g, &eta) - 0.5 * manifold.inner(&x, &h_eta, &eta);
235
236        // Actual decrease: f(x) - f(Retract(x, eta))
237        let x_new = manifold.retract(&x, &eta);
238        let f_new = cost(&x_new);
239        let actual_decrease = f_x - f_new;
240
241        // Ratio ρ = actual / model.
242        let rho = if model_decrease.abs() < 1e-30 {
243            1.0 // model predicts ~0, accept any step
244        } else {
245            actual_decrease / model_decrease
246        };
247
248        // Accept or reject.
249        if rho > config.rho_min {
250            x = x_new;
251            f_x = f_new;
252            g = rgrad(&x);
253            g_norm = manifold.norm(&x, &g);
254        }
255
256        // Update trust radius.
257        let eta_norm = manifold.norm(&x, &eta);
258        if rho < 0.25 {
259            delta *= 0.25;
260        } else if rho > 0.75 && (delta - eta_norm).abs() < 1e-10 * delta {
261            delta = (2.0 * delta).min(config.delta_max);
262        }
263    }
264
265    OptResult {
266        point: x,
267        value: f_x,
268        grad_norm: g_norm,
269        iterations: config.max_iters,
270        converged: false,
271    }
272}