1use super::set::apply_stored_warps;
4use super::srsf::{reparameterize_curve, srsf_inverse, srsf_transform};
5use super::{dp_alignment_core, KarcherMeanResult};
6use crate::fdata::mean_1d;
7use crate::helpers::{gradient_uniform, linear_interp};
8use crate::iter_maybe_parallel;
9use crate::matrix::FdMatrix;
10use crate::warping::{
11 exp_map_sphere, gam_to_psi, inv_exp_map_sphere, invert_gamma, l2_norm_l2, psi_to_gam,
12};
13#[cfg(feature = "parallel")]
14use rayon::iter::ParallelIterator;
15
16use super::srsf::srsf_single;
18
19fn karcher_sphere_step(mu: &mut Vec<f64>, psis: &[Vec<f64>], time: &[f64], step_size: f64) -> bool {
25 let m = mu.len();
26 let n = psis.len();
27 let mut vbar = vec![0.0; m];
28 for psi in psis {
29 let v = inv_exp_map_sphere(mu, psi, time);
30 for j in 0..m {
31 vbar[j] += v[j];
32 }
33 }
34 for j in 0..m {
35 vbar[j] /= n as f64;
36 }
37 if l2_norm_l2(&vbar, time) <= 1e-8 {
38 return true;
39 }
40 let scaled: Vec<f64> = vbar.iter().map(|&v| v * step_size).collect();
41 *mu = exp_map_sphere(mu, &scaled, time);
42 false
43}
44
45pub(crate) fn sqrt_mean_inverse(gammas: &FdMatrix, argvals: &[f64]) -> Vec<f64> {
48 let (n, m) = gammas.shape();
49 let t0 = argvals[0];
50 let t1 = argvals[m - 1];
51 let domain = t1 - t0;
52
53 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
54 let binsize = 1.0 / (m - 1) as f64;
55
56 let psis: Vec<Vec<f64>> = (0..n)
57 .map(|i| {
58 let gam_01: Vec<f64> = (0..m).map(|j| (gammas[(i, j)] - t0) / domain).collect();
59 gam_to_psi(&gam_01, binsize)
60 })
61 .collect();
62
63 let mut mu = vec![0.0; m];
64 for psi in &psis {
65 for j in 0..m {
66 mu[j] += psi[j];
67 }
68 }
69 for j in 0..m {
70 mu[j] /= n as f64;
71 }
72
73 for _ in 0..501 {
74 if karcher_sphere_step(&mut mu, &psis, &time, 0.3) {
75 break;
76 }
77 }
78
79 let gam_mu = psi_to_gam(&mu, &time);
80 let gam_inv = invert_gamma(&gam_mu, &time);
81 gam_inv.iter().map(|&g| t0 + g * domain).collect()
82}
83
84fn relative_change(q_old: &[f64], q_new: &[f64]) -> f64 {
89 let diff_norm: f64 = q_old
90 .iter()
91 .zip(q_new.iter())
92 .map(|(&a, &b)| (a - b).powi(2))
93 .sum::<f64>()
94 .sqrt();
95 let old_norm: f64 = q_old.iter().map(|&v| v * v).sum::<f64>().sqrt().max(1e-10);
96 diff_norm / old_norm
97}
98
99pub(super) fn align_srsf_pair(
101 q1: &[f64],
102 q2: &[f64],
103 argvals: &[f64],
104 lambda: f64,
105) -> (Vec<f64>, Vec<f64>) {
106 let gamma = dp_alignment_core(q1, q2, argvals, lambda);
107
108 let q2_warped = reparameterize_curve(q2, argvals, &gamma);
110
111 let m = gamma.len();
113 let mut gamma_dot = vec![0.0; m];
114 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
115 for j in 1..(m - 1) {
116 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
117 }
118 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
119
120 let q2_aligned: Vec<f64> = q2_warped
122 .iter()
123 .zip(gamma_dot.iter())
124 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
125 .collect();
126
127 (gamma, q2_aligned)
128}
129
130fn accumulate_alignments(
132 results: &[(Vec<f64>, Vec<f64>)],
133 gammas: &mut FdMatrix,
134 m: usize,
135 n: usize,
136) -> Vec<f64> {
137 let mut mu_q_new = vec![0.0; m];
138 for (i, (gamma, q_aligned)) in results.iter().enumerate() {
139 for j in 0..m {
140 gammas[(i, j)] = gamma[j];
141 mu_q_new[j] += q_aligned[j];
142 }
143 }
144 for j in 0..m {
145 mu_q_new[j] /= n as f64;
146 }
147 mu_q_new
148}
149
150fn select_template(srsf_mat: &FdMatrix, data: &FdMatrix, argvals: &[f64]) -> (Vec<f64>, Vec<f64>) {
152 let (n, m) = srsf_mat.shape();
153 let mnq = mean_1d(srsf_mat);
154 let mut min_dist = f64::INFINITY;
155 let mut min_idx = 0;
156 for i in 0..n {
157 let dist_sq: f64 = (0..m).map(|j| (srsf_mat[(i, j)] - mnq[j]).powi(2)).sum();
158 if dist_sq < min_dist {
159 min_dist = dist_sq;
160 min_idx = i;
161 }
162 }
163 let _ = argvals; (srsf_mat.row(min_idx), data.row(min_idx))
165}
166
167fn pre_center_template(
169 data: &FdMatrix,
170 mu_q: &[f64],
171 mu: &[f64],
172 argvals: &[f64],
173 lambda: f64,
174) -> (Vec<f64>, Vec<f64>) {
175 let (n, m) = data.shape();
176 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
177 .map(|i| {
178 let fi = data.row(i);
179 let qi = srsf_single(&fi, argvals);
180 align_srsf_pair(mu_q, &qi, argvals, lambda)
181 })
182 .collect();
183
184 let mut init_gammas = FdMatrix::zeros(n, m);
185 for (i, (gamma, _)) in align_results.iter().enumerate() {
186 for j in 0..m {
187 init_gammas[(i, j)] = gamma[j];
188 }
189 }
190
191 let gam_inv = sqrt_mean_inverse(&init_gammas, argvals);
192 let mu_new = reparameterize_curve(mu, argvals, &gam_inv);
193 let mu_q_new = srsf_single(&mu_new, argvals);
194 (mu_q_new, mu_new)
195}
196
197fn post_center_results(
199 data: &FdMatrix,
200 mu_q: &[f64],
201 final_gammas: &mut FdMatrix,
202 argvals: &[f64],
203) -> (Vec<f64>, Vec<f64>, FdMatrix) {
204 let (n, m) = data.shape();
205 let gam_inv = sqrt_mean_inverse(final_gammas, argvals);
206 let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
207 let gam_inv_dev = gradient_uniform(&gam_inv, h);
208
209 let mu_q_warped = reparameterize_curve(mu_q, argvals, &gam_inv);
210 let mu_q_centered: Vec<f64> = mu_q_warped
211 .iter()
212 .zip(gam_inv_dev.iter())
213 .map(|(&q, &gd)| q * gd.max(0.0).sqrt())
214 .collect();
215
216 for i in 0..n {
217 let gam_i: Vec<f64> = (0..m).map(|j| final_gammas[(i, j)]).collect();
218 let gam_centered = reparameterize_curve(&gam_i, argvals, &gam_inv);
219 for j in 0..m {
220 final_gammas[(i, j)] = gam_centered[j];
221 }
222 }
223
224 let initial_mean = mean_1d(data);
225 let mu = srsf_inverse(&mu_q_centered, argvals, initial_mean[0]);
226 let final_aligned = apply_stored_warps(data, final_gammas, argvals);
227 (mu, mu_q_centered, final_aligned)
228}
229
230fn downsample_uniform(signal: &[f64], argvals: &[f64], factor: usize) -> (Vec<f64>, Vec<f64>) {
232 let m = signal.len();
233 if factor <= 1 || m <= 2 {
234 return (signal.to_vec(), argvals.to_vec());
235 }
236 let mut sig = Vec::new();
237 let mut arg = Vec::new();
238 for i in (0..m).step_by(factor) {
239 sig.push(signal[i]);
240 arg.push(argvals[i]);
241 }
242 if (m - 1) % factor != 0 {
244 sig.push(signal[m - 1]);
245 arg.push(argvals[m - 1]);
246 }
247 (sig, arg)
248}
249
250fn upsample_to_fine(coarse: &[f64], argvals_coarse: &[f64], argvals_fine: &[f64]) -> Vec<f64> {
252 argvals_fine
253 .iter()
254 .map(|&t| linear_interp(argvals_coarse, coarse, t))
255 .collect()
256}
257
258#[must_use = "expensive computation whose result should not be discarded"]
288pub fn karcher_mean(
289 data: &FdMatrix,
290 argvals: &[f64],
291 max_iter: usize,
292 tol: f64,
293 lambda: f64,
294) -> KarcherMeanResult {
295 let (n, m) = data.shape();
296
297 let srsf_mat = srsf_transform(data, argvals);
298 let (mut mu_q, mu) = select_template(&srsf_mat, data, argvals);
299 let (mu_q_c, mu_c) = pre_center_template(data, &mu_q, &mu, argvals, lambda);
300 mu_q = mu_q_c;
301 let mut mu = mu_c;
302
303 let mut converged = false;
304 let mut n_iter = 0;
305 let mut final_gammas = FdMatrix::zeros(n, m);
306
307 let coarse_factor = if m > 50 && max_iter >= 10 { 4 } else { 1 };
310 let coarse_iters = if coarse_factor > 1 { max_iter / 2 } else { 0 };
311 let fine_iters = max_iter - coarse_iters;
312
313 if coarse_iters > 0 {
315 let (mu_q_coarse, argvals_coarse) = downsample_uniform(&mu_q, argvals, coarse_factor);
316 let m_c = argvals_coarse.len();
317 let mut mu_q_c = mu_q_coarse;
318
319 let data_coarse: Vec<Vec<f64>> = (0..n)
321 .map(|i| {
322 let row = data.row(i);
323 downsample_uniform(&row, argvals, coarse_factor).0
324 })
325 .collect();
326
327 let mut coarse_gammas = FdMatrix::zeros(n, m_c);
328
329 for iter in 0..coarse_iters {
330 n_iter = iter + 1;
331
332 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
333 .map(|i| {
334 let qi = srsf_single(&data_coarse[i], &argvals_coarse);
335 align_srsf_pair(&mu_q_c, &qi, &argvals_coarse, lambda)
336 })
337 .collect();
338
339 let mu_q_new = accumulate_alignments(&align_results, &mut coarse_gammas, m_c, n);
340
341 let rel = relative_change(&mu_q_c, &mu_q_new);
342 if rel < tol {
343 converged = true;
344 mu_q_c = mu_q_new;
345 break;
346 }
347
348 mu_q_c = mu_q_new;
349 }
350
351 mu_q = upsample_to_fine(&mu_q_c, &argvals_coarse, argvals);
353 mu = srsf_inverse(&mu_q, argvals, mu[0]);
354 }
355
356 if fine_iters > 0 {
358 converged = false; }
360 let fine_start = n_iter;
361 for iter in 0..fine_iters {
362 n_iter = fine_start + iter + 1;
363
364 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
365 .map(|i| {
366 let fi = data.row(i);
367 let qi = srsf_single(&fi, argvals);
368 align_srsf_pair(&mu_q, &qi, argvals, lambda)
369 })
370 .collect();
371
372 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
373
374 let rel = relative_change(&mu_q, &mu_q_new);
375 if rel < tol {
376 converged = true;
377 mu_q = mu_q_new;
378 break;
379 }
380
381 mu_q = mu_q_new;
382 mu = srsf_inverse(&mu_q, argvals, mu[0]);
383 }
384
385 if converged && fine_start > 0 {
387 let align_results: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
388 .map(|i| {
389 let fi = data.row(i);
390 let qi = srsf_single(&fi, argvals);
391 align_srsf_pair(&mu_q, &qi, argvals, lambda)
392 })
393 .collect();
394 let mu_q_new = accumulate_alignments(&align_results, &mut final_gammas, m, n);
395 mu_q = mu_q_new;
396 }
397
398 let (mu_final, mu_q_final, final_aligned) =
399 post_center_results(data, &mu_q, &mut final_gammas, argvals);
400
401 KarcherMeanResult {
402 mean: mu_final,
403 mean_srsf: mu_q_final,
404 gammas: final_gammas,
405 aligned_data: final_aligned,
406 n_iter,
407 converged,
408 aligned_srsfs: None,
409 }
410}