Skip to main content

cartan_optim/
rgd.rs

1// ~/cartan/cartan-optim/src/rgd.rs
2
3//! Riemannian Gradient Descent with Armijo line search.
4//!
5//! ## Algorithm
6//!
7//! Given iterate x_k, Riemannian gradient g_k = grad f(x_k):
8//!   1. d_k = -g_k  (steepest descent)
9//!   2. Find step t_k via Armijo backtracking:
10//!      f(Retract(x_k, t·d_k)) ≤ f(x_k) + c · t · <g_k, d_k>_x_k
11//!   3. x_{k+1} = Retract(x_k, t_k · d_k)
12//!
13//! ## Convergence
14//!
15//! For smooth f on a complete Riemannian manifold with Armijo line search,
16//! the iterates satisfy lim inf ||grad f(x_k)|| = 0. For geodesically convex f,
17//! the method converges to the global minimum.
18//!
19//! ## References
20//!
21//! - Absil, Mahony, Sepulchre. "Optimization Algorithms on Matrix Manifolds."
22//!   Chapter 4 (line search methods).
23//! - Boumal. "An Introduction to Optimization on Smooth Manifolds."
24//!   Chapter 4 (gradient descent).
25
26use cartan_core::{Manifold, Real, Retraction};
27
28use crate::result::OptResult;
29
30/// Configuration for Riemannian Gradient Descent.
31#[derive(Debug, Clone)]
32pub struct RGDConfig {
33    /// Maximum number of iterations.
34    pub max_iters: usize,
35    /// Stop when ||grad f(x)|| < grad_tol.
36    pub grad_tol: Real,
37    /// Initial step size for Armijo backtracking.
38    pub init_step: Real,
39    /// Armijo sufficient decrease constant (typically 1e-4 to 0.5).
40    pub armijo_c: Real,
41    /// Backtracking factor (< 1, typically 0.5).
42    pub armijo_beta: Real,
43    /// Maximum number of backtracking steps per iteration.
44    pub max_ls_iters: usize,
45}
46
47impl Default for RGDConfig {
48    fn default() -> Self {
49        Self {
50            max_iters: 1000,
51            grad_tol: 1e-6,
52            init_step: 1.0,
53            armijo_c: 1e-4,
54            armijo_beta: 0.5,
55            max_ls_iters: 50,
56        }
57    }
58}
59
60/// Run Riemannian Gradient Descent.
61///
62/// # Arguments
63///
64/// - `manifold`: The manifold (must implement `Retraction`).
65/// - `cost`: Cost function f: M → R.
66/// - `rgrad`: Riemannian gradient of f at x (already projected onto T_x M).
67///   Typically: `|x| manifold.project_tangent(x, &euclidean_grad(x))`.
68/// - `x0`: Initial point.
69/// - `config`: Solver parameters.
70///
71/// # Returns
72///
73/// [`OptResult`] with the final point, value, gradient norm, and convergence flag.
74pub fn minimize_rgd<M, F, G>(
75    manifold: &M,
76    cost: F,
77    rgrad: G,
78    x0: M::Point,
79    config: &RGDConfig,
80) -> OptResult<M::Point>
81where
82    M: Manifold + Retraction,
83    F: Fn(&M::Point) -> Real,
84    G: Fn(&M::Point) -> M::Tangent,
85{
86    let mut x = x0;
87    let mut f_x = cost(&x);
88    let mut g = rgrad(&x);
89    let mut g_norm = manifold.norm(&x, &g);
90
91    for iter in 0..config.max_iters {
92        if g_norm < config.grad_tol {
93            return OptResult {
94                point: x,
95                value: f_x,
96                grad_norm: g_norm,
97                iterations: iter,
98                converged: true,
99            };
100        }
101
102        // Steepest descent direction: d = -g.
103        let d = -g.clone();
104
105        // Armijo backtracking line search.
106        // Sufficient decrease: f(Retract(x, t·d)) ≤ f(x) + c·t·<g, d>_x
107        // Since d = -g: <g, d> = -||g||² < 0 (descent direction). ✓
108        let slope = manifold.inner(&x, &g, &d); // = -||g||²
109        let mut t = config.init_step;
110        let f_threshold_base = f_x + config.armijo_c * slope; // at t=1
111
112        for _ in 0..config.max_ls_iters {
113            let x_trial = manifold.retract(&x, &(d.clone() * t));
114            let f_trial = cost(&x_trial);
115            if f_trial <= f_x + config.armijo_c * t * slope {
116                x = x_trial;
117                f_x = f_trial;
118                break;
119            }
120            t *= config.armijo_beta;
121        }
122
123        // Suppress unused variable warning: f_threshold_base is for documentation.
124        let _ = f_threshold_base;
125
126        g = rgrad(&x);
127        g_norm = manifold.norm(&x, &g);
128    }
129
130    OptResult {
131        point: x,
132        value: f_x,
133        grad_norm: g_norm,
134        iterations: config.max_iters,
135        converged: false,
136    }
137}