1use ndarray::{ArrayView1, ArrayViewMut1, s};
2
3const TWO_PI: f64 = std::f64::consts::PI * 2.0;
4
5pub trait Retraction {
6 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>);
7}
8
9#[derive(Clone, Copy, Debug, Default, PartialEq)]
10pub struct EuclideanRetraction;
11
12impl Retraction for EuclideanRetraction {
13 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
14 assert_eq!(base.len(), tangent.len());
15 let manifold = gam_geometry::EuclideanManifold::new(base.len());
16 let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
17 .expect("Euclidean retraction dimensions were prevalidated");
18 for axis in 0..base.len() {
19 base[axis] = next[axis];
20 }
21 }
22}
23
24#[derive(Clone, Copy, Debug, Default, PartialEq)]
25pub struct CircleRetraction;
26
27impl Retraction for CircleRetraction {
28 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
29 assert_eq!(base.len(), 1);
30 assert_eq!(tangent.len(), 1);
31 let manifold = gam_geometry::CircleManifold::new();
32 let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
33 .expect("Circle retraction dimensions were prevalidated");
34 base[0] = next[0];
35 }
36}
37
38#[derive(Clone, Copy, Debug, PartialEq)]
39pub struct SphereRetraction {
40 pub dim: usize,
41}
42
43impl Retraction for SphereRetraction {
44 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
45 assert_eq!(base.len(), self.dim);
46 assert_eq!(tangent.len(), self.dim);
47 assert!(
48 self.dim >= 2,
49 "SphereRetraction ambient dim must be at least 2"
50 );
51 let manifold = gam_geometry::SphereManifold::new(self.dim - 1);
52 let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
53 .expect("Sphere retraction dimensions were prevalidated");
54 for axis in 0..self.dim {
55 base[axis] = next[axis];
56 }
57 }
58}
59
60#[derive(Clone, Debug, PartialEq)]
61pub struct ProductRetraction {
62 pub parts: Vec<RetractionKind>,
63}
64
65impl ProductRetraction {
66 pub fn ambient_dim(&self) -> usize {
67 self.parts.iter().map(RetractionKind::ambient_dim).sum()
68 }
69
70 pub fn is_all_euclidean(&self) -> bool {
71 self.parts.iter().all(RetractionKind::is_euclidean)
72 }
73}
74
75impl Retraction for ProductRetraction {
76 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
77 assert_eq!(base.len(), tangent.len());
78 let mut offset = 0_usize;
79 for part in &self.parts {
80 let dim = part.ambient_dim();
81 let mut base_part = base.slice_mut(s![offset..offset + dim]);
82 let tangent_part = tangent.slice(s![offset..offset + dim]);
83 part.retract(&mut base_part, tangent_part);
84 offset += dim;
85 }
86 assert_eq!(offset, base.len());
87 }
88}
89
90#[derive(Clone, Debug, PartialEq)]
91pub enum RetractionKind {
92 Euclidean { dim: usize },
93 Circle,
94 Sphere { dim: usize },
95 Product(ProductRetraction),
96}
97
98impl RetractionKind {
99 pub fn euclidean(dim: usize) -> Self {
100 Self::Euclidean { dim }
101 }
102
103 pub fn ambient_dim(&self) -> usize {
104 match self {
105 Self::Euclidean { dim } | Self::Sphere { dim } => *dim,
106 Self::Circle => 1,
107 Self::Product(product) => product.ambient_dim(),
108 }
109 }
110
111 pub fn is_euclidean(&self) -> bool {
112 match self {
113 Self::Euclidean { .. } => true,
114 Self::Circle | Self::Sphere { .. } => false,
115 Self::Product(product) => product.is_all_euclidean(),
116 }
117 }
118
119 pub fn metric_weights(&self) -> Vec<f64> {
120 match self {
121 Self::Euclidean { dim } => vec![1.0; *dim],
122 Self::Circle => vec![1.0 / (TWO_PI * TWO_PI)],
123 Self::Sphere { dim } => {
124 let weight = 1.0 / (std::f64::consts::PI * std::f64::consts::PI);
125 vec![weight; *dim]
126 }
127 Self::Product(product) => {
128 let mut out = Vec::with_capacity(product.ambient_dim());
129 for part in &product.parts {
130 out.extend(part.metric_weights());
131 }
132 out
133 }
134 }
135 }
136
137 pub fn axis_periods(&self) -> Vec<Option<f64>> {
142 match self {
143 Self::Euclidean { dim } => vec![None; *dim],
144 Self::Circle => vec![Some(TWO_PI)],
145 Self::Sphere { dim } => vec![None; *dim],
146 Self::Product(product) => {
147 let mut out = Vec::with_capacity(product.ambient_dim());
148 for part in &product.parts {
149 out.extend(part.axis_periods());
150 }
151 out
152 }
153 }
154 }
155}
156
157impl Retraction for RetractionKind {
158 fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
159 match self {
160 Self::Euclidean { .. } => EuclideanRetraction.retract(base, tangent),
161 Self::Circle => CircleRetraction.retract(base, tangent),
162 Self::Sphere { dim } => SphereRetraction { dim: *dim }.retract(base, tangent),
163 Self::Product(product) => product.retract(base, tangent),
164 }
165 }
166}
167
168#[derive(Clone, Debug, Default, PartialEq)]
169pub struct LatentRetractionRegistry {
170 block: Option<RetractionKind>,
171}
172
173impl LatentRetractionRegistry {
174 pub fn all_euclidean() -> Self {
175 Self { block: None }
176 }
177
178 pub fn new(block: RetractionKind) -> Self {
179 if block.is_euclidean() {
180 Self::all_euclidean()
181 } else {
182 Self { block: Some(block) }
183 }
184 }
185
186 pub fn is_all_euclidean(&self) -> bool {
187 self.block.is_none()
188 }
189
190 pub(crate) fn ambient_dim(&self, fallback_dim: usize) -> usize {
191 self.block
192 .as_ref()
193 .map_or(fallback_dim, RetractionKind::ambient_dim)
194 }
195
196 pub fn metric_weights(&self, fallback_dim: usize) -> Vec<f64> {
197 self.block
198 .as_ref()
199 .map_or_else(|| vec![1.0; fallback_dim], RetractionKind::metric_weights)
200 }
201
202 pub fn axis_periods(&self, fallback_dim: usize) -> Vec<Option<f64>> {
206 self.block
207 .as_ref()
208 .map_or_else(|| vec![None; fallback_dim], RetractionKind::axis_periods)
209 }
210
211 pub fn validate_dim(&self, latent_dim: usize, context: &str) -> Result<(), String> {
212 let dim = self.ambient_dim(latent_dim);
213 if dim != latent_dim {
214 return Err(format!(
215 "{context} retraction ambient dimension {dim} does not match latent d={latent_dim}"
216 ));
217 }
218 Ok(())
219 }
220
221 pub fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
222 assert_eq!(base.len(), tangent.len());
223 if let Some(block) = self.block.as_ref() {
224 block.retract(base, tangent);
225 } else {
226 for (value, delta) in base.iter_mut().zip(tangent.iter()) {
227 *value += *delta;
228 }
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
240 fn euclidean_ambient_dim() {
241 assert_eq!(RetractionKind::euclidean(4).ambient_dim(), 4);
242 assert_eq!(RetractionKind::euclidean(0).ambient_dim(), 0);
243 }
244
245 #[test]
246 fn circle_ambient_dim_is_one() {
247 assert_eq!(RetractionKind::Circle.ambient_dim(), 1);
248 }
249
250 #[test]
251 fn sphere_ambient_dim() {
252 assert_eq!(RetractionKind::Sphere { dim: 3 }.ambient_dim(), 3);
253 }
254
255 #[test]
256 fn product_ambient_dim_is_sum() {
257 let product = ProductRetraction {
258 parts: vec![
259 RetractionKind::euclidean(2),
260 RetractionKind::Circle,
261 RetractionKind::Sphere { dim: 3 },
262 ],
263 };
264 assert_eq!(product.ambient_dim(), 6); }
266
267 #[test]
270 fn euclidean_is_euclidean() {
271 assert!(RetractionKind::euclidean(5).is_euclidean());
272 }
273
274 #[test]
275 fn circle_is_not_euclidean() {
276 assert!(!RetractionKind::Circle.is_euclidean());
277 }
278
279 #[test]
280 fn sphere_is_not_euclidean() {
281 assert!(!RetractionKind::Sphere { dim: 3 }.is_euclidean());
282 }
283
284 #[test]
285 fn all_euclidean_product_is_euclidean() {
286 let product = ProductRetraction {
287 parts: vec![RetractionKind::euclidean(2), RetractionKind::euclidean(3)],
288 };
289 assert!(RetractionKind::Product(product).is_euclidean());
290 }
291
292 #[test]
293 fn mixed_product_is_not_euclidean() {
294 let product = ProductRetraction {
295 parts: vec![RetractionKind::euclidean(2), RetractionKind::Circle],
296 };
297 assert!(!RetractionKind::Product(product).is_euclidean());
298 }
299
300 #[test]
303 fn euclidean_metric_weights_are_all_one() {
304 let w = RetractionKind::euclidean(3).metric_weights();
305 assert_eq!(w, vec![1.0, 1.0, 1.0]);
306 }
307
308 #[test]
309 fn circle_metric_weight_is_inv_twopi_sq() {
310 let w = RetractionKind::Circle.metric_weights();
311 let expected = 1.0 / (TWO_PI * TWO_PI);
312 assert_eq!(w.len(), 1);
313 assert!((w[0] - expected).abs() < 1e-15);
314 }
315
316 #[test]
319 fn euclidean_axis_periods_all_none() {
320 let p = RetractionKind::euclidean(3).axis_periods();
321 assert_eq!(p, vec![None, None, None]);
322 }
323
324 #[test]
325 fn circle_axis_period_is_two_pi() {
326 let p = RetractionKind::Circle.axis_periods();
327 assert_eq!(p.len(), 1);
328 assert!((p[0].unwrap() - TWO_PI).abs() < 1e-15);
329 }
330
331 #[test]
332 fn sphere_axis_periods_all_none() {
333 let p = RetractionKind::Sphere { dim: 3 }.axis_periods();
334 assert_eq!(p, vec![None, None, None]);
335 }
336
337 #[test]
340 fn all_euclidean_registry_reports_is_all_euclidean() {
341 let r = LatentRetractionRegistry::all_euclidean();
342 assert!(r.is_all_euclidean());
343 }
344
345 #[test]
346 fn circle_registry_is_not_all_euclidean() {
347 let r = LatentRetractionRegistry::new(RetractionKind::Circle);
348 assert!(!r.is_all_euclidean());
349 }
350
351 #[test]
352 fn euclidean_registry_collapses_to_all_euclidean() {
353 let r = LatentRetractionRegistry::new(RetractionKind::euclidean(3));
355 assert!(r.is_all_euclidean());
356 }
357
358 #[test]
359 fn registry_validate_dim_ok_when_matching() {
360 let r = LatentRetractionRegistry::new(RetractionKind::Circle);
361 assert!(r.validate_dim(1, "ctx").is_ok());
362 }
363
364 #[test]
365 fn registry_validate_dim_err_when_mismatched() {
366 let r = LatentRetractionRegistry::new(RetractionKind::Circle);
367 let e = r.validate_dim(3, "ctx").unwrap_err();
368 assert!(e.contains("ctx"), "error should mention context: {e}");
369 }
370
371 #[test]
372 fn registry_euclidean_retract_adds_tangent() {
373 let r = LatentRetractionRegistry::all_euclidean();
374 let mut base = ndarray::array![1.0, 2.0, 3.0];
375 let tangent = ndarray::array![0.1, -0.2, 0.5];
376 r.retract(&mut base.view_mut(), tangent.view());
377 assert!((base[0] - 1.1).abs() < 1e-15);
378 assert!((base[1] - 1.8).abs() < 1e-15);
379 assert!((base[2] - 3.5).abs() < 1e-15);
380 }
381}