Skip to main content

nodedb_types/approx/
tdigest.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! TDigest — approximate percentile estimation (mergeable centroids).
4
5/// Centroid in the t-digest: represents a cluster of values.
6#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
7struct Centroid {
8    mean: f64,
9    count: u64,
10}
11
12/// T-digest approximate quantile estimator.
13///
14/// Maintains a sorted set of centroids that approximate the data distribution.
15/// Accurate at the extremes (p1, p99) and reasonable in the middle.
16/// Mergeable across partitions and shards.
17#[derive(Debug, serde::Serialize, serde::Deserialize)]
18pub struct TDigest {
19    centroids: Vec<Centroid>,
20    max_centroids: usize,
21    total_count: u64,
22}
23
24impl TDigest {
25    pub fn new() -> Self {
26        Self::with_compression(200)
27    }
28
29    pub fn with_compression(max_centroids: usize) -> Self {
30        Self {
31            centroids: Vec::with_capacity(max_centroids),
32            max_centroids: max_centroids.max(10),
33            total_count: 0,
34        }
35    }
36
37    pub fn add(&mut self, value: f64) {
38        if value.is_nan() {
39            return;
40        }
41        self.centroids.push(Centroid {
42            mean: value,
43            count: 1,
44        });
45        self.total_count += 1;
46
47        if self.centroids.len() > self.max_centroids * 2 {
48            self.compress();
49        }
50    }
51
52    pub fn add_batch(&mut self, values: &[f64]) {
53        for &v in values {
54            self.add(v);
55        }
56    }
57
58    /// Estimate the value at a given quantile (0.0 to 1.0).
59    pub fn quantile(&self, q: f64) -> f64 {
60        let q = q.clamp(0.0, 1.0);
61        if self.centroids.is_empty() {
62            return f64::NAN;
63        }
64        self.compress_clone().quantile_sorted(q)
65    }
66
67    pub fn merge(&mut self, other: &TDigest) {
68        self.centroids.extend_from_slice(&other.centroids);
69        self.total_count += other.total_count;
70        if self.centroids.len() > self.max_centroids * 2 {
71            self.compress();
72        }
73    }
74
75    pub fn count(&self) -> u64 {
76        self.total_count
77    }
78
79    /// Add a pre-aggregated centroid (for merge/deserialization).
80    pub fn add_centroid(&mut self, mean: f64, count: u64) {
81        self.centroids.push(Centroid { mean, count });
82        self.total_count += count;
83        if self.centroids.len() > self.max_centroids * 2 {
84            self.compress();
85        }
86    }
87
88    /// Access centroids as (mean, count) pairs for serialization.
89    pub fn centroids(&self) -> Vec<(f64, u64)> {
90        self.centroids.iter().map(|c| (c.mean, c.count)).collect()
91    }
92
93    /// Approximate memory usage in bytes.
94    pub fn memory_bytes(&self) -> usize {
95        std::mem::size_of::<Self>() + self.centroids.capacity() * std::mem::size_of::<Centroid>()
96    }
97
98    fn compress(&mut self) {
99        if self.centroids.len() <= self.max_centroids {
100            return;
101        }
102
103        self.centroids.sort_by(|a, b| {
104            a.mean
105                .partial_cmp(&b.mean)
106                .unwrap_or(std::cmp::Ordering::Equal)
107        });
108
109        let target = self.max_centroids;
110        while self.centroids.len() > target {
111            let mut best_i = 0;
112            let mut best_gap = f64::INFINITY;
113            for i in 0..self.centroids.len() - 1 {
114                let gap = self.centroids[i + 1].mean - self.centroids[i].mean;
115                if gap < best_gap {
116                    best_gap = gap;
117                    best_i = i;
118                }
119            }
120            let a = self.centroids[best_i];
121            let b = self.centroids.remove(best_i + 1);
122            let total = a.count + b.count;
123            self.centroids[best_i] = Centroid {
124                mean: (a.mean * a.count as f64 + b.mean * b.count as f64) / total as f64,
125                count: total,
126            };
127        }
128    }
129
130    fn compress_clone(&self) -> TDigest {
131        let mut clone = self.clone_inner();
132        clone.compress();
133        clone
134    }
135
136    fn clone_inner(&self) -> TDigest {
137        TDigest {
138            centroids: self.centroids.clone(),
139            max_centroids: self.max_centroids,
140            total_count: self.total_count,
141        }
142    }
143
144    fn quantile_sorted(&self, q: f64) -> f64 {
145        if self.centroids.is_empty() {
146            return f64::NAN;
147        }
148        if self.centroids.len() == 1 {
149            return self.centroids[0].mean;
150        }
151
152        let target = q * self.total_count as f64;
153        let mut cumulative = 0.0;
154
155        for c in &self.centroids {
156            cumulative += c.count as f64;
157            if cumulative >= target {
158                return c.mean;
159            }
160        }
161
162        self.centroids.last().map_or(f64::NAN, |c| c.mean)
163    }
164}
165
166impl Default for TDigest {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn tdigest_empty() {
178        let td = TDigest::new();
179        assert!(td.quantile(0.5).is_nan());
180    }
181
182    #[test]
183    fn tdigest_single_value() {
184        let mut td = TDigest::new();
185        td.add(42.0);
186        assert!((td.quantile(0.5) - 42.0).abs() < f64::EPSILON);
187    }
188
189    #[test]
190    fn tdigest_uniform() {
191        let mut td = TDigest::new();
192        for i in 0..10_000 {
193            td.add(i as f64);
194        }
195        let p50 = td.quantile(0.5);
196        assert!(
197            (4500.0..5500.0).contains(&p50),
198            "p50 expected ~5000, got {p50:.0}"
199        );
200        let p99 = td.quantile(0.99);
201        assert!(
202            (9800.0..10000.0).contains(&p99),
203            "p99 expected ~9900, got {p99:.0}"
204        );
205    }
206
207    #[test]
208    fn tdigest_merge() {
209        let mut a = TDigest::new();
210        let mut b = TDigest::new();
211        for i in 0..5000 {
212            a.add(i as f64);
213        }
214        for i in 5000..10000 {
215            b.add(i as f64);
216        }
217        a.merge(&b);
218        let p50 = a.quantile(0.5);
219        assert!(
220            (4000.0..6000.0).contains(&p50),
221            "merged p50 expected ~5000, got {p50:.0}"
222        );
223    }
224
225    #[test]
226    fn tdigest_serde_roundtrip_merge_semantics() {
227        let mut a = TDigest::new();
228        let mut b = TDigest::new();
229        for i in 0..500 {
230            a.add(i as f64);
231        }
232        for i in 500..1000 {
233            b.add(i as f64);
234        }
235
236        let bytes = serde_json::to_vec(&a).expect("serialize TDigest");
237        let mut a_prime: TDigest = serde_json::from_slice(&bytes).expect("deserialize TDigest");
238
239        a_prime.merge(&b);
240        a.merge(&b);
241
242        let p50_orig = a.quantile(0.5);
243        let p50_rt = a_prime.quantile(0.5);
244        // Both are approximate; require within 10% of each other.
245        let diff = (p50_orig - p50_rt).abs() / p50_orig.abs().max(1.0);
246        assert!(
247            diff < 0.10,
248            "roundtrip diverged: orig p50={p50_orig:.0}, roundtrip p50={p50_rt:.0}"
249        );
250    }
251}