composable_indexes/aggregation/
stddev.rs1use num_traits::ToPrimitive;
2
3use crate::{
4 Index, ShallowClone,
5 core::{Insert, Remove, Seal, Update},
6};
7
8#[derive(Debug, Clone, Copy)]
23pub struct StdDev<T> {
24 mean: f64,
25 sum_sq_diff: f64,
26 count: u64,
27 _phantom: core::marker::PhantomData<T>,
28}
29
30impl<T> Default for StdDev<T> {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<T> StdDev<T> {
37 pub fn new() -> Self {
38 StdDev {
39 mean: 0.0,
40 sum_sq_diff: 0.0,
41 count: 0,
42 _phantom: core::marker::PhantomData,
43 }
44 }
45}
46
47impl<T> Index<T> for StdDev<T>
48where
49 T: ToPrimitive + Copy + 'static,
50{
51 #[inline]
52 fn insert(&mut self, _seal: Seal, op: &Insert<T>) {
53 if let Some(x) = op.new.to_f64() {
54 self.count += 1;
55 let k = self.count;
56
57 let old_mean = self.mean;
60 self.mean = old_mean + (x - old_mean) / k as f64;
61
62 self.sum_sq_diff += (x - old_mean) * (x - self.mean);
64 }
65 }
66
67 #[inline]
68 fn remove(&mut self, _seal: Seal, op: &Remove<T>) {
69 if let Some(x) = op.existing.to_f64() {
70 let n = self.count;
71
72 if n <= 1 {
73 self.mean = 0.0;
75 self.sum_sq_diff = 0.0;
76 self.count = 0;
77 return;
78 }
79
80 let old_mean = self.mean;
83 self.mean = (n as f64 * old_mean - x) / (n - 1) as f64;
84
85 self.sum_sq_diff -= (x - old_mean) * (x - self.mean);
87
88 self.sum_sq_diff = self.sum_sq_diff.max(0.0);
90
91 self.count = n - 1;
92 }
93 }
94
95 #[inline]
96 fn update(&mut self, _seal: Seal, op: &Update<T>) {
97 if let (Some(old_val), Some(new_val)) = (op.existing.to_f64(), op.new.to_f64()) {
98 let n = self.count;
100
101 if n == 0 {
102 return;
103 }
104
105 if n == 1 {
106 self.mean = new_val;
108 self.sum_sq_diff = 0.0;
109 return;
110 }
111
112 let old_mean = self.mean;
114 let mean_without_old = (n as f64 * old_mean - old_val) / (n - 1) as f64;
115 let sum_sq_diff_without_old =
116 self.sum_sq_diff - (old_val - old_mean) * (old_val - mean_without_old);
117
118 let new_mean = mean_without_old + (new_val - mean_without_old) / n as f64;
120 let new_sum_sq_diff =
121 sum_sq_diff_without_old + (new_val - mean_without_old) * (new_val - new_mean);
122
123 self.mean = new_mean;
124 self.sum_sq_diff = new_sum_sq_diff.max(0.0);
125 }
126 }
127}
128
129impl<T> StdDev<T> {
130 #[inline]
131 pub fn get(&self) -> f64 {
132 if self.count < 2 {
133 return 0.0;
134 }
135 (self.sum_sq_diff / (self.count - 1) as f64).sqrt()
137 }
138}
139
140impl<T: Clone> ShallowClone for StdDev<T> {}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_std_dev_basic() {
148 use crate::core::Collection;
149
150 let mut db = Collection::new(StdDev::<f64>::new());
151
152 let calculate_std_dev = |collection: &Collection<f64, _>| -> f64 {
154 let values: Vec<f64> = collection.iter().into_iter().map(|(_, &v)| v).collect();
155
156 if values.len() < 2 {
157 return 0.0;
158 }
159 let mean = values.iter().sum::<f64>() / values.len() as f64;
160 let variance = values
161 .iter()
162 .map(|&x| {
163 let diff = x - mean;
164 diff * diff
165 })
166 .sum::<f64>()
167 / (values.len() - 1) as f64;
168 variance.sqrt()
169 };
170
171 assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
173
174 let _k1 = db.insert(5.0);
176 assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
177
178 let k2 = db.insert(10.0);
180 let expected = calculate_std_dev(&db);
181 let result = db.query(|ix| ix.get());
182 assert!((result - expected).abs() < 1e-10);
183
184 let k3 = db.insert(15.0);
186 let expected = calculate_std_dev(&db);
187 let result = db.query(|ix| ix.get());
188 assert!((result - expected).abs() < 1e-10);
189
190 db.delete_by_key(k3);
192 let expected = calculate_std_dev(&db);
193 let result = db.query(|ix| ix.get());
194 assert!((result - expected).abs() < 1e-10);
195
196 db.delete_by_key(k2);
198 assert_eq!(db.query(|ix| ix.get()), calculate_std_dev(&db));
199 }
200}