datafusion_functions_aggregate_common/
tdigest.rs1use arrow::datatypes::DataType;
33use arrow::datatypes::Float64Type;
34use datafusion_common::ScalarValue;
35use datafusion_common::cast::as_primitive_array;
36use std::cmp::Ordering;
37use std::mem::{size_of, size_of_val};
38
39pub const DEFAULT_MAX_SIZE: usize = 100;
40
41macro_rules! cast_scalar_f64 {
44 ($value:expr ) => {
45 match &$value {
46 ScalarValue::Float64(Some(v)) => *v,
47 v => panic!("invalid type {}", v),
48 }
49 };
50}
51
52#[derive(Debug, PartialEq, Clone)]
54pub struct Centroid {
55 mean: f64,
56 weight: f64,
57}
58
59impl Centroid {
60 pub fn new(mean: f64, weight: f64) -> Self {
61 Centroid { mean, weight }
62 }
63
64 #[inline]
65 pub fn mean(&self) -> f64 {
66 self.mean
67 }
68
69 #[inline]
70 pub fn weight(&self) -> f64 {
71 self.weight
72 }
73
74 pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
75 let new_sum = sum + self.weight * self.mean;
76 let new_weight = self.weight + weight;
77 self.weight = new_weight;
78 self.mean = new_sum / new_weight;
79 new_sum
80 }
81
82 pub fn cmp_mean(&self, other: &Self) -> Ordering {
83 self.mean.total_cmp(&other.mean)
84 }
85}
86
87impl Default for Centroid {
88 fn default() -> Self {
89 Centroid {
90 mean: 0_f64,
91 weight: 1_f64,
92 }
93 }
94}
95
96#[derive(Debug, PartialEq, Clone)]
98pub struct TDigest {
99 centroids: Vec<Centroid>,
100 max_size: usize,
101 sum: f64,
102 count: f64,
103 max: f64,
104 min: f64,
105}
106
107impl TDigest {
108 pub fn new(max_size: usize) -> Self {
109 TDigest {
110 centroids: Vec::new(),
111 max_size,
112 sum: 0.0,
113 count: 0.0,
114 max: f64::NAN,
115 min: f64::NAN,
116 }
117 }
118
119 #[expect(clippy::needless_pass_by_value)]
120 pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
121 TDigest {
122 centroids: vec![centroid.clone()],
123 max_size,
124 sum: centroid.mean * centroid.weight,
125 count: centroid.weight,
126 max: centroid.mean,
127 min: centroid.mean,
128 }
129 }
130
131 #[inline]
132 pub fn count(&self) -> f64 {
133 self.count
134 }
135
136 #[inline]
137 pub fn max(&self) -> f64 {
138 self.max
139 }
140
141 #[inline]
142 pub fn min(&self) -> f64 {
143 self.min
144 }
145
146 #[inline]
147 pub fn max_size(&self) -> usize {
148 self.max_size
149 }
150
151 pub fn size(&self) -> usize {
153 size_of_val(self) + (size_of::<Centroid>() * self.centroids.capacity())
154 }
155}
156
157impl Default for TDigest {
158 fn default() -> Self {
159 TDigest {
160 centroids: Vec::new(),
161 max_size: 100,
162 sum: 0.0,
163 count: 0.0,
164 max: f64::NAN,
165 min: f64::NAN,
166 }
167 }
168}
169
170impl TDigest {
171 fn k_to_q(k: u64, d: usize) -> f64 {
172 let k_div_d = k as f64 / d as f64;
173 if k_div_d >= 0.5 {
174 let base = 1.0 - k_div_d;
175 1.0 - 2.0 * base * base
176 } else {
177 2.0 * k_div_d * k_div_d
178 }
179 }
180
181 fn clamp(v: f64, lo: f64, hi: f64) -> f64 {
182 if lo.is_nan() || hi.is_nan() {
183 return v;
184 }
185
186 let (min, max) = if lo > hi { (hi, lo) } else { (lo, hi) };
188
189 v.clamp(min, max)
190 }
191
192 pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
194 let mut values = unsorted_values;
195 values.sort_by(|a, b| a.total_cmp(b));
196 self.merge_sorted_f64(&values)
197 }
198
199 pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
200 #[cfg(debug_assertions)]
201 debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
202
203 if sorted_values.is_empty() {
204 return self.clone();
205 }
206
207 let mut result = TDigest::new(self.max_size());
208 result.count = self.count() + sorted_values.len() as f64;
209
210 let maybe_min = *sorted_values.first().unwrap();
211 let maybe_max = *sorted_values.last().unwrap();
212
213 if self.count() > 0.0 {
214 result.min = self.min.min(maybe_min);
215 result.max = self.max.max(maybe_max);
216 } else {
217 result.min = maybe_min;
218 result.max = maybe_max;
219 }
220
221 let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
222
223 let mut k_limit: u64 = 1;
224 let mut q_limit_times_count =
225 Self::k_to_q(k_limit, self.max_size) * result.count();
226 k_limit += 1;
227
228 let mut iter_centroids = self.centroids.iter().peekable();
229 let mut iter_sorted_values = sorted_values.iter().peekable();
230
231 let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
232 let curr = **iter_sorted_values.peek().unwrap();
233 if c.mean() < curr {
234 iter_centroids.next().unwrap().clone()
235 } else {
236 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
237 }
238 } else {
239 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
240 };
241
242 let mut weight_so_far = curr.weight();
243
244 let mut sums_to_merge = 0_f64;
245 let mut weights_to_merge = 0_f64;
246
247 while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() {
248 let next: Centroid = if let Some(c) = iter_centroids.peek() {
249 if iter_sorted_values.peek().is_none()
250 || c.mean() < **iter_sorted_values.peek().unwrap()
251 {
252 iter_centroids.next().unwrap().clone()
253 } else {
254 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
255 }
256 } else {
257 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
258 };
259
260 let next_sum = next.mean() * next.weight();
261 weight_so_far += next.weight();
262
263 if weight_so_far <= q_limit_times_count {
264 sums_to_merge += next_sum;
265 weights_to_merge += next.weight();
266 } else {
267 result.sum += curr.add(sums_to_merge, weights_to_merge);
268 sums_to_merge = 0_f64;
269 weights_to_merge = 0_f64;
270
271 compressed.push(curr.clone());
272 q_limit_times_count =
273 Self::k_to_q(k_limit, self.max_size) * result.count();
274 k_limit += 1;
275 curr = next;
276 }
277 }
278
279 result.sum += curr.add(sums_to_merge, weights_to_merge);
280 compressed.push(curr);
281 compressed.shrink_to_fit();
282 compressed.sort_by(|a, b| a.cmp_mean(b));
283
284 result.centroids = compressed;
285 result
286 }
287
288 fn external_merge(
289 centroids: &mut [Centroid],
290 first: usize,
291 middle: usize,
292 last: usize,
293 ) {
294 let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
295
296 let mut i = first;
297 let mut j = middle;
298
299 while i < middle && j < last {
300 match centroids[i].cmp_mean(¢roids[j]) {
301 Ordering::Less => {
302 result.push(centroids[i].clone());
303 i += 1;
304 }
305 Ordering::Greater => {
306 result.push(centroids[j].clone());
307 j += 1;
308 }
309 Ordering::Equal => {
310 result.push(centroids[i].clone());
311 i += 1;
312 }
313 }
314 }
315
316 while i < middle {
317 result.push(centroids[i].clone());
318 i += 1;
319 }
320
321 while j < last {
322 result.push(centroids[j].clone());
323 j += 1;
324 }
325
326 i = first;
327 for centroid in result.into_iter() {
328 centroids[i] = centroid;
329 i += 1;
330 }
331 }
332
333 pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest {
335 let digests = digests.into_iter().collect::<Vec<_>>();
336 let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum();
337 if n_centroids == 0 {
338 return TDigest::default();
339 }
340
341 let max_size = digests.first().unwrap().max_size;
342 let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
343 let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
344
345 let mut count = 0.0;
346 let mut min = f64::INFINITY;
347 let mut max = f64::NEG_INFINITY;
348
349 let mut start: usize = 0;
350 for digest in digests.iter() {
351 starts.push(start);
352
353 let curr_count = digest.count();
354 if curr_count > 0.0 {
355 min = min.min(digest.min);
356 max = max.max(digest.max);
357 count += curr_count;
358 for centroid in &digest.centroids {
359 centroids.push(centroid.clone());
360 start += 1;
361 }
362 }
363 }
364
365 if centroids.is_empty() {
367 return TDigest::default();
368 }
369
370 let mut digests_per_block: usize = 1;
371 while digests_per_block < starts.len() {
372 for i in (0..starts.len()).step_by(digests_per_block * 2) {
373 if i + digests_per_block < starts.len() {
374 let first = starts[i];
375 let middle = starts[i + digests_per_block];
376 let last = if i + 2 * digests_per_block < starts.len() {
377 starts[i + 2 * digests_per_block]
378 } else {
379 centroids.len()
380 };
381
382 debug_assert!(first <= middle && middle <= last);
383 Self::external_merge(&mut centroids, first, middle, last);
384 }
385 }
386
387 digests_per_block *= 2;
388 }
389
390 let mut result = TDigest::new(max_size);
391 let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
392
393 let mut k_limit = 1;
394 let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;
395
396 let mut iter_centroids = centroids.iter_mut();
397 let mut curr = iter_centroids.next().unwrap();
398 let mut weight_so_far = curr.weight();
399 let mut sums_to_merge = 0_f64;
400 let mut weights_to_merge = 0_f64;
401
402 for centroid in iter_centroids {
403 weight_so_far += centroid.weight();
404
405 if weight_so_far <= q_limit_times_count {
406 sums_to_merge += centroid.mean() * centroid.weight();
407 weights_to_merge += centroid.weight();
408 } else {
409 result.sum += curr.add(sums_to_merge, weights_to_merge);
410 sums_to_merge = 0_f64;
411 weights_to_merge = 0_f64;
412 compressed.push(curr.clone());
413 q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;
414 k_limit += 1;
415 curr = centroid;
416 }
417 }
418
419 result.sum += curr.add(sums_to_merge, weights_to_merge);
420 compressed.push(curr.clone());
421 compressed.shrink_to_fit();
422 compressed.sort_by(|a, b| a.cmp_mean(b));
423
424 result.count = count;
425 result.min = min;
426 result.max = max;
427 result.centroids = compressed;
428 result
429 }
430
431 pub fn estimate_quantile(&self, q: f64) -> f64 {
433 if self.centroids.is_empty() {
434 return 0.0;
435 }
436
437 let rank = q * self.count;
438
439 let mut pos: usize;
440 let mut t;
441 if q > 0.5 {
442 if q >= 1.0 {
443 return self.max();
444 }
445
446 pos = 0;
447 t = self.count;
448
449 for (k, centroid) in self.centroids.iter().enumerate().rev() {
450 t -= centroid.weight();
451
452 if rank >= t {
453 pos = k;
454 break;
455 }
456 }
457 } else {
458 if q <= 0.0 {
459 return self.min();
460 }
461
462 pos = self.centroids.len() - 1;
463 t = 0_f64;
464
465 for (k, centroid) in self.centroids.iter().enumerate() {
466 if rank < t + centroid.weight() {
467 pos = k;
468 break;
469 }
470
471 t += centroid.weight();
472 }
473 }
474
475 let mut delta = 0_f64;
476 let mut min = self.min;
477 let mut max = self.max;
478
479 if self.centroids.len() > 1 {
480 if pos == 0 {
481 delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean();
482 max = self.centroids[pos + 1].mean();
483 } else if pos == (self.centroids.len() - 1) {
484 delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean();
485 min = self.centroids[pos - 1].mean();
486 } else {
487 delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean())
488 / 2.0;
489 min = self.centroids[pos - 1].mean();
490 max = self.centroids[pos + 1].mean();
491 }
492 }
493
494 let value = self.centroids[pos].mean()
495 + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
496
497 if !min.is_finite() && min.is_sign_positive() {
502 min = f64::NEG_INFINITY;
503 }
504
505 if !max.is_finite() && max.is_sign_negative() {
506 max = f64::INFINITY;
507 }
508
509 Self::clamp(value, min, max)
510 }
511
512 pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
547 let centroids: Vec<ScalarValue> = self
549 .centroids
550 .iter()
551 .flat_map(|c| [c.mean(), c.weight()])
552 .map(|v| ScalarValue::Float64(Some(v)))
553 .collect();
554
555 let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64);
556
557 vec![
558 ScalarValue::UInt64(Some(self.max_size as u64)),
559 ScalarValue::Float64(Some(self.sum)),
560 ScalarValue::Float64(Some(self.count)),
561 ScalarValue::Float64(Some(self.max)),
562 ScalarValue::Float64(Some(self.min)),
563 ScalarValue::List(arr),
564 ]
565 }
566
567 pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
576 assert_eq!(state.len(), 6, "invalid TDigest state");
577
578 let max_size = match &state[0] {
579 ScalarValue::UInt64(Some(v)) => *v as usize,
580 v => panic!("invalid max_size type {v:?}"),
581 };
582
583 let centroids: Vec<_> = match &state[5] {
584 ScalarValue::List(arr) => {
585 let array = arr.values();
586
587 let f64arr =
588 as_primitive_array::<Float64Type>(array).expect("expected f64 array");
589 f64arr
590 .values()
591 .chunks(2)
592 .map(|v| Centroid::new(v[0], v[1]))
593 .collect()
594 }
595 v => panic!("invalid centroids type {v:?}"),
596 };
597
598 let max = cast_scalar_f64!(&state[3]);
599 let min = cast_scalar_f64!(&state[4]);
600
601 if min.is_finite() && max.is_finite() {
602 assert!(max.total_cmp(&min).is_ge());
603 }
604
605 Self {
606 max_size,
607 sum: cast_scalar_f64!(state[1]),
608 count: cast_scalar_f64!(state[2]),
609 max,
610 min,
611 centroids,
612 }
613 }
614}
615
616#[cfg(debug_assertions)]
617fn is_sorted(values: &[f64]) -> bool {
618 values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le())
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 macro_rules! assert_error_bounds {
628 ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
629 assert_error_bounds!(
630 $t,
631 quantile = $quantile,
632 want = $want,
633 allowable_error = 0.01
634 )
635 };
636 ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => {
637 let ans = $t.estimate_quantile($quantile);
638 let expected: f64 = $want;
639 let percentage: f64 = (expected - ans).abs() / expected;
640 assert!(
641 percentage < $re,
642 "relative error {} is more than {}% (got quantile {}, want {})",
643 percentage,
644 $re,
645 ans,
646 expected
647 );
648 };
649 }
650
651 macro_rules! assert_state_roundtrip {
652 ($t:ident) => {
653 let state = $t.to_scalar_state();
654 let other = TDigest::from_scalar_state(&state);
655 assert_eq!($t, other);
656 };
657 }
658
659 #[test]
660 fn test_int64_uniform() {
661 let values = (1i64..=1000).map(|v| v as f64).collect();
662
663 let t = TDigest::new(100);
664 let t = t.merge_unsorted_f64(values);
665
666 assert_error_bounds!(t, quantile = 0.1, want = 100.0);
667 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
668 assert_error_bounds!(t, quantile = 0.9, want = 900.0);
669 assert_state_roundtrip!(t);
670 }
671
672 #[test]
673 fn test_centroid_addition_regression() {
674 let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
677 let mut t = TDigest::new(10);
678
679 for v in vals {
680 t = t.merge_unsorted_f64(vec![v]);
681 }
682
683 assert_error_bounds!(t, quantile = 0.5, want = 1.0);
684 assert_error_bounds!(t, quantile = 0.95, want = 2.0);
685 assert_state_roundtrip!(t);
686 }
687
688 #[test]
689 fn test_merge_unsorted_against_uniform_distro() {
690 let t = TDigest::new(100);
691 let values: Vec<_> = (1..=1_000_000).map(f64::from).collect();
692
693 let t = t.merge_unsorted_f64(values);
694
695 assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
696 assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
697 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
698 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
699 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
700 assert_state_roundtrip!(t);
701 }
702
703 #[test]
704 fn test_merge_unsorted_against_skewed_distro() {
705 let t = TDigest::new(100);
706 let mut values: Vec<_> = (1..=600_000).map(f64::from).collect();
707 values.resize(1_000_000, 1_000_000_f64);
708
709 let t = t.merge_unsorted_f64(values);
710
711 assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
712 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
713 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
714 assert_state_roundtrip!(t);
715 }
716
717 #[test]
718 fn test_merge_digests() {
719 let mut digests: Vec<TDigest> = Vec::new();
720
721 for _ in 1..=100 {
722 let t = TDigest::new(100);
723 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
724 let t = t.merge_unsorted_f64(values);
725 digests.push(t)
726 }
727
728 let t = TDigest::merge_digests(&digests);
729
730 assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
731 assert_error_bounds!(t, quantile = 0.99, want = 990.0);
732 assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2);
733 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
734 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
735 assert_state_roundtrip!(t);
736 }
737
738 #[test]
739 fn test_size() {
740 let t = TDigest::new(10);
741 let t = t.merge_unsorted_f64(vec![0.0, 1.0]);
742
743 assert_eq!(t.size(), 96);
744 }
745
746 #[test]
747 fn test_identical_values_floating_point_precision() {
748 let t = TDigest::new(100);
754 let values: Vec<_> = (0..215).map(|_| 15.699999988079073_f64).collect();
755
756 let t = t.merge_unsorted_f64(values);
757
758 let result = t.estimate_quantile(0.99);
760 assert!((result - 15.699999988079073).abs() < 1e-10);
762 }
763}