1use crate::core::DocId;
11
12use super::{AggregationResult, Aggregator, AggregatorFactory, MetricResult};
13use crate::segment::reader::SegmentReader;
14
15#[derive(Clone, Copy, Debug)]
17struct Centroid {
18 mean: f64,
19 weight: f64,
20}
21
22#[derive(Clone)]
24pub struct TDigest {
25 centroids: Vec<Centroid>,
26 compression: f64,
27 total_weight: f64,
28}
29
30impl TDigest {
31 pub fn new(compression: f64) -> Self {
32 Self {
33 centroids: Vec::new(),
34 compression,
35 total_weight: 0.0,
36 }
37 }
38
39 pub fn add(&mut self, value: f64) {
41 self.centroids.push(Centroid {
42 mean: value,
43 weight: 1.0,
44 });
45 self.total_weight += 1.0;
46
47 if self.centroids.len() > (self.compression * 5.0) as usize {
49 self.compress();
50 }
51 }
52
53 pub fn merge(&mut self, other: &TDigest) {
55 self.centroids.extend_from_slice(&other.centroids);
56 self.total_weight += other.total_weight;
57 self.compress();
58 }
59
60 fn compress(&mut self) {
65 if self.centroids.is_empty() {
66 return;
67 }
68
69 self.centroids
70 .sort_by(|a, b| a.mean.partial_cmp(&b.mean).unwrap());
71
72 let mut compressed: Vec<Centroid> = Vec::new();
73 let mut weight_so_far = 0.0f64;
74 let mut current = self.centroids[0];
75
76 for &c in &self.centroids[1..] {
77 let q = (weight_so_far + current.weight) / self.total_weight;
78 let q_clamped = q.clamp(0.001, 0.999);
82 let max_weight = self.compression * (q_clamped * (1.0 - q_clamped)).sqrt();
83
84 if current.weight + c.weight <= max_weight.max(1.0) {
85 let new_weight = current.weight + c.weight;
86 current.mean = (current.mean * current.weight + c.mean * c.weight) / new_weight;
87 current.weight = new_weight;
88 } else {
89 weight_so_far += current.weight;
90 compressed.push(current);
91 current = c;
92 }
93 }
94 compressed.push(current);
95 self.centroids = compressed;
96 }
97
98 pub fn to_bytes(&self) -> Vec<u8> {
100 let mut buf = Vec::with_capacity(16 + self.centroids.len() * 16);
101 buf.extend_from_slice(&self.compression.to_le_bytes());
102 buf.extend_from_slice(&self.total_weight.to_le_bytes());
103 buf.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes());
104 for c in &self.centroids {
105 buf.extend_from_slice(&c.mean.to_le_bytes());
106 buf.extend_from_slice(&c.weight.to_le_bytes());
107 }
108 buf
109 }
110
111 pub fn from_bytes(data: &[u8]) -> Option<Self> {
113 if data.len() < 20 {
114 return None;
115 }
116 let compression = f64::from_le_bytes(data[0..8].try_into().ok()?);
117 let total_weight = f64::from_le_bytes(data[8..16].try_into().ok()?);
118 let count = u32::from_le_bytes(data[16..20].try_into().ok()?) as usize;
119 if data.len() < 20 + count * 16 {
120 return None;
121 }
122 let mut centroids = Vec::with_capacity(count);
123 let mut pos = 20;
124 for _ in 0..count {
125 let mean = f64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
126 let weight = f64::from_le_bytes(data[pos + 8..pos + 16].try_into().ok()?);
127 centroids.push(Centroid { mean, weight });
128 pos += 16;
129 }
130 Some(Self {
131 centroids,
132 compression,
133 total_weight,
134 })
135 }
136
137 pub fn percentile(&self, p: f64) -> f64 {
139 if self.centroids.is_empty() {
140 return 0.0;
141 }
142 if self.centroids.len() == 1 {
143 return self.centroids[0].mean;
144 }
145
146 let target = p / 100.0 * self.total_weight;
148 let mut cumulative = 0.0;
149
150 for (i, c) in self.centroids.iter().enumerate() {
151 if cumulative + c.weight >= target {
152 if i == 0 {
153 return c.mean;
154 }
155 let prev = &self.centroids[i - 1];
157 let prev_cumulative = cumulative;
158 let frac = if c.weight > 0.0 {
159 (target - prev_cumulative) / c.weight
160 } else {
161 0.0
162 };
163 return prev.mean + frac * (c.mean - prev.mean);
164 }
165 cumulative += c.weight;
166 }
167
168 self.centroids.last().unwrap().mean
169 }
170}
171
172pub struct PercentilesAggFactory {
175 pub field_name: String,
176 pub percents: Vec<f64>,
177 pub compression: f64,
178}
179
180impl AggregatorFactory for PercentilesAggFactory {
181 fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
182 let field_id = reader
183 .header()
184 .fields
185 .iter()
186 .find(|f| f.field_name == self.field_name)
187 .map(|f| f.field_id);
188
189 let col = super::bucket::OwnedColumn::new(field_id, reader);
190
191 Box::new(PercentilesCollector {
192 digest: TDigest::new(self.compression),
193 col,
194 percents: self.percents.clone(),
195 })
196 }
197
198 fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
199 let mut merged = TDigest::new(self.compression);
200 let mut has_data = false;
201
202 for r in &results {
203 if let AggregationResult::Metric(m) = r {
204 if let Some(ref bytes) = m.merge_state {
205 if let Some(segment_digest) = TDigest::from_bytes(bytes) {
206 merged.merge(&segment_digest);
207 has_data = true;
208 }
209 }
210 }
211 }
212
213 if !has_data {
214 return AggregationResult::Metric(MetricResult::single(None));
215 }
216
217 merged.compress();
218 let mut result = MetricResult::single(None);
219 result.extra.insert("count".into(), merged.total_weight);
220 for &p in &self.percents {
221 let val = merged.percentile(p);
222 result.extra.insert(format!("{p}"), val);
223 }
224 AggregationResult::Metric(result)
225 }
226}
227
228struct PercentilesCollector {
229 digest: TDigest,
230 col: Option<super::bucket::OwnedColumn>,
231 percents: Vec<f64>,
232}
233
234unsafe impl Send for PercentilesCollector {}
235
236impl Aggregator for PercentilesCollector {
237 fn collect(&mut self, doc_id: DocId) {
238 if let Some(v) = self
239 .col
240 .as_ref()
241 .and_then(|c| c.numeric_value(doc_id.as_u32()))
242 {
243 self.digest.add(v);
244 }
245 }
246
247 fn finish(self: Box<Self>) -> AggregationResult {
248 if self.digest.total_weight == 0.0 {
249 return AggregationResult::Metric(MetricResult::single(None));
250 }
251
252 let mut digest = self.digest;
254 digest.compress();
255
256 let mut result = MetricResult::single(None);
257 result.extra.insert("count".into(), digest.total_weight);
258 for &p in &self.percents {
259 let val = digest.percentile(p);
260 result.extra.insert(format!("{p}"), val);
261 }
262 result.merge_state = Some(digest.to_bytes());
263 AggregationResult::Metric(result)
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn tdigest_empty() {
273 let d = TDigest::new(100.0);
274 assert_eq!(d.percentile(50.0), 0.0);
275 }
276
277 #[test]
278 fn tdigest_single_value() {
279 let mut d = TDigest::new(100.0);
280 d.add(42.0);
281 assert_eq!(d.percentile(50.0), 42.0);
282 assert_eq!(d.percentile(1.0), 42.0);
283 assert_eq!(d.percentile(99.0), 42.0);
284 }
285
286 #[test]
287 fn tdigest_uniform_distribution() {
288 let mut d = TDigest::new(100.0);
289 for i in 0..10000 {
290 d.add(i as f64);
291 }
292 d.compress();
293
294 let p50 = d.percentile(50.0);
295 assert!((p50 - 5000.0).abs() < 500.0, "p50: {p50}");
296
297 let p99 = d.percentile(99.0);
298 assert!((p99 - 9900.0).abs() < 500.0, "p99: {p99}");
299
300 let p1 = d.percentile(1.0);
301 assert!((p1 - 100.0).abs() < 500.0, "p1: {p1}");
302 }
303
304 #[test]
305 fn tdigest_merge() {
306 let mut d1 = TDigest::new(100.0);
307 let mut d2 = TDigest::new(100.0);
308 for i in 0..5000 {
309 d1.add(i as f64);
310 }
311 for i in 5000..10000 {
312 d2.add(i as f64);
313 }
314 d1.merge(&d2);
315
316 let p50 = d1.percentile(50.0);
317 assert!((p50 - 5000.0).abs() < 500.0, "merged p50: {p50}");
318 }
319}