1use rand::rngs::StdRng;
4use rand::Rng;
5use rand::SeedableRng;
6
7use super::karcher::karcher_mean;
8use super::pairwise::elastic_align_pair;
9use crate::error::FdarError;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12#[cfg(feature = "parallel")]
13use rayon::iter::ParallelIterator;
14
15#[derive(Debug, Clone, PartialEq)]
19pub struct ShapeCiConfig {
20 pub n_bootstrap: usize,
22 pub confidence_level: f64,
24 pub lambda: f64,
26 pub max_iter: usize,
28 pub tol: f64,
30 pub seed: u64,
32}
33
34impl Default for ShapeCiConfig {
35 fn default() -> Self {
36 Self {
37 n_bootstrap: 200,
38 confidence_level: 0.95,
39 lambda: 0.0,
40 max_iter: 15,
41 tol: 1e-3,
42 seed: 42,
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq)]
49#[non_exhaustive]
50pub struct ShapeCiResult {
51 pub mean: Vec<f64>,
53 pub lower_band: Vec<f64>,
55 pub upper_band: Vec<f64>,
57 pub bootstrap_means: FdMatrix,
59}
60
61#[must_use = "expensive computation whose result should not be discarded"]
80pub fn shape_confidence_interval(
81 data: &FdMatrix,
82 argvals: &[f64],
83 config: &ShapeCiConfig,
84) -> Result<ShapeCiResult, FdarError> {
85 let (n, m) = data.shape();
86
87 if argvals.len() != m {
89 return Err(FdarError::InvalidDimension {
90 parameter: "argvals",
91 expected: format!("{m}"),
92 actual: format!("{}", argvals.len()),
93 });
94 }
95 if n < 3 {
96 return Err(FdarError::InvalidDimension {
97 parameter: "data",
98 expected: "at least 3 rows".to_string(),
99 actual: format!("{n} rows"),
100 });
101 }
102 if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
103 return Err(FdarError::InvalidParameter {
104 parameter: "confidence_level",
105 message: format!("must be in (0, 1), got {}", config.confidence_level),
106 });
107 }
108 if config.n_bootstrap < 1 {
109 return Err(FdarError::InvalidParameter {
110 parameter: "n_bootstrap",
111 message: format!("must be >= 1, got {}", config.n_bootstrap),
112 });
113 }
114
115 let full_karcher = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
117
118 let boot_means: Vec<Vec<f64>> = iter_maybe_parallel!(0..config.n_bootstrap)
120 .map(|b| {
121 let mut rng = StdRng::seed_from_u64(config.seed + b as u64);
122
123 let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
125
126 let mut boot_data = FdMatrix::zeros(n, m);
128 for (row, &idx) in indices.iter().enumerate() {
129 for j in 0..m {
130 boot_data[(row, j)] = data[(idx, j)];
131 }
132 }
133
134 let boot_karcher = karcher_mean(
136 &boot_data,
137 argvals,
138 config.max_iter,
139 config.tol,
140 config.lambda,
141 );
142
143 let aligned = elastic_align_pair(
145 &full_karcher.mean,
146 &boot_karcher.mean,
147 argvals,
148 config.lambda,
149 );
150
151 aligned.f_aligned
152 })
153 .collect();
154
155 let mut bootstrap_means = FdMatrix::zeros(config.n_bootstrap, m);
157 for (b, bm) in boot_means.iter().enumerate() {
158 for j in 0..m {
159 bootstrap_means[(b, j)] = bm[j];
160 }
161 }
162
163 let alpha = 1.0 - config.confidence_level;
165 let mut lower_band = vec![0.0; m];
166 let mut upper_band = vec![0.0; m];
167
168 for j in 0..m {
169 let mut col_vals: Vec<f64> = (0..config.n_bootstrap)
170 .map(|b| bootstrap_means[(b, j)])
171 .collect();
172 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
173
174 lower_band[j] = quantile_sorted(&col_vals, alpha / 2.0);
175 upper_band[j] = quantile_sorted(&col_vals, 1.0 - alpha / 2.0);
176 }
177
178 Ok(ShapeCiResult {
179 mean: full_karcher.mean,
180 lower_band,
181 upper_band,
182 bootstrap_means,
183 })
184}
185
186fn quantile_sorted(sorted: &[f64], p: f64) -> f64 {
188 let n = sorted.len();
189 if n == 0 {
190 return 0.0;
191 }
192 if n == 1 {
193 return sorted[0];
194 }
195 let idx = p * (n - 1) as f64;
196 let lo = idx.floor() as usize;
197 let hi = idx.ceil() as usize;
198 let frac = idx - lo as f64;
199 if lo == hi || hi >= n {
200 sorted[lo.min(n - 1)]
201 } else {
202 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
203 }
204}
205
206#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::simulation::{sim_fundata, EFunType, EValType};
212 use crate::test_helpers::uniform_grid;
213
214 fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
215 let t = uniform_grid(m);
216 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
217 (data, t)
218 }
219
220 #[test]
221 fn shape_ci_band_contains_mean() {
222 let (data, t) = make_data(8, 20);
223 let config = ShapeCiConfig {
224 n_bootstrap: 30,
225 confidence_level: 0.95,
226 max_iter: 5,
227 tol: 1e-2,
228 ..Default::default()
229 };
230 let result = shape_confidence_interval(&data, &t, &config).unwrap();
231 let m = t.len();
232 for j in 0..m {
233 assert!(
234 result.lower_band[j] <= result.mean[j] + 1e-6
235 && result.mean[j] <= result.upper_band[j] + 1e-6,
236 "mean[{j}]={} not in [{}, {}]",
237 result.mean[j],
238 result.lower_band[j],
239 result.upper_band[j],
240 );
241 }
242 }
243
244 #[test]
245 fn shape_ci_band_width_positive() {
246 let (data, t) = make_data(8, 20);
247 let config = ShapeCiConfig {
248 n_bootstrap: 30,
249 confidence_level: 0.95,
250 max_iter: 5,
251 tol: 1e-2,
252 ..Default::default()
253 };
254 let result = shape_confidence_interval(&data, &t, &config).unwrap();
255 let m = t.len();
256 let n_positive = (0..m)
257 .filter(|&j| result.upper_band[j] > result.lower_band[j] + 1e-12)
258 .count();
259 assert!(
260 n_positive > m / 2,
261 "upper > lower for only {n_positive}/{m} points, expected > {}/{}",
262 m / 2,
263 m
264 );
265 }
266
267 #[test]
268 fn shape_ci_bootstrap_means_shape() {
269 let (data, t) = make_data(6, 20);
270 let n_boot = 15;
271 let config = ShapeCiConfig {
272 n_bootstrap: n_boot,
273 confidence_level: 0.90,
274 max_iter: 3,
275 tol: 1e-2,
276 ..Default::default()
277 };
278 let result = shape_confidence_interval(&data, &t, &config).unwrap();
279 assert_eq!(result.bootstrap_means.shape(), (n_boot, t.len()));
280 }
281
282 #[test]
283 fn shape_ci_rejects_too_few_curves() {
284 let t = uniform_grid(20);
285 let data = FdMatrix::zeros(2, 20);
286 let config = ShapeCiConfig::default();
287 assert!(shape_confidence_interval(&data, &t, &config).is_err());
288 }
289}