1use super::dp_alignment_core;
4use super::srsf::{reparameterize_curve, srsf_single};
5use crate::error::FdarError;
6use crate::helpers::simpsons_weights;
7use crate::matrix::FdMatrix;
8use crate::warping::{
9 exp_map_sphere, gam_to_psi, inner_product_l2, inv_exp_map_sphere, l2_norm_l2, normalize_warp,
10 psi_to_gam,
11};
12
13use rand::prelude::*;
14use rand_distr::StandardNormal;
15
16#[derive(Debug, Clone, PartialEq)]
20pub struct BayesianAlignConfig {
21 pub n_samples: usize,
23 pub burn_in: usize,
25 pub step_size: f64,
27 pub proposal_variance: f64,
29 pub seed: u64,
31}
32
33impl Default for BayesianAlignConfig {
34 fn default() -> Self {
35 Self {
36 n_samples: 1000,
37 burn_in: 200,
38 step_size: 0.1,
39 proposal_variance: 1.0,
40 seed: 42,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
47#[non_exhaustive]
48pub struct BayesianAlignmentResult {
49 pub posterior_gammas: FdMatrix,
51 pub posterior_mean_gamma: Vec<f64>,
53 pub credible_lower: Vec<f64>,
55 pub credible_upper: Vec<f64>,
57 pub acceptance_rate: f64,
59 pub f_aligned_mean: Vec<f64>,
61}
62
63fn log_likelihood(q1: &[f64], q2: &[f64], argvals: &[f64], gamma: &[f64], weights: &[f64]) -> f64 {
70 let m = q1.len();
71 let q2_warped = reparameterize_curve(q2, argvals, gamma);
72
73 let mut gamma_dot = vec![0.0; m];
75 gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
76 for j in 1..(m - 1) {
77 gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
78 }
79 gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
80
81 let mut ll = 0.0;
82 for j in 0..m {
83 let q2g = q2_warped[j] * gamma_dot[j].max(0.0).sqrt();
84 let diff = q1[j] - q2g;
85 ll -= 0.5 * weights[j] * diff * diff;
86 }
87 ll
88}
89
90fn project_to_tangent(v: &[f64], psi_base: &[f64], time: &[f64]) -> Vec<f64> {
94 let ip = inner_product_l2(v, psi_base, time);
95 v.iter()
96 .zip(psi_base.iter())
97 .map(|(&vi, &pi)| vi - ip * pi)
98 .collect()
99}
100
101#[must_use = "expensive computation whose result should not be discarded"]
118pub fn bayesian_align_pair(
119 f1: &[f64],
120 f2: &[f64],
121 argvals: &[f64],
122 config: &BayesianAlignConfig,
123) -> Result<BayesianAlignmentResult, FdarError> {
124 let m = f1.len();
125
126 if m != f2.len() || m != argvals.len() {
128 return Err(FdarError::InvalidDimension {
129 parameter: "f1/f2/argvals",
130 expected: format!("all length {m}"),
131 actual: format!("f1={}, f2={}, argvals={}", m, f2.len(), argvals.len()),
132 });
133 }
134 if m < 2 {
135 return Err(FdarError::InvalidDimension {
136 parameter: "f1",
137 expected: "length >= 2".to_string(),
138 actual: format!("length {m}"),
139 });
140 }
141 if config.n_samples == 0 {
142 return Err(FdarError::InvalidParameter {
143 parameter: "n_samples",
144 message: "n_samples must be > 0".to_string(),
145 });
146 }
147 if config.step_size <= 0.0 || config.step_size >= 1.0 {
148 return Err(FdarError::InvalidParameter {
149 parameter: "step_size",
150 message: format!("step_size must be in (0, 1), got {}", config.step_size),
151 });
152 }
153
154 let t0 = argvals[0];
155 let t1 = argvals[m - 1];
156 let domain = t1 - t0;
157 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
158 let binsize = 1.0 / (m - 1) as f64;
159
160 let q1 = srsf_single(f1, argvals);
162 let q2 = srsf_single(f2, argvals);
163
164 let weights = simpsons_weights(argvals);
166
167 let psi_id: Vec<f64> = {
169 let raw = vec![1.0; m];
170 let norm = l2_norm_l2(&raw, &time);
171 raw.iter().map(|&v| v / norm).collect()
172 };
173
174 let gamma_dp = dp_alignment_core(&q1, &q2, argvals, 0.0);
176 let gam_01: Vec<f64> = gamma_dp.iter().map(|&g| (g - t0) / domain).collect();
177 let mut psi_curr = gam_to_psi(&gam_01, binsize);
178 let psi_norm = l2_norm_l2(&psi_curr, &time);
179 if psi_norm > 1e-10 {
180 for v in &mut psi_curr {
181 *v /= psi_norm;
182 }
183 }
184
185 let mut v_curr = inv_exp_map_sphere(&psi_id, &psi_curr, &time);
187 let mut ll_curr = log_likelihood(&q1, &q2, argvals, &gamma_dp, &weights);
188
189 let beta = config.step_size;
190 let sqrt_1_beta2 = (1.0 - beta * beta).sqrt();
191 let total_iter = config.n_samples + config.burn_in;
192
193 let mut rng = StdRng::seed_from_u64(config.seed);
194 let mut stored_gammas: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
195 let mut n_accepted = 0usize;
196
197 for iter in 0..total_iter {
198 let xi_raw: Vec<f64> = (0..m)
200 .map(|_| rng.sample::<f64, _>(StandardNormal))
201 .collect();
202 let xi_tangent = project_to_tangent(&xi_raw, &psi_id, &time);
203 let xi_scaled: Vec<f64> = xi_tangent
204 .iter()
205 .map(|&v| v * config.proposal_variance.sqrt())
206 .collect();
207
208 let v_prop: Vec<f64> = v_curr
210 .iter()
211 .zip(xi_scaled.iter())
212 .map(|(&vc, &xi)| sqrt_1_beta2 * vc + beta * xi)
213 .collect();
214
215 let psi_prop = exp_map_sphere(&psi_id, &v_prop, &time);
217
218 let gam_prop_01 = psi_to_gam(&psi_prop, &time);
220 let mut gamma_prop: Vec<f64> = gam_prop_01.iter().map(|&g| t0 + g * domain).collect();
221 normalize_warp(&mut gamma_prop, argvals);
222
223 let ll_prop = log_likelihood(&q1, &q2, argvals, &gamma_prop, &weights);
225
226 let log_alpha = ll_prop - ll_curr;
228 let u: f64 = rng.gen();
229 if u.ln() < log_alpha {
230 psi_curr = psi_prop;
231 v_curr = v_prop;
232 ll_curr = ll_prop;
233 n_accepted += 1;
234
235 if iter >= config.burn_in {
236 stored_gammas.push(gamma_prop);
237 }
238 } else if iter >= config.burn_in {
239 let gam_curr_01 = psi_to_gam(&psi_curr, &time);
241 let mut gamma_curr: Vec<f64> = gam_curr_01.iter().map(|&g| t0 + g * domain).collect();
242 normalize_warp(&mut gamma_curr, argvals);
243 stored_gammas.push(gamma_curr);
244 }
245 }
246
247 let n_stored = stored_gammas.len();
248 let acceptance_rate = n_accepted as f64 / total_iter as f64;
249
250 let mut posterior_gammas = FdMatrix::zeros(n_stored, m);
252 for (i, gam) in stored_gammas.iter().enumerate() {
253 for j in 0..m {
254 posterior_gammas[(i, j)] = gam[j];
255 }
256 }
257
258 let mut posterior_mean_gamma = vec![0.0; m];
260 for j in 0..m {
261 for i in 0..n_stored {
262 posterior_mean_gamma[j] += posterior_gammas[(i, j)];
263 }
264 posterior_mean_gamma[j] /= n_stored as f64;
265 }
266 normalize_warp(&mut posterior_mean_gamma, argvals);
267
268 let mut credible_lower = vec![0.0; m];
270 let mut credible_upper = vec![0.0; m];
271 for j in 0..m {
272 let mut col: Vec<f64> = (0..n_stored).map(|i| posterior_gammas[(i, j)]).collect();
273 col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
274 let idx_lo = ((0.025 * n_stored as f64).floor() as usize).min(n_stored.saturating_sub(1));
275 let idx_hi = ((0.975 * n_stored as f64).ceil() as usize).min(n_stored.saturating_sub(1));
276 credible_lower[j] = col[idx_lo];
277 credible_upper[j] = col[idx_hi];
278 }
279
280 let f_aligned_mean = reparameterize_curve(f2, argvals, &posterior_mean_gamma);
282
283 Ok(BayesianAlignmentResult {
284 posterior_gammas,
285 posterior_mean_gamma,
286 credible_lower,
287 credible_upper,
288 acceptance_rate,
289 f_aligned_mean,
290 })
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::f64::consts::PI;
297
298 fn uniform_grid(n: usize) -> Vec<f64> {
299 (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
300 }
301
302 #[test]
303 fn bayesian_align_identical_curves() {
304 let m = 51;
305 let t = uniform_grid(m);
306 let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
307 let f2 = f1.clone();
308
309 let config = BayesianAlignConfig {
310 n_samples: 200,
311 burn_in: 50,
312 step_size: 0.1,
313 proposal_variance: 0.5,
314 seed: 42,
315 };
316 let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
317
318 for j in 0..m {
320 assert!(
321 (result.posterior_mean_gamma[j] - t[j]).abs() < 0.15,
322 "posterior mean gamma at j={j} deviates too much from identity: {} vs {}",
323 result.posterior_mean_gamma[j],
324 t[j]
325 );
326 }
327
328 assert!(
330 result.acceptance_rate > 0.05,
331 "acceptance rate too low: {}",
332 result.acceptance_rate
333 );
334 }
335
336 #[test]
337 fn bayesian_align_credible_bands_order() {
338 let m = 51;
339 let t = uniform_grid(m);
340 let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
341 let f2: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * (ti + 0.05)).sin()).collect();
342
343 let config = BayesianAlignConfig {
344 n_samples: 200,
345 burn_in: 50,
346 step_size: 0.15,
347 proposal_variance: 0.5,
348 seed: 7,
349 };
350 let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
351
352 for j in 0..m {
353 assert!(
354 result.credible_lower[j] <= result.posterior_mean_gamma[j] + 1e-10,
355 "lower > mean at j={j}: {} > {}",
356 result.credible_lower[j],
357 result.posterior_mean_gamma[j]
358 );
359 assert!(
360 result.posterior_mean_gamma[j] <= result.credible_upper[j] + 1e-10,
361 "mean > upper at j={j}: {} > {}",
362 result.posterior_mean_gamma[j],
363 result.credible_upper[j]
364 );
365 }
366 }
367
368 #[test]
369 fn bayesian_align_shifted_sine() {
370 let m = 51;
371 let t = uniform_grid(m);
372 let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
373 let shift = 0.1;
374 let f2: Vec<f64> = t
375 .iter()
376 .map(|&ti| (2.0 * PI * (ti + shift)).sin())
377 .collect();
378
379 let config = BayesianAlignConfig {
380 n_samples: 300,
381 burn_in: 100,
382 step_size: 0.15,
383 proposal_variance: 1.0,
384 seed: 99,
385 };
386 let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
387
388 let error_original: f64 = f1
390 .iter()
391 .zip(f2.iter())
392 .map(|(&a, &b)| (a - b).powi(2))
393 .sum::<f64>();
394 let error_aligned: f64 = f1
395 .iter()
396 .zip(result.f_aligned_mean.iter())
397 .map(|(&a, &b)| (a - b).powi(2))
398 .sum::<f64>();
399
400 assert!(
401 error_aligned < error_original + 1e-6,
402 "aligned error ({error_aligned:.4}) should be <= original ({error_original:.4})"
403 );
404 }
405
406 #[test]
407 fn bayesian_align_rejects_bad_config() {
408 let m = 21;
409 let t = uniform_grid(m);
410 let f1: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
411 let f2 = f1.clone();
412
413 let config = BayesianAlignConfig {
415 n_samples: 0,
416 ..BayesianAlignConfig::default()
417 };
418 assert!(
419 bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
420 "should reject n_samples=0"
421 );
422
423 let config = BayesianAlignConfig {
425 step_size: 0.0,
426 ..BayesianAlignConfig::default()
427 };
428 assert!(
429 bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
430 "should reject step_size=0"
431 );
432
433 let config = BayesianAlignConfig {
435 step_size: 1.0,
436 ..BayesianAlignConfig::default()
437 };
438 assert!(
439 bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
440 "should reject step_size=1"
441 );
442 }
443}