Skip to main content

fdars_core/alignment/
shape_ci.rs

1//! Bootstrap confidence intervals for curve shapes in the elastic metric.
2
3use rand::rngs::StdRng;
4use rand::Rng;
5use rand::SeedableRng;
6
7use super::karcher::karcher_mean;
8use super::pairwise::elastic_align_pair;
9use crate::error::FdarError;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12#[cfg(feature = "parallel")]
13use rayon::iter::ParallelIterator;
14
15// ─── Types ──────────────────────────────────────────────────────────────────
16
17/// Configuration for shape bootstrap confidence intervals.
18#[derive(Debug, Clone, PartialEq)]
19pub struct ShapeCiConfig {
20    /// Number of bootstrap resamples.
21    pub n_bootstrap: usize,
22    /// Confidence level (e.g., 0.95 for 95% CI).
23    pub confidence_level: f64,
24    /// Roughness penalty for elastic alignment.
25    pub lambda: f64,
26    /// Maximum Karcher mean iterations.
27    pub max_iter: usize,
28    /// Convergence tolerance for the Karcher mean.
29    pub tol: f64,
30    /// Random seed for reproducibility.
31    pub seed: u64,
32}
33
34impl Default for ShapeCiConfig {
35    fn default() -> Self {
36        Self {
37            n_bootstrap: 200,
38            confidence_level: 0.95,
39            lambda: 0.0,
40            max_iter: 15,
41            tol: 1e-3,
42            seed: 42,
43        }
44    }
45}
46
47/// Result of shape bootstrap confidence interval computation.
48#[derive(Debug, Clone, PartialEq)]
49#[non_exhaustive]
50pub struct ShapeCiResult {
51    /// Karcher mean of the full sample.
52    pub mean: Vec<f64>,
53    /// Pointwise lower confidence band (length m).
54    pub lower_band: Vec<f64>,
55    /// Pointwise upper confidence band (length m).
56    pub upper_band: Vec<f64>,
57    /// Bootstrap Karcher means (n_bootstrap x m).
58    pub bootstrap_means: FdMatrix,
59}
60
61// ─── Public API ─────────────────────────────────────────────────────────────
62
63/// Compute bootstrap confidence intervals for the elastic Karcher mean.
64///
65/// Resamples the input curves with replacement, computes the Karcher mean
66/// of each bootstrap sample, aligns each bootstrap mean to the full-sample
67/// mean, and derives pointwise confidence bands from the empirical quantiles.
68///
69/// # Arguments
70/// * `data`    - Functional data matrix (n x m).
71/// * `argvals` - Evaluation points (length m).
72/// * `config`  - Bootstrap configuration.
73///
74/// # Errors
75/// Returns [`FdarError::InvalidDimension`] if `n < 3` or `argvals` length
76/// does not match `m`.
77/// Returns [`FdarError::InvalidParameter`] if `confidence_level` is not in
78/// `(0, 1)` or `n_bootstrap < 1`.
79#[must_use = "expensive computation whose result should not be discarded"]
80pub fn shape_confidence_interval(
81    data: &FdMatrix,
82    argvals: &[f64],
83    config: &ShapeCiConfig,
84) -> Result<ShapeCiResult, FdarError> {
85    let (n, m) = data.shape();
86
87    // ── Validation ──
88    if argvals.len() != m {
89        return Err(FdarError::InvalidDimension {
90            parameter: "argvals",
91            expected: format!("{m}"),
92            actual: format!("{}", argvals.len()),
93        });
94    }
95    if n < 3 {
96        return Err(FdarError::InvalidDimension {
97            parameter: "data",
98            expected: "at least 3 rows".to_string(),
99            actual: format!("{n} rows"),
100        });
101    }
102    if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
103        return Err(FdarError::InvalidParameter {
104            parameter: "confidence_level",
105            message: format!("must be in (0, 1), got {}", config.confidence_level),
106        });
107    }
108    if config.n_bootstrap < 1 {
109        return Err(FdarError::InvalidParameter {
110            parameter: "n_bootstrap",
111            message: format!("must be >= 1, got {}", config.n_bootstrap),
112        });
113    }
114
115    // ── Full-sample Karcher mean ──
116    let full_karcher = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
117
118    // ── Bootstrap loop ──
119    let boot_means: Vec<Vec<f64>> = iter_maybe_parallel!(0..config.n_bootstrap)
120        .map(|b| {
121            let mut rng = StdRng::seed_from_u64(config.seed + b as u64);
122
123            // Resample n indices with replacement
124            let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
125
126            // Build bootstrap matrix
127            let mut boot_data = FdMatrix::zeros(n, m);
128            for (row, &idx) in indices.iter().enumerate() {
129                for j in 0..m {
130                    boot_data[(row, j)] = data[(idx, j)];
131                }
132            }
133
134            // Compute bootstrap Karcher mean
135            let boot_karcher = karcher_mean(
136                &boot_data,
137                argvals,
138                config.max_iter,
139                config.tol,
140                config.lambda,
141            );
142
143            // Align bootstrap mean to full-sample mean
144            let aligned = elastic_align_pair(
145                &full_karcher.mean,
146                &boot_karcher.mean,
147                argvals,
148                config.lambda,
149            );
150
151            aligned.f_aligned
152        })
153        .collect();
154
155    // ── Build bootstrap_means matrix ──
156    let mut bootstrap_means = FdMatrix::zeros(config.n_bootstrap, m);
157    for (b, bm) in boot_means.iter().enumerate() {
158        for j in 0..m {
159            bootstrap_means[(b, j)] = bm[j];
160        }
161    }
162
163    // ── Pointwise confidence bands ──
164    let alpha = 1.0 - config.confidence_level;
165    let mut lower_band = vec![0.0; m];
166    let mut upper_band = vec![0.0; m];
167
168    for j in 0..m {
169        let mut col_vals: Vec<f64> = (0..config.n_bootstrap)
170            .map(|b| bootstrap_means[(b, j)])
171            .collect();
172        col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
173
174        lower_band[j] = quantile_sorted(&col_vals, alpha / 2.0);
175        upper_band[j] = quantile_sorted(&col_vals, 1.0 - alpha / 2.0);
176    }
177
178    Ok(ShapeCiResult {
179        mean: full_karcher.mean,
180        lower_band,
181        upper_band,
182        bootstrap_means,
183    })
184}
185
186/// Compute a quantile from a sorted slice using linear interpolation.
187fn quantile_sorted(sorted: &[f64], p: f64) -> f64 {
188    let n = sorted.len();
189    if n == 0 {
190        return 0.0;
191    }
192    if n == 1 {
193        return sorted[0];
194    }
195    let idx = p * (n - 1) as f64;
196    let lo = idx.floor() as usize;
197    let hi = idx.ceil() as usize;
198    let frac = idx - lo as f64;
199    if lo == hi || hi >= n {
200        sorted[lo.min(n - 1)]
201    } else {
202        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
203    }
204}
205
206// ─── Tests ──────────────────────────────────────────────────────────────────
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::simulation::{sim_fundata, EFunType, EValType};
212    use crate::test_helpers::uniform_grid;
213
214    fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
215        let t = uniform_grid(m);
216        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
217        (data, t)
218    }
219
220    #[test]
221    fn shape_ci_band_contains_mean() {
222        let (data, t) = make_data(8, 20);
223        let config = ShapeCiConfig {
224            n_bootstrap: 30,
225            confidence_level: 0.95,
226            max_iter: 5,
227            tol: 1e-2,
228            ..Default::default()
229        };
230        let result = shape_confidence_interval(&data, &t, &config).unwrap();
231        let m = t.len();
232        for j in 0..m {
233            assert!(
234                result.lower_band[j] <= result.mean[j] + 1e-6
235                    && result.mean[j] <= result.upper_band[j] + 1e-6,
236                "mean[{j}]={} not in [{}, {}]",
237                result.mean[j],
238                result.lower_band[j],
239                result.upper_band[j],
240            );
241        }
242    }
243
244    #[test]
245    fn shape_ci_band_width_positive() {
246        let (data, t) = make_data(8, 20);
247        let config = ShapeCiConfig {
248            n_bootstrap: 30,
249            confidence_level: 0.95,
250            max_iter: 5,
251            tol: 1e-2,
252            ..Default::default()
253        };
254        let result = shape_confidence_interval(&data, &t, &config).unwrap();
255        let m = t.len();
256        let n_positive = (0..m)
257            .filter(|&j| result.upper_band[j] > result.lower_band[j] + 1e-12)
258            .count();
259        assert!(
260            n_positive > m / 2,
261            "upper > lower for only {n_positive}/{m} points, expected > {}/{}",
262            m / 2,
263            m
264        );
265    }
266
267    #[test]
268    fn shape_ci_bootstrap_means_shape() {
269        let (data, t) = make_data(6, 20);
270        let n_boot = 15;
271        let config = ShapeCiConfig {
272            n_bootstrap: n_boot,
273            confidence_level: 0.90,
274            max_iter: 3,
275            tol: 1e-2,
276            ..Default::default()
277        };
278        let result = shape_confidence_interval(&data, &t, &config).unwrap();
279        assert_eq!(result.bootstrap_means.shape(), (n_boot, t.len()));
280    }
281
282    #[test]
283    fn shape_ci_rejects_too_few_curves() {
284        let t = uniform_grid(20);
285        let data = FdMatrix::zeros(2, 20);
286        let config = ShapeCiConfig::default();
287        assert!(shape_confidence_interval(&data, &t, &config).is_err());
288    }
289}