composable_indexes/aggregation/
stddev.rs

1use num_traits::ToPrimitive;
2
3use crate::{
4    Index, ShallowClone,
5    core::{Insert, Remove, Seal, Update},
6};
7
8/// Standard deviation aggregation index using Welford's algorithm.
9///
10/// Welford's algorithm computes variance and standard deviation incrementally
11/// by maintaining:
12/// - `mean`: running mean (M in the algorithm)
13/// - `sum_sq_diff`: sum of squared differences from the mean (S in the algorithm)
14/// - `count`: number of samples (n in the algorithm)
15///
16/// This index does not hold samples, thus requiring only O(1) space.
17///
18/// Returns 0.0 when count < 2 (need at least 2 samples for std dev).
19///
20/// Warning: This implementation is susceptible to numerical instability
21/// for very large datasets or values with high variance.
22#[derive(Debug, Clone, Copy)]
23pub struct StdDev<T> {
24    mean: f64,
25    sum_sq_diff: f64,
26    count: u64,
27    _phantom: core::marker::PhantomData<T>,
28}
29
30impl<T> Default for StdDev<T> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<T> StdDev<T> {
37    pub fn new() -> Self {
38        StdDev {
39            mean: 0.0,
40            sum_sq_diff: 0.0,
41            count: 0,
42            _phantom: core::marker::PhantomData,
43        }
44    }
45}
46
47impl<T> Index<T> for StdDev<T>
48where
49    T: ToPrimitive + Copy + 'static,
50{
51    #[inline]
52    fn insert(&mut self, _seal: Seal, op: &Insert<T>) {
53        if let Some(x) = op.new.to_f64() {
54            self.count += 1;
55            let k = self.count;
56
57            // Adding a sample xₖ:
58            // M_new = M_old + (xₖ - M_old) / k
59            let old_mean = self.mean;
60            self.mean = old_mean + (x - old_mean) / k as f64;
61
62            // S_new = S_old + (xₖ - M_old) * (xₖ - M_new)
63            self.sum_sq_diff += (x - old_mean) * (x - self.mean);
64        }
65    }
66
67    #[inline]
68    fn remove(&mut self, _seal: Seal, op: &Remove<T>) {
69        if let Some(x) = op.existing.to_f64() {
70            let n = self.count;
71
72            if n <= 1 {
73                // Reset to initial state if removing last element
74                self.mean = 0.0;
75                self.sum_sq_diff = 0.0;
76                self.count = 0;
77                return;
78            }
79
80            // Removing a sample xⱼ:
81            // M_new = (n * M_old - xⱼ) / (n - 1)
82            let old_mean = self.mean;
83            self.mean = (n as f64 * old_mean - x) / (n - 1) as f64;
84
85            // S_new = S_old - (xⱼ - M_old) * (xⱼ - M_new)
86            self.sum_sq_diff -= (x - old_mean) * (x - self.mean);
87
88            // float precision safety: ensure count doesn't go negative
89            self.sum_sq_diff = self.sum_sq_diff.max(0.0);
90
91            self.count = n - 1;
92        }
93    }
94
95    #[inline]
96    fn update(&mut self, _seal: Seal, op: &Update<T>) {
97        if let (Some(old_val), Some(new_val)) = (op.existing.to_f64(), op.new.to_f64()) {
98            // For update, we remove the old value and insert the new one
99            let n = self.count;
100
101            if n == 0 {
102                return;
103            }
104
105            if n == 1 {
106                // Special case: single element, just update the mean
107                self.mean = new_val;
108                self.sum_sq_diff = 0.0;
109                return;
110            }
111
112            // Remove old value
113            let old_mean = self.mean;
114            let mean_without_old = (n as f64 * old_mean - old_val) / (n - 1) as f64;
115            let sum_sq_diff_without_old =
116                self.sum_sq_diff - (old_val - old_mean) * (old_val - mean_without_old);
117
118            // Add new value
119            let new_mean = mean_without_old + (new_val - mean_without_old) / n as f64;
120            let new_sum_sq_diff =
121                sum_sq_diff_without_old + (new_val - mean_without_old) * (new_val - new_mean);
122
123            self.mean = new_mean;
124            self.sum_sq_diff = new_sum_sq_diff.max(0.0);
125        }
126    }
127}
128
129impl<T> StdDev<T> {
130    #[inline]
131    pub fn get(&self) -> f64 {
132        if self.count < 2 {
133            return 0.0;
134        }
135        // Standard deviation: σ = √(S / (n - 1))
136        (self.sum_sq_diff / (self.count - 1) as f64).sqrt()
137    }
138}
139
140impl<T: Clone> ShallowClone for StdDev<T> {}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_std_dev_basic() {
148        use crate::core::Collection;
149
150        let mut db = Collection::new(StdDev::<f64>::new());
151
152        // Traditional standard deviation calculation that iterates over the collection
153        let calculate_std_dev = |collection: &Collection<f64, _>| -> f64 {
154            let values: Vec<f64> = collection.iter().into_iter().map(|(_, &v)| v).collect();
155
156            if values.len() < 2 {
157                return 0.0;
158            }
159            let mean = values.iter().sum::<f64>() / values.len() as f64;
160            let variance = values
161                .iter()
162                .map(|&x| {
163                    let diff = x - mean;
164                    diff * diff
165                })
166                .sum::<f64>()
167                / (values.len() - 1) as f64;
168            variance.sqrt()
169        };
170
171        // Test with no elements
172        assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
173
174        // Test with one element
175        let _k1 = db.insert(5.0);
176        assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
177
178        // Test with two elements: [5.0, 10.0]
179        let k2 = db.insert(10.0);
180        let expected = calculate_std_dev(&db);
181        let result = db.query(|ix| ix.get());
182        assert!((result - expected).abs() < 1e-10);
183
184        // Test with three elements: [5.0, 10.0, 15.0]
185        let k3 = db.insert(15.0);
186        let expected = calculate_std_dev(&db);
187        let result = db.query(|ix| ix.get());
188        assert!((result - expected).abs() < 1e-10);
189
190        // Remove one element: [5.0, 10.0]
191        db.delete_by_key(k3);
192        let expected = calculate_std_dev(&db);
193        let result = db.query(|ix| ix.get());
194        assert!((result - expected).abs() < 1e-10);
195
196        // Remove another: [5.0]
197        db.delete_by_key(k2);
198        assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
199    }
200}