Skip to main content

fdars_core/
warping.rs

1//! Warping function utilities and Hilbert sphere geometry.
2//!
3//! This module provides operations on warping (reparameterization) functions,
4//! including their Hilbert sphere representation via `ψ(t) = √γ'(t)`.
5//!
6//! Key capabilities:
7//! - [`gam_to_psi`] / [`psi_to_gam`] — Convert between warping functions and sphere
8//! - [`exp_map_sphere`] / [`inv_exp_map_sphere`] — Riemannian exponential / log maps
9//! - [`normalize_warp`] / [`invert_gamma`] — Warp normalization and inversion
10//! - [`phase_distance`] — Geodesic distance from a warp to the identity
11
12use crate::helpers::{cumulative_trapz, gradient_uniform, linear_interp, trapz};
13
14/// Ensure γ is a valid warping: monotone non-decreasing, with correct boundary values.
15pub fn normalize_warp(gamma: &mut [f64], argvals: &[f64]) {
16    let n = gamma.len();
17    if n == 0 {
18        return;
19    }
20
21    // Fix boundaries
22    gamma[0] = argvals[0];
23    gamma[n - 1] = argvals[n - 1];
24
25    // Enforce monotonicity
26    for i in 1..n {
27        if gamma[i] < gamma[i - 1] {
28            gamma[i] = gamma[i - 1];
29        }
30    }
31}
32
33/// Convert warping function to Hilbert sphere representation: ψ = √γ'.
34pub fn gam_to_psi(gam: &[f64], h: f64) -> Vec<f64> {
35    gradient_uniform(gam, h)
36        .iter()
37        .map(|&g| g.max(0.0).sqrt())
38        .collect()
39}
40
41/// Convert ψ back to warping function: γ = cumtrapz(ψ²), normalized to \[0,1\].
42pub fn psi_to_gam(psi: &[f64], time: &[f64]) -> Vec<f64> {
43    let psi_sq: Vec<f64> = psi.iter().map(|&p| p * p).collect();
44    let gam = cumulative_trapz(&psi_sq, time);
45    let min_val = gam.iter().cloned().fold(f64::INFINITY, f64::min);
46    let max_val = gam.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
47    let range = (max_val - min_val).max(1e-10);
48    gam.iter().map(|&v| (v - min_val) / range).collect()
49}
50
51/// L2 inner product: ∫ψ₁·ψ₂ dt via trapezoidal rule.
52pub fn inner_product_l2(psi1: &[f64], psi2: &[f64], time: &[f64]) -> f64 {
53    let prod: Vec<f64> = psi1.iter().zip(psi2.iter()).map(|(&a, &b)| a * b).collect();
54    trapz(&prod, time)
55}
56
57/// L2 norm: √(∫ψ² dt).
58pub fn l2_norm_l2(psi: &[f64], time: &[f64]) -> f64 {
59    inner_product_l2(psi, psi, time).max(0.0).sqrt()
60}
61
62/// Inverse exponential (log) map on the Hilbert sphere.
63/// Returns tangent vector at `mu` pointing toward `psi`.
64pub fn inv_exp_map_sphere(mu: &[f64], psi: &[f64], time: &[f64]) -> Vec<f64> {
65    let ip = inner_product_l2(mu, psi, time).clamp(-1.0, 1.0);
66    let theta = ip.acos();
67    if theta < 1e-10 {
68        vec![0.0; mu.len()]
69    } else {
70        let coeff = theta / theta.sin();
71        let cos_theta = theta.cos();
72        mu.iter()
73            .zip(psi.iter())
74            .map(|(&m, &p)| coeff * (p - cos_theta * m))
75            .collect()
76    }
77}
78
79/// Exponential map on the Hilbert sphere.
80/// Moves from `psi` along tangent vector `v`.
81pub fn exp_map_sphere(psi: &[f64], v: &[f64], time: &[f64]) -> Vec<f64> {
82    let v_norm = l2_norm_l2(v, time);
83    if v_norm < 1e-10 {
84        psi.to_vec()
85    } else {
86        let cos_n = v_norm.cos();
87        let sin_n = v_norm.sin();
88        psi.iter()
89            .zip(v.iter())
90            .map(|(&p, &vi)| cos_n * p + sin_n * vi / v_norm)
91            .collect()
92    }
93}
94
95/// Invert a warping function: find γ⁻¹ such that γ⁻¹(γ(t)) = t.
96/// `gam` and `time` are both on \[0,1\].
97pub fn invert_gamma(gam: &[f64], time: &[f64]) -> Vec<f64> {
98    let n = time.len();
99    let mut gam_inv: Vec<f64> = time.iter().map(|&t| linear_interp(gam, time, t)).collect();
100    gam_inv[0] = time[0];
101    gam_inv[n - 1] = time[n - 1];
102    gam_inv
103}
104
105/// Geodesic distance from a warping function to the identity on the Hilbert sphere.
106///
107/// Computes `arccos(⟨ψ/‖ψ‖, 1/‖1‖⟩_L2)` where `ψ = √γ'`.
108///
109/// # Arguments
110/// * `gamma` — Warping function values (length m)
111/// * `argvals` — Evaluation points (length m)
112///
113/// # Returns
114/// Geodesic distance (≥ 0). Returns 0 for the identity warp.
115pub fn phase_distance(gamma: &[f64], argvals: &[f64]) -> f64 {
116    let m = gamma.len();
117    if m < 2 {
118        return 0.0;
119    }
120
121    let t0 = argvals[0];
122    let t1 = argvals[m - 1];
123    let domain = t1 - t0;
124
125    // Work on [0,1] internally
126    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
127    let binsize = 1.0 / (m - 1) as f64;
128
129    // Convert gamma to [0,1] and compute psi
130    let gam_01: Vec<f64> = (0..m).map(|j| (gamma[j] - t0) / domain).collect();
131    let psi = gam_to_psi(&gam_01, binsize);
132
133    // Normalize psi to unit sphere
134    let psi_norm = l2_norm_l2(&psi, &time);
135    if psi_norm < 1e-10 {
136        return 0.0;
137    }
138    let psi_unit: Vec<f64> = psi.iter().map(|&p| p / psi_norm).collect();
139
140    // Identity warp psi = constant 1, normalized
141    let id_raw = vec![1.0; m];
142    let id_norm = l2_norm_l2(&id_raw, &time);
143    let id_unit: Vec<f64> = id_raw.iter().map(|&v| v / id_norm).collect();
144
145    // Geodesic distance = arccos(inner product)
146    let ip = inner_product_l2(&psi_unit, &id_unit, &time).clamp(-1.0, 1.0);
147    ip.acos()
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    fn uniform_grid(m: usize) -> Vec<f64> {
155        (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
156    }
157
158    #[test]
159    fn test_gam_psi_round_trip() {
160        let m = 101;
161        let time = uniform_grid(m);
162        let h = 1.0 / (m - 1) as f64;
163
164        // Start with identity warp
165        let gam = time.clone();
166        let psi = gam_to_psi(&gam, h);
167        let gam_recovered = psi_to_gam(&psi, &time);
168
169        for j in 0..m {
170            assert!(
171                (gam_recovered[j] - time[j]).abs() < 0.02,
172                "Round trip failed at j={j}: got {}, expected {}",
173                gam_recovered[j],
174                time[j]
175            );
176        }
177    }
178
179    #[test]
180    fn test_normalize_warp_properties() {
181        let t = uniform_grid(20);
182        let mut gamma = vec![0.1; 20];
183        normalize_warp(&mut gamma, &t);
184
185        assert_eq!(gamma[0], t[0]);
186        assert_eq!(gamma[19], t[19]);
187        for i in 1..20 {
188            assert!(gamma[i] >= gamma[i - 1]);
189        }
190    }
191
192    #[test]
193    fn test_invert_gamma_identity() {
194        let m = 50;
195        let time = uniform_grid(m);
196        let inv = invert_gamma(&time, &time);
197        for j in 0..m {
198            assert!(
199                (inv[j] - time[j]).abs() < 1e-12,
200                "Inverting identity should give identity at j={j}"
201            );
202        }
203    }
204
205    #[test]
206    fn test_sphere_round_trip() {
207        let m = 21;
208        let time = uniform_grid(m);
209
210        // Construct two unit vectors on the sphere
211        let raw1 = vec![1.0; m];
212        let norm1 = l2_norm_l2(&raw1, &time);
213        let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
214
215        let raw2: Vec<f64> = time
216            .iter()
217            .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
218            .collect();
219        let norm2 = l2_norm_l2(&raw2, &time);
220        let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
221
222        let v = inv_exp_map_sphere(&psi1, &psi2, &time);
223        let recovered = exp_map_sphere(&psi1, &v, &time);
224
225        let diff: Vec<f64> = psi2
226            .iter()
227            .zip(recovered.iter())
228            .map(|(&a, &b)| (a - b).powi(2))
229            .collect();
230        let l2_err = trapz(&diff, &time).max(0.0).sqrt();
231        assert!(
232            l2_err < 1e-12,
233            "Sphere round-trip error = {l2_err:.2e}, expected < 1e-12"
234        );
235    }
236
237    #[test]
238    fn test_phase_distance_identity_zero() {
239        let m = 101;
240        let t = uniform_grid(m);
241        let d = phase_distance(&t, &t);
242        assert!(
243            d < 1e-6,
244            "Phase distance of identity warp should be ~0, got {d}"
245        );
246    }
247
248    #[test]
249    fn test_phase_distance_nonidentity_positive() {
250        let m = 101;
251        let t = uniform_grid(m);
252        let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); // quadratic warp
253        let d = phase_distance(&gamma, &t);
254        assert!(
255            d > 0.01,
256            "Phase distance of non-identity warp should be > 0, got {d}"
257        );
258    }
259}