nodedb_types/approx/
tdigest.rs1#[derive(Debug, Clone, Copy)]
5struct Centroid {
6 mean: f64,
7 count: u64,
8}
9
10#[derive(Debug)]
16pub struct TDigest {
17 centroids: Vec<Centroid>,
18 max_centroids: usize,
19 total_count: u64,
20}
21
22impl TDigest {
23 pub fn new() -> Self {
24 Self::with_compression(200)
25 }
26
27 pub fn with_compression(max_centroids: usize) -> Self {
28 Self {
29 centroids: Vec::with_capacity(max_centroids),
30 max_centroids: max_centroids.max(10),
31 total_count: 0,
32 }
33 }
34
35 pub fn add(&mut self, value: f64) {
36 if value.is_nan() {
37 return;
38 }
39 self.centroids.push(Centroid {
40 mean: value,
41 count: 1,
42 });
43 self.total_count += 1;
44
45 if self.centroids.len() > self.max_centroids * 2 {
46 self.compress();
47 }
48 }
49
50 pub fn add_batch(&mut self, values: &[f64]) {
51 for &v in values {
52 self.add(v);
53 }
54 }
55
56 pub fn quantile(&self, q: f64) -> f64 {
58 let q = q.clamp(0.0, 1.0);
59 if self.centroids.is_empty() {
60 return f64::NAN;
61 }
62 self.compress_clone().quantile_sorted(q)
63 }
64
65 pub fn merge(&mut self, other: &TDigest) {
66 self.centroids.extend_from_slice(&other.centroids);
67 self.total_count += other.total_count;
68 if self.centroids.len() > self.max_centroids * 2 {
69 self.compress();
70 }
71 }
72
73 pub fn count(&self) -> u64 {
74 self.total_count
75 }
76
77 pub fn add_centroid(&mut self, mean: f64, count: u64) {
79 self.centroids.push(Centroid { mean, count });
80 self.total_count += count;
81 if self.centroids.len() > self.max_centroids * 2 {
82 self.compress();
83 }
84 }
85
86 pub fn centroids(&self) -> Vec<(f64, u64)> {
88 self.centroids.iter().map(|c| (c.mean, c.count)).collect()
89 }
90
91 pub fn memory_bytes(&self) -> usize {
93 std::mem::size_of::<Self>() + self.centroids.capacity() * std::mem::size_of::<Centroid>()
94 }
95
96 fn compress(&mut self) {
97 if self.centroids.len() <= self.max_centroids {
98 return;
99 }
100
101 self.centroids.sort_by(|a, b| {
102 a.mean
103 .partial_cmp(&b.mean)
104 .unwrap_or(std::cmp::Ordering::Equal)
105 });
106
107 let target = self.max_centroids;
108 while self.centroids.len() > target {
109 let mut best_i = 0;
110 let mut best_gap = f64::INFINITY;
111 for i in 0..self.centroids.len() - 1 {
112 let gap = self.centroids[i + 1].mean - self.centroids[i].mean;
113 if gap < best_gap {
114 best_gap = gap;
115 best_i = i;
116 }
117 }
118 let a = self.centroids[best_i];
119 let b = self.centroids.remove(best_i + 1);
120 let total = a.count + b.count;
121 self.centroids[best_i] = Centroid {
122 mean: (a.mean * a.count as f64 + b.mean * b.count as f64) / total as f64,
123 count: total,
124 };
125 }
126 }
127
128 fn compress_clone(&self) -> TDigest {
129 let mut clone = self.clone_inner();
130 clone.compress();
131 clone
132 }
133
134 fn clone_inner(&self) -> TDigest {
135 TDigest {
136 centroids: self.centroids.clone(),
137 max_centroids: self.max_centroids,
138 total_count: self.total_count,
139 }
140 }
141
142 fn quantile_sorted(&self, q: f64) -> f64 {
143 if self.centroids.is_empty() {
144 return f64::NAN;
145 }
146 if self.centroids.len() == 1 {
147 return self.centroids[0].mean;
148 }
149
150 let target = q * self.total_count as f64;
151 let mut cumulative = 0.0;
152
153 for c in &self.centroids {
154 cumulative += c.count as f64;
155 if cumulative >= target {
156 return c.mean;
157 }
158 }
159
160 self.centroids.last().map_or(f64::NAN, |c| c.mean)
161 }
162}
163
164impl Default for TDigest {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn tdigest_empty() {
176 let td = TDigest::new();
177 assert!(td.quantile(0.5).is_nan());
178 }
179
180 #[test]
181 fn tdigest_single_value() {
182 let mut td = TDigest::new();
183 td.add(42.0);
184 assert!((td.quantile(0.5) - 42.0).abs() < f64::EPSILON);
185 }
186
187 #[test]
188 fn tdigest_uniform() {
189 let mut td = TDigest::new();
190 for i in 0..10_000 {
191 td.add(i as f64);
192 }
193 let p50 = td.quantile(0.5);
194 assert!(
195 (4500.0..5500.0).contains(&p50),
196 "p50 expected ~5000, got {p50:.0}"
197 );
198 let p99 = td.quantile(0.99);
199 assert!(
200 (9800.0..10000.0).contains(&p99),
201 "p99 expected ~9900, got {p99:.0}"
202 );
203 }
204
205 #[test]
206 fn tdigest_merge() {
207 let mut a = TDigest::new();
208 let mut b = TDigest::new();
209 for i in 0..5000 {
210 a.add(i as f64);
211 }
212 for i in 5000..10000 {
213 b.add(i as f64);
214 }
215 a.merge(&b);
216 let p50 = a.quantile(0.5);
217 assert!(
218 (4000.0..6000.0).contains(&p50),
219 "merged p50 expected ~5000, got {p50:.0}"
220 );
221 }
222}