cartan_optim/rcg.rs
1// ~/cartan/cartan-optim/src/rcg.rs
2
3//! Riemannian Conjugate Gradient (Fletcher-Reeves and Polak-Ribière).
4//!
5//! ## Algorithm
6//!
7//! At iterate x_k with gradient g_k and conjugate direction p_k:
8//! 1. Armijo line search: find t_k
9//! 2. x_{k+1} = Retract(x_k, t_k · p_k)
10//! 3. g_{k+1} = rgrad(x_{k+1})
11//! 4. Transport p_k and g_k from x_k to x_{k+1} via parallel transport.
12//! 5. Compute β:
13//! FR: β = ||g_{k+1}||² / ||g_k||²
14//! PR+: β = max(0, <g_{k+1}, g_{k+1} − PT(g_k)>_{x_{k+1}} / ||g_k||²)
15//! 6. p_{k+1} = −g_{k+1} + β · PT(p_k)
16//!
17//! ## References
18//!
19//! - Absil, Mahony, Sepulchre. "Optimization Algorithms on Matrix Manifolds."
20//! Chapter 8 (Riemannian CG).
21//! - Sato. "Riemannian Conjugate Gradient Methods." SIAM J. Optim. 2022.
22
23use cartan_core::{Manifold, ParallelTransport, Real, Retraction};
24
25use crate::result::OptResult;
26
27/// Which β formula to use for the conjugate direction update.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum CgVariant {
30 /// Fletcher-Reeves: β = ||g_{k+1}||² / ||g_k||²
31 FletcherReeves,
32 /// Polak-Ribière+ (clamped): β = max(0, <g_{k+1}, g_{k+1}−PT(g_k)> / ||g_k||²)
33 #[default]
34 PolakRibiere,
35}
36
37/// Configuration for Riemannian Conjugate Gradient.
38#[derive(Debug, Clone)]
39pub struct RCGConfig {
40 /// Maximum number of iterations.
41 pub max_iters: usize,
42 /// Stop when ||grad f(x)|| < grad_tol.
43 pub grad_tol: Real,
44 /// Initial step size for Armijo backtracking.
45 pub init_step: Real,
46 /// Armijo sufficient decrease constant.
47 pub armijo_c: Real,
48 /// Backtracking factor (< 1).
49 pub armijo_beta: Real,
50 /// Maximum Armijo backtracking steps per iteration.
51 pub max_ls_iters: usize,
52 /// Fletcher-Reeves or Polak-Ribière+.
53 pub variant: CgVariant,
54 /// Restart to steepest descent every N iterations (0 = never force restart).
55 ///
56 /// Automatic restart still occurs when the conjugate direction is not
57 /// a descent direction.
58 pub restart_every: usize,
59}
60
61impl Default for RCGConfig {
62 fn default() -> Self {
63 Self {
64 max_iters: 1000,
65 grad_tol: 1e-6,
66 init_step: 1.0,
67 armijo_c: 1e-4,
68 armijo_beta: 0.5,
69 max_ls_iters: 50,
70 variant: CgVariant::PolakRibiere,
71 restart_every: 0,
72 }
73 }
74}
75
76/// Run Riemannian Conjugate Gradient.
77///
78/// # Arguments
79///
80/// - `manifold`: Must implement `Retraction` and `ParallelTransport`.
81/// - `cost`: Cost function f: M → R.
82/// - `rgrad`: Riemannian gradient (already projected onto T_x M).
83/// - `x0`: Initial point.
84/// - `config`: Solver parameters.
85pub fn minimize_rcg<M, F, G>(
86 manifold: &M,
87 cost: F,
88 rgrad: G,
89 x0: M::Point,
90 config: &RCGConfig,
91) -> OptResult<M::Point>
92where
93 M: Manifold + Retraction + ParallelTransport,
94 F: Fn(&M::Point) -> Real,
95 G: Fn(&M::Point) -> M::Tangent,
96{
97 let mut x = x0;
98 let mut f_x = cost(&x);
99 let mut g = rgrad(&x);
100 let mut g_sq = manifold.inner(&x, &g, &g);
101 let mut g_norm = {
102 #[cfg(feature = "std")]
103 {
104 g_sq.sqrt()
105 }
106 #[cfg(not(feature = "std"))]
107 {
108 libm::sqrt(g_sq)
109 }
110 };
111
112 // Initial direction: steepest descent.
113 let mut p = -g.clone();
114
115 for iter in 0..config.max_iters {
116 if g_norm < config.grad_tol {
117 return OptResult {
118 point: x,
119 value: f_x,
120 grad_norm: g_norm,
121 iterations: iter,
122 converged: true,
123 };
124 }
125
126 // Ensure p is a descent direction; if not, restart.
127 if manifold.inner(&x, &g, &p) >= 0.0 {
128 p = -g.clone();
129 }
130 let slope = manifold.inner(&x, &g, &p);
131
132 // Armijo backtracking line search.
133 let mut t = config.init_step;
134 let mut x_new = manifold.retract(&x, &(p.clone() * t));
135 let mut f_new = cost(&x_new);
136 for _ in 0..config.max_ls_iters {
137 if f_new <= f_x + config.armijo_c * t * slope {
138 break;
139 }
140 t *= config.armijo_beta;
141 x_new = manifold.retract(&x, &(p.clone() * t));
142 f_new = cost(&x_new);
143 }
144
145 // Capture state before stepping.
146 let x_prev = x.clone();
147 let g_prev = g.clone();
148 let g_sq_prev = g_sq;
149 let p_prev = p.clone();
150
151 // Accept step.
152 x = x_new;
153 f_x = f_new;
154 g = rgrad(&x);
155 g_sq = manifold.inner(&x, &g, &g);
156 g_norm = {
157 #[cfg(feature = "std")]
158 {
159 g_sq.sqrt()
160 }
161 #[cfg(not(feature = "std"))]
162 {
163 libm::sqrt(g_sq)
164 }
165 };
166
167 // Forced restart check.
168 let force_restart = config.restart_every > 0 && (iter + 1) % config.restart_every == 0;
169
170 let beta = if force_restart || g_sq_prev < 1e-30 {
171 0.0
172 } else {
173 match config.variant {
174 CgVariant::FletcherReeves => g_sq / g_sq_prev,
175 CgVariant::PolakRibiere => {
176 // Transport g_prev from x_prev to x, compute PR+ β.
177 let g_pt = manifold
178 .transport(&x_prev, &x, &g_prev)
179 .unwrap_or_else(|_| g.clone());
180 let diff = g.clone() - g_pt; // g_{k+1} - PT(g_k)
181 let num = manifold.inner(&x, &g, &diff);
182 (num / g_sq_prev).max(0.0)
183 }
184 }
185 };
186
187 // Transport p_prev from x_prev to x and form new direction.
188 let p_pt = if beta.abs() < 1e-30 {
189 manifold.zero_tangent(&x)
190 } else {
191 manifold
192 .transport(&x_prev, &x, &p_prev)
193 .unwrap_or_else(|_| manifold.zero_tangent(&x))
194 };
195
196 p = -g.clone() + p_pt * beta;
197 }
198
199 OptResult {
200 point: x,
201 value: f_x,
202 grad_norm: g_norm,
203 iterations: config.max_iters,
204 converged: false,
205 }
206}