rag_plusplus_core/
stats.rs1#[derive(Debug, Clone)]
28pub struct OutcomeStats {
29 count: u64,
31 mean: Vec<f32>,
33 m2: Vec<f32>,
35 min: Vec<f32>,
37 max: Vec<f32>,
39}
40
41impl OutcomeStats {
42 #[must_use]
44 pub fn new(dim: usize) -> Self {
45 Self {
46 count: 0,
47 mean: vec![0.0; dim],
48 m2: vec![0.0; dim],
49 min: vec![f32::INFINITY; dim],
50 max: vec![f32::NEG_INFINITY; dim],
51 }
52 }
53
54 pub fn update(&mut self, outcome: &[f32]) {
60 assert_eq!(
61 outcome.len(),
62 self.dim(),
63 "Outcome dimension mismatch: expected {}, got {}",
64 self.dim(),
65 outcome.len()
66 );
67
68 self.count += 1;
69 let n = self.count as f32;
70
71 for i in 0..self.dim() {
72 let x = outcome[i];
73
74 let delta = x - self.mean[i];
76 self.mean[i] += delta / n;
77 let delta2 = x - self.mean[i];
78 self.m2[i] += delta * delta2;
79
80 self.min[i] = self.min[i].min(x);
82 self.max[i] = self.max[i].max(x);
83 }
84 }
85
86 #[must_use]
90 pub fn merge(&self, other: &Self) -> Self {
91 if self.count == 0 {
92 return other.clone();
93 }
94 if other.count == 0 {
95 return self.clone();
96 }
97
98 assert_eq!(self.dim(), other.dim(), "Dimension mismatch in merge");
99
100 let combined_count = self.count + other.count;
101 let mut combined_mean = vec![0.0; self.dim()];
102 let mut combined_m2 = vec![0.0; self.dim()];
103 let mut combined_min = vec![0.0; self.dim()];
104 let mut combined_max = vec![0.0; self.dim()];
105
106 for i in 0..self.dim() {
107 let delta = other.mean[i] - self.mean[i];
108 combined_mean[i] = self.mean[i]
109 + delta * (other.count as f32 / combined_count as f32);
110 combined_m2[i] = self.m2[i]
111 + other.m2[i]
112 + delta * delta
113 * (self.count as f32 * other.count as f32 / combined_count as f32);
114 combined_min[i] = self.min[i].min(other.min[i]);
115 combined_max[i] = self.max[i].max(other.max[i]);
116 }
117
118 Self {
119 count: combined_count,
120 mean: combined_mean,
121 m2: combined_m2,
122 min: combined_min,
123 max: combined_max,
124 }
125 }
126
127 pub fn update_scalar(&mut self, value: f64) {
129 self.update(&[value as f32]);
130 }
131
132 #[must_use]
134 pub const fn count(&self) -> u64 {
135 self.count
136 }
137
138 #[must_use]
140 pub fn mean_scalar(&self) -> Option<f64> {
141 self.mean().map(|m| m[0] as f64)
142 }
143
144 #[must_use]
146 pub fn variance_scalar(&self) -> Option<f64> {
147 self.variance().map(|v| v[0] as f64)
148 }
149
150 #[must_use]
152 pub fn std_scalar(&self) -> Option<f64> {
153 self.std().map(|s| s[0] as f64)
154 }
155
156 #[must_use]
158 pub fn dim(&self) -> usize {
159 self.mean.len()
160 }
161
162 #[must_use]
164 pub fn mean(&self) -> Option<&[f32]> {
165 if self.count > 0 {
166 Some(&self.mean)
167 } else {
168 None
169 }
170 }
171
172 #[must_use]
174 pub fn variance(&self) -> Option<Vec<f32>> {
175 if self.count < 2 {
176 return None;
177 }
178 Some(self.m2.iter().map(|m| m / self.count as f32).collect())
179 }
180
181 #[must_use]
183 pub fn std(&self) -> Option<Vec<f32>> {
184 self.variance().map(|v| v.iter().map(|x| x.sqrt()).collect())
185 }
186
187 #[must_use]
189 pub fn sample_variance(&self) -> Option<Vec<f32>> {
190 if self.count < 2 {
191 return None;
192 }
193 Some(
194 self.m2
195 .iter()
196 .map(|m| m / (self.count - 1) as f32)
197 .collect(),
198 )
199 }
200
201 #[must_use]
203 pub fn min(&self) -> Option<&[f32]> {
204 if self.count > 0 {
205 Some(&self.min)
206 } else {
207 None
208 }
209 }
210
211 #[must_use]
213 pub fn max(&self) -> Option<&[f32]> {
214 if self.count > 0 {
215 Some(&self.max)
216 } else {
217 None
218 }
219 }
220
221 #[must_use]
226 pub fn confidence_interval(&self, confidence: f32) -> Option<(Vec<f32>, Vec<f32>)> {
227 if self.count < 2 {
228 return None;
229 }
230
231 let std = self.std()?;
232 let std_err: Vec<f32> = std.iter().map(|s| s / (self.count as f32).sqrt()).collect();
233
234 let t_val = if self.count < 30 {
236 2.0 + 1.0 / (self.count as f32).sqrt()
238 } else {
239 match confidence {
241 c if (c - 0.90).abs() < 0.01 => 1.645,
242 c if (c - 0.95).abs() < 0.01 => 1.96,
243 c if (c - 0.99).abs() < 0.01 => 2.576,
244 _ => 1.96, }
246 };
247
248 let lower: Vec<f32> = self
249 .mean
250 .iter()
251 .zip(&std_err)
252 .map(|(m, se)| m - t_val * se)
253 .collect();
254 let upper: Vec<f32> = self
255 .mean
256 .iter()
257 .zip(&std_err)
258 .map(|(m, se)| m + t_val * se)
259 .collect();
260
261 Some((lower, upper))
262 }
263}
264
265impl Default for OutcomeStats {
266 fn default() -> Self {
267 Self::new(0)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_empty_stats() {
277 let stats = OutcomeStats::new(3);
278 assert_eq!(stats.count(), 0);
279 assert!(stats.mean().is_none());
280 assert!(stats.variance().is_none());
281 }
282
283 #[test]
284 fn test_single_update() {
285 let mut stats = OutcomeStats::new(3);
286 stats.update(&[1.0, 2.0, 3.0]);
287
288 assert_eq!(stats.count(), 1);
289 assert_eq!(stats.mean(), Some([1.0, 2.0, 3.0].as_slice()));
290 assert!(stats.variance().is_none()); }
292
293 #[test]
294 fn test_multiple_updates() {
295 let mut stats = OutcomeStats::new(2);
296 stats.update(&[1.0, 2.0]);
297 stats.update(&[3.0, 4.0]);
298 stats.update(&[5.0, 6.0]);
299
300 assert_eq!(stats.count(), 3);
301 let mean = stats.mean().unwrap();
302 assert!((mean[0] - 3.0).abs() < 1e-6);
303 assert!((mean[1] - 4.0).abs() < 1e-6);
304 }
305
306 #[test]
307 fn test_merge() {
308 let mut stats1 = OutcomeStats::new(2);
309 stats1.update(&[1.0, 2.0]);
310 stats1.update(&[2.0, 3.0]);
311
312 let mut stats2 = OutcomeStats::new(2);
313 stats2.update(&[3.0, 4.0]);
314 stats2.update(&[4.0, 5.0]);
315
316 let merged = stats1.merge(&stats2);
317 assert_eq!(merged.count(), 4);
318
319 let mean = merged.mean().unwrap();
320 assert!((mean[0] - 2.5).abs() < 1e-6);
321 assert!((mean[1] - 3.5).abs() < 1e-6);
322 }
323
324 #[test]
325 fn test_numerical_stability() {
326 let mut stats = OutcomeStats::new(1);
328 let base = 1e9_f32;
329
330 for i in 0..1000 {
331 stats.update(&[base + (i as f32) * 0.001]);
332 }
333
334 let mean = stats.mean().unwrap()[0];
335 assert!((mean - base).abs() < 1.0); let var = stats.variance().unwrap()[0];
338 assert!(var >= 0.0); }
340
341 #[test]
342 fn test_min_max() {
343 let mut stats = OutcomeStats::new(2);
344 stats.update(&[1.0, 5.0]);
345 stats.update(&[3.0, 2.0]);
346 stats.update(&[2.0, 8.0]);
347
348 assert_eq!(stats.min(), Some([1.0, 2.0].as_slice()));
349 assert_eq!(stats.max(), Some([3.0, 8.0].as_slice()));
350 }
351}