nodedb_types/approx/
tdigest.rs1#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
7struct Centroid {
8 mean: f64,
9 count: u64,
10}
11
12#[derive(Debug, serde::Serialize, serde::Deserialize)]
18pub struct TDigest {
19 centroids: Vec<Centroid>,
20 max_centroids: usize,
21 total_count: u64,
22}
23
24impl TDigest {
25 pub fn new() -> Self {
26 Self::with_compression(200)
27 }
28
29 pub fn with_compression(max_centroids: usize) -> Self {
30 Self {
31 centroids: Vec::with_capacity(max_centroids),
32 max_centroids: max_centroids.max(10),
33 total_count: 0,
34 }
35 }
36
37 pub fn add(&mut self, value: f64) {
38 if value.is_nan() {
39 return;
40 }
41 self.centroids.push(Centroid {
42 mean: value,
43 count: 1,
44 });
45 self.total_count += 1;
46
47 if self.centroids.len() > self.max_centroids * 2 {
48 self.compress();
49 }
50 }
51
52 pub fn add_batch(&mut self, values: &[f64]) {
53 for &v in values {
54 self.add(v);
55 }
56 }
57
58 pub fn quantile(&self, q: f64) -> f64 {
60 let q = q.clamp(0.0, 1.0);
61 if self.centroids.is_empty() {
62 return f64::NAN;
63 }
64 self.compress_clone().quantile_sorted(q)
65 }
66
67 pub fn merge(&mut self, other: &TDigest) {
68 self.centroids.extend_from_slice(&other.centroids);
69 self.total_count += other.total_count;
70 if self.centroids.len() > self.max_centroids * 2 {
71 self.compress();
72 }
73 }
74
75 pub fn count(&self) -> u64 {
76 self.total_count
77 }
78
79 pub fn add_centroid(&mut self, mean: f64, count: u64) {
81 self.centroids.push(Centroid { mean, count });
82 self.total_count += count;
83 if self.centroids.len() > self.max_centroids * 2 {
84 self.compress();
85 }
86 }
87
88 pub fn centroids(&self) -> Vec<(f64, u64)> {
90 self.centroids.iter().map(|c| (c.mean, c.count)).collect()
91 }
92
93 pub fn memory_bytes(&self) -> usize {
95 std::mem::size_of::<Self>() + self.centroids.capacity() * std::mem::size_of::<Centroid>()
96 }
97
98 fn compress(&mut self) {
99 if self.centroids.len() <= self.max_centroids {
100 return;
101 }
102
103 self.centroids.sort_by(|a, b| {
104 a.mean
105 .partial_cmp(&b.mean)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108
109 let target = self.max_centroids;
110 while self.centroids.len() > target {
111 let mut best_i = 0;
112 let mut best_gap = f64::INFINITY;
113 for i in 0..self.centroids.len() - 1 {
114 let gap = self.centroids[i + 1].mean - self.centroids[i].mean;
115 if gap < best_gap {
116 best_gap = gap;
117 best_i = i;
118 }
119 }
120 let a = self.centroids[best_i];
121 let b = self.centroids.remove(best_i + 1);
122 let total = a.count + b.count;
123 self.centroids[best_i] = Centroid {
124 mean: (a.mean * a.count as f64 + b.mean * b.count as f64) / total as f64,
125 count: total,
126 };
127 }
128 }
129
130 fn compress_clone(&self) -> TDigest {
131 let mut clone = self.clone_inner();
132 clone.compress();
133 clone
134 }
135
136 fn clone_inner(&self) -> TDigest {
137 TDigest {
138 centroids: self.centroids.clone(),
139 max_centroids: self.max_centroids,
140 total_count: self.total_count,
141 }
142 }
143
144 fn quantile_sorted(&self, q: f64) -> f64 {
145 if self.centroids.is_empty() {
146 return f64::NAN;
147 }
148 if self.centroids.len() == 1 {
149 return self.centroids[0].mean;
150 }
151
152 let target = q * self.total_count as f64;
153 let mut cumulative = 0.0;
154
155 for c in &self.centroids {
156 cumulative += c.count as f64;
157 if cumulative >= target {
158 return c.mean;
159 }
160 }
161
162 self.centroids.last().map_or(f64::NAN, |c| c.mean)
163 }
164}
165
166impl Default for TDigest {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn tdigest_empty() {
178 let td = TDigest::new();
179 assert!(td.quantile(0.5).is_nan());
180 }
181
182 #[test]
183 fn tdigest_single_value() {
184 let mut td = TDigest::new();
185 td.add(42.0);
186 assert!((td.quantile(0.5) - 42.0).abs() < f64::EPSILON);
187 }
188
189 #[test]
190 fn tdigest_uniform() {
191 let mut td = TDigest::new();
192 for i in 0..10_000 {
193 td.add(i as f64);
194 }
195 let p50 = td.quantile(0.5);
196 assert!(
197 (4500.0..5500.0).contains(&p50),
198 "p50 expected ~5000, got {p50:.0}"
199 );
200 let p99 = td.quantile(0.99);
201 assert!(
202 (9800.0..10000.0).contains(&p99),
203 "p99 expected ~9900, got {p99:.0}"
204 );
205 }
206
207 #[test]
208 fn tdigest_merge() {
209 let mut a = TDigest::new();
210 let mut b = TDigest::new();
211 for i in 0..5000 {
212 a.add(i as f64);
213 }
214 for i in 5000..10000 {
215 b.add(i as f64);
216 }
217 a.merge(&b);
218 let p50 = a.quantile(0.5);
219 assert!(
220 (4000.0..6000.0).contains(&p50),
221 "merged p50 expected ~5000, got {p50:.0}"
222 );
223 }
224
225 #[test]
226 fn tdigest_serde_roundtrip_merge_semantics() {
227 let mut a = TDigest::new();
228 let mut b = TDigest::new();
229 for i in 0..500 {
230 a.add(i as f64);
231 }
232 for i in 500..1000 {
233 b.add(i as f64);
234 }
235
236 let bytes = serde_json::to_vec(&a).expect("serialize TDigest");
237 let mut a_prime: TDigest = serde_json::from_slice(&bytes).expect("deserialize TDigest");
238
239 a_prime.merge(&b);
240 a.merge(&b);
241
242 let p50_orig = a.quantile(0.5);
243 let p50_rt = a_prime.quantile(0.5);
244 let diff = (p50_orig - p50_rt).abs() / p50_orig.abs().max(1.0);
246 assert!(
247 diff < 0.10,
248 "roundtrip diverged: orig p50={p50_orig:.0}, roundtrip p50={p50_rt:.0}"
249 );
250 }
251}