use crate::parallel::Mergeable;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct RunningStats {
count: usize,
mean: f64,
m2: f64, min: Option<f64>,
max: Option<f64>,
}
impl RunningStats {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, value: f64) {
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
self.min = Some(self.min.map_or(value, |m| m.min(value)));
self.max = Some(self.max.map_or(value, |m| m.max(value)));
}
pub fn count(&self) -> usize {
self.count
}
pub fn mean(&self) -> Option<f64> {
if self.count > 0 {
Some(self.mean)
} else {
None
}
}
pub fn variance(&self) -> Option<f64> {
if self.count > 1 {
Some(self.m2 / (self.count - 1) as f64)
} else {
None
}
}
pub fn std_dev(&self) -> Option<f64> {
self.variance().map(|v| v.sqrt())
}
pub fn min(&self) -> Option<f64> {
self.min
}
pub fn max(&self) -> Option<f64> {
self.max
}
}
impl Mergeable for RunningStats {
fn merge(&mut self, other: Self) {
if other.count == 0 {
return; }
if self.count == 0 {
*self = other; return;
}
let total_count = self.count + other.count;
let delta = other.mean - self.mean;
self.min = match (self.min, other.min) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
self.max = match (self.max, other.max) {
(Some(a), Some(b)) => Some(a.max(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
self.m2 += other.m2 + delta * delta * (self.count * other.count) as f64 / total_count as f64;
self.mean = (self.mean * self.count as f64 + other.mean * other.count as f64)
/ total_count as f64;
self.count = total_count;
}
}
#[derive(Debug, Clone, Default)]
pub struct CategoryCounter<T: std::hash::Hash + Eq> {
counts: HashMap<T, usize>,
total: usize,
}
impl<T: std::hash::Hash + Eq> CategoryCounter<T> {
pub fn new() -> Self {
Self {
counts: HashMap::new(),
total: 0,
}
}
pub fn increment(&mut self, category: T) {
*self.counts.entry(category).or_insert(0) += 1;
self.total += 1;
}
pub fn increment_by(&mut self, category: T, amount: usize) {
*self.counts.entry(category).or_insert(0) += amount;
self.total += amount;
}
pub fn get(&self, category: &T) -> usize {
self.counts.get(category).copied().unwrap_or(0)
}
pub fn total(&self) -> usize {
self.total
}
pub fn num_categories(&self) -> usize {
self.counts.len()
}
pub fn frequency(&self, category: &T) -> f64 {
if self.total == 0 {
0.0
} else {
self.get(category) as f64 / self.total as f64
}
}
pub fn categories(&self) -> &HashMap<T, usize> {
&self.counts
}
pub fn iter(&self) -> impl Iterator<Item = (&T, &usize)> {
self.counts.iter()
}
}
impl CategoryCounter<String> {
pub fn increment_str(&mut self, category: &str) {
if let Some(count) = self.counts.get_mut(category) {
*count += 1;
} else {
self.counts.insert(category.to_string(), 1);
}
self.total += 1;
}
}
impl<T: std::hash::Hash + Eq + Send> Mergeable for CategoryCounter<T> {
fn merge(&mut self, other: Self) {
for (category, count) in other.counts {
self.increment_by(category, count);
}
}
}
#[derive(Debug, Clone)]
pub struct PercentileEstimator {
samples: Vec<f64>,
capacity: usize,
total_seen: usize,
}
impl PercentileEstimator {
pub fn new(capacity: usize) -> Self {
Self {
samples: Vec::with_capacity(capacity),
capacity,
total_seen: 0,
}
}
pub fn push(&mut self, value: f64) {
self.total_seen += 1;
if self.samples.len() < self.capacity {
self.samples.push(value);
} else {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
let mut hasher = RandomState::new().build_hasher();
self.total_seen.hash(&mut hasher);
let random_index = (hasher.finish() as usize) % self.total_seen;
if random_index < self.capacity {
self.samples[random_index] = value;
}
}
}
pub fn percentile(&mut self, p: f64) -> Option<f64> {
if self.samples.is_empty() {
return None;
}
self.samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
let index = (p * (self.samples.len() - 1) as f64) as usize;
Some(self.samples[index])
}
pub fn median(&mut self) -> Option<f64> {
self.percentile(0.5)
}
pub fn total_seen(&self) -> usize {
self.total_seen
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_running_stats() {
let mut stats = RunningStats::new();
stats.push(10.0);
stats.push(20.0);
stats.push(30.0);
assert_eq!(stats.count(), 3);
assert_eq!(stats.mean(), Some(20.0));
assert_eq!(stats.min(), Some(10.0));
assert_eq!(stats.max(), Some(30.0));
let var = stats.variance().unwrap();
assert!((var - 100.0).abs() < 1e-10);
let std = stats.std_dev().unwrap();
assert!((std - 10.0).abs() < 1e-10);
}
#[test]
fn test_category_counter() {
let mut counter = CategoryCounter::new();
counter.increment("A");
counter.increment("B");
counter.increment("A");
counter.increment("C");
assert_eq!(counter.total(), 4);
assert_eq!(counter.num_categories(), 3);
assert_eq!(counter.get(&"A"), 2);
assert_eq!(counter.get(&"B"), 1);
assert_eq!(counter.frequency(&"A"), 0.5);
}
#[test]
fn test_percentile_estimator() {
let mut estimator = PercentileEstimator::new(100);
for i in 1..=100 {
estimator.push(i as f64);
}
assert_eq!(estimator.total_seen(), 100);
let median = estimator.median().unwrap();
assert!((median - 50.5).abs() < 1.0);
let p95 = estimator.percentile(0.95).unwrap();
assert!(p95 > 90.0 && p95 <= 100.0);
}
#[test]
fn test_running_stats_merge() {
let mut stats1 = RunningStats::new();
stats1.push(10.0);
stats1.push(20.0);
stats1.push(30.0);
let mut stats2 = RunningStats::new();
stats2.push(40.0);
stats2.push(50.0);
stats1.merge(stats2);
assert_eq!(stats1.count(), 5);
assert_eq!(stats1.mean(), Some(30.0)); assert_eq!(stats1.min(), Some(10.0));
assert_eq!(stats1.max(), Some(50.0));
let var = stats1.variance().unwrap();
assert!((var - 250.0).abs() < 1e-10);
}
#[test]
fn test_running_stats_merge_empty() {
let mut stats1 = RunningStats::new();
stats1.push(10.0);
let stats2 = RunningStats::new();
stats1.merge(stats2);
assert_eq!(stats1.count(), 1);
assert_eq!(stats1.mean(), Some(10.0));
}
#[test]
fn test_category_counter_merge() {
let mut counter1 = CategoryCounter::new();
counter1.increment("A");
counter1.increment("B");
counter1.increment("A");
let mut counter2 = CategoryCounter::new();
counter2.increment("B");
counter2.increment("C");
counter2.increment("C");
counter1.merge(counter2);
assert_eq!(counter1.total(), 6);
assert_eq!(counter1.get(&"A"), 2);
assert_eq!(counter1.get(&"B"), 2);
assert_eq!(counter1.get(&"C"), 2);
assert_eq!(counter1.num_categories(), 3);
}
#[test]
fn test_mergeable_merge_all() {
let mut stats1 = RunningStats::new();
stats1.push(10.0);
let mut stats2 = RunningStats::new();
stats2.push(20.0);
let mut stats3 = RunningStats::new();
stats3.push(30.0);
let merged = RunningStats::merge_all(vec![stats1, stats2, stats3]).unwrap();
assert_eq!(merged.count(), 3);
assert_eq!(merged.mean(), Some(20.0));
}
}