Skip to main content

oxihuman_core/
t_digest.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! t-digest sketch for quantile approximation.
5
6#![allow(dead_code)]
7
8/// A centroid in the t-digest.
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct Centroid {
12    pub mean: f64,
13    pub weight: f64,
14}
15
16/// t-digest data structure for quantile estimation.
17#[allow(dead_code)]
18pub struct TDigest {
19    centroids: Vec<Centroid>,
20    compression: f64,
21    total_weight: f64,
22}
23
24impl TDigest {
25    #[allow(dead_code)]
26    pub fn new(compression: f64) -> Self {
27        Self {
28            centroids: Vec::new(),
29            compression: compression.max(10.0),
30            total_weight: 0.0,
31        }
32    }
33
34    #[allow(dead_code)]
35    pub fn add(&mut self, value: f64) {
36        self.add_weighted(value, 1.0);
37    }
38
39    #[allow(dead_code)]
40    pub fn add_weighted(&mut self, value: f64, weight: f64) {
41        self.total_weight += weight;
42        let idx = self.centroids.partition_point(|c| c.mean < value);
43
44        let max_w = self.max_weight_for(idx);
45        if let Some(c) = self.centroids.get_mut(idx) {
46            if c.mean == value || c.weight + weight <= max_w {
47                c.mean = (c.mean * c.weight + value * weight) / (c.weight + weight);
48                c.weight += weight;
49                return;
50            }
51        }
52        self.centroids.insert(
53            idx,
54            Centroid {
55                mean: value,
56                weight,
57            },
58        );
59        self.compress();
60    }
61
62    fn max_weight_for(&self, idx: usize) -> f64 {
63        let q = if self.total_weight > 0.0 {
64            let cumulative: f64 = self.centroids[..idx].iter().map(|c| c.weight).sum();
65            (cumulative + self.centroids.get(idx).map_or(0.0, |c| c.weight / 2.0))
66                / self.total_weight
67        } else {
68            0.5
69        };
70        4.0 * self.total_weight * q * (1.0 - q) / self.compression
71    }
72
73    fn compress(&mut self) {
74        let max_centroids = (self.compression * std::f64::consts::PI / 2.0).ceil() as usize;
75        if self.centroids.len() <= max_centroids {
76            return;
77        }
78        /* Sort by mean before compressing */
79        self.centroids.sort_by(|a, b| {
80            a.mean
81                .partial_cmp(&b.mean)
82                .unwrap_or(std::cmp::Ordering::Equal)
83        });
84        let total = self.total_weight;
85        let mut merged: Vec<Centroid> = Vec::new();
86        let mut cumulative_w = 0.0f64;
87        for c in self.centroids.drain(..) {
88            if let Some(last) = merged.last_mut() {
89                /* Compute quantile at midpoint of last centroid */
90                let q = (cumulative_w - last.weight / 2.0) / total;
91                let q = q.clamp(0.0, 1.0);
92                let limit = 4.0 * total * q * (1.0 - q) / self.compression;
93                let limit = limit.max(1.0);
94                if last.weight + c.weight <= limit {
95                    last.mean =
96                        (last.mean * last.weight + c.mean * c.weight) / (last.weight + c.weight);
97                    last.weight += c.weight;
98                    cumulative_w += c.weight;
99                    continue;
100                }
101            }
102            cumulative_w += c.weight;
103            merged.push(c);
104        }
105        self.centroids = merged;
106    }
107
108    /// Estimate quantile q in [0, 1].
109    #[allow(dead_code)]
110    pub fn quantile(&self, q: f64) -> f64 {
111        if self.centroids.is_empty() {
112            return f64::NAN;
113        }
114        let target = q * self.total_weight;
115        let mut cumulative = 0.0;
116        for (i, c) in self.centroids.iter().enumerate() {
117            let lower = cumulative;
118            let upper = cumulative + c.weight;
119            let mid = (lower + upper) / 2.0;
120            if target <= mid {
121                if i == 0 {
122                    return c.mean;
123                }
124                let prev = &self.centroids[i - 1];
125                let prev_mid = cumulative - prev.weight / 2.0;
126                let frac = (target - prev_mid) / (mid - prev_mid);
127                return prev.mean + frac * (c.mean - prev.mean);
128            }
129            cumulative += c.weight;
130        }
131        self.centroids.last().map_or(f64::NAN, |c| c.mean)
132    }
133
134    #[allow(dead_code)]
135    pub fn count(&self) -> f64 {
136        self.total_weight
137    }
138
139    #[allow(dead_code)]
140    pub fn centroid_count(&self) -> usize {
141        self.centroids.len()
142    }
143
144    #[allow(dead_code)]
145    pub fn min(&self) -> f64 {
146        self.centroids.first().map_or(f64::NAN, |c| c.mean)
147    }
148
149    #[allow(dead_code)]
150    pub fn max(&self) -> f64 {
151        self.centroids.last().map_or(f64::NAN, |c| c.mean)
152    }
153
154    #[allow(dead_code)]
155    pub fn merge(&mut self, other: &TDigest) {
156        for c in &other.centroids {
157            self.add_weighted(c.mean, c.weight);
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_empty_quantile() {
168        let td = TDigest::new(100.0);
169        assert!(td.quantile(0.5).is_nan());
170    }
171
172    #[test]
173    fn test_single_value() {
174        let mut td = TDigest::new(100.0);
175        td.add(5.0);
176        assert!((td.quantile(0.5) - 5.0).abs() < 1e-6);
177    }
178
179    #[test]
180    fn test_median_uniform() {
181        let mut td = TDigest::new(100.0);
182        for i in 1..=100 {
183            td.add(i as f64);
184        }
185        let m = td.quantile(0.5);
186        assert!(m > 40.0 && m < 60.0, "median={m}");
187    }
188
189    #[test]
190    fn test_min_max() {
191        let mut td = TDigest::new(100.0);
192        for i in [3.0, 1.0, 5.0, 2.0, 4.0] {
193            td.add(i);
194        }
195        assert!((td.min() - 1.0).abs() < 1.0);
196        assert!((td.max() - 5.0).abs() < 1.0);
197    }
198
199    #[test]
200    fn test_count() {
201        let mut td = TDigest::new(50.0);
202        for _ in 0..10 {
203            td.add(1.0);
204        }
205        assert!((td.count() - 10.0).abs() < 1e-6);
206    }
207
208    #[test]
209    fn test_compression_clamp() {
210        let td = TDigest::new(5.0);
211        assert!(td.compression >= 10.0);
212    }
213
214    #[test]
215    fn test_merge() {
216        let mut td1 = TDigest::new(100.0);
217        let mut td2 = TDigest::new(100.0);
218        for i in 1..=50 {
219            td1.add(i as f64);
220        }
221        for i in 51..=100 {
222            td2.add(i as f64);
223        }
224        td1.merge(&td2);
225        let m = td1.quantile(0.5);
226        assert!(m > 40.0 && m < 60.0, "merged median={m}");
227    }
228
229    #[test]
230    fn test_quantile_zero_one() {
231        let mut td = TDigest::new(100.0);
232        for i in 1..=10 {
233            td.add(i as f64);
234        }
235        assert!(td.quantile(0.0) <= td.quantile(0.5));
236        assert!(td.quantile(0.5) <= td.quantile(1.0));
237    }
238
239    #[test]
240    fn test_centroid_count_bounded() {
241        let mut td = TDigest::new(50.0);
242        for i in 0..1000 {
243            td.add(i as f64);
244        }
245        assert!(td.centroid_count() < 500);
246    }
247
248    #[test]
249    fn test_add_weighted() {
250        let mut td = TDigest::new(100.0);
251        td.add_weighted(10.0, 5.0);
252        assert!((td.count() - 5.0).abs() < 1e-6);
253    }
254}