Skip to main content

fdars_core/alignment/
shape_ci.rs

1//! Bootstrap confidence intervals for curve shapes in the elastic metric.
2
3use rand::rngs::StdRng;
4use rand::Rng;
5use rand::SeedableRng;
6
7use super::karcher::karcher_mean;
8use super::pairwise::elastic_align_pair;
9use crate::error::FdarError;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12#[cfg(feature = "parallel")]
13use rayon::iter::ParallelIterator;
14
15// ─── Types ──────────────────────────────────────────────────────────────────
16
17/// Configuration for shape bootstrap confidence intervals.
18#[derive(Debug, Clone, PartialEq)]
19pub struct ShapeCiConfig {
20    /// Number of bootstrap resamples.
21    pub n_bootstrap: usize,
22    /// Confidence level (e.g., 0.95 for 95% CI).
23    pub confidence_level: f64,
24    /// Roughness penalty for elastic alignment.
25    pub lambda: f64,
26    /// Maximum Karcher mean iterations.
27    pub max_iter: usize,
28    /// Convergence tolerance for the Karcher mean.
29    pub tol: f64,
30    /// Random seed for reproducibility.
31    pub seed: u64,
32}
33
34impl Default for ShapeCiConfig {
35    fn default() -> Self {
36        Self {
37            n_bootstrap: 200,
38            confidence_level: 0.95,
39            lambda: 0.0,
40            max_iter: 15,
41            tol: 1e-3,
42            seed: 42,
43        }
44    }
45}
46
47/// Result of shape bootstrap confidence interval computation.
48#[derive(Debug, Clone, PartialEq)]
49#[non_exhaustive]
50pub struct ShapeCiResult {
51    /// Karcher mean of the full sample.
52    pub mean: Vec<f64>,
53    /// Pointwise lower confidence band (length m).
54    pub lower_band: Vec<f64>,
55    /// Pointwise upper confidence band (length m).
56    pub upper_band: Vec<f64>,
57    /// Bootstrap Karcher means (n_bootstrap x m).
58    pub bootstrap_means: FdMatrix,
59}
60
61// ─── Public API ─────────────────────────────────────────────────────────────
62
63/// Compute bootstrap confidence intervals for the elastic Karcher mean.
64///
65/// Resamples the input curves with replacement, computes the Karcher mean
66/// of each bootstrap sample, aligns each bootstrap mean to the full-sample
67/// mean, and derives pointwise confidence bands from the empirical quantiles.
68///
69/// # Arguments
70/// * `data`    - Functional data matrix (n x m).
71/// * `argvals` - Evaluation points (length m).
72/// * `config`  - Bootstrap configuration.
73///
74/// # Errors
75/// Returns [`FdarError::InvalidDimension`] if `n < 3` or `argvals` length
76/// does not match `m`.
77/// Returns [`FdarError::InvalidParameter`] if `confidence_level` is not in
78/// `(0, 1)` or `n_bootstrap < 1`.
79#[must_use = "expensive computation whose result should not be discarded"]
80pub fn shape_confidence_interval(
81    data: &FdMatrix,
82    argvals: &[f64],
83    config: &ShapeCiConfig,
84) -> Result<ShapeCiResult, FdarError> {
85    let (n, m) = data.shape();
86
87    // ── Validation ──
88    if argvals.len() != m {
89        return Err(FdarError::InvalidDimension {
90            parameter: "argvals",
91            expected: format!("{m}"),
92            actual: format!("{}", argvals.len()),
93        });
94    }
95    if n < 3 {
96        return Err(FdarError::InvalidDimension {
97            parameter: "data",
98            expected: "at least 3 rows".to_string(),
99            actual: format!("{n} rows"),
100        });
101    }
102    if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
103        return Err(FdarError::InvalidParameter {
104            parameter: "confidence_level",
105            message: format!("must be in (0, 1), got {}", config.confidence_level),
106        });
107    }
108    if config.n_bootstrap < 1 {
109        return Err(FdarError::InvalidParameter {
110            parameter: "n_bootstrap",
111            message: format!("must be >= 1, got {}", config.n_bootstrap),
112        });
113    }
114
115    // ── Full-sample Karcher mean ──
116    let full_karcher = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
117
118    // ── Bootstrap loop ──
119    let boot_means: Vec<Vec<f64>> = iter_maybe_parallel!(0..config.n_bootstrap)
120        .map(|b| {
121            let mut rng = StdRng::seed_from_u64(config.seed + b as u64);
122
123            // Resample n indices with replacement
124            let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
125
126            // Build bootstrap matrix
127            let mut boot_data = FdMatrix::zeros(n, m);
128            for (row, &idx) in indices.iter().enumerate() {
129                for j in 0..m {
130                    boot_data[(row, j)] = data[(idx, j)];
131                }
132            }
133
134            // Compute bootstrap Karcher mean
135            let boot_karcher = karcher_mean(
136                &boot_data,
137                argvals,
138                config.max_iter,
139                config.tol,
140                config.lambda,
141            );
142
143            // Align bootstrap mean to full-sample mean
144            let aligned = elastic_align_pair(
145                &full_karcher.mean,
146                &boot_karcher.mean,
147                argvals,
148                config.lambda,
149            );
150
151            aligned.f_aligned
152        })
153        .collect();
154
155    // ── Build bootstrap_means matrix ──
156    let mut bootstrap_means = FdMatrix::zeros(config.n_bootstrap, m);
157    for (b, bm) in boot_means.iter().enumerate() {
158        for j in 0..m {
159            bootstrap_means[(b, j)] = bm[j];
160        }
161    }
162
163    // ── Pointwise confidence bands ──
164    let alpha = 1.0 - config.confidence_level;
165    let mut lower_band = vec![0.0; m];
166    let mut upper_band = vec![0.0; m];
167
168    for j in 0..m {
169        let mut col_vals: Vec<f64> = (0..config.n_bootstrap)
170            .map(|b| bootstrap_means[(b, j)])
171            .collect();
172        col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
173
174        lower_band[j] = quantile_sorted(&col_vals, alpha / 2.0);
175        upper_band[j] = quantile_sorted(&col_vals, 1.0 - alpha / 2.0);
176    }
177
178    Ok(ShapeCiResult {
179        mean: full_karcher.mean,
180        lower_band,
181        upper_band,
182        bootstrap_means,
183    })
184}
185
186use crate::helpers::quantile_sorted;
187
188// ─── Tests ──────────────────────────────────────────────────────────────────
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::simulation::{sim_fundata, EFunType, EValType};
194    use crate::test_helpers::uniform_grid;
195
196    fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
197        let t = uniform_grid(m);
198        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
199        (data, t)
200    }
201
202    #[test]
203    fn shape_ci_band_contains_mean() {
204        let (data, t) = make_data(8, 20);
205        let config = ShapeCiConfig {
206            n_bootstrap: 30,
207            confidence_level: 0.95,
208            max_iter: 5,
209            tol: 1e-2,
210            ..Default::default()
211        };
212        let result = shape_confidence_interval(&data, &t, &config).unwrap();
213        let m = t.len();
214        for j in 0..m {
215            assert!(
216                result.lower_band[j] <= result.mean[j] + 1e-6
217                    && result.mean[j] <= result.upper_band[j] + 1e-6,
218                "mean[{j}]={} not in [{}, {}]",
219                result.mean[j],
220                result.lower_band[j],
221                result.upper_band[j],
222            );
223        }
224    }
225
226    #[test]
227    fn shape_ci_band_width_positive() {
228        let (data, t) = make_data(8, 20);
229        let config = ShapeCiConfig {
230            n_bootstrap: 30,
231            confidence_level: 0.95,
232            max_iter: 5,
233            tol: 1e-2,
234            ..Default::default()
235        };
236        let result = shape_confidence_interval(&data, &t, &config).unwrap();
237        let m = t.len();
238        let n_positive = (0..m)
239            .filter(|&j| result.upper_band[j] > result.lower_band[j] + 1e-12)
240            .count();
241        assert!(
242            n_positive > m / 2,
243            "upper > lower for only {n_positive}/{m} points, expected > {}/{}",
244            m / 2,
245            m
246        );
247    }
248
249    #[test]
250    fn shape_ci_bootstrap_means_shape() {
251        let (data, t) = make_data(6, 20);
252        let n_boot = 15;
253        let config = ShapeCiConfig {
254            n_bootstrap: n_boot,
255            confidence_level: 0.90,
256            max_iter: 3,
257            tol: 1e-2,
258            ..Default::default()
259        };
260        let result = shape_confidence_interval(&data, &t, &config).unwrap();
261        assert_eq!(result.bootstrap_means.shape(), (n_boot, t.len()));
262    }
263
264    #[test]
265    fn shape_ci_rejects_too_few_curves() {
266        let t = uniform_grid(20);
267        let data = FdMatrix::zeros(2, 20);
268        let config = ShapeCiConfig::default();
269        assert!(shape_confidence_interval(&data, &t, &config).is_err());
270    }
271}