Skip to main content

nodedb_types/approx/
tdigest.rs

1//! TDigest — approximate percentile estimation (mergeable centroids).
2
3/// Centroid in the t-digest: represents a cluster of values.
4#[derive(Debug, Clone, Copy)]
5struct Centroid {
6    mean: f64,
7    count: u64,
8}
9
10/// T-digest approximate quantile estimator.
11///
12/// Maintains a sorted set of centroids that approximate the data distribution.
13/// Accurate at the extremes (p1, p99) and reasonable in the middle.
14/// Mergeable across partitions and shards.
15#[derive(Debug)]
16pub struct TDigest {
17    centroids: Vec<Centroid>,
18    max_centroids: usize,
19    total_count: u64,
20}
21
22impl TDigest {
23    pub fn new() -> Self {
24        Self::with_compression(200)
25    }
26
27    pub fn with_compression(max_centroids: usize) -> Self {
28        Self {
29            centroids: Vec::with_capacity(max_centroids),
30            max_centroids: max_centroids.max(10),
31            total_count: 0,
32        }
33    }
34
35    pub fn add(&mut self, value: f64) {
36        if value.is_nan() {
37            return;
38        }
39        self.centroids.push(Centroid {
40            mean: value,
41            count: 1,
42        });
43        self.total_count += 1;
44
45        if self.centroids.len() > self.max_centroids * 2 {
46            self.compress();
47        }
48    }
49
50    pub fn add_batch(&mut self, values: &[f64]) {
51        for &v in values {
52            self.add(v);
53        }
54    }
55
56    /// Estimate the value at a given quantile (0.0 to 1.0).
57    pub fn quantile(&self, q: f64) -> f64 {
58        let q = q.clamp(0.0, 1.0);
59        if self.centroids.is_empty() {
60            return f64::NAN;
61        }
62        self.compress_clone().quantile_sorted(q)
63    }
64
65    pub fn merge(&mut self, other: &TDigest) {
66        self.centroids.extend_from_slice(&other.centroids);
67        self.total_count += other.total_count;
68        if self.centroids.len() > self.max_centroids * 2 {
69            self.compress();
70        }
71    }
72
73    pub fn count(&self) -> u64 {
74        self.total_count
75    }
76
77    /// Add a pre-aggregated centroid (for merge/deserialization).
78    pub fn add_centroid(&mut self, mean: f64, count: u64) {
79        self.centroids.push(Centroid { mean, count });
80        self.total_count += count;
81        if self.centroids.len() > self.max_centroids * 2 {
82            self.compress();
83        }
84    }
85
86    /// Access centroids as (mean, count) pairs for serialization.
87    pub fn centroids(&self) -> Vec<(f64, u64)> {
88        self.centroids.iter().map(|c| (c.mean, c.count)).collect()
89    }
90
91    /// Approximate memory usage in bytes.
92    pub fn memory_bytes(&self) -> usize {
93        std::mem::size_of::<Self>() + self.centroids.capacity() * std::mem::size_of::<Centroid>()
94    }
95
96    fn compress(&mut self) {
97        if self.centroids.len() <= self.max_centroids {
98            return;
99        }
100
101        self.centroids.sort_by(|a, b| {
102            a.mean
103                .partial_cmp(&b.mean)
104                .unwrap_or(std::cmp::Ordering::Equal)
105        });
106
107        let target = self.max_centroids;
108        while self.centroids.len() > target {
109            let mut best_i = 0;
110            let mut best_gap = f64::INFINITY;
111            for i in 0..self.centroids.len() - 1 {
112                let gap = self.centroids[i + 1].mean - self.centroids[i].mean;
113                if gap < best_gap {
114                    best_gap = gap;
115                    best_i = i;
116                }
117            }
118            let a = self.centroids[best_i];
119            let b = self.centroids.remove(best_i + 1);
120            let total = a.count + b.count;
121            self.centroids[best_i] = Centroid {
122                mean: (a.mean * a.count as f64 + b.mean * b.count as f64) / total as f64,
123                count: total,
124            };
125        }
126    }
127
128    fn compress_clone(&self) -> TDigest {
129        let mut clone = self.clone_inner();
130        clone.compress();
131        clone
132    }
133
134    fn clone_inner(&self) -> TDigest {
135        TDigest {
136            centroids: self.centroids.clone(),
137            max_centroids: self.max_centroids,
138            total_count: self.total_count,
139        }
140    }
141
142    fn quantile_sorted(&self, q: f64) -> f64 {
143        if self.centroids.is_empty() {
144            return f64::NAN;
145        }
146        if self.centroids.len() == 1 {
147            return self.centroids[0].mean;
148        }
149
150        let target = q * self.total_count as f64;
151        let mut cumulative = 0.0;
152
153        for c in &self.centroids {
154            cumulative += c.count as f64;
155            if cumulative >= target {
156                return c.mean;
157            }
158        }
159
160        self.centroids.last().map_or(f64::NAN, |c| c.mean)
161    }
162}
163
164impl Default for TDigest {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn tdigest_empty() {
176        let td = TDigest::new();
177        assert!(td.quantile(0.5).is_nan());
178    }
179
180    #[test]
181    fn tdigest_single_value() {
182        let mut td = TDigest::new();
183        td.add(42.0);
184        assert!((td.quantile(0.5) - 42.0).abs() < f64::EPSILON);
185    }
186
187    #[test]
188    fn tdigest_uniform() {
189        let mut td = TDigest::new();
190        for i in 0..10_000 {
191            td.add(i as f64);
192        }
193        let p50 = td.quantile(0.5);
194        assert!(
195            (4500.0..5500.0).contains(&p50),
196            "p50 expected ~5000, got {p50:.0}"
197        );
198        let p99 = td.quantile(0.99);
199        assert!(
200            (9800.0..10000.0).contains(&p99),
201            "p99 expected ~9900, got {p99:.0}"
202        );
203    }
204
205    #[test]
206    fn tdigest_merge() {
207        let mut a = TDigest::new();
208        let mut b = TDigest::new();
209        for i in 0..5000 {
210            a.add(i as f64);
211        }
212        for i in 5000..10000 {
213            b.add(i as f64);
214        }
215        a.merge(&b);
216        let p50 = a.quantile(0.5);
217        assert!(
218            (4000.0..6000.0).contains(&p50),
219            "merged p50 expected ~5000, got {p50:.0}"
220        );
221    }
222}