Skip to main content

genomicframe_core/
stats.rs

1//! Shared statistics utilities for genomic data
2//!
3//! This module provides composable, streaming-friendly statistics
4//! that can be computed over large genomic datasets with minimal memory usage.
5
6use crate::parallel::Mergeable;
7use std::collections::HashMap;
8
9/// Running statistics accumulator (Welford's online algorithm)
10///
11/// Computes mean, variance, and standard deviation in a single pass
12/// with numerically stable updates. Memory: O(1)
13#[derive(Debug, Clone, Default)]
14pub struct RunningStats {
15    count: usize,
16    mean: f64,
17    m2: f64, // Sum of squared differences from mean
18    min: Option<f64>,
19    max: Option<f64>,
20}
21
22impl RunningStats {
23    /// Create a new running statistics accumulator
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    /// Add a single value to the accumulator
29    pub fn push(&mut self, value: f64) {
30        self.count += 1;
31        let delta = value - self.mean;
32        self.mean += delta / self.count as f64;
33        let delta2 = value - self.mean;
34        self.m2 += delta * delta2;
35
36        self.min = Some(self.min.map_or(value, |m| m.min(value)));
37        self.max = Some(self.max.map_or(value, |m| m.max(value)));
38    }
39
40    /// Get the count of values
41    pub fn count(&self) -> usize {
42        self.count
43    }
44
45    /// Get the mean value
46    pub fn mean(&self) -> Option<f64> {
47        if self.count > 0 {
48            Some(self.mean)
49        } else {
50            None
51        }
52    }
53
54    /// Get the sample variance
55    pub fn variance(&self) -> Option<f64> {
56        if self.count > 1 {
57            Some(self.m2 / (self.count - 1) as f64)
58        } else {
59            None
60        }
61    }
62
63    /// Get the sample standard deviation
64    pub fn std_dev(&self) -> Option<f64> {
65        self.variance().map(|v| v.sqrt())
66    }
67
68    /// Get the minimum value
69    pub fn min(&self) -> Option<f64> {
70        self.min
71    }
72
73    /// Get the maximum value
74    pub fn max(&self) -> Option<f64> {
75        self.max
76    }
77}
78
79/// Implementation of Mergeable for parallel statistics computation
80///
81/// Uses Chan's parallel variance algorithm to correctly merge
82/// running statistics from multiple threads.
83///
84/// Reference: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
85impl Mergeable for RunningStats {
86    fn merge(&mut self, other: Self) {
87        if other.count == 0 {
88            return; // Nothing to merge
89        }
90        if self.count == 0 {
91            *self = other; // Replace with other
92            return;
93        }
94
95        let total_count = self.count + other.count;
96        let delta = other.mean - self.mean;
97
98        // Merge min/max
99        self.min = match (self.min, other.min) {
100            (Some(a), Some(b)) => Some(a.min(b)),
101            (Some(a), None) => Some(a),
102            (None, Some(b)) => Some(b),
103            (None, None) => None,
104        };
105
106        self.max = match (self.max, other.max) {
107            (Some(a), Some(b)) => Some(a.max(b)),
108            (Some(a), None) => Some(a),
109            (None, Some(b)) => Some(b),
110            (None, None) => None,
111        };
112
113        // Merge mean and m2 using Chan's algorithm
114        self.m2 += other.m2 + delta * delta * (self.count * other.count) as f64 / total_count as f64;
115        self.mean = (self.mean * self.count as f64 + other.mean * other.count as f64)
116            / total_count as f64;
117        self.count = total_count;
118    }
119}
120
121/// Accumulator for categorical data (counts by category)
122///
123/// Memory: O(k) where k is the number of unique categories
124#[derive(Debug, Clone, Default)]
125pub struct CategoryCounter<T: std::hash::Hash + Eq> {
126    counts: HashMap<T, usize>,
127    total: usize,
128}
129
130impl<T: std::hash::Hash + Eq> CategoryCounter<T> {
131    /// Create a new category counter
132    pub fn new() -> Self {
133        Self {
134            counts: HashMap::new(),
135            total: 0,
136        }
137    }
138
139    /// Increment count for a category
140    pub fn increment(&mut self, category: T) {
141        *self.counts.entry(category).or_insert(0) += 1;
142        self.total += 1;
143    }
144
145    /// Increment count for a category by a specific amount
146    pub fn increment_by(&mut self, category: T, amount: usize) {
147        *self.counts.entry(category).or_insert(0) += amount;
148        self.total += amount;
149    }
150
151    /// Get count for a specific category
152    pub fn get(&self, category: &T) -> usize {
153        self.counts.get(category).copied().unwrap_or(0)
154    }
155
156    /// Get total count across all categories
157    pub fn total(&self) -> usize {
158        self.total
159    }
160
161    /// Get the number of unique categories
162    pub fn num_categories(&self) -> usize {
163        self.counts.len()
164    }
165
166    /// Get frequency (proportion) for a category
167    pub fn frequency(&self, category: &T) -> f64 {
168        if self.total == 0 {
169            0.0
170        } else {
171            self.get(category) as f64 / self.total as f64
172        }
173    }
174
175    /// Get all categories and their counts
176    pub fn categories(&self) -> &HashMap<T, usize> {
177        &self.counts
178    }
179
180    /// Iterate over (category, count) pairs
181    pub fn iter(&self) -> impl Iterator<Item = (&T, &usize)> {
182        self.counts.iter()
183    }
184}
185
186/// Specialized implementation for String to allow efficient &str increments
187impl CategoryCounter<String> {
188    /// Optimized increment for &str - only allocates when category is new
189    ///
190    /// This is much more efficient than increment(category.to_string()) because
191    /// it only allocates a String when inserting a new category, not on every call.
192    ///
193    /// For existing categories, this does a HashMap lookup with &str (no allocation),
194    /// and only allocates a new String when the category is seen for the first time.
195    pub fn increment_str(&mut self, category: &str) {
196        // First try to increment existing entry (no allocation)
197        if let Some(count) = self.counts.get_mut(category) {
198            *count += 1;
199        } else {
200            // Only allocate String for new categories
201            self.counts.insert(category.to_string(), 1);
202        }
203        self.total += 1;
204    }
205}
206
207/// Implementation of Mergeable for parallel statistics computation
208///
209/// Merges category counts from multiple threads by summing counts
210/// for each category.
211impl<T: std::hash::Hash + Eq + Send> Mergeable for CategoryCounter<T> {
212    fn merge(&mut self, other: Self) {
213        for (category, count) in other.counts {
214            self.increment_by(category, count);
215        }
216    }
217}
218
219/// Percentile calculator using reservoir sampling for memory efficiency
220///
221/// Stores up to `capacity` samples, then switches to reservoir sampling
222/// for approximate percentile calculation.
223#[derive(Debug, Clone)]
224pub struct PercentileEstimator {
225    samples: Vec<f64>,
226    capacity: usize,
227    total_seen: usize,
228}
229
230impl PercentileEstimator {
231    /// Create a new percentile estimator with given capacity
232    ///
233    /// For exact percentiles, set capacity >= expected number of values.
234    /// For approximate percentiles on large datasets, use smaller capacity (e.g., 10,000).
235    pub fn new(capacity: usize) -> Self {
236        Self {
237            samples: Vec::with_capacity(capacity),
238            capacity,
239            total_seen: 0,
240        }
241    }
242
243    /// Add a value to the estimator
244    pub fn push(&mut self, value: f64) {
245        self.total_seen += 1;
246
247        if self.samples.len() < self.capacity {
248            // Still filling buffer
249            self.samples.push(value);
250        } else {
251            // Reservoir sampling: random replacement
252            use std::collections::hash_map::RandomState;
253            use std::hash::{BuildHasher, Hash, Hasher};
254
255            let mut hasher = RandomState::new().build_hasher();
256            self.total_seen.hash(&mut hasher);
257            let random_index = (hasher.finish() as usize) % self.total_seen;
258
259            if random_index < self.capacity {
260                self.samples[random_index] = value;
261            }
262        }
263    }
264
265    /// Calculate percentile (0.0 to 1.0)
266    ///
267    /// Returns None if no samples have been added.
268    pub fn percentile(&mut self, p: f64) -> Option<f64> {
269        if self.samples.is_empty() {
270            return None;
271        }
272
273        // Sort samples (required for percentile calculation)
274        self.samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
275
276        let index = (p * (self.samples.len() - 1) as f64) as usize;
277        Some(self.samples[index])
278    }
279
280    /// Get the median (50th percentile)
281    pub fn median(&mut self) -> Option<f64> {
282        self.percentile(0.5)
283    }
284
285    /// Get the total number of values seen (not just stored)
286    pub fn total_seen(&self) -> usize {
287        self.total_seen
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_running_stats() {
297        let mut stats = RunningStats::new();
298
299        stats.push(10.0);
300        stats.push(20.0);
301        stats.push(30.0);
302
303        assert_eq!(stats.count(), 3);
304        assert_eq!(stats.mean(), Some(20.0));
305        assert_eq!(stats.min(), Some(10.0));
306        assert_eq!(stats.max(), Some(30.0));
307
308        // Variance should be 100.0
309        let var = stats.variance().unwrap();
310        assert!((var - 100.0).abs() < 1e-10);
311
312        // Std dev should be 10.0
313        let std = stats.std_dev().unwrap();
314        assert!((std - 10.0).abs() < 1e-10);
315    }
316
317    #[test]
318    fn test_category_counter() {
319        let mut counter = CategoryCounter::new();
320
321        counter.increment("A");
322        counter.increment("B");
323        counter.increment("A");
324        counter.increment("C");
325
326        assert_eq!(counter.total(), 4);
327        assert_eq!(counter.num_categories(), 3);
328        assert_eq!(counter.get(&"A"), 2);
329        assert_eq!(counter.get(&"B"), 1);
330        assert_eq!(counter.frequency(&"A"), 0.5);
331    }
332
333    #[test]
334    fn test_percentile_estimator() {
335        let mut estimator = PercentileEstimator::new(100);
336
337        for i in 1..=100 {
338            estimator.push(i as f64);
339        }
340
341        assert_eq!(estimator.total_seen(), 100);
342
343        let median = estimator.median().unwrap();
344        assert!((median - 50.5).abs() < 1.0); // Approximate
345
346        let p95 = estimator.percentile(0.95).unwrap();
347        assert!(p95 > 90.0 && p95 <= 100.0);
348    }
349
350    #[test]
351    fn test_running_stats_merge() {
352        // Create two stats objects
353        let mut stats1 = RunningStats::new();
354        stats1.push(10.0);
355        stats1.push(20.0);
356        stats1.push(30.0);
357
358        let mut stats2 = RunningStats::new();
359        stats2.push(40.0);
360        stats2.push(50.0);
361
362        // Merge stats2 into stats1
363        stats1.merge(stats2);
364
365        // Should have combined statistics
366        assert_eq!(stats1.count(), 5);
367        assert_eq!(stats1.mean(), Some(30.0)); // (10+20+30+40+50)/5 = 30
368        assert_eq!(stats1.min(), Some(10.0));
369        assert_eq!(stats1.max(), Some(50.0));
370
371        // Variance of [10, 20, 30, 40, 50] should be 250
372        let var = stats1.variance().unwrap();
373        assert!((var - 250.0).abs() < 1e-10);
374    }
375
376    #[test]
377    fn test_running_stats_merge_empty() {
378        let mut stats1 = RunningStats::new();
379        stats1.push(10.0);
380
381        let stats2 = RunningStats::new(); // Empty
382
383        stats1.merge(stats2);
384        assert_eq!(stats1.count(), 1);
385        assert_eq!(stats1.mean(), Some(10.0));
386    }
387
388    #[test]
389    fn test_category_counter_merge() {
390        let mut counter1 = CategoryCounter::new();
391        counter1.increment("A");
392        counter1.increment("B");
393        counter1.increment("A");
394
395        let mut counter2 = CategoryCounter::new();
396        counter2.increment("B");
397        counter2.increment("C");
398        counter2.increment("C");
399
400        counter1.merge(counter2);
401
402        assert_eq!(counter1.total(), 6);
403        assert_eq!(counter1.get(&"A"), 2);
404        assert_eq!(counter1.get(&"B"), 2);
405        assert_eq!(counter1.get(&"C"), 2);
406        assert_eq!(counter1.num_categories(), 3);
407    }
408
409    #[test]
410    fn test_mergeable_merge_all() {
411        let mut stats1 = RunningStats::new();
412        stats1.push(10.0);
413
414        let mut stats2 = RunningStats::new();
415        stats2.push(20.0);
416
417        let mut stats3 = RunningStats::new();
418        stats3.push(30.0);
419
420        let merged = RunningStats::merge_all(vec![stats1, stats2, stats3]).unwrap();
421
422        assert_eq!(merged.count(), 3);
423        assert_eq!(merged.mean(), Some(20.0));
424    }
425}