1use numra_core::Scalar;
10
11#[derive(Clone, Debug)]
13pub struct EnsembleStats<S: Scalar> {
14 pub n_samples: usize,
16 pub mean: S,
18 pub std: S,
20 pub variance: S,
22 pub min: S,
24 pub max: S,
26 pub percentiles: Percentiles<S>,
28}
29
30#[derive(Clone, Debug)]
32pub struct Percentiles<S: Scalar> {
33 pub p5: S,
34 pub p25: S,
35 pub p50: S, pub p75: S,
37 pub p95: S,
38}
39
40impl<S: Scalar> EnsembleStats<S> {
41 pub fn from_samples(samples: &[S]) -> Option<Self> {
43 if samples.is_empty() {
44 return None;
45 }
46
47 let n = samples.len();
48 let n_f = S::from_usize(n);
49
50 let sum: S = samples.iter().fold(S::ZERO, |acc, &x| acc + x);
52 let mean = sum / n_f;
53
54 let var_sum: S = samples.iter().fold(S::ZERO, |acc, &x| {
56 let diff = x - mean;
57 acc + diff * diff
58 });
59 let variance = if n > 1 {
60 var_sum / S::from_usize(n - 1) } else {
62 S::ZERO
63 };
64 let std = variance.sqrt();
65
66 let mut min = samples[0];
68 let mut max = samples[0];
69 for &x in samples.iter().skip(1) {
70 if x < min {
71 min = x;
72 }
73 if x > max {
74 max = x;
75 }
76 }
77
78 let mut sorted = samples.to_vec();
80 sorted.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
82
83 let percentiles = Percentiles {
84 p5: percentile_sorted(&sorted, 5.0),
85 p25: percentile_sorted(&sorted, 25.0),
86 p50: percentile_sorted(&sorted, 50.0),
87 p75: percentile_sorted(&sorted, 75.0),
88 p95: percentile_sorted(&sorted, 95.0),
89 };
90
91 Some(Self {
92 n_samples: n,
93 mean,
94 std,
95 variance,
96 min,
97 max,
98 percentiles,
99 })
100 }
101
102 pub fn standard_error(&self) -> S {
104 self.std / S::from_usize(self.n_samples).sqrt()
105 }
106
107 pub fn confidence_interval(&self, level: S) -> (S, S) {
111 let alpha = (S::ONE - level) / S::from_f64(2.0);
114 let z = normal_quantile(S::ONE - alpha);
115 let margin = z * self.standard_error();
116 (self.mean - margin, self.mean + margin)
117 }
118
119 pub fn iqr(&self) -> S {
121 self.percentiles.p75 - self.percentiles.p25
122 }
123
124 pub fn median(&self) -> S {
126 self.percentiles.p50
127 }
128}
129
130fn percentile_sorted<S: Scalar>(sorted: &[S], p: f64) -> S {
132 let n = sorted.len();
133 if n == 0 {
134 return S::ZERO;
135 }
136 if n == 1 {
137 return sorted[0];
138 }
139
140 let rank = (p / 100.0) * (n - 1) as f64;
142 let lower = rank.floor() as usize;
143 let upper = rank.ceil() as usize;
144
145 if lower == upper {
146 sorted[lower]
147 } else {
148 let frac = S::from_f64(rank - lower as f64);
149 sorted[lower] + frac * (sorted[upper] - sorted[lower])
150 }
151}
152
153fn normal_quantile<S: Scalar>(p: S) -> S {
157 let p_f = p.to_f64();
159 if p_f <= 0.0 || p_f >= 1.0 {
160 return S::ZERO;
161 }
162
163 #[allow(clippy::excessive_precision)]
164 let a = [
165 -3.969683028665376e+01,
166 2.209460984245205e+02,
167 -2.759285104469687e+02,
168 1.383577518672690e+02,
169 -3.066479806614716e+01,
170 2.506628277459239e+00,
171 ];
172 let b = [
173 -5.447609879822406e+01,
174 1.615858368580409e+02,
175 -1.556989798598866e+02,
176 6.680131188771972e+01,
177 -1.328068155288572e+01,
178 ];
179 let c = [
180 -7.784894002430293e-03,
181 -3.223964580411365e-01,
182 -2.400758277161838e+00,
183 -2.549732539343734e+00,
184 4.374664141464968e+00,
185 2.938163982698783e+00,
186 ];
187 let d = [
188 7.784695709041462e-03,
189 3.224671290700398e-01,
190 2.445134137142996e+00,
191 3.754408661907416e+00,
192 ];
193
194 let p_low = 0.02425;
195 let p_high = 1.0 - p_low;
196
197 let q = if p_f < p_low {
198 let q = (-2.0 * p_f.ln()).sqrt();
199 (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
200 / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
201 } else if p_f <= p_high {
202 let q = p_f - 0.5;
203 let r = q * q;
204 (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
205 / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
206 } else {
207 let q = (-2.0 * (1.0 - p_f).ln()).sqrt();
208 -(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
209 / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
210 };
211
212 S::from_f64(q)
213}
214
215#[derive(Clone, Debug)]
219pub struct RunningStats<S: Scalar> {
220 n: usize,
221 mean: S,
222 m2: S, min: S,
224 max: S,
225}
226
227impl<S: Scalar> RunningStats<S> {
228 pub fn new() -> Self {
230 Self {
231 n: 0,
232 mean: S::ZERO,
233 m2: S::ZERO,
234 min: S::INFINITY,
235 max: S::NEG_INFINITY,
236 }
237 }
238
239 pub fn update(&mut self, value: S) {
241 self.n += 1;
242 let n_f = S::from_usize(self.n);
243
244 let delta = value - self.mean;
245 self.mean += delta / n_f;
246 let delta2 = value - self.mean;
247 self.m2 += delta * delta2;
248
249 if value < self.min {
250 self.min = value;
251 }
252 if value > self.max {
253 self.max = value;
254 }
255 }
256
257 pub fn count(&self) -> usize {
259 self.n
260 }
261
262 pub fn mean(&self) -> S {
264 self.mean
265 }
266
267 pub fn variance(&self) -> S {
269 if self.n < 2 {
270 S::ZERO
271 } else {
272 self.m2 / S::from_usize(self.n - 1)
273 }
274 }
275
276 pub fn std(&self) -> S {
278 self.variance().sqrt()
279 }
280
281 pub fn standard_error(&self) -> S {
283 self.std() / S::from_usize(self.n).sqrt()
284 }
285
286 pub fn min(&self) -> S {
288 self.min
289 }
290
291 pub fn max(&self) -> S {
293 self.max
294 }
295
296 pub fn merge(&mut self, other: &RunningStats<S>) {
298 if other.n == 0 {
299 return;
300 }
301 if self.n == 0 {
302 *self = other.clone();
303 return;
304 }
305
306 let n_a = S::from_usize(self.n);
307 let n_b = S::from_usize(other.n);
308 let n_total = n_a + n_b;
309
310 let delta = other.mean - self.mean;
311 let new_mean = (n_a * self.mean + n_b * other.mean) / n_total;
312
313 let new_m2 = self.m2 + other.m2 + delta * delta * n_a * n_b / n_total;
315
316 self.n += other.n;
317 self.mean = new_mean;
318 self.m2 = new_m2;
319
320 if other.min < self.min {
321 self.min = other.min;
322 }
323 if other.max > self.max {
324 self.max = other.max;
325 }
326 }
327}
328
329impl<S: Scalar> Default for RunningStats<S> {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335#[inline]
339pub fn mean<S: Scalar>(data: &[S]) -> S {
340 if data.is_empty() {
341 return S::ZERO;
342 }
343 data.iter().fold(S::ZERO, |acc, &x| acc + x) / S::from_usize(data.len())
344}
345
346pub fn std<S: Scalar>(data: &[S]) -> S {
348 variance(data).sqrt()
349}
350
351pub fn variance<S: Scalar>(data: &[S]) -> S {
353 if data.len() < 2 {
354 return S::ZERO;
355 }
356 let m = mean(data);
357 let sum_sq: S = data.iter().fold(S::ZERO, |acc, &x| {
358 let diff = x - m;
359 acc + diff * diff
360 });
361 sum_sq / S::from_usize(data.len() - 1)
362}
363
364pub fn percentile<S: Scalar>(data: &[S], p: f64) -> S {
366 if data.is_empty() {
367 return S::ZERO;
368 }
369 let mut sorted = data.to_vec();
370 sorted.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
371 percentile_sorted(&sorted, p)
372}
373
374pub fn median<S: Scalar>(data: &[S]) -> S {
376 percentile(data, 50.0)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_basic_stats() {
385 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
386
387 assert!((mean(&data) - 3.0).abs() < 1e-10);
388 assert!((variance(&data) - 2.5).abs() < 1e-10); assert!((std(&data) - 2.5_f64.sqrt()).abs() < 1e-10);
390 assert!((median(&data) - 3.0).abs() < 1e-10);
391 }
392
393 #[test]
394 fn test_percentiles() {
395 let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
396
397 assert!((percentile(&data, 50.0) - 50.5).abs() < 0.5);
398 assert!((percentile(&data, 25.0) - 25.0).abs() < 1.0);
399 assert!((percentile(&data, 75.0) - 75.0).abs() < 1.0);
400 }
401
402 #[test]
403 fn test_ensemble_stats() {
404 let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
405 let stats = EnsembleStats::from_samples(&data).unwrap();
406
407 assert_eq!(stats.n_samples, 100);
408 assert!((stats.mean - 50.5).abs() < 0.01);
409 assert!((stats.min - 1.0).abs() < 1e-10);
410 assert!((stats.max - 100.0).abs() < 1e-10);
411 }
412
413 #[test]
414 fn test_running_stats() {
415 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
416 let mut rs = RunningStats::<f64>::new();
417
418 for &x in &data {
419 rs.update(x);
420 }
421
422 assert_eq!(rs.count(), 5);
423 assert!((rs.mean() - 3.0).abs() < 1e-10);
424 assert!((rs.variance() - 2.5).abs() < 1e-10);
425 assert!((rs.min() - 1.0).abs() < 1e-10);
426 assert!((rs.max() - 5.0).abs() < 1e-10);
427 }
428
429 #[test]
430 fn test_running_stats_merge() {
431 let data1 = vec![1.0, 2.0, 3.0];
432 let data2 = vec![4.0, 5.0, 6.0];
433
434 let mut rs1 = RunningStats::<f64>::new();
435 let mut rs2 = RunningStats::<f64>::new();
436
437 for &x in &data1 {
438 rs1.update(x);
439 }
440 for &x in &data2 {
441 rs2.update(x);
442 }
443
444 rs1.merge(&rs2);
445
446 let combined = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
448 assert_eq!(rs1.count(), 6);
449 assert!((rs1.mean() - mean(&combined)).abs() < 1e-10);
450 assert!((rs1.variance() - variance(&combined)).abs() < 1e-10);
451 }
452
453 #[test]
454 fn test_confidence_interval() {
455 let data: Vec<f64> = (0..1000)
457 .map(|i| {
458 (i as f64 / 1000.0 - 0.5) * 2.0
460 })
461 .collect();
462
463 let stats = EnsembleStats::from_samples(&data).unwrap();
464 let (lo, hi) = stats.confidence_interval(0.95);
465
466 assert!(lo < stats.mean);
468 assert!(hi > stats.mean);
469 assert!(hi - lo < stats.max - stats.min);
471 }
472}