Skip to main content

gam_problem/
riemannian_retraction.rs

1use ndarray::{ArrayView1, ArrayViewMut1, s};
2
3const TWO_PI: f64 = std::f64::consts::PI * 2.0;
4
5pub trait Retraction {
6    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>);
7}
8
9#[derive(Clone, Copy, Debug, Default, PartialEq)]
10pub struct EuclideanRetraction;
11
12impl Retraction for EuclideanRetraction {
13    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
14        assert_eq!(base.len(), tangent.len());
15        let manifold = gam_geometry::EuclideanManifold::new(base.len());
16        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
17            .expect("Euclidean retraction dimensions were prevalidated");
18        for axis in 0..base.len() {
19            base[axis] = next[axis];
20        }
21    }
22}
23
24#[derive(Clone, Copy, Debug, Default, PartialEq)]
25pub struct CircleRetraction;
26
27impl Retraction for CircleRetraction {
28    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
29        assert_eq!(base.len(), 1);
30        assert_eq!(tangent.len(), 1);
31        let manifold = gam_geometry::CircleManifold::new();
32        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
33            .expect("Circle retraction dimensions were prevalidated");
34        base[0] = next[0];
35    }
36}
37
38#[derive(Clone, Copy, Debug, PartialEq)]
39pub struct SphereRetraction {
40    pub dim: usize,
41}
42
43impl Retraction for SphereRetraction {
44    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
45        assert_eq!(base.len(), self.dim);
46        assert_eq!(tangent.len(), self.dim);
47        assert!(
48            self.dim >= 2,
49            "SphereRetraction ambient dim must be at least 2"
50        );
51        let manifold = gam_geometry::SphereManifold::new(self.dim - 1);
52        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
53            .expect("Sphere retraction dimensions were prevalidated");
54        for axis in 0..self.dim {
55            base[axis] = next[axis];
56        }
57    }
58}
59
60#[derive(Clone, Debug, PartialEq)]
61pub struct ProductRetraction {
62    pub parts: Vec<RetractionKind>,
63}
64
65impl ProductRetraction {
66    pub fn ambient_dim(&self) -> usize {
67        self.parts.iter().map(RetractionKind::ambient_dim).sum()
68    }
69
70    pub fn is_all_euclidean(&self) -> bool {
71        self.parts.iter().all(RetractionKind::is_euclidean)
72    }
73}
74
75impl Retraction for ProductRetraction {
76    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
77        assert_eq!(base.len(), tangent.len());
78        let mut offset = 0_usize;
79        for part in &self.parts {
80            let dim = part.ambient_dim();
81            let mut base_part = base.slice_mut(s![offset..offset + dim]);
82            let tangent_part = tangent.slice(s![offset..offset + dim]);
83            part.retract(&mut base_part, tangent_part);
84            offset += dim;
85        }
86        assert_eq!(offset, base.len());
87    }
88}
89
90#[derive(Clone, Debug, PartialEq)]
91pub enum RetractionKind {
92    Euclidean { dim: usize },
93    Circle,
94    Sphere { dim: usize },
95    Product(ProductRetraction),
96}
97
98impl RetractionKind {
99    pub fn euclidean(dim: usize) -> Self {
100        Self::Euclidean { dim }
101    }
102
103    pub fn ambient_dim(&self) -> usize {
104        match self {
105            Self::Euclidean { dim } | Self::Sphere { dim } => *dim,
106            Self::Circle => 1,
107            Self::Product(product) => product.ambient_dim(),
108        }
109    }
110
111    pub fn is_euclidean(&self) -> bool {
112        match self {
113            Self::Euclidean { .. } => true,
114            Self::Circle | Self::Sphere { .. } => false,
115            Self::Product(product) => product.is_all_euclidean(),
116        }
117    }
118
119    pub fn metric_weights(&self) -> Vec<f64> {
120        match self {
121            Self::Euclidean { dim } => vec![1.0; *dim],
122            Self::Circle => vec![1.0 / (TWO_PI * TWO_PI)],
123            Self::Sphere { dim } => {
124                let weight = 1.0 / (std::f64::consts::PI * std::f64::consts::PI);
125                vec![weight; *dim]
126            }
127            Self::Product(product) => {
128                let mut out = Vec::with_capacity(product.ambient_dim());
129                for part in &product.parts {
130                    out.extend(part.metric_weights());
131                }
132                out
133            }
134        }
135    }
136
137    /// Per-ambient-axis periodicity, mirroring
138    /// [`gam_terms::latent::LatentManifold::axis_periods`]. A `Circle`
139    /// retraction wraps modulo `2π`; an embedded `Sphere` retraction is smooth
140    /// with no cut and is reported non-periodic.
141    pub fn axis_periods(&self) -> Vec<Option<f64>> {
142        match self {
143            Self::Euclidean { dim } => vec![None; *dim],
144            Self::Circle => vec![Some(TWO_PI)],
145            Self::Sphere { dim } => vec![None; *dim],
146            Self::Product(product) => {
147                let mut out = Vec::with_capacity(product.ambient_dim());
148                for part in &product.parts {
149                    out.extend(part.axis_periods());
150                }
151                out
152            }
153        }
154    }
155}
156
157impl Retraction for RetractionKind {
158    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
159        match self {
160            Self::Euclidean { .. } => EuclideanRetraction.retract(base, tangent),
161            Self::Circle => CircleRetraction.retract(base, tangent),
162            Self::Sphere { dim } => SphereRetraction { dim: *dim }.retract(base, tangent),
163            Self::Product(product) => product.retract(base, tangent),
164        }
165    }
166}
167
168#[derive(Clone, Debug, Default, PartialEq)]
169pub struct LatentRetractionRegistry {
170    block: Option<RetractionKind>,
171}
172
173impl LatentRetractionRegistry {
174    pub fn all_euclidean() -> Self {
175        Self { block: None }
176    }
177
178    pub fn new(block: RetractionKind) -> Self {
179        if block.is_euclidean() {
180            Self::all_euclidean()
181        } else {
182            Self { block: Some(block) }
183        }
184    }
185
186    pub fn is_all_euclidean(&self) -> bool {
187        self.block.is_none()
188    }
189
190    pub(crate) fn ambient_dim(&self, fallback_dim: usize) -> usize {
191        self.block
192            .as_ref()
193            .map_or(fallback_dim, RetractionKind::ambient_dim)
194    }
195
196    pub fn metric_weights(&self, fallback_dim: usize) -> Vec<f64> {
197        self.block
198            .as_ref()
199            .map_or_else(|| vec![1.0; fallback_dim], RetractionKind::metric_weights)
200    }
201
202    /// Per-ambient-axis periodicity for the override retraction, falling back
203    /// to all-non-periodic (`None`) of length `fallback_dim` when no override
204    /// is installed.
205    pub fn axis_periods(&self, fallback_dim: usize) -> Vec<Option<f64>> {
206        self.block
207            .as_ref()
208            .map_or_else(|| vec![None; fallback_dim], RetractionKind::axis_periods)
209    }
210
211    pub fn validate_dim(&self, latent_dim: usize, context: &str) -> Result<(), String> {
212        let dim = self.ambient_dim(latent_dim);
213        if dim != latent_dim {
214            return Err(format!(
215                "{context} retraction ambient dimension {dim} does not match latent d={latent_dim}"
216            ));
217        }
218        Ok(())
219    }
220
221    pub fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
222        assert_eq!(base.len(), tangent.len());
223        if let Some(block) = self.block.as_ref() {
224            block.retract(base, tangent);
225        } else {
226            for (value, delta) in base.iter_mut().zip(tangent.iter()) {
227                *value += *delta;
228            }
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    // ── RetractionKind::ambient_dim ───────────────────────────────────────────
238
239    #[test]
240    fn euclidean_ambient_dim() {
241        assert_eq!(RetractionKind::euclidean(4).ambient_dim(), 4);
242        assert_eq!(RetractionKind::euclidean(0).ambient_dim(), 0);
243    }
244
245    #[test]
246    fn circle_ambient_dim_is_one() {
247        assert_eq!(RetractionKind::Circle.ambient_dim(), 1);
248    }
249
250    #[test]
251    fn sphere_ambient_dim() {
252        assert_eq!(RetractionKind::Sphere { dim: 3 }.ambient_dim(), 3);
253    }
254
255    #[test]
256    fn product_ambient_dim_is_sum() {
257        let product = ProductRetraction {
258            parts: vec![
259                RetractionKind::euclidean(2),
260                RetractionKind::Circle,
261                RetractionKind::Sphere { dim: 3 },
262            ],
263        };
264        assert_eq!(product.ambient_dim(), 6); // 2 + 1 + 3
265    }
266
267    // ── RetractionKind::is_euclidean ──────────────────────────────────────────
268
269    #[test]
270    fn euclidean_is_euclidean() {
271        assert!(RetractionKind::euclidean(5).is_euclidean());
272    }
273
274    #[test]
275    fn circle_is_not_euclidean() {
276        assert!(!RetractionKind::Circle.is_euclidean());
277    }
278
279    #[test]
280    fn sphere_is_not_euclidean() {
281        assert!(!RetractionKind::Sphere { dim: 3 }.is_euclidean());
282    }
283
284    #[test]
285    fn all_euclidean_product_is_euclidean() {
286        let product = ProductRetraction {
287            parts: vec![RetractionKind::euclidean(2), RetractionKind::euclidean(3)],
288        };
289        assert!(RetractionKind::Product(product).is_euclidean());
290    }
291
292    #[test]
293    fn mixed_product_is_not_euclidean() {
294        let product = ProductRetraction {
295            parts: vec![RetractionKind::euclidean(2), RetractionKind::Circle],
296        };
297        assert!(!RetractionKind::Product(product).is_euclidean());
298    }
299
300    // ── RetractionKind::metric_weights ────────────────────────────────────────
301
302    #[test]
303    fn euclidean_metric_weights_are_all_one() {
304        let w = RetractionKind::euclidean(3).metric_weights();
305        assert_eq!(w, vec![1.0, 1.0, 1.0]);
306    }
307
308    #[test]
309    fn circle_metric_weight_is_inv_twopi_sq() {
310        let w = RetractionKind::Circle.metric_weights();
311        let expected = 1.0 / (TWO_PI * TWO_PI);
312        assert_eq!(w.len(), 1);
313        assert!((w[0] - expected).abs() < 1e-15);
314    }
315
316    // ── RetractionKind::axis_periods ─────────────────────────────────────────
317
318    #[test]
319    fn euclidean_axis_periods_all_none() {
320        let p = RetractionKind::euclidean(3).axis_periods();
321        assert_eq!(p, vec![None, None, None]);
322    }
323
324    #[test]
325    fn circle_axis_period_is_two_pi() {
326        let p = RetractionKind::Circle.axis_periods();
327        assert_eq!(p.len(), 1);
328        assert!((p[0].unwrap() - TWO_PI).abs() < 1e-15);
329    }
330
331    #[test]
332    fn sphere_axis_periods_all_none() {
333        let p = RetractionKind::Sphere { dim: 3 }.axis_periods();
334        assert_eq!(p, vec![None, None, None]);
335    }
336
337    // ── LatentRetractionRegistry ──────────────────────────────────────────────
338
339    #[test]
340    fn all_euclidean_registry_reports_is_all_euclidean() {
341        let r = LatentRetractionRegistry::all_euclidean();
342        assert!(r.is_all_euclidean());
343    }
344
345    #[test]
346    fn circle_registry_is_not_all_euclidean() {
347        let r = LatentRetractionRegistry::new(RetractionKind::Circle);
348        assert!(!r.is_all_euclidean());
349    }
350
351    #[test]
352    fn euclidean_registry_collapses_to_all_euclidean() {
353        // Constructing with a Euclidean kind must collapse to all-Euclidean.
354        let r = LatentRetractionRegistry::new(RetractionKind::euclidean(3));
355        assert!(r.is_all_euclidean());
356    }
357
358    #[test]
359    fn registry_validate_dim_ok_when_matching() {
360        let r = LatentRetractionRegistry::new(RetractionKind::Circle);
361        assert!(r.validate_dim(1, "ctx").is_ok());
362    }
363
364    #[test]
365    fn registry_validate_dim_err_when_mismatched() {
366        let r = LatentRetractionRegistry::new(RetractionKind::Circle);
367        let e = r.validate_dim(3, "ctx").unwrap_err();
368        assert!(e.contains("ctx"), "error should mention context: {e}");
369    }
370
371    #[test]
372    fn registry_euclidean_retract_adds_tangent() {
373        let r = LatentRetractionRegistry::all_euclidean();
374        let mut base = ndarray::array![1.0, 2.0, 3.0];
375        let tangent = ndarray::array![0.1, -0.2, 0.5];
376        r.retract(&mut base.view_mut(), tangent.view());
377        assert!((base[0] - 1.1).abs() < 1e-15);
378        assert!((base[1] - 1.8).abs() < 1e-15);
379        assert!((base[2] - 3.5).abs() < 1e-15);
380    }
381}