datafusion_functions_aggregate_common/
tdigest.rs1use arrow::datatypes::DataType;
33use arrow::datatypes::Float64Type;
34use datafusion_common::ScalarValue;
35use datafusion_common::cast::as_primitive_array;
36use std::cmp::Ordering;
37use std::mem::{size_of, size_of_val};
38
39pub const DEFAULT_MAX_SIZE: usize = 100;
40
41macro_rules! cast_scalar_f64 {
44 ($value:expr ) => {
45 match &$value {
46 ScalarValue::Float64(Some(v)) => *v,
47 v => panic!("invalid type {}", v),
48 }
49 };
50}
51
52macro_rules! cast_scalar_u64 {
55 ($value:expr ) => {
56 match &$value {
57 ScalarValue::UInt64(Some(v)) => *v,
58 v => panic!("invalid type {}", v),
59 }
60 };
61}
62
63#[derive(Debug, PartialEq, Clone)]
65pub struct Centroid {
66 mean: f64,
67 weight: f64,
68}
69
70impl Centroid {
71 pub fn new(mean: f64, weight: f64) -> Self {
72 Centroid { mean, weight }
73 }
74
75 #[inline]
76 pub fn mean(&self) -> f64 {
77 self.mean
78 }
79
80 #[inline]
81 pub fn weight(&self) -> f64 {
82 self.weight
83 }
84
85 pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
86 let new_sum = sum + self.weight * self.mean;
87 let new_weight = self.weight + weight;
88 self.weight = new_weight;
89 self.mean = new_sum / new_weight;
90 new_sum
91 }
92
93 pub fn cmp_mean(&self, other: &Self) -> Ordering {
94 self.mean.total_cmp(&other.mean)
95 }
96}
97
98impl Default for Centroid {
99 fn default() -> Self {
100 Centroid {
101 mean: 0_f64,
102 weight: 1_f64,
103 }
104 }
105}
106
107#[derive(Debug, PartialEq, Clone)]
109pub struct TDigest {
110 centroids: Vec<Centroid>,
111 max_size: usize,
112 sum: f64,
113 count: u64,
114 max: f64,
115 min: f64,
116}
117
118impl TDigest {
119 pub fn new(max_size: usize) -> Self {
120 TDigest {
121 centroids: Vec::new(),
122 max_size,
123 sum: 0_f64,
124 count: 0,
125 max: f64::NAN,
126 min: f64::NAN,
127 }
128 }
129
130 #[expect(clippy::needless_pass_by_value)]
131 pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
132 TDigest {
133 centroids: vec![centroid.clone()],
134 max_size,
135 sum: centroid.mean * centroid.weight,
136 count: 1,
137 max: centroid.mean,
138 min: centroid.mean,
139 }
140 }
141
142 #[inline]
143 pub fn count(&self) -> u64 {
144 self.count
145 }
146
147 #[inline]
148 pub fn max(&self) -> f64 {
149 self.max
150 }
151
152 #[inline]
153 pub fn min(&self) -> f64 {
154 self.min
155 }
156
157 #[inline]
158 pub fn max_size(&self) -> usize {
159 self.max_size
160 }
161
162 pub fn size(&self) -> usize {
164 size_of_val(self) + (size_of::<Centroid>() * self.centroids.capacity())
165 }
166}
167
168impl Default for TDigest {
169 fn default() -> Self {
170 TDigest {
171 centroids: Vec::new(),
172 max_size: 100,
173 sum: 0_f64,
174 count: 0,
175 max: f64::NAN,
176 min: f64::NAN,
177 }
178 }
179}
180
181impl TDigest {
182 fn k_to_q(k: u64, d: usize) -> f64 {
183 let k_div_d = k as f64 / d as f64;
184 if k_div_d >= 0.5 {
185 let base = 1.0 - k_div_d;
186 1.0 - 2.0 * base * base
187 } else {
188 2.0 * k_div_d * k_div_d
189 }
190 }
191
192 fn clamp(v: f64, lo: f64, hi: f64) -> f64 {
193 if lo.is_nan() || hi.is_nan() {
194 return v;
195 }
196
197 let (min, max) = if lo > hi { (hi, lo) } else { (lo, hi) };
199
200 v.clamp(min, max)
201 }
202
203 pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
205 let mut values = unsorted_values;
206 values.sort_by(|a, b| a.total_cmp(b));
207 self.merge_sorted_f64(&values)
208 }
209
210 pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
211 #[cfg(debug_assertions)]
212 debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
213
214 if sorted_values.is_empty() {
215 return self.clone();
216 }
217
218 let mut result = TDigest::new(self.max_size());
219 result.count = self.count() + sorted_values.len() as u64;
220
221 let maybe_min = *sorted_values.first().unwrap();
222 let maybe_max = *sorted_values.last().unwrap();
223
224 if self.count() > 0 {
225 result.min = self.min.min(maybe_min);
226 result.max = self.max.max(maybe_max);
227 } else {
228 result.min = maybe_min;
229 result.max = maybe_max;
230 }
231
232 let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
233
234 let mut k_limit: u64 = 1;
235 let mut q_limit_times_count =
236 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
237 k_limit += 1;
238
239 let mut iter_centroids = self.centroids.iter().peekable();
240 let mut iter_sorted_values = sorted_values.iter().peekable();
241
242 let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
243 let curr = **iter_sorted_values.peek().unwrap();
244 if c.mean() < curr {
245 iter_centroids.next().unwrap().clone()
246 } else {
247 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
248 }
249 } else {
250 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
251 };
252
253 let mut weight_so_far = curr.weight();
254
255 let mut sums_to_merge = 0_f64;
256 let mut weights_to_merge = 0_f64;
257
258 while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() {
259 let next: Centroid = if let Some(c) = iter_centroids.peek() {
260 if iter_sorted_values.peek().is_none()
261 || c.mean() < **iter_sorted_values.peek().unwrap()
262 {
263 iter_centroids.next().unwrap().clone()
264 } else {
265 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
266 }
267 } else {
268 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
269 };
270
271 let next_sum = next.mean() * next.weight();
272 weight_so_far += next.weight();
273
274 if weight_so_far <= q_limit_times_count {
275 sums_to_merge += next_sum;
276 weights_to_merge += next.weight();
277 } else {
278 result.sum += curr.add(sums_to_merge, weights_to_merge);
279 sums_to_merge = 0_f64;
280 weights_to_merge = 0_f64;
281
282 compressed.push(curr.clone());
283 q_limit_times_count =
284 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
285 k_limit += 1;
286 curr = next;
287 }
288 }
289
290 result.sum += curr.add(sums_to_merge, weights_to_merge);
291 compressed.push(curr);
292 compressed.shrink_to_fit();
293 compressed.sort_by(|a, b| a.cmp_mean(b));
294
295 result.centroids = compressed;
296 result
297 }
298
299 fn external_merge(
300 centroids: &mut [Centroid],
301 first: usize,
302 middle: usize,
303 last: usize,
304 ) {
305 let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
306
307 let mut i = first;
308 let mut j = middle;
309
310 while i < middle && j < last {
311 match centroids[i].cmp_mean(¢roids[j]) {
312 Ordering::Less => {
313 result.push(centroids[i].clone());
314 i += 1;
315 }
316 Ordering::Greater => {
317 result.push(centroids[j].clone());
318 j += 1;
319 }
320 Ordering::Equal => {
321 result.push(centroids[i].clone());
322 i += 1;
323 }
324 }
325 }
326
327 while i < middle {
328 result.push(centroids[i].clone());
329 i += 1;
330 }
331
332 while j < last {
333 result.push(centroids[j].clone());
334 j += 1;
335 }
336
337 i = first;
338 for centroid in result.into_iter() {
339 centroids[i] = centroid;
340 i += 1;
341 }
342 }
343
344 pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest {
346 let digests = digests.into_iter().collect::<Vec<_>>();
347 let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum();
348 if n_centroids == 0 {
349 return TDigest::default();
350 }
351
352 let max_size = digests.first().unwrap().max_size;
353 let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
354 let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
355
356 let mut count = 0;
357 let mut min = f64::INFINITY;
358 let mut max = f64::NEG_INFINITY;
359
360 let mut start: usize = 0;
361 for digest in digests.iter() {
362 starts.push(start);
363
364 let curr_count = digest.count();
365 if curr_count > 0 {
366 min = min.min(digest.min);
367 max = max.max(digest.max);
368 count += curr_count;
369 for centroid in &digest.centroids {
370 centroids.push(centroid.clone());
371 start += 1;
372 }
373 }
374 }
375
376 let mut digests_per_block: usize = 1;
377 while digests_per_block < starts.len() {
378 for i in (0..starts.len()).step_by(digests_per_block * 2) {
379 if i + digests_per_block < starts.len() {
380 let first = starts[i];
381 let middle = starts[i + digests_per_block];
382 let last = if i + 2 * digests_per_block < starts.len() {
383 starts[i + 2 * digests_per_block]
384 } else {
385 centroids.len()
386 };
387
388 debug_assert!(first <= middle && middle <= last);
389 Self::external_merge(&mut centroids, first, middle, last);
390 }
391 }
392
393 digests_per_block *= 2;
394 }
395
396 let mut result = TDigest::new(max_size);
397 let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
398
399 let mut k_limit = 1;
400 let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
401
402 let mut iter_centroids = centroids.iter_mut();
403 let mut curr = iter_centroids.next().unwrap();
404 let mut weight_so_far = curr.weight();
405 let mut sums_to_merge = 0_f64;
406 let mut weights_to_merge = 0_f64;
407
408 for centroid in iter_centroids {
409 weight_so_far += centroid.weight();
410
411 if weight_so_far <= q_limit_times_count {
412 sums_to_merge += centroid.mean() * centroid.weight();
413 weights_to_merge += centroid.weight();
414 } else {
415 result.sum += curr.add(sums_to_merge, weights_to_merge);
416 sums_to_merge = 0_f64;
417 weights_to_merge = 0_f64;
418 compressed.push(curr.clone());
419 q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
420 k_limit += 1;
421 curr = centroid;
422 }
423 }
424
425 result.sum += curr.add(sums_to_merge, weights_to_merge);
426 compressed.push(curr.clone());
427 compressed.shrink_to_fit();
428 compressed.sort_by(|a, b| a.cmp_mean(b));
429
430 result.count = count;
431 result.min = min;
432 result.max = max;
433 result.centroids = compressed;
434 result
435 }
436
437 pub fn estimate_quantile(&self, q: f64) -> f64 {
439 if self.centroids.is_empty() {
440 return 0.0;
441 }
442
443 let rank = q * self.count as f64;
444
445 let mut pos: usize;
446 let mut t;
447 if q > 0.5 {
448 if q >= 1.0 {
449 return self.max();
450 }
451
452 pos = 0;
453 t = self.count as f64;
454
455 for (k, centroid) in self.centroids.iter().enumerate().rev() {
456 t -= centroid.weight();
457
458 if rank >= t {
459 pos = k;
460 break;
461 }
462 }
463 } else {
464 if q <= 0.0 {
465 return self.min();
466 }
467
468 pos = self.centroids.len() - 1;
469 t = 0_f64;
470
471 for (k, centroid) in self.centroids.iter().enumerate() {
472 if rank < t + centroid.weight() {
473 pos = k;
474 break;
475 }
476
477 t += centroid.weight();
478 }
479 }
480
481 let mut delta = 0_f64;
482 let mut min = self.min;
483 let mut max = self.max;
484
485 if self.centroids.len() > 1 {
486 if pos == 0 {
487 delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean();
488 max = self.centroids[pos + 1].mean();
489 } else if pos == (self.centroids.len() - 1) {
490 delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean();
491 min = self.centroids[pos - 1].mean();
492 } else {
493 delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean())
494 / 2.0;
495 min = self.centroids[pos - 1].mean();
496 max = self.centroids[pos + 1].mean();
497 }
498 }
499
500 let value = self.centroids[pos].mean()
501 + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
502
503 if !min.is_finite() && min.is_sign_positive() {
508 min = f64::NEG_INFINITY;
509 }
510
511 if !max.is_finite() && max.is_sign_negative() {
512 max = f64::INFINITY;
513 }
514
515 Self::clamp(value, min, max)
516 }
517
518 pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
553 let centroids: Vec<ScalarValue> = self
555 .centroids
556 .iter()
557 .flat_map(|c| [c.mean(), c.weight()])
558 .map(|v| ScalarValue::Float64(Some(v)))
559 .collect();
560
561 let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64);
562
563 vec![
564 ScalarValue::UInt64(Some(self.max_size as u64)),
565 ScalarValue::Float64(Some(self.sum)),
566 ScalarValue::UInt64(Some(self.count)),
567 ScalarValue::Float64(Some(self.max)),
568 ScalarValue::Float64(Some(self.min)),
569 ScalarValue::List(arr),
570 ]
571 }
572
573 pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
582 assert_eq!(state.len(), 6, "invalid TDigest state");
583
584 let max_size = match &state[0] {
585 ScalarValue::UInt64(Some(v)) => *v as usize,
586 v => panic!("invalid max_size type {v:?}"),
587 };
588
589 let centroids: Vec<_> = match &state[5] {
590 ScalarValue::List(arr) => {
591 let array = arr.values();
592
593 let f64arr =
594 as_primitive_array::<Float64Type>(array).expect("expected f64 array");
595 f64arr
596 .values()
597 .chunks(2)
598 .map(|v| Centroid::new(v[0], v[1]))
599 .collect()
600 }
601 v => panic!("invalid centroids type {v:?}"),
602 };
603
604 let max = cast_scalar_f64!(&state[3]);
605 let min = cast_scalar_f64!(&state[4]);
606
607 if min.is_finite() && max.is_finite() {
608 assert!(max.total_cmp(&min).is_ge());
609 }
610
611 Self {
612 max_size,
613 sum: cast_scalar_f64!(state[1]),
614 count: cast_scalar_u64!(&state[2]),
615 max,
616 min,
617 centroids,
618 }
619 }
620}
621
622#[cfg(debug_assertions)]
623fn is_sorted(values: &[f64]) -> bool {
624 values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le())
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 macro_rules! assert_error_bounds {
634 ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
635 assert_error_bounds!(
636 $t,
637 quantile = $quantile,
638 want = $want,
639 allowable_error = 0.01
640 )
641 };
642 ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => {
643 let ans = $t.estimate_quantile($quantile);
644 let expected: f64 = $want;
645 let percentage: f64 = (expected - ans).abs() / expected;
646 assert!(
647 percentage < $re,
648 "relative error {} is more than {}% (got quantile {}, want {})",
649 percentage,
650 $re,
651 ans,
652 expected
653 );
654 };
655 }
656
657 macro_rules! assert_state_roundtrip {
658 ($t:ident) => {
659 let state = $t.to_scalar_state();
660 let other = TDigest::from_scalar_state(&state);
661 assert_eq!($t, other);
662 };
663 }
664
665 #[test]
666 fn test_int64_uniform() {
667 let values = (1i64..=1000).map(|v| v as f64).collect();
668
669 let t = TDigest::new(100);
670 let t = t.merge_unsorted_f64(values);
671
672 assert_error_bounds!(t, quantile = 0.1, want = 100.0);
673 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
674 assert_error_bounds!(t, quantile = 0.9, want = 900.0);
675 assert_state_roundtrip!(t);
676 }
677
678 #[test]
679 fn test_centroid_addition_regression() {
680 let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
683 let mut t = TDigest::new(10);
684
685 for v in vals {
686 t = t.merge_unsorted_f64(vec![v]);
687 }
688
689 assert_error_bounds!(t, quantile = 0.5, want = 1.0);
690 assert_error_bounds!(t, quantile = 0.95, want = 2.0);
691 assert_state_roundtrip!(t);
692 }
693
694 #[test]
695 fn test_merge_unsorted_against_uniform_distro() {
696 let t = TDigest::new(100);
697 let values: Vec<_> = (1..=1_000_000).map(f64::from).collect();
698
699 let t = t.merge_unsorted_f64(values);
700
701 assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
702 assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
703 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
704 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
705 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
706 assert_state_roundtrip!(t);
707 }
708
709 #[test]
710 fn test_merge_unsorted_against_skewed_distro() {
711 let t = TDigest::new(100);
712 let mut values: Vec<_> = (1..=600_000).map(f64::from).collect();
713 values.resize(1_000_000, 1_000_000_f64);
714
715 let t = t.merge_unsorted_f64(values);
716
717 assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
718 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
719 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
720 assert_state_roundtrip!(t);
721 }
722
723 #[test]
724 fn test_merge_digests() {
725 let mut digests: Vec<TDigest> = Vec::new();
726
727 for _ in 1..=100 {
728 let t = TDigest::new(100);
729 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
730 let t = t.merge_unsorted_f64(values);
731 digests.push(t)
732 }
733
734 let t = TDigest::merge_digests(&digests);
735
736 assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
737 assert_error_bounds!(t, quantile = 0.99, want = 990.0);
738 assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2);
739 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
740 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
741 assert_state_roundtrip!(t);
742 }
743
744 #[test]
745 fn test_size() {
746 let t = TDigest::new(10);
747 let t = t.merge_unsorted_f64(vec![0.0, 1.0]);
748
749 assert_eq!(t.size(), 96);
750 }
751
752 #[test]
753 fn test_identical_values_floating_point_precision() {
754 let t = TDigest::new(100);
760 let values: Vec<_> = (0..215).map(|_| 15.699999988079073_f64).collect();
761
762 let t = t.merge_unsorted_f64(values);
763
764 let result = t.estimate_quantile(0.99);
766 assert!((result - 15.699999988079073).abs() < 1e-10);
768 }
769}