Skip to main content

gam_problem/
riemannian_retraction.rs

1use ndarray::{ArrayView1, ArrayViewMut1, s};
2
3const TWO_PI: f64 = std::f64::consts::PI * 2.0;
4
5pub trait Retraction {
6    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>);
7}
8
9#[derive(Clone, Copy, Debug, Default, PartialEq)]
10pub struct EuclideanRetraction;
11
12impl Retraction for EuclideanRetraction {
13    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
14        assert_eq!(base.len(), tangent.len());
15        let manifold = gam_geometry::EuclideanManifold::new(base.len());
16        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
17            .expect("Euclidean retraction dimensions were prevalidated");
18        for axis in 0..base.len() {
19            base[axis] = next[axis];
20        }
21    }
22}
23
24#[derive(Clone, Copy, Debug, Default, PartialEq)]
25pub struct CircleRetraction;
26
27impl Retraction for CircleRetraction {
28    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
29        assert_eq!(base.len(), 1);
30        assert_eq!(tangent.len(), 1);
31        let manifold = gam_geometry::CircleManifold::new();
32        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
33            .expect("Circle retraction dimensions were prevalidated");
34        base[0] = next[0];
35    }
36}
37
38#[derive(Clone, Copy, Debug, PartialEq)]
39pub struct SphereRetraction {
40    pub dim: usize,
41}
42
43impl Retraction for SphereRetraction {
44    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
45        assert_eq!(base.len(), self.dim);
46        assert_eq!(tangent.len(), self.dim);
47        assert!(
48            self.dim >= 2,
49            "SphereRetraction ambient dim must be at least 2"
50        );
51        let manifold = gam_geometry::SphereManifold::new(self.dim - 1);
52        let next = gam_geometry::RiemannianManifold::exp_map(&manifold, base.view(), tangent)
53            .expect("Sphere retraction dimensions were prevalidated");
54        for axis in 0..self.dim {
55            base[axis] = next[axis];
56        }
57    }
58}
59
60#[derive(Clone, Debug, PartialEq)]
61pub struct ProductRetraction {
62    pub parts: Vec<RetractionKind>,
63}
64
65impl ProductRetraction {
66    pub fn ambient_dim(&self) -> usize {
67        self.parts.iter().map(RetractionKind::ambient_dim).sum()
68    }
69
70    pub fn is_all_euclidean(&self) -> bool {
71        self.parts.iter().all(RetractionKind::is_euclidean)
72    }
73}
74
75impl Retraction for ProductRetraction {
76    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
77        assert_eq!(base.len(), tangent.len());
78        let mut offset = 0_usize;
79        for part in &self.parts {
80            let dim = part.ambient_dim();
81            let mut base_part = base.slice_mut(s![offset..offset + dim]);
82            let tangent_part = tangent.slice(s![offset..offset + dim]);
83            part.retract(&mut base_part, tangent_part);
84            offset += dim;
85        }
86        assert_eq!(offset, base.len());
87    }
88}
89
90#[derive(Clone, Debug, PartialEq)]
91pub enum RetractionKind {
92    Euclidean { dim: usize },
93    Circle,
94    Sphere { dim: usize },
95    Product(ProductRetraction),
96}
97
98impl RetractionKind {
99    pub fn euclidean(dim: usize) -> Self {
100        Self::Euclidean { dim }
101    }
102
103    pub fn ambient_dim(&self) -> usize {
104        match self {
105            Self::Euclidean { dim } | Self::Sphere { dim } => *dim,
106            Self::Circle => 1,
107            Self::Product(product) => product.ambient_dim(),
108        }
109    }
110
111    pub fn is_euclidean(&self) -> bool {
112        match self {
113            Self::Euclidean { .. } => true,
114            Self::Circle | Self::Sphere { .. } => false,
115            Self::Product(product) => product.is_all_euclidean(),
116        }
117    }
118
119    pub fn metric_weights(&self) -> Vec<f64> {
120        match self {
121            Self::Euclidean { dim } => vec![1.0; *dim],
122            Self::Circle => vec![1.0 / (TWO_PI * TWO_PI)],
123            Self::Sphere { dim } => {
124                let weight = 1.0 / (std::f64::consts::PI * std::f64::consts::PI);
125                vec![weight; *dim]
126            }
127            Self::Product(product) => {
128                let mut out = Vec::with_capacity(product.ambient_dim());
129                for part in &product.parts {
130                    out.extend(part.metric_weights());
131                }
132                out
133            }
134        }
135    }
136
137    /// Per-ambient-axis periodicity, mirroring
138    /// [`gam_terms::latent::LatentManifold::axis_periods`]. A `Circle`
139    /// retraction wraps modulo `2π`; an embedded `Sphere` retraction is smooth
140    /// with no cut and is reported non-periodic.
141    pub fn axis_periods(&self) -> Vec<Option<f64>> {
142        match self {
143            Self::Euclidean { dim } => vec![None; *dim],
144            Self::Circle => vec![Some(TWO_PI)],
145            Self::Sphere { dim } => vec![None; *dim],
146            Self::Product(product) => {
147                let mut out = Vec::with_capacity(product.ambient_dim());
148                for part in &product.parts {
149                    out.extend(part.axis_periods());
150                }
151                out
152            }
153        }
154    }
155}
156
157impl Retraction for RetractionKind {
158    fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
159        match self {
160            Self::Euclidean { .. } => EuclideanRetraction.retract(base, tangent),
161            Self::Circle => CircleRetraction.retract(base, tangent),
162            Self::Sphere { dim } => SphereRetraction { dim: *dim }.retract(base, tangent),
163            Self::Product(product) => product.retract(base, tangent),
164        }
165    }
166}
167
168#[derive(Clone, Debug, Default, PartialEq)]
169pub struct LatentRetractionRegistry {
170    block: Option<RetractionKind>,
171}
172
173impl LatentRetractionRegistry {
174    pub fn all_euclidean() -> Self {
175        Self { block: None }
176    }
177
178    pub fn new(block: RetractionKind) -> Self {
179        if block.is_euclidean() {
180            Self::all_euclidean()
181        } else {
182            Self { block: Some(block) }
183        }
184    }
185
186    pub fn is_all_euclidean(&self) -> bool {
187        self.block.is_none()
188    }
189
190    pub(crate) fn ambient_dim(&self, fallback_dim: usize) -> usize {
191        self.block
192            .as_ref()
193            .map_or(fallback_dim, RetractionKind::ambient_dim)
194    }
195
196    pub fn metric_weights(&self, fallback_dim: usize) -> Vec<f64> {
197        self.block
198            .as_ref()
199            .map_or_else(|| vec![1.0; fallback_dim], RetractionKind::metric_weights)
200    }
201
202    /// Per-ambient-axis periodicity for the override retraction, falling back
203    /// to all-non-periodic (`None`) of length `fallback_dim` when no override
204    /// is installed.
205    pub fn axis_periods(&self, fallback_dim: usize) -> Vec<Option<f64>> {
206        self.block
207            .as_ref()
208            .map_or_else(|| vec![None; fallback_dim], RetractionKind::axis_periods)
209    }
210
211    pub fn validate_dim(&self, latent_dim: usize, context: &str) -> Result<(), String> {
212        let dim = self.ambient_dim(latent_dim);
213        if dim != latent_dim {
214            return Err(format!(
215                "{context} retraction ambient dimension {dim} does not match latent d={latent_dim}"
216            ));
217        }
218        Ok(())
219    }
220
221    pub fn retract(&self, base: &mut ArrayViewMut1<f64>, tangent: ArrayView1<f64>) {
222        assert_eq!(base.len(), tangent.len());
223        if let Some(block) = self.block.as_ref() {
224            block.retract(base, tangent);
225        } else {
226            for (value, delta) in base.iter_mut().zip(tangent.iter()) {
227                *value += *delta;
228            }
229        }
230    }
231}