Skip to main content

cartan_optim/
rcg.rs

1// ~/cartan/cartan-optim/src/rcg.rs
2
3//! Riemannian Conjugate Gradient (Fletcher-Reeves and Polak-Ribière).
4//!
5//! ## Algorithm
6//!
7//! At iterate x_k with gradient g_k and conjugate direction p_k:
8//!   1. Armijo line search: find t_k
9//!   2. x_{k+1} = Retract(x_k, t_k · p_k)
10//!   3. g_{k+1} = rgrad(x_{k+1})
11//!   4. Transport p_k and g_k from x_k to x_{k+1} via parallel transport.
12//!   5. Compute β:
13//!      FR:  β = ||g_{k+1}||² / ||g_k||²
14//!      PR+: β = max(0, <g_{k+1}, g_{k+1} − PT(g_k)>_{x_{k+1}} / ||g_k||²)
15//!   6. p_{k+1} = −g_{k+1} + β · PT(p_k)
16//!
17//! ## References
18//!
19//! - Absil, Mahony, Sepulchre. "Optimization Algorithms on Matrix Manifolds."
20//!   Chapter 8 (Riemannian CG).
21//! - Sato. "Riemannian Conjugate Gradient Methods." SIAM J. Optim. 2022.
22
23use cartan_core::{Manifold, ParallelTransport, Real, Retraction};
24
25use crate::result::OptResult;
26
27/// Which β formula to use for the conjugate direction update.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum CgVariant {
30    /// Fletcher-Reeves: β = ||g_{k+1}||² / ||g_k||²
31    FletcherReeves,
32    /// Polak-Ribière+ (clamped): β = max(0, <g_{k+1}, g_{k+1}−PT(g_k)> / ||g_k||²)
33    #[default]
34    PolakRibiere,
35}
36
37/// Configuration for Riemannian Conjugate Gradient.
38#[derive(Debug, Clone)]
39pub struct RCGConfig {
40    /// Maximum number of iterations.
41    pub max_iters: usize,
42    /// Stop when ||grad f(x)|| < grad_tol.
43    pub grad_tol: Real,
44    /// Initial step size for Armijo backtracking.
45    pub init_step: Real,
46    /// Armijo sufficient decrease constant.
47    pub armijo_c: Real,
48    /// Backtracking factor (< 1).
49    pub armijo_beta: Real,
50    /// Maximum Armijo backtracking steps per iteration.
51    pub max_ls_iters: usize,
52    /// Fletcher-Reeves or Polak-Ribière+.
53    pub variant: CgVariant,
54    /// Restart to steepest descent every N iterations (0 = never force restart).
55    ///
56    /// Automatic restart still occurs when the conjugate direction is not
57    /// a descent direction.
58    pub restart_every: usize,
59}
60
61impl Default for RCGConfig {
62    fn default() -> Self {
63        Self {
64            max_iters: 1000,
65            grad_tol: 1e-6,
66            init_step: 1.0,
67            armijo_c: 1e-4,
68            armijo_beta: 0.5,
69            max_ls_iters: 50,
70            variant: CgVariant::PolakRibiere,
71            restart_every: 0,
72        }
73    }
74}
75
76/// Run Riemannian Conjugate Gradient.
77///
78/// # Arguments
79///
80/// - `manifold`: Must implement `Retraction` and `ParallelTransport`.
81/// - `cost`: Cost function f: M → R.
82/// - `rgrad`: Riemannian gradient (already projected onto T_x M).
83/// - `x0`: Initial point.
84/// - `config`: Solver parameters.
85pub fn minimize_rcg<M, F, G>(
86    manifold: &M,
87    cost: F,
88    rgrad: G,
89    x0: M::Point,
90    config: &RCGConfig,
91) -> OptResult<M::Point>
92where
93    M: Manifold + Retraction + ParallelTransport,
94    F: Fn(&M::Point) -> Real,
95    G: Fn(&M::Point) -> M::Tangent,
96{
97    let mut x = x0;
98    let mut f_x = cost(&x);
99    let mut g = rgrad(&x);
100    let mut g_sq = manifold.inner(&x, &g, &g);
101    let mut g_norm = {
102        #[cfg(feature = "std")]
103        {
104            g_sq.sqrt()
105        }
106        #[cfg(not(feature = "std"))]
107        {
108            libm::sqrt(g_sq)
109        }
110    };
111
112    // Initial direction: steepest descent.
113    let mut p = -g.clone();
114
115    for iter in 0..config.max_iters {
116        if g_norm < config.grad_tol {
117            return OptResult {
118                point: x,
119                value: f_x,
120                grad_norm: g_norm,
121                iterations: iter,
122                converged: true,
123            };
124        }
125
126        // Ensure p is a descent direction; if not, restart.
127        if manifold.inner(&x, &g, &p) >= 0.0 {
128            p = -g.clone();
129        }
130        let slope = manifold.inner(&x, &g, &p);
131
132        // Armijo backtracking line search.
133        let mut t = config.init_step;
134        let mut x_new = manifold.retract(&x, &(p.clone() * t));
135        let mut f_new = cost(&x_new);
136        for _ in 0..config.max_ls_iters {
137            if f_new <= f_x + config.armijo_c * t * slope {
138                break;
139            }
140            t *= config.armijo_beta;
141            x_new = manifold.retract(&x, &(p.clone() * t));
142            f_new = cost(&x_new);
143        }
144
145        // Capture state before stepping.
146        let x_prev = x.clone();
147        let g_prev = g.clone();
148        let g_sq_prev = g_sq;
149        let p_prev = p.clone();
150
151        // Accept step.
152        x = x_new;
153        f_x = f_new;
154        g = rgrad(&x);
155        g_sq = manifold.inner(&x, &g, &g);
156        g_norm = {
157            #[cfg(feature = "std")]
158            {
159                g_sq.sqrt()
160            }
161            #[cfg(not(feature = "std"))]
162            {
163                libm::sqrt(g_sq)
164            }
165        };
166
167        // Forced restart check.
168        let force_restart = config.restart_every > 0 && (iter + 1) % config.restart_every == 0;
169
170        let beta = if force_restart || g_sq_prev < 1e-30 {
171            0.0
172        } else {
173            match config.variant {
174                CgVariant::FletcherReeves => g_sq / g_sq_prev,
175                CgVariant::PolakRibiere => {
176                    // Transport g_prev from x_prev to x, compute PR+ β.
177                    let g_pt = manifold
178                        .transport(&x_prev, &x, &g_prev)
179                        .unwrap_or_else(|_| g.clone());
180                    let diff = g.clone() - g_pt; // g_{k+1} - PT(g_k)
181                    let num = manifold.inner(&x, &g, &diff);
182                    (num / g_sq_prev).max(0.0)
183                }
184            }
185        };
186
187        // Transport p_prev from x_prev to x and form new direction.
188        let p_pt = if beta.abs() < 1e-30 {
189            manifold.zero_tangent(&x)
190        } else {
191            manifold
192                .transport(&x_prev, &x, &p_prev)
193                .unwrap_or_else(|_| manifold.zero_tangent(&x))
194        };
195
196        p = -g.clone() + p_pt * beta;
197    }
198
199    OptResult {
200        point: x,
201        value: f_x,
202        grad_norm: g_norm,
203        iterations: config.max_iters,
204        converged: false,
205    }
206}