1use crate::helpers::{cumulative_trapz, gradient_uniform, linear_interp, trapz};
13use crate::smoothing::nadaraya_watson;
14
15pub fn normalize_warp(gamma: &mut [f64], argvals: &[f64]) {
17 let n = gamma.len();
18 if n == 0 {
19 return;
20 }
21
22 gamma[0] = argvals[0];
24 gamma[n - 1] = argvals[n - 1];
25
26 for i in 1..n {
28 if gamma[i] < gamma[i - 1] {
29 gamma[i] = gamma[i - 1];
30 }
31 }
32}
33
34pub fn gam_to_psi(gam: &[f64], h: f64) -> Vec<f64> {
36 gradient_uniform(gam, h)
37 .iter()
38 .map(|&g| g.max(0.0).sqrt())
39 .collect()
40}
41
42pub fn gam_to_psi_smooth(gam: &[f64], h: f64) -> Vec<f64> {
52 let m = gam.len();
53 if m < 3 {
54 return gam_to_psi(gam, h);
55 }
56
57 let time: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
58
59 let bandwidth = 2.0 * h;
62 let gam_smooth = nadaraya_watson(&time, gam, &time, bandwidth, "gaussian");
63
64 gradient_uniform(&gam_smooth, h)
65 .iter()
66 .map(|&g| g.max(0.0).sqrt())
67 .collect()
68}
69
70pub fn psi_to_gam(psi: &[f64], time: &[f64]) -> Vec<f64> {
72 let psi_sq: Vec<f64> = psi.iter().map(|&p| p * p).collect();
73 let gam = cumulative_trapz(&psi_sq, time);
74 let min_val = gam.iter().cloned().fold(f64::INFINITY, f64::min);
75 let max_val = gam.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
76 let range = (max_val - min_val).max(1e-10);
77 gam.iter().map(|&v| (v - min_val) / range).collect()
78}
79
80pub fn inner_product_l2(psi1: &[f64], psi2: &[f64], time: &[f64]) -> f64 {
82 let prod: Vec<f64> = psi1.iter().zip(psi2.iter()).map(|(&a, &b)| a * b).collect();
83 trapz(&prod, time)
84}
85
86pub fn l2_norm_l2(psi: &[f64], time: &[f64]) -> f64 {
88 inner_product_l2(psi, psi, time).max(0.0).sqrt()
89}
90
91pub fn inv_exp_map_sphere(mu: &[f64], psi: &[f64], time: &[f64]) -> Vec<f64> {
94 let ip = inner_product_l2(mu, psi, time).clamp(-1.0, 1.0);
95 let theta = ip.acos();
96 if theta < 1e-10 {
97 vec![0.0; mu.len()]
98 } else {
99 let coeff = theta / theta.sin();
100 let cos_theta = theta.cos();
101 mu.iter()
102 .zip(psi.iter())
103 .map(|(&m, &p)| coeff * (p - cos_theta * m))
104 .collect()
105 }
106}
107
108pub fn exp_map_sphere(psi: &[f64], v: &[f64], time: &[f64]) -> Vec<f64> {
111 let v_norm = l2_norm_l2(v, time);
112 if v_norm < 1e-10 {
113 psi.to_vec()
114 } else {
115 let cos_n = v_norm.cos();
116 let sin_n = v_norm.sin();
117 psi.iter()
118 .zip(v.iter())
119 .map(|(&p, &vi)| cos_n * p + sin_n * vi / v_norm)
120 .collect()
121 }
122}
123
124pub fn invert_gamma(gam: &[f64], time: &[f64]) -> Vec<f64> {
127 let n = time.len();
128 let mut gam_inv: Vec<f64> = time.iter().map(|&t| linear_interp(gam, time, t)).collect();
129 gam_inv[0] = time[0];
130 gam_inv[n - 1] = time[n - 1];
131 gam_inv
132}
133
134pub fn phase_distance(gamma: &[f64], argvals: &[f64]) -> f64 {
145 let m = gamma.len();
146 if m < 2 {
147 return 0.0;
148 }
149
150 let t0 = argvals[0];
151 let t1 = argvals[m - 1];
152 let domain = t1 - t0;
153
154 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
156 let binsize = 1.0 / (m - 1) as f64;
157
158 let gam_01: Vec<f64> = (0..m).map(|j| (gamma[j] - t0) / domain).collect();
160 let psi = gam_to_psi(&gam_01, binsize);
161
162 let psi_norm = l2_norm_l2(&psi, &time);
164 if psi_norm < 1e-10 {
165 return 0.0;
166 }
167 let psi_unit: Vec<f64> = psi.iter().map(|&p| p / psi_norm).collect();
168
169 let id_raw = vec![1.0; m];
171 let id_norm = l2_norm_l2(&id_raw, &time);
172 let id_unit: Vec<f64> = id_raw.iter().map(|&v| v / id_norm).collect();
173
174 let ip = inner_product_l2(&psi_unit, &id_unit, &time).clamp(-1.0, 1.0);
176 ip.acos()
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 fn uniform_grid(m: usize) -> Vec<f64> {
184 (0..m).map(|i| i as f64 / (m - 1) as f64).collect()
185 }
186
187 #[test]
188 fn test_gam_psi_round_trip() {
189 let m = 101;
190 let time = uniform_grid(m);
191 let h = 1.0 / (m - 1) as f64;
192
193 let gam = time.clone();
195 let psi = gam_to_psi(&gam, h);
196 let gam_recovered = psi_to_gam(&psi, &time);
197
198 for j in 0..m {
199 assert!(
200 (gam_recovered[j] - time[j]).abs() < 0.02,
201 "Round trip failed at j={j}: got {}, expected {}",
202 gam_recovered[j],
203 time[j]
204 );
205 }
206 }
207
208 #[test]
209 fn test_normalize_warp_properties() {
210 let t = uniform_grid(20);
211 let mut gamma = vec![0.1; 20];
212 normalize_warp(&mut gamma, &t);
213
214 assert_eq!(gamma[0], t[0]);
215 assert_eq!(gamma[19], t[19]);
216 for i in 1..20 {
217 assert!(gamma[i] >= gamma[i - 1]);
218 }
219 }
220
221 #[test]
222 fn test_invert_gamma_identity() {
223 let m = 50;
224 let time = uniform_grid(m);
225 let inv = invert_gamma(&time, &time);
226 for j in 0..m {
227 assert!(
228 (inv[j] - time[j]).abs() < 1e-12,
229 "Inverting identity should give identity at j={j}"
230 );
231 }
232 }
233
234 #[test]
235 fn test_sphere_round_trip() {
236 let m = 21;
237 let time = uniform_grid(m);
238
239 let raw1 = vec![1.0; m];
241 let norm1 = l2_norm_l2(&raw1, &time);
242 let psi1: Vec<f64> = raw1.iter().map(|&v| v / norm1).collect();
243
244 let raw2: Vec<f64> = time
245 .iter()
246 .map(|&t| 1.0 + 0.3 * (2.0 * std::f64::consts::PI * t).sin())
247 .collect();
248 let norm2 = l2_norm_l2(&raw2, &time);
249 let psi2: Vec<f64> = raw2.iter().map(|&v| v / norm2).collect();
250
251 let v = inv_exp_map_sphere(&psi1, &psi2, &time);
252 let recovered = exp_map_sphere(&psi1, &v, &time);
253
254 let diff: Vec<f64> = psi2
255 .iter()
256 .zip(recovered.iter())
257 .map(|(&a, &b)| (a - b).powi(2))
258 .collect();
259 let l2_err = trapz(&diff, &time).max(0.0).sqrt();
260 assert!(
261 l2_err < 1e-12,
262 "Sphere round-trip error = {l2_err:.2e}, expected < 1e-12"
263 );
264 }
265
266 #[test]
267 fn test_phase_distance_identity_zero() {
268 let m = 101;
269 let t = uniform_grid(m);
270 let d = phase_distance(&t, &t);
271 assert!(
272 d < 1e-6,
273 "Phase distance of identity warp should be ~0, got {d}"
274 );
275 }
276
277 #[test]
278 fn test_phase_distance_nonidentity_positive() {
279 let m = 101;
280 let t = uniform_grid(m);
281 let gamma: Vec<f64> = t.iter().map(|&ti| ti * ti).collect(); let d = phase_distance(&gamma, &t);
283 assert!(
284 d > 0.01,
285 "Phase distance of non-identity warp should be > 0, got {d}"
286 );
287 }
288}