Skip to main content

fdars_core/alignment/
bayesian.rs

1//! Bayesian pairwise alignment via pCN MCMC on the Hilbert sphere.
2
3use super::dp_alignment_core;
4use super::srsf::{reparameterize_curve, srsf_single};
5use crate::error::FdarError;
6use crate::helpers::simpsons_weights;
7use crate::matrix::FdMatrix;
8use crate::warping::{
9    exp_map_sphere, gam_to_psi, inner_product_l2, inv_exp_map_sphere, l2_norm_l2, normalize_warp,
10    psi_to_gam,
11};
12
13use rand::prelude::*;
14use rand_distr::StandardNormal;
15
16// ─── Config / Result ─────────────────────────────────────────────────────────
17
18/// Configuration for Bayesian pairwise alignment.
19#[derive(Debug, Clone, PartialEq)]
20pub struct BayesianAlignConfig {
21    /// Number of posterior samples to retain (after burn-in).
22    pub n_samples: usize,
23    /// Number of burn-in iterations to discard.
24    pub burn_in: usize,
25    /// pCN step size beta in (0, 1).
26    pub step_size: f64,
27    /// Variance scaling for random tangent-vector proposals.
28    pub proposal_variance: f64,
29    /// RNG seed for reproducibility.
30    pub seed: u64,
31}
32
33impl Default for BayesianAlignConfig {
34    fn default() -> Self {
35        Self {
36            n_samples: 1000,
37            burn_in: 200,
38            step_size: 0.1,
39            proposal_variance: 1.0,
40            seed: 42,
41        }
42    }
43}
44
45/// Result of Bayesian pairwise alignment.
46#[derive(Debug, Clone, PartialEq)]
47#[non_exhaustive]
48pub struct BayesianAlignmentResult {
49    /// Posterior warping function samples (n_samples x m), after burn-in.
50    pub posterior_gammas: FdMatrix,
51    /// Pointwise posterior mean warping function (length m).
52    pub posterior_mean_gamma: Vec<f64>,
53    /// Pointwise 2.5% credible band (length m).
54    pub credible_lower: Vec<f64>,
55    /// Pointwise 97.5% credible band (length m).
56    pub credible_upper: Vec<f64>,
57    /// MCMC acceptance rate.
58    pub acceptance_rate: f64,
59    /// f2 aligned to f1 using the posterior mean warping function.
60    pub f_aligned_mean: Vec<f64>,
61}
62
63// ─── Bayesian Alignment ─────────────────────────────────────────────────────
64
65/// Compute the SRSF-based log-likelihood for a warping function.
66///
67/// `log_lik = -0.5 * sum_j(w[j] * (q1[j] - q2_gamma[j])^2)`
68/// where q2_gamma is the SRSF of f2 composed with gamma (with sqrt(gamma') factor).
69fn log_likelihood(q1: &[f64], q2: &[f64], argvals: &[f64], gamma: &[f64], weights: &[f64]) -> f64 {
70    let m = q1.len();
71    let q2_warped = reparameterize_curve(q2, argvals, gamma);
72
73    // Compute gamma' via finite differences
74    let mut gamma_dot = vec![0.0; m];
75    gamma_dot[0] = (gamma[1] - gamma[0]) / (argvals[1] - argvals[0]);
76    for j in 1..(m - 1) {
77        gamma_dot[j] = (gamma[j + 1] - gamma[j - 1]) / (argvals[j + 1] - argvals[j - 1]);
78    }
79    gamma_dot[m - 1] = (gamma[m - 1] - gamma[m - 2]) / (argvals[m - 1] - argvals[m - 2]);
80
81    let mut ll = 0.0;
82    for j in 0..m {
83        let q2g = q2_warped[j] * gamma_dot[j].max(0.0).sqrt();
84        let diff = q1[j] - q2g;
85        ll -= 0.5 * weights[j] * diff * diff;
86    }
87    ll
88}
89
90/// Project a vector onto the tangent plane at a point on the sphere.
91///
92/// Removes the component along `psi_base`: `v - <v, psi_base> * psi_base`
93fn project_to_tangent(v: &[f64], psi_base: &[f64], time: &[f64]) -> Vec<f64> {
94    let ip = inner_product_l2(v, psi_base, time);
95    v.iter()
96        .zip(psi_base.iter())
97        .map(|(&vi, &pi)| vi - ip * pi)
98        .collect()
99}
100
101/// Perform Bayesian pairwise alignment of f2 to f1 via pCN MCMC on the
102/// Hilbert sphere.
103///
104/// Uses a preconditioned Crank-Nicolson (pCN) proposal in the tangent space
105/// of the identity warping function on the Hilbert sphere. The DP-optimal
106/// alignment serves as initialization.
107///
108/// # Arguments
109/// * `f1` — Target curve (length m)
110/// * `f2` — Curve to align (length m)
111/// * `argvals` — Evaluation points (length m)
112/// * `config` — MCMC configuration
113///
114/// # Errors
115/// Returns `FdarError::InvalidDimension` if lengths don't match or m < 2.
116/// Returns `FdarError::InvalidParameter` if config values are out of range.
117#[must_use = "expensive computation whose result should not be discarded"]
118pub fn bayesian_align_pair(
119    f1: &[f64],
120    f2: &[f64],
121    argvals: &[f64],
122    config: &BayesianAlignConfig,
123) -> Result<BayesianAlignmentResult, FdarError> {
124    let m = f1.len();
125
126    // ── Validation ──────────────────────────────────────────────────────
127    if m != f2.len() || m != argvals.len() {
128        return Err(FdarError::InvalidDimension {
129            parameter: "f1/f2/argvals",
130            expected: format!("all length {m}"),
131            actual: format!("f1={}, f2={}, argvals={}", m, f2.len(), argvals.len()),
132        });
133    }
134    if m < 2 {
135        return Err(FdarError::InvalidDimension {
136            parameter: "f1",
137            expected: "length >= 2".to_string(),
138            actual: format!("length {m}"),
139        });
140    }
141    if config.n_samples == 0 {
142        return Err(FdarError::InvalidParameter {
143            parameter: "n_samples",
144            message: "n_samples must be > 0".to_string(),
145        });
146    }
147    if config.step_size <= 0.0 || config.step_size >= 1.0 {
148        return Err(FdarError::InvalidParameter {
149            parameter: "step_size",
150            message: format!("step_size must be in (0, 1), got {}", config.step_size),
151        });
152    }
153
154    let t0 = argvals[0];
155    let t1 = argvals[m - 1];
156    let domain = t1 - t0;
157    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
158    let binsize = 1.0 / (m - 1) as f64;
159
160    // Compute SRSFs
161    let q1 = srsf_single(f1, argvals);
162    let q2 = srsf_single(f2, argvals);
163
164    // Simpson's weights for log-likelihood
165    let weights = simpsons_weights(argvals);
166
167    // Identity warp psi on sphere: constant 1, normalized
168    let psi_id: Vec<f64> = {
169        let raw = vec![1.0; m];
170        let norm = l2_norm_l2(&raw, &time);
171        raw.iter().map(|&v| v / norm).collect()
172    };
173
174    // DP initialization
175    let gamma_dp = dp_alignment_core(&q1, &q2, argvals, 0.0);
176    let gam_01: Vec<f64> = gamma_dp.iter().map(|&g| (g - t0) / domain).collect();
177    let mut psi_curr = gam_to_psi(&gam_01, binsize);
178    let psi_norm = l2_norm_l2(&psi_curr, &time);
179    if psi_norm > 1e-10 {
180        for v in &mut psi_curr {
181            *v /= psi_norm;
182        }
183    }
184
185    // Current tangent vector and log-likelihood
186    let mut v_curr = inv_exp_map_sphere(&psi_id, &psi_curr, &time);
187    let mut ll_curr = log_likelihood(&q1, &q2, argvals, &gamma_dp, &weights);
188
189    let beta = config.step_size;
190    let sqrt_1_beta2 = (1.0 - beta * beta).sqrt();
191    let total_iter = config.n_samples + config.burn_in;
192
193    let mut rng = StdRng::seed_from_u64(config.seed);
194    let mut stored_gammas: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
195    let mut n_accepted = 0usize;
196
197    for iter in 0..total_iter {
198        // Generate random tangent vector at identity
199        let xi_raw: Vec<f64> = (0..m)
200            .map(|_| rng.sample::<f64, _>(StandardNormal))
201            .collect();
202        let xi_tangent = project_to_tangent(&xi_raw, &psi_id, &time);
203        let xi_scaled: Vec<f64> = xi_tangent
204            .iter()
205            .map(|&v| v * config.proposal_variance.sqrt())
206            .collect();
207
208        // pCN proposal: v_prop = sqrt(1 - beta^2) * v_curr + beta * xi
209        let v_prop: Vec<f64> = v_curr
210            .iter()
211            .zip(xi_scaled.iter())
212            .map(|(&vc, &xi)| sqrt_1_beta2 * vc + beta * xi)
213            .collect();
214
215        // Map to sphere
216        let psi_prop = exp_map_sphere(&psi_id, &v_prop, &time);
217
218        // Convert to gamma
219        let gam_prop_01 = psi_to_gam(&psi_prop, &time);
220        let mut gamma_prop: Vec<f64> = gam_prop_01.iter().map(|&g| t0 + g * domain).collect();
221        normalize_warp(&mut gamma_prop, argvals);
222
223        // Log-likelihood of proposal
224        let ll_prop = log_likelihood(&q1, &q2, argvals, &gamma_prop, &weights);
225
226        // Accept/reject
227        let log_alpha = ll_prop - ll_curr;
228        let u: f64 = rng.gen();
229        if u.ln() < log_alpha {
230            psi_curr = psi_prop;
231            v_curr = v_prop;
232            ll_curr = ll_prop;
233            n_accepted += 1;
234
235            if iter >= config.burn_in {
236                stored_gammas.push(gamma_prop);
237            }
238        } else if iter >= config.burn_in {
239            // Store current (rejected proposal keeps previous)
240            let gam_curr_01 = psi_to_gam(&psi_curr, &time);
241            let mut gamma_curr: Vec<f64> = gam_curr_01.iter().map(|&g| t0 + g * domain).collect();
242            normalize_warp(&mut gamma_curr, argvals);
243            stored_gammas.push(gamma_curr);
244        }
245    }
246
247    let n_stored = stored_gammas.len();
248    let acceptance_rate = n_accepted as f64 / total_iter as f64;
249
250    // Build posterior gamma matrix
251    let mut posterior_gammas = FdMatrix::zeros(n_stored, m);
252    for (i, gam) in stored_gammas.iter().enumerate() {
253        for j in 0..m {
254            posterior_gammas[(i, j)] = gam[j];
255        }
256    }
257
258    // Pointwise posterior mean
259    let mut posterior_mean_gamma = vec![0.0; m];
260    for j in 0..m {
261        for i in 0..n_stored {
262            posterior_mean_gamma[j] += posterior_gammas[(i, j)];
263        }
264        posterior_mean_gamma[j] /= n_stored as f64;
265    }
266    normalize_warp(&mut posterior_mean_gamma, argvals);
267
268    // Pointwise credible bands (2.5% and 97.5% quantiles)
269    let mut credible_lower = vec![0.0; m];
270    let mut credible_upper = vec![0.0; m];
271    for j in 0..m {
272        let mut col: Vec<f64> = (0..n_stored).map(|i| posterior_gammas[(i, j)]).collect();
273        col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
274        let idx_lo = ((0.025 * n_stored as f64).floor() as usize).min(n_stored.saturating_sub(1));
275        let idx_hi = ((0.975 * n_stored as f64).ceil() as usize).min(n_stored.saturating_sub(1));
276        credible_lower[j] = col[idx_lo];
277        credible_upper[j] = col[idx_hi];
278    }
279
280    // Align f2 using posterior mean gamma
281    let f_aligned_mean = reparameterize_curve(f2, argvals, &posterior_mean_gamma);
282
283    Ok(BayesianAlignmentResult {
284        posterior_gammas,
285        posterior_mean_gamma,
286        credible_lower,
287        credible_upper,
288        acceptance_rate,
289        f_aligned_mean,
290    })
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::f64::consts::PI;
297
298    fn uniform_grid(n: usize) -> Vec<f64> {
299        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
300    }
301
302    #[test]
303    fn bayesian_align_identical_curves() {
304        let m = 51;
305        let t = uniform_grid(m);
306        let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
307        let f2 = f1.clone();
308
309        let config = BayesianAlignConfig {
310            n_samples: 200,
311            burn_in: 50,
312            step_size: 0.1,
313            proposal_variance: 0.5,
314            seed: 42,
315        };
316        let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
317
318        // Posterior mean gamma should be close to identity
319        for j in 0..m {
320            assert!(
321                (result.posterior_mean_gamma[j] - t[j]).abs() < 0.15,
322                "posterior mean gamma at j={j} deviates too much from identity: {} vs {}",
323                result.posterior_mean_gamma[j],
324                t[j]
325            );
326        }
327
328        // Acceptance rate should be reasonable
329        assert!(
330            result.acceptance_rate > 0.05,
331            "acceptance rate too low: {}",
332            result.acceptance_rate
333        );
334    }
335
336    #[test]
337    fn bayesian_align_credible_bands_order() {
338        let m = 51;
339        let t = uniform_grid(m);
340        let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
341        let f2: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * (ti + 0.05)).sin()).collect();
342
343        let config = BayesianAlignConfig {
344            n_samples: 200,
345            burn_in: 50,
346            step_size: 0.15,
347            proposal_variance: 0.5,
348            seed: 7,
349        };
350        let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
351
352        for j in 0..m {
353            assert!(
354                result.credible_lower[j] <= result.posterior_mean_gamma[j] + 1e-10,
355                "lower > mean at j={j}: {} > {}",
356                result.credible_lower[j],
357                result.posterior_mean_gamma[j]
358            );
359            assert!(
360                result.posterior_mean_gamma[j] <= result.credible_upper[j] + 1e-10,
361                "mean > upper at j={j}: {} > {}",
362                result.posterior_mean_gamma[j],
363                result.credible_upper[j]
364            );
365        }
366    }
367
368    #[test]
369    fn bayesian_align_shifted_sine() {
370        let m = 51;
371        let t = uniform_grid(m);
372        let f1: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * ti).sin()).collect();
373        let shift = 0.1;
374        let f2: Vec<f64> = t
375            .iter()
376            .map(|&ti| (2.0 * PI * (ti + shift)).sin())
377            .collect();
378
379        let config = BayesianAlignConfig {
380            n_samples: 300,
381            burn_in: 100,
382            step_size: 0.15,
383            proposal_variance: 1.0,
384            seed: 99,
385        };
386        let result = bayesian_align_pair(&f1, &f2, &t, &config).unwrap();
387
388        // The aligned curve should be closer to f1 than the original f2
389        let error_original: f64 = f1
390            .iter()
391            .zip(f2.iter())
392            .map(|(&a, &b)| (a - b).powi(2))
393            .sum::<f64>();
394        let error_aligned: f64 = f1
395            .iter()
396            .zip(result.f_aligned_mean.iter())
397            .map(|(&a, &b)| (a - b).powi(2))
398            .sum::<f64>();
399
400        assert!(
401            error_aligned < error_original + 1e-6,
402            "aligned error ({error_aligned:.4}) should be <= original ({error_original:.4})"
403        );
404    }
405
406    #[test]
407    fn bayesian_align_rejects_bad_config() {
408        let m = 21;
409        let t = uniform_grid(m);
410        let f1: Vec<f64> = t.iter().map(|&ti| ti * ti).collect();
411        let f2 = f1.clone();
412
413        // n_samples = 0
414        let config = BayesianAlignConfig {
415            n_samples: 0,
416            ..BayesianAlignConfig::default()
417        };
418        assert!(
419            bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
420            "should reject n_samples=0"
421        );
422
423        // step_size = 0
424        let config = BayesianAlignConfig {
425            step_size: 0.0,
426            ..BayesianAlignConfig::default()
427        };
428        assert!(
429            bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
430            "should reject step_size=0"
431        );
432
433        // step_size = 1
434        let config = BayesianAlignConfig {
435            step_size: 1.0,
436            ..BayesianAlignConfig::default()
437        };
438        assert!(
439            bayesian_align_pair(&f1, &f2, &t, &config).is_err(),
440            "should reject step_size=1"
441        );
442    }
443}