1use rand::rngs::StdRng;
4use rand::Rng;
5use rand::SeedableRng;
6
7use super::karcher::karcher_mean;
8use super::pairwise::elastic_align_pair;
9use crate::error::FdarError;
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12#[cfg(feature = "parallel")]
13use rayon::iter::ParallelIterator;
14
15#[derive(Debug, Clone, PartialEq)]
19pub struct ShapeCiConfig {
20 pub n_bootstrap: usize,
22 pub confidence_level: f64,
24 pub lambda: f64,
26 pub max_iter: usize,
28 pub tol: f64,
30 pub seed: u64,
32}
33
34impl Default for ShapeCiConfig {
35 fn default() -> Self {
36 Self {
37 n_bootstrap: 200,
38 confidence_level: 0.95,
39 lambda: 0.0,
40 max_iter: 15,
41 tol: 1e-3,
42 seed: 42,
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq)]
49#[non_exhaustive]
50pub struct ShapeCiResult {
51 pub mean: Vec<f64>,
53 pub lower_band: Vec<f64>,
55 pub upper_band: Vec<f64>,
57 pub bootstrap_means: FdMatrix,
59}
60
61#[must_use = "expensive computation whose result should not be discarded"]
80pub fn shape_confidence_interval(
81 data: &FdMatrix,
82 argvals: &[f64],
83 config: &ShapeCiConfig,
84) -> Result<ShapeCiResult, FdarError> {
85 let (n, m) = data.shape();
86
87 if argvals.len() != m {
89 return Err(FdarError::InvalidDimension {
90 parameter: "argvals",
91 expected: format!("{m}"),
92 actual: format!("{}", argvals.len()),
93 });
94 }
95 if n < 3 {
96 return Err(FdarError::InvalidDimension {
97 parameter: "data",
98 expected: "at least 3 rows".to_string(),
99 actual: format!("{n} rows"),
100 });
101 }
102 if config.confidence_level <= 0.0 || config.confidence_level >= 1.0 {
103 return Err(FdarError::InvalidParameter {
104 parameter: "confidence_level",
105 message: format!("must be in (0, 1), got {}", config.confidence_level),
106 });
107 }
108 if config.n_bootstrap < 1 {
109 return Err(FdarError::InvalidParameter {
110 parameter: "n_bootstrap",
111 message: format!("must be >= 1, got {}", config.n_bootstrap),
112 });
113 }
114
115 let full_karcher = karcher_mean(data, argvals, config.max_iter, config.tol, config.lambda);
117
118 let boot_means: Vec<Vec<f64>> = iter_maybe_parallel!(0..config.n_bootstrap)
120 .map(|b| {
121 let mut rng = StdRng::seed_from_u64(config.seed + b as u64);
122
123 let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
125
126 let mut boot_data = FdMatrix::zeros(n, m);
128 for (row, &idx) in indices.iter().enumerate() {
129 for j in 0..m {
130 boot_data[(row, j)] = data[(idx, j)];
131 }
132 }
133
134 let boot_karcher = karcher_mean(
136 &boot_data,
137 argvals,
138 config.max_iter,
139 config.tol,
140 config.lambda,
141 );
142
143 let aligned = elastic_align_pair(
145 &full_karcher.mean,
146 &boot_karcher.mean,
147 argvals,
148 config.lambda,
149 );
150
151 aligned.f_aligned
152 })
153 .collect();
154
155 let mut bootstrap_means = FdMatrix::zeros(config.n_bootstrap, m);
157 for (b, bm) in boot_means.iter().enumerate() {
158 for j in 0..m {
159 bootstrap_means[(b, j)] = bm[j];
160 }
161 }
162
163 let alpha = 1.0 - config.confidence_level;
165 let mut lower_band = vec![0.0; m];
166 let mut upper_band = vec![0.0; m];
167
168 for j in 0..m {
169 let mut col_vals: Vec<f64> = (0..config.n_bootstrap)
170 .map(|b| bootstrap_means[(b, j)])
171 .collect();
172 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
173
174 lower_band[j] = quantile_sorted(&col_vals, alpha / 2.0);
175 upper_band[j] = quantile_sorted(&col_vals, 1.0 - alpha / 2.0);
176 }
177
178 Ok(ShapeCiResult {
179 mean: full_karcher.mean,
180 lower_band,
181 upper_band,
182 bootstrap_means,
183 })
184}
185
186use crate::helpers::quantile_sorted;
187
188#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::simulation::{sim_fundata, EFunType, EValType};
194 use crate::test_helpers::uniform_grid;
195
196 fn make_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
197 let t = uniform_grid(m);
198 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(99));
199 (data, t)
200 }
201
202 #[test]
203 fn shape_ci_band_contains_mean() {
204 let (data, t) = make_data(8, 20);
205 let config = ShapeCiConfig {
206 n_bootstrap: 30,
207 confidence_level: 0.95,
208 max_iter: 5,
209 tol: 1e-2,
210 ..Default::default()
211 };
212 let result = shape_confidence_interval(&data, &t, &config).unwrap();
213 let m = t.len();
214 for j in 0..m {
215 assert!(
216 result.lower_band[j] <= result.mean[j] + 1e-6
217 && result.mean[j] <= result.upper_band[j] + 1e-6,
218 "mean[{j}]={} not in [{}, {}]",
219 result.mean[j],
220 result.lower_band[j],
221 result.upper_band[j],
222 );
223 }
224 }
225
226 #[test]
227 fn shape_ci_band_width_positive() {
228 let (data, t) = make_data(8, 20);
229 let config = ShapeCiConfig {
230 n_bootstrap: 30,
231 confidence_level: 0.95,
232 max_iter: 5,
233 tol: 1e-2,
234 ..Default::default()
235 };
236 let result = shape_confidence_interval(&data, &t, &config).unwrap();
237 let m = t.len();
238 let n_positive = (0..m)
239 .filter(|&j| result.upper_band[j] > result.lower_band[j] + 1e-12)
240 .count();
241 assert!(
242 n_positive > m / 2,
243 "upper > lower for only {n_positive}/{m} points, expected > {}/{}",
244 m / 2,
245 m
246 );
247 }
248
249 #[test]
250 fn shape_ci_bootstrap_means_shape() {
251 let (data, t) = make_data(6, 20);
252 let n_boot = 15;
253 let config = ShapeCiConfig {
254 n_bootstrap: n_boot,
255 confidence_level: 0.90,
256 max_iter: 3,
257 tol: 1e-2,
258 ..Default::default()
259 };
260 let result = shape_confidence_interval(&data, &t, &config).unwrap();
261 assert_eq!(result.bootstrap_means.shape(), (n_boot, t.len()));
262 }
263
264 #[test]
265 fn shape_ci_rejects_too_few_curves() {
266 let t = uniform_grid(20);
267 let data = FdMatrix::zeros(2, 20);
268 let config = ShapeCiConfig::default();
269 assert!(shape_confidence_interval(&data, &t, &config).is_err());
270 }
271}