oxihuman_core/
t_digest.rs1#![allow(dead_code)]
7
8#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct Centroid {
12 pub mean: f64,
13 pub weight: f64,
14}
15
16#[allow(dead_code)]
18pub struct TDigest {
19 centroids: Vec<Centroid>,
20 compression: f64,
21 total_weight: f64,
22}
23
24impl TDigest {
25 #[allow(dead_code)]
26 pub fn new(compression: f64) -> Self {
27 Self {
28 centroids: Vec::new(),
29 compression: compression.max(10.0),
30 total_weight: 0.0,
31 }
32 }
33
34 #[allow(dead_code)]
35 pub fn add(&mut self, value: f64) {
36 self.add_weighted(value, 1.0);
37 }
38
39 #[allow(dead_code)]
40 pub fn add_weighted(&mut self, value: f64, weight: f64) {
41 self.total_weight += weight;
42 let idx = self.centroids.partition_point(|c| c.mean < value);
43
44 let max_w = self.max_weight_for(idx);
45 if let Some(c) = self.centroids.get_mut(idx) {
46 if c.mean == value || c.weight + weight <= max_w {
47 c.mean = (c.mean * c.weight + value * weight) / (c.weight + weight);
48 c.weight += weight;
49 return;
50 }
51 }
52 self.centroids.insert(
53 idx,
54 Centroid {
55 mean: value,
56 weight,
57 },
58 );
59 self.compress();
60 }
61
62 fn max_weight_for(&self, idx: usize) -> f64 {
63 let q = if self.total_weight > 0.0 {
64 let cumulative: f64 = self.centroids[..idx].iter().map(|c| c.weight).sum();
65 (cumulative + self.centroids.get(idx).map_or(0.0, |c| c.weight / 2.0))
66 / self.total_weight
67 } else {
68 0.5
69 };
70 4.0 * self.total_weight * q * (1.0 - q) / self.compression
71 }
72
73 fn compress(&mut self) {
74 let max_centroids = (self.compression * std::f64::consts::PI / 2.0).ceil() as usize;
75 if self.centroids.len() <= max_centroids {
76 return;
77 }
78 self.centroids.sort_by(|a, b| {
80 a.mean
81 .partial_cmp(&b.mean)
82 .unwrap_or(std::cmp::Ordering::Equal)
83 });
84 let total = self.total_weight;
85 let mut merged: Vec<Centroid> = Vec::new();
86 let mut cumulative_w = 0.0f64;
87 for c in self.centroids.drain(..) {
88 if let Some(last) = merged.last_mut() {
89 let q = (cumulative_w - last.weight / 2.0) / total;
91 let q = q.clamp(0.0, 1.0);
92 let limit = 4.0 * total * q * (1.0 - q) / self.compression;
93 let limit = limit.max(1.0);
94 if last.weight + c.weight <= limit {
95 last.mean =
96 (last.mean * last.weight + c.mean * c.weight) / (last.weight + c.weight);
97 last.weight += c.weight;
98 cumulative_w += c.weight;
99 continue;
100 }
101 }
102 cumulative_w += c.weight;
103 merged.push(c);
104 }
105 self.centroids = merged;
106 }
107
108 #[allow(dead_code)]
110 pub fn quantile(&self, q: f64) -> f64 {
111 if self.centroids.is_empty() {
112 return f64::NAN;
113 }
114 let target = q * self.total_weight;
115 let mut cumulative = 0.0;
116 for (i, c) in self.centroids.iter().enumerate() {
117 let lower = cumulative;
118 let upper = cumulative + c.weight;
119 let mid = (lower + upper) / 2.0;
120 if target <= mid {
121 if i == 0 {
122 return c.mean;
123 }
124 let prev = &self.centroids[i - 1];
125 let prev_mid = cumulative - prev.weight / 2.0;
126 let frac = (target - prev_mid) / (mid - prev_mid);
127 return prev.mean + frac * (c.mean - prev.mean);
128 }
129 cumulative += c.weight;
130 }
131 self.centroids.last().map_or(f64::NAN, |c| c.mean)
132 }
133
134 #[allow(dead_code)]
135 pub fn count(&self) -> f64 {
136 self.total_weight
137 }
138
139 #[allow(dead_code)]
140 pub fn centroid_count(&self) -> usize {
141 self.centroids.len()
142 }
143
144 #[allow(dead_code)]
145 pub fn min(&self) -> f64 {
146 self.centroids.first().map_or(f64::NAN, |c| c.mean)
147 }
148
149 #[allow(dead_code)]
150 pub fn max(&self) -> f64 {
151 self.centroids.last().map_or(f64::NAN, |c| c.mean)
152 }
153
154 #[allow(dead_code)]
155 pub fn merge(&mut self, other: &TDigest) {
156 for c in &other.centroids {
157 self.add_weighted(c.mean, c.weight);
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_empty_quantile() {
168 let td = TDigest::new(100.0);
169 assert!(td.quantile(0.5).is_nan());
170 }
171
172 #[test]
173 fn test_single_value() {
174 let mut td = TDigest::new(100.0);
175 td.add(5.0);
176 assert!((td.quantile(0.5) - 5.0).abs() < 1e-6);
177 }
178
179 #[test]
180 fn test_median_uniform() {
181 let mut td = TDigest::new(100.0);
182 for i in 1..=100 {
183 td.add(i as f64);
184 }
185 let m = td.quantile(0.5);
186 assert!(m > 40.0 && m < 60.0, "median={m}");
187 }
188
189 #[test]
190 fn test_min_max() {
191 let mut td = TDigest::new(100.0);
192 for i in [3.0, 1.0, 5.0, 2.0, 4.0] {
193 td.add(i);
194 }
195 assert!((td.min() - 1.0).abs() < 1.0);
196 assert!((td.max() - 5.0).abs() < 1.0);
197 }
198
199 #[test]
200 fn test_count() {
201 let mut td = TDigest::new(50.0);
202 for _ in 0..10 {
203 td.add(1.0);
204 }
205 assert!((td.count() - 10.0).abs() < 1e-6);
206 }
207
208 #[test]
209 fn test_compression_clamp() {
210 let td = TDigest::new(5.0);
211 assert!(td.compression >= 10.0);
212 }
213
214 #[test]
215 fn test_merge() {
216 let mut td1 = TDigest::new(100.0);
217 let mut td2 = TDigest::new(100.0);
218 for i in 1..=50 {
219 td1.add(i as f64);
220 }
221 for i in 51..=100 {
222 td2.add(i as f64);
223 }
224 td1.merge(&td2);
225 let m = td1.quantile(0.5);
226 assert!(m > 40.0 && m < 60.0, "merged median={m}");
227 }
228
229 #[test]
230 fn test_quantile_zero_one() {
231 let mut td = TDigest::new(100.0);
232 for i in 1..=10 {
233 td.add(i as f64);
234 }
235 assert!(td.quantile(0.0) <= td.quantile(0.5));
236 assert!(td.quantile(0.5) <= td.quantile(1.0));
237 }
238
239 #[test]
240 fn test_centroid_count_bounded() {
241 let mut td = TDigest::new(50.0);
242 for i in 0..1000 {
243 td.add(i as f64);
244 }
245 assert!(td.centroid_count() < 500);
246 }
247
248 #[test]
249 fn test_add_weighted() {
250 let mut td = TDigest::new(100.0);
251 td.add_weighted(10.0, 5.0);
252 assert!((td.count() - 5.0).abs() < 1e-6);
253 }
254}