Skip to main content

numra_sde/
stats.rs

1//! Statistics utilities for ensemble SDE results.
2//!
3//! Provides tools for computing statistics from Monte Carlo simulations.
4//!
5//! Author: Moussa Leblouba
6//! Date: 3 February 2026
7//! Modified: 2 May 2026
8
9use numra_core::Scalar;
10
11/// Statistics computed from an ensemble of trajectories.
12#[derive(Clone, Debug)]
13pub struct EnsembleStats<S: Scalar> {
14    /// Number of samples
15    pub n_samples: usize,
16    /// Sample mean
17    pub mean: S,
18    /// Sample standard deviation
19    pub std: S,
20    /// Sample variance
21    pub variance: S,
22    /// Minimum value
23    pub min: S,
24    /// Maximum value
25    pub max: S,
26    /// Percentiles (5th, 25th, 50th, 75th, 95th)
27    pub percentiles: Percentiles<S>,
28}
29
30/// Common percentiles.
31#[derive(Clone, Debug)]
32pub struct Percentiles<S: Scalar> {
33    pub p5: S,
34    pub p25: S,
35    pub p50: S, // Median
36    pub p75: S,
37    pub p95: S,
38}
39
40impl<S: Scalar> EnsembleStats<S> {
41    /// Compute statistics from a vector of samples.
42    pub fn from_samples(samples: &[S]) -> Option<Self> {
43        if samples.is_empty() {
44            return None;
45        }
46
47        let n = samples.len();
48        let n_f = S::from_usize(n);
49
50        // Mean
51        let sum: S = samples.iter().fold(S::ZERO, |acc, &x| acc + x);
52        let mean = sum / n_f;
53
54        // Variance and std
55        let var_sum: S = samples.iter().fold(S::ZERO, |acc, &x| {
56            let diff = x - mean;
57            acc + diff * diff
58        });
59        let variance = if n > 1 {
60            var_sum / S::from_usize(n - 1) // Bessel correction
61        } else {
62            S::ZERO
63        };
64        let std = variance.sqrt();
65
66        // Min and max
67        let mut min = samples[0];
68        let mut max = samples[0];
69        for &x in samples.iter().skip(1) {
70            if x < min {
71                min = x;
72            }
73            if x > max {
74                max = x;
75            }
76        }
77
78        // Percentiles (need sorted data)
79        let mut sorted = samples.to_vec();
80        // Sort by converting to f64 for comparison
81        sorted.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
82
83        let percentiles = Percentiles {
84            p5: percentile_sorted(&sorted, 5.0),
85            p25: percentile_sorted(&sorted, 25.0),
86            p50: percentile_sorted(&sorted, 50.0),
87            p75: percentile_sorted(&sorted, 75.0),
88            p95: percentile_sorted(&sorted, 95.0),
89        };
90
91        Some(Self {
92            n_samples: n,
93            mean,
94            std,
95            variance,
96            min,
97            max,
98            percentiles,
99        })
100    }
101
102    /// Compute standard error of the mean.
103    pub fn standard_error(&self) -> S {
104        self.std / S::from_usize(self.n_samples).sqrt()
105    }
106
107    /// Compute confidence interval for the mean at given level (e.g., 0.95 for 95%).
108    ///
109    /// Uses normal approximation (valid for large samples).
110    pub fn confidence_interval(&self, level: S) -> (S, S) {
111        // z-score for two-tailed test
112        // For 95%: z ≈ 1.96, for 99%: z ≈ 2.576
113        let alpha = (S::ONE - level) / S::from_f64(2.0);
114        let z = normal_quantile(S::ONE - alpha);
115        let margin = z * self.standard_error();
116        (self.mean - margin, self.mean + margin)
117    }
118
119    /// Interquartile range (IQR = Q3 - Q1).
120    pub fn iqr(&self) -> S {
121        self.percentiles.p75 - self.percentiles.p25
122    }
123
124    /// Median (50th percentile).
125    pub fn median(&self) -> S {
126        self.percentiles.p50
127    }
128}
129
130/// Compute percentile from sorted data.
131fn percentile_sorted<S: Scalar>(sorted: &[S], p: f64) -> S {
132    let n = sorted.len();
133    if n == 0 {
134        return S::ZERO;
135    }
136    if n == 1 {
137        return sorted[0];
138    }
139
140    // Linear interpolation method
141    let rank = (p / 100.0) * (n - 1) as f64;
142    let lower = rank.floor() as usize;
143    let upper = rank.ceil() as usize;
144
145    if lower == upper {
146        sorted[lower]
147    } else {
148        let frac = S::from_f64(rank - lower as f64);
149        sorted[lower] + frac * (sorted[upper] - sorted[lower])
150    }
151}
152
153/// Approximate inverse normal CDF (quantile function).
154///
155/// Uses Abramowitz and Stegun approximation.
156fn normal_quantile<S: Scalar>(p: S) -> S {
157    // Rational approximation for 0 < p < 1
158    let p_f = p.to_f64();
159    if p_f <= 0.0 || p_f >= 1.0 {
160        return S::ZERO;
161    }
162
163    #[allow(clippy::excessive_precision)]
164    let a = [
165        -3.969683028665376e+01,
166        2.209460984245205e+02,
167        -2.759285104469687e+02,
168        1.383577518672690e+02,
169        -3.066479806614716e+01,
170        2.506628277459239e+00,
171    ];
172    let b = [
173        -5.447609879822406e+01,
174        1.615858368580409e+02,
175        -1.556989798598866e+02,
176        6.680131188771972e+01,
177        -1.328068155288572e+01,
178    ];
179    let c = [
180        -7.784894002430293e-03,
181        -3.223964580411365e-01,
182        -2.400758277161838e+00,
183        -2.549732539343734e+00,
184        4.374664141464968e+00,
185        2.938163982698783e+00,
186    ];
187    let d = [
188        7.784695709041462e-03,
189        3.224671290700398e-01,
190        2.445134137142996e+00,
191        3.754408661907416e+00,
192    ];
193
194    let p_low = 0.02425;
195    let p_high = 1.0 - p_low;
196
197    let q = if p_f < p_low {
198        let q = (-2.0 * p_f.ln()).sqrt();
199        (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
200            / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
201    } else if p_f <= p_high {
202        let q = p_f - 0.5;
203        let r = q * q;
204        (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
205            / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
206    } else {
207        let q = (-2.0 * (1.0 - p_f).ln()).sqrt();
208        -(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
209            / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
210    };
211
212    S::from_f64(q)
213}
214
215/// Running statistics using Welford's online algorithm.
216///
217/// Memory-efficient for streaming data or very large ensembles.
218#[derive(Clone, Debug)]
219pub struct RunningStats<S: Scalar> {
220    n: usize,
221    mean: S,
222    m2: S, // Sum of squared deviations
223    min: S,
224    max: S,
225}
226
227impl<S: Scalar> RunningStats<S> {
228    /// Create a new running statistics accumulator.
229    pub fn new() -> Self {
230        Self {
231            n: 0,
232            mean: S::ZERO,
233            m2: S::ZERO,
234            min: S::INFINITY,
235            max: S::NEG_INFINITY,
236        }
237    }
238
239    /// Update statistics with a new value (Welford's algorithm).
240    pub fn update(&mut self, value: S) {
241        self.n += 1;
242        let n_f = S::from_usize(self.n);
243
244        let delta = value - self.mean;
245        self.mean += delta / n_f;
246        let delta2 = value - self.mean;
247        self.m2 += delta * delta2;
248
249        if value < self.min {
250            self.min = value;
251        }
252        if value > self.max {
253            self.max = value;
254        }
255    }
256
257    /// Number of samples seen.
258    pub fn count(&self) -> usize {
259        self.n
260    }
261
262    /// Current mean estimate.
263    pub fn mean(&self) -> S {
264        self.mean
265    }
266
267    /// Current variance estimate (sample variance with Bessel correction).
268    pub fn variance(&self) -> S {
269        if self.n < 2 {
270            S::ZERO
271        } else {
272            self.m2 / S::from_usize(self.n - 1)
273        }
274    }
275
276    /// Current standard deviation estimate.
277    pub fn std(&self) -> S {
278        self.variance().sqrt()
279    }
280
281    /// Standard error of the mean.
282    pub fn standard_error(&self) -> S {
283        self.std() / S::from_usize(self.n).sqrt()
284    }
285
286    /// Minimum value seen.
287    pub fn min(&self) -> S {
288        self.min
289    }
290
291    /// Maximum value seen.
292    pub fn max(&self) -> S {
293        self.max
294    }
295
296    /// Merge another RunningStats into this one (parallel reduction).
297    pub fn merge(&mut self, other: &RunningStats<S>) {
298        if other.n == 0 {
299            return;
300        }
301        if self.n == 0 {
302            *self = other.clone();
303            return;
304        }
305
306        let n_a = S::from_usize(self.n);
307        let n_b = S::from_usize(other.n);
308        let n_total = n_a + n_b;
309
310        let delta = other.mean - self.mean;
311        let new_mean = (n_a * self.mean + n_b * other.mean) / n_total;
312
313        // Chan's parallel algorithm for M2
314        let new_m2 = self.m2 + other.m2 + delta * delta * n_a * n_b / n_total;
315
316        self.n += other.n;
317        self.mean = new_mean;
318        self.m2 = new_m2;
319
320        if other.min < self.min {
321            self.min = other.min;
322        }
323        if other.max > self.max {
324            self.max = other.max;
325        }
326    }
327}
328
329impl<S: Scalar> Default for RunningStats<S> {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335// Helper functions exported for convenience
336
337/// Compute mean of a slice.
338#[inline]
339pub fn mean<S: Scalar>(data: &[S]) -> S {
340    if data.is_empty() {
341        return S::ZERO;
342    }
343    data.iter().fold(S::ZERO, |acc, &x| acc + x) / S::from_usize(data.len())
344}
345
346/// Compute sample standard deviation.
347pub fn std<S: Scalar>(data: &[S]) -> S {
348    variance(data).sqrt()
349}
350
351/// Compute sample variance.
352pub fn variance<S: Scalar>(data: &[S]) -> S {
353    if data.len() < 2 {
354        return S::ZERO;
355    }
356    let m = mean(data);
357    let sum_sq: S = data.iter().fold(S::ZERO, |acc, &x| {
358        let diff = x - m;
359        acc + diff * diff
360    });
361    sum_sq / S::from_usize(data.len() - 1)
362}
363
364/// Compute percentile (0-100 scale).
365pub fn percentile<S: Scalar>(data: &[S], p: f64) -> S {
366    if data.is_empty() {
367        return S::ZERO;
368    }
369    let mut sorted = data.to_vec();
370    sorted.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
371    percentile_sorted(&sorted, p)
372}
373
374/// Compute median.
375pub fn median<S: Scalar>(data: &[S]) -> S {
376    percentile(data, 50.0)
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_basic_stats() {
385        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
386
387        assert!((mean(&data) - 3.0).abs() < 1e-10);
388        assert!((variance(&data) - 2.5).abs() < 1e-10); // Sample variance
389        assert!((std(&data) - 2.5_f64.sqrt()).abs() < 1e-10);
390        assert!((median(&data) - 3.0).abs() < 1e-10);
391    }
392
393    #[test]
394    fn test_percentiles() {
395        let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
396
397        assert!((percentile(&data, 50.0) - 50.5).abs() < 0.5);
398        assert!((percentile(&data, 25.0) - 25.0).abs() < 1.0);
399        assert!((percentile(&data, 75.0) - 75.0).abs() < 1.0);
400    }
401
402    #[test]
403    fn test_ensemble_stats() {
404        let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
405        let stats = EnsembleStats::from_samples(&data).unwrap();
406
407        assert_eq!(stats.n_samples, 100);
408        assert!((stats.mean - 50.5).abs() < 0.01);
409        assert!((stats.min - 1.0).abs() < 1e-10);
410        assert!((stats.max - 100.0).abs() < 1e-10);
411    }
412
413    #[test]
414    fn test_running_stats() {
415        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
416        let mut rs = RunningStats::<f64>::new();
417
418        for &x in &data {
419            rs.update(x);
420        }
421
422        assert_eq!(rs.count(), 5);
423        assert!((rs.mean() - 3.0).abs() < 1e-10);
424        assert!((rs.variance() - 2.5).abs() < 1e-10);
425        assert!((rs.min() - 1.0).abs() < 1e-10);
426        assert!((rs.max() - 5.0).abs() < 1e-10);
427    }
428
429    #[test]
430    fn test_running_stats_merge() {
431        let data1 = vec![1.0, 2.0, 3.0];
432        let data2 = vec![4.0, 5.0, 6.0];
433
434        let mut rs1 = RunningStats::<f64>::new();
435        let mut rs2 = RunningStats::<f64>::new();
436
437        for &x in &data1 {
438            rs1.update(x);
439        }
440        for &x in &data2 {
441            rs2.update(x);
442        }
443
444        rs1.merge(&rs2);
445
446        // Should equal stats of combined data
447        let combined = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
448        assert_eq!(rs1.count(), 6);
449        assert!((rs1.mean() - mean(&combined)).abs() < 1e-10);
450        assert!((rs1.variance() - variance(&combined)).abs() < 1e-10);
451    }
452
453    #[test]
454    fn test_confidence_interval() {
455        // Large normal sample should have mean close to 0
456        let data: Vec<f64> = (0..1000)
457            .map(|i| {
458                // Pseudo-normal using Box-Muller would be better, but just use uniform here
459                (i as f64 / 1000.0 - 0.5) * 2.0
460            })
461            .collect();
462
463        let stats = EnsembleStats::from_samples(&data).unwrap();
464        let (lo, hi) = stats.confidence_interval(0.95);
465
466        // Interval should contain mean
467        assert!(lo < stats.mean);
468        assert!(hi > stats.mean);
469        // Interval should be narrower than total range
470        assert!(hi - lo < stats.max - stats.min);
471    }
472}