rag_plusplus_core/
stats.rs

1//! Outcome Statistics with Welford's Online Algorithm
2//!
3//! Provides numerically stable running statistics for outcome dimensions.
4//!
5//! # Invariants
6//!
7//! - INV-002: Variance is always non-negative
8//! - All operations are O(dim) where dim is the outcome dimension
9//!
10//! # Example
11//!
12//! ```
13//! use rag_plusplus_core::OutcomeStats;
14//!
15//! let mut stats = OutcomeStats::new(3);
16//! stats.update(&[0.8, 0.9, 0.7]);
17//! stats.update(&[0.85, 0.88, 0.75]);
18//!
19//! assert_eq!(stats.count(), 2);
20//! assert!(stats.mean().is_some());
21//! ```
22
23/// Running statistics using Welford's online algorithm.
24///
25/// Maintains count, mean, and M2 (sum of squared differences) for
26/// numerically stable variance computation.
27#[derive(Debug, Clone)]
28pub struct OutcomeStats {
29    /// Number of observations
30    count: u64,
31    /// Running mean per dimension
32    mean: Vec<f32>,
33    /// Sum of squared differences from mean (for variance)
34    m2: Vec<f32>,
35    /// Minimum observed values
36    min: Vec<f32>,
37    /// Maximum observed values
38    max: Vec<f32>,
39}
40
41impl OutcomeStats {
42    /// Create new empty statistics for given dimension.
43    #[must_use]
44    pub fn new(dim: usize) -> Self {
45        Self {
46            count: 0,
47            mean: vec![0.0; dim],
48            m2: vec![0.0; dim],
49            min: vec![f32::INFINITY; dim],
50            max: vec![f32::NEG_INFINITY; dim],
51        }
52    }
53
54    /// Update statistics with a new observation (Welford's algorithm).
55    ///
56    /// # Panics
57    ///
58    /// Panics if `outcome.len() != self.dim()`.
59    pub fn update(&mut self, outcome: &[f32]) {
60        assert_eq!(
61            outcome.len(),
62            self.dim(),
63            "Outcome dimension mismatch: expected {}, got {}",
64            self.dim(),
65            outcome.len()
66        );
67
68        self.count += 1;
69        let n = self.count as f32;
70
71        for i in 0..self.dim() {
72            let x = outcome[i];
73
74            // Welford's update
75            let delta = x - self.mean[i];
76            self.mean[i] += delta / n;
77            let delta2 = x - self.mean[i];
78            self.m2[i] += delta * delta2;
79
80            // Min/max tracking
81            self.min[i] = self.min[i].min(x);
82            self.max[i] = self.max[i].max(x);
83        }
84    }
85
86    /// Merge two statistics objects (parallel Welford).
87    ///
88    /// Useful for combining statistics computed in parallel.
89    #[must_use]
90    pub fn merge(&self, other: &Self) -> Self {
91        if self.count == 0 {
92            return other.clone();
93        }
94        if other.count == 0 {
95            return self.clone();
96        }
97
98        assert_eq!(self.dim(), other.dim(), "Dimension mismatch in merge");
99
100        let combined_count = self.count + other.count;
101        let mut combined_mean = vec![0.0; self.dim()];
102        let mut combined_m2 = vec![0.0; self.dim()];
103        let mut combined_min = vec![0.0; self.dim()];
104        let mut combined_max = vec![0.0; self.dim()];
105
106        for i in 0..self.dim() {
107            let delta = other.mean[i] - self.mean[i];
108            combined_mean[i] = self.mean[i]
109                + delta * (other.count as f32 / combined_count as f32);
110            combined_m2[i] = self.m2[i]
111                + other.m2[i]
112                + delta * delta
113                    * (self.count as f32 * other.count as f32 / combined_count as f32);
114            combined_min[i] = self.min[i].min(other.min[i]);
115            combined_max[i] = self.max[i].max(other.max[i]);
116        }
117
118        Self {
119            count: combined_count,
120            mean: combined_mean,
121            m2: combined_m2,
122            min: combined_min,
123            max: combined_max,
124        }
125    }
126
127    /// Update with a single scalar value (1D convenience method).
128    pub fn update_scalar(&mut self, value: f64) {
129        self.update(&[value as f32]);
130    }
131
132    /// Number of observations.
133    #[must_use]
134    pub const fn count(&self) -> u64 {
135        self.count
136    }
137
138    /// Get the scalar mean (for 1D stats).
139    #[must_use]
140    pub fn mean_scalar(&self) -> Option<f64> {
141        self.mean().map(|m| m[0] as f64)
142    }
143
144    /// Get the scalar variance (for 1D stats).
145    #[must_use]
146    pub fn variance_scalar(&self) -> Option<f64> {
147        self.variance().map(|v| v[0] as f64)
148    }
149
150    /// Get the scalar std (for 1D stats).
151    #[must_use]
152    pub fn std_scalar(&self) -> Option<f64> {
153        self.std().map(|s| s[0] as f64)
154    }
155
156    /// Dimension of outcome vectors.
157    #[must_use]
158    pub fn dim(&self) -> usize {
159        self.mean.len()
160    }
161
162    /// Current mean estimate (None if no observations).
163    #[must_use]
164    pub fn mean(&self) -> Option<&[f32]> {
165        if self.count > 0 {
166            Some(&self.mean)
167        } else {
168            None
169        }
170    }
171
172    /// Population variance (None if < 2 observations).
173    #[must_use]
174    pub fn variance(&self) -> Option<Vec<f32>> {
175        if self.count < 2 {
176            return None;
177        }
178        Some(self.m2.iter().map(|m| m / self.count as f32).collect())
179    }
180
181    /// Population standard deviation (None if < 2 observations).
182    #[must_use]
183    pub fn std(&self) -> Option<Vec<f32>> {
184        self.variance().map(|v| v.iter().map(|x| x.sqrt()).collect())
185    }
186
187    /// Sample variance with Bessel's correction (None if < 2 observations).
188    #[must_use]
189    pub fn sample_variance(&self) -> Option<Vec<f32>> {
190        if self.count < 2 {
191            return None;
192        }
193        Some(
194            self.m2
195                .iter()
196                .map(|m| m / (self.count - 1) as f32)
197                .collect(),
198        )
199    }
200
201    /// Minimum observed values (None if no observations).
202    #[must_use]
203    pub fn min(&self) -> Option<&[f32]> {
204        if self.count > 0 {
205            Some(&self.min)
206        } else {
207            None
208        }
209    }
210
211    /// Maximum observed values (None if no observations).
212    #[must_use]
213    pub fn max(&self) -> Option<&[f32]> {
214        if self.count > 0 {
215            Some(&self.max)
216        } else {
217            None
218        }
219    }
220
221    /// Compute confidence interval for the mean.
222    ///
223    /// Uses t-distribution for small samples (< 30), normal for large.
224    /// Returns (lower, upper) bounds.
225    #[must_use]
226    pub fn confidence_interval(&self, confidence: f32) -> Option<(Vec<f32>, Vec<f32>)> {
227        if self.count < 2 {
228            return None;
229        }
230
231        let std = self.std()?;
232        let std_err: Vec<f32> = std.iter().map(|s| s / (self.count as f32).sqrt()).collect();
233
234        // Approximate t-value (use 1.96 for 95% CI with large n)
235        let t_val = if self.count < 30 {
236            // Rough approximation for small samples
237            2.0 + 1.0 / (self.count as f32).sqrt()
238        } else {
239            // Normal approximation
240            match confidence {
241                c if (c - 0.90).abs() < 0.01 => 1.645,
242                c if (c - 0.95).abs() < 0.01 => 1.96,
243                c if (c - 0.99).abs() < 0.01 => 2.576,
244                _ => 1.96, // Default to 95%
245            }
246        };
247
248        let lower: Vec<f32> = self
249            .mean
250            .iter()
251            .zip(&std_err)
252            .map(|(m, se)| m - t_val * se)
253            .collect();
254        let upper: Vec<f32> = self
255            .mean
256            .iter()
257            .zip(&std_err)
258            .map(|(m, se)| m + t_val * se)
259            .collect();
260
261        Some((lower, upper))
262    }
263}
264
265impl Default for OutcomeStats {
266    fn default() -> Self {
267        Self::new(0)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_empty_stats() {
277        let stats = OutcomeStats::new(3);
278        assert_eq!(stats.count(), 0);
279        assert!(stats.mean().is_none());
280        assert!(stats.variance().is_none());
281    }
282
283    #[test]
284    fn test_single_update() {
285        let mut stats = OutcomeStats::new(3);
286        stats.update(&[1.0, 2.0, 3.0]);
287
288        assert_eq!(stats.count(), 1);
289        assert_eq!(stats.mean(), Some([1.0, 2.0, 3.0].as_slice()));
290        assert!(stats.variance().is_none()); // Need 2+ observations
291    }
292
293    #[test]
294    fn test_multiple_updates() {
295        let mut stats = OutcomeStats::new(2);
296        stats.update(&[1.0, 2.0]);
297        stats.update(&[3.0, 4.0]);
298        stats.update(&[5.0, 6.0]);
299
300        assert_eq!(stats.count(), 3);
301        let mean = stats.mean().unwrap();
302        assert!((mean[0] - 3.0).abs() < 1e-6);
303        assert!((mean[1] - 4.0).abs() < 1e-6);
304    }
305
306    #[test]
307    fn test_merge() {
308        let mut stats1 = OutcomeStats::new(2);
309        stats1.update(&[1.0, 2.0]);
310        stats1.update(&[2.0, 3.0]);
311
312        let mut stats2 = OutcomeStats::new(2);
313        stats2.update(&[3.0, 4.0]);
314        stats2.update(&[4.0, 5.0]);
315
316        let merged = stats1.merge(&stats2);
317        assert_eq!(merged.count(), 4);
318
319        let mean = merged.mean().unwrap();
320        assert!((mean[0] - 2.5).abs() < 1e-6);
321        assert!((mean[1] - 3.5).abs() < 1e-6);
322    }
323
324    #[test]
325    fn test_numerical_stability() {
326        // Test with large values that would overflow naive algorithm
327        let mut stats = OutcomeStats::new(1);
328        let base = 1e9_f32;
329
330        for i in 0..1000 {
331            stats.update(&[base + (i as f32) * 0.001]);
332        }
333
334        let mean = stats.mean().unwrap()[0];
335        assert!((mean - base).abs() < 1.0); // Should be close to base
336        
337        let var = stats.variance().unwrap()[0];
338        assert!(var >= 0.0); // Variance must be non-negative (INV-002)
339    }
340
341    #[test]
342    fn test_min_max() {
343        let mut stats = OutcomeStats::new(2);
344        stats.update(&[1.0, 5.0]);
345        stats.update(&[3.0, 2.0]);
346        stats.update(&[2.0, 8.0]);
347
348        assert_eq!(stats.min(), Some([1.0, 2.0].as_slice()));
349        assert_eq!(stats.max(), Some([3.0, 8.0].as_slice()));
350    }
351}