1use cartan_core::{Connection, Manifold, Real, Retraction};
33
34use crate::result::OptResult;
35
36#[derive(Debug, Clone)]
38pub struct RTRConfig {
39 pub max_iters: usize,
41 pub grad_tol: Real,
43 pub delta_init: Real,
45 pub delta_max: Real,
47 pub rho_min: Real,
49 pub max_cg_iters: usize,
51 pub cg_tol: Real,
53}
54
55impl Default for RTRConfig {
56 fn default() -> Self {
57 Self {
58 max_iters: 500,
59 grad_tol: 1e-6,
60 delta_init: 1.0,
61 delta_max: 8.0,
62 rho_min: 0.1,
63 max_cg_iters: 50,
64 cg_tol: 0.1,
65 }
66 }
67}
68
69fn solve_trs<M>(
75 manifold: &M,
76 x: &M::Point,
77 g: &M::Tangent,
78 hess: &dyn Fn(&M::Tangent) -> M::Tangent,
79 delta: Real,
80 max_cg: usize,
81 cg_tol: Real,
82) -> M::Tangent
83where
84 M: Manifold,
85{
86 let g_norm = manifold.norm(x, g);
87 let tol = cg_tol * g_norm;
88
89 let mut eta = manifold.zero_tangent(x);
91 let mut r = g.clone();
92 let mut p = -g.clone();
93
94 for _ in 0..max_cg {
95 let hp = hess(&p);
97 let kappa = manifold.inner(x, &p, &hp);
98
99 if kappa <= 0.0 {
101 return boundary_step(manifold, x, &eta, &p, delta);
102 }
103
104 let r_sq = manifold.inner(x, &r, &r);
105 let alpha = r_sq / kappa;
106
107 let eta_new = eta.clone() + p.clone() * alpha;
109
110 if manifold.norm(x, &eta_new) >= delta {
112 return boundary_step(manifold, x, &eta, &p, delta);
113 }
114
115 eta = eta_new;
116 let r_new = r.clone() + hp * alpha;
117
118 if manifold.norm(x, &r_new) < tol {
119 return eta;
120 }
121
122 let r_sq_new = manifold.inner(x, &r_new, &r_new);
123 let beta = r_sq_new / r_sq;
124 p = -r_new.clone() + p * beta;
125 r = r_new;
126 }
127
128 eta
129}
130
131fn boundary_step<M>(
135 manifold: &M,
136 x: &M::Point,
137 eta: &M::Tangent,
138 p: &M::Tangent,
139 delta: Real,
140) -> M::Tangent
141where
142 M: Manifold,
143{
144 let eta_sq = manifold.inner(x, eta, eta);
145 let ep = manifold.inner(x, eta, p);
146 let p_sq = manifold.inner(x, p, p);
147
148 if p_sq < 1e-30 {
149 return eta.clone(); }
151
152 let discriminant = ep * ep - p_sq * (eta_sq - delta * delta);
154 if discriminant < 0.0 {
155 return eta.clone();
156 }
157 let sqrt_disc = {
158 #[cfg(feature = "std")]
159 {
160 discriminant.sqrt()
161 }
162 #[cfg(not(feature = "std"))]
163 {
164 libm::sqrt(discriminant)
165 }
166 };
167 let tau = (-ep + sqrt_disc) / p_sq;
168 eta.clone() + p.clone() * tau
169}
170
171pub fn minimize_rtr<M, F, G, H>(
184 manifold: &M,
185 cost: F,
186 rgrad: G,
187 ehvp: H,
188 x0: M::Point,
189 config: &RTRConfig,
190) -> OptResult<M::Point>
191where
192 M: Manifold + Retraction + Connection,
193 F: Fn(&M::Point) -> Real,
194 G: Fn(&M::Point) -> M::Tangent,
195 H: Fn(&M::Point, &M::Tangent) -> M::Tangent,
196{
197 let mut x = x0;
198 let mut f_x = cost(&x);
199 let mut g = rgrad(&x);
200 let mut g_norm = manifold.norm(&x, &g);
201 let mut delta = config.delta_init;
202
203 for iter in 0..config.max_iters {
204 if g_norm < config.grad_tol {
205 return OptResult {
206 point: x,
207 value: f_x,
208 grad_norm: g_norm,
209 iterations: iter,
210 converged: true,
211 };
212 }
213
214 let hess_riem = |v: &M::Tangent| -> M::Tangent {
216 manifold
217 .riemannian_hessian_vector_product(&x, &g, v, &|w| ehvp(&x, w))
218 .unwrap_or_else(|_| manifold.zero_tangent(&x))
219 };
220
221 let eta = solve_trs(
223 manifold,
224 &x,
225 &g,
226 &hess_riem,
227 delta,
228 config.max_cg_iters,
229 config.cg_tol,
230 );
231
232 let h_eta = hess_riem(&eta);
234 let model_decrease = -manifold.inner(&x, &g, &eta) - 0.5 * manifold.inner(&x, &h_eta, &eta);
235
236 let x_new = manifold.retract(&x, &eta);
238 let f_new = cost(&x_new);
239 let actual_decrease = f_x - f_new;
240
241 let rho = if model_decrease.abs() < 1e-30 {
243 1.0 } else {
245 actual_decrease / model_decrease
246 };
247
248 if rho > config.rho_min {
250 x = x_new;
251 f_x = f_new;
252 g = rgrad(&x);
253 g_norm = manifold.norm(&x, &g);
254 }
255
256 let eta_norm = manifold.norm(&x, &eta);
258 if rho < 0.25 {
259 delta *= 0.25;
260 } else if rho > 0.75 && (delta - eta_norm).abs() < 1e-10 * delta {
261 delta = (2.0 * delta).min(config.delta_max);
262 }
263 }
264
265 OptResult {
266 point: x,
267 value: f_x,
268 grad_norm: g_norm,
269 iterations: config.max_iters,
270 converged: false,
271 }
272}