datafusion_functions_aggregate_common/
tdigest.rs1use arrow::datatypes::DataType;
33use arrow::datatypes::Float64Type;
34use datafusion_common::cast::as_primitive_array;
35use datafusion_common::Result;
36use datafusion_common::ScalarValue;
37use std::cmp::Ordering;
38use std::mem::{size_of, size_of_val};
39
40pub const DEFAULT_MAX_SIZE: usize = 100;
41
42macro_rules! cast_scalar_f64 {
45 ($value:expr ) => {
46 match &$value {
47 ScalarValue::Float64(Some(v)) => *v,
48 v => panic!("invalid type {:?}", v),
49 }
50 };
51}
52
53macro_rules! cast_scalar_u64 {
56 ($value:expr ) => {
57 match &$value {
58 ScalarValue::UInt64(Some(v)) => *v,
59 v => panic!("invalid type {:?}", v),
60 }
61 };
62}
63
64pub trait TryIntoF64 {
68 fn try_as_f64(&self) -> Result<Option<f64>>;
75}
76
77macro_rules! impl_try_ordered_f64 {
79 ($type:ty) => {
80 impl TryIntoF64 for $type {
81 fn try_as_f64(&self) -> Result<Option<f64>> {
82 Ok(Some(*self as f64))
83 }
84 }
85 };
86}
87
88impl_try_ordered_f64!(f64);
89impl_try_ordered_f64!(f32);
90impl_try_ordered_f64!(i64);
91impl_try_ordered_f64!(i32);
92impl_try_ordered_f64!(i16);
93impl_try_ordered_f64!(i8);
94impl_try_ordered_f64!(u64);
95impl_try_ordered_f64!(u32);
96impl_try_ordered_f64!(u16);
97impl_try_ordered_f64!(u8);
98
99#[derive(Debug, PartialEq, Clone)]
101pub struct Centroid {
102 mean: f64,
103 weight: f64,
104}
105
106impl PartialOrd for Centroid {
107 fn partial_cmp(&self, other: &Centroid) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl Eq for Centroid {}
113
114impl Ord for Centroid {
115 fn cmp(&self, other: &Centroid) -> Ordering {
116 self.mean.total_cmp(&other.mean)
117 }
118}
119
120impl Centroid {
121 pub fn new(mean: f64, weight: f64) -> Self {
122 Centroid { mean, weight }
123 }
124
125 #[inline]
126 pub fn mean(&self) -> f64 {
127 self.mean
128 }
129
130 #[inline]
131 pub fn weight(&self) -> f64 {
132 self.weight
133 }
134
135 pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
136 let new_sum = sum + self.weight * self.mean;
137 let new_weight = self.weight + weight;
138 self.weight = new_weight;
139 self.mean = new_sum / new_weight;
140 new_sum
141 }
142}
143
144impl Default for Centroid {
145 fn default() -> Self {
146 Centroid {
147 mean: 0_f64,
148 weight: 1_f64,
149 }
150 }
151}
152
153#[derive(Debug, PartialEq, Clone)]
155pub struct TDigest {
156 centroids: Vec<Centroid>,
157 max_size: usize,
158 sum: f64,
159 count: u64,
160 max: f64,
161 min: f64,
162}
163
164impl TDigest {
165 pub fn new(max_size: usize) -> Self {
166 TDigest {
167 centroids: Vec::new(),
168 max_size,
169 sum: 0_f64,
170 count: 0,
171 max: f64::NAN,
172 min: f64::NAN,
173 }
174 }
175
176 pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
177 TDigest {
178 centroids: vec![centroid.clone()],
179 max_size,
180 sum: centroid.mean * centroid.weight,
181 count: 1,
182 max: centroid.mean,
183 min: centroid.mean,
184 }
185 }
186
187 #[inline]
188 pub fn count(&self) -> u64 {
189 self.count
190 }
191
192 #[inline]
193 pub fn max(&self) -> f64 {
194 self.max
195 }
196
197 #[inline]
198 pub fn min(&self) -> f64 {
199 self.min
200 }
201
202 #[inline]
203 pub fn max_size(&self) -> usize {
204 self.max_size
205 }
206
207 pub fn size(&self) -> usize {
209 size_of_val(self) + (size_of::<Centroid>() * self.centroids.capacity())
210 }
211}
212
213impl Default for TDigest {
214 fn default() -> Self {
215 TDigest {
216 centroids: Vec::new(),
217 max_size: 100,
218 sum: 0_f64,
219 count: 0,
220 max: f64::NAN,
221 min: f64::NAN,
222 }
223 }
224}
225
226impl TDigest {
227 fn k_to_q(k: u64, d: usize) -> f64 {
228 let k_div_d = k as f64 / d as f64;
229 if k_div_d >= 0.5 {
230 let base = 1.0 - k_div_d;
231 1.0 - 2.0 * base * base
232 } else {
233 2.0 * k_div_d * k_div_d
234 }
235 }
236
237 fn clamp(v: f64, lo: f64, hi: f64) -> f64 {
238 if lo.is_nan() || hi.is_nan() {
239 return v;
240 }
241 v.clamp(lo, hi)
242 }
243
244 pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
246 let mut values = unsorted_values;
247 values.sort_by(|a, b| a.total_cmp(b));
248 self.merge_sorted_f64(&values)
249 }
250
251 pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
252 #[cfg(debug_assertions)]
253 debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
254
255 if sorted_values.is_empty() {
256 return self.clone();
257 }
258
259 let mut result = TDigest::new(self.max_size());
260 result.count = self.count() + sorted_values.len() as u64;
261
262 let maybe_min = *sorted_values.first().unwrap();
263 let maybe_max = *sorted_values.last().unwrap();
264
265 if self.count() > 0 {
266 result.min = self.min.min(maybe_min);
267 result.max = self.max.max(maybe_max);
268 } else {
269 result.min = maybe_min;
270 result.max = maybe_max;
271 }
272
273 let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
274
275 let mut k_limit: u64 = 1;
276 let mut q_limit_times_count =
277 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
278 k_limit += 1;
279
280 let mut iter_centroids = self.centroids.iter().peekable();
281 let mut iter_sorted_values = sorted_values.iter().peekable();
282
283 let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
284 let curr = **iter_sorted_values.peek().unwrap();
285 if c.mean() < curr {
286 iter_centroids.next().unwrap().clone()
287 } else {
288 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
289 }
290 } else {
291 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
292 };
293
294 let mut weight_so_far = curr.weight();
295
296 let mut sums_to_merge = 0_f64;
297 let mut weights_to_merge = 0_f64;
298
299 while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() {
300 let next: Centroid = if let Some(c) = iter_centroids.peek() {
301 if iter_sorted_values.peek().is_none()
302 || c.mean() < **iter_sorted_values.peek().unwrap()
303 {
304 iter_centroids.next().unwrap().clone()
305 } else {
306 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
307 }
308 } else {
309 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
310 };
311
312 let next_sum = next.mean() * next.weight();
313 weight_so_far += next.weight();
314
315 if weight_so_far <= q_limit_times_count {
316 sums_to_merge += next_sum;
317 weights_to_merge += next.weight();
318 } else {
319 result.sum += curr.add(sums_to_merge, weights_to_merge);
320 sums_to_merge = 0_f64;
321 weights_to_merge = 0_f64;
322
323 compressed.push(curr.clone());
324 q_limit_times_count =
325 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
326 k_limit += 1;
327 curr = next;
328 }
329 }
330
331 result.sum += curr.add(sums_to_merge, weights_to_merge);
332 compressed.push(curr);
333 compressed.shrink_to_fit();
334 compressed.sort();
335
336 result.centroids = compressed;
337 result
338 }
339
340 fn external_merge(
341 centroids: &mut [Centroid],
342 first: usize,
343 middle: usize,
344 last: usize,
345 ) {
346 let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
347
348 let mut i = first;
349 let mut j = middle;
350
351 while i < middle && j < last {
352 match centroids[i].cmp(¢roids[j]) {
353 Ordering::Less => {
354 result.push(centroids[i].clone());
355 i += 1;
356 }
357 Ordering::Greater => {
358 result.push(centroids[j].clone());
359 j += 1;
360 }
361 Ordering::Equal => {
362 result.push(centroids[i].clone());
363 i += 1;
364 }
365 }
366 }
367
368 while i < middle {
369 result.push(centroids[i].clone());
370 i += 1;
371 }
372
373 while j < last {
374 result.push(centroids[j].clone());
375 j += 1;
376 }
377
378 i = first;
379 for centroid in result.into_iter() {
380 centroids[i] = centroid;
381 i += 1;
382 }
383 }
384
385 pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest {
387 let digests = digests.into_iter().collect::<Vec<_>>();
388 let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum();
389 if n_centroids == 0 {
390 return TDigest::default();
391 }
392
393 let max_size = digests.first().unwrap().max_size;
394 let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
395 let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
396
397 let mut count = 0;
398 let mut min = f64::INFINITY;
399 let mut max = f64::NEG_INFINITY;
400
401 let mut start: usize = 0;
402 for digest in digests.iter() {
403 starts.push(start);
404
405 let curr_count = digest.count();
406 if curr_count > 0 {
407 min = min.min(digest.min);
408 max = max.max(digest.max);
409 count += curr_count;
410 for centroid in &digest.centroids {
411 centroids.push(centroid.clone());
412 start += 1;
413 }
414 }
415 }
416
417 let mut digests_per_block: usize = 1;
418 while digests_per_block < starts.len() {
419 for i in (0..starts.len()).step_by(digests_per_block * 2) {
420 if i + digests_per_block < starts.len() {
421 let first = starts[i];
422 let middle = starts[i + digests_per_block];
423 let last = if i + 2 * digests_per_block < starts.len() {
424 starts[i + 2 * digests_per_block]
425 } else {
426 centroids.len()
427 };
428
429 debug_assert!(first <= middle && middle <= last);
430 Self::external_merge(&mut centroids, first, middle, last);
431 }
432 }
433
434 digests_per_block *= 2;
435 }
436
437 let mut result = TDigest::new(max_size);
438 let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
439
440 let mut k_limit = 1;
441 let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
442
443 let mut iter_centroids = centroids.iter_mut();
444 let mut curr = iter_centroids.next().unwrap();
445 let mut weight_so_far = curr.weight();
446 let mut sums_to_merge = 0_f64;
447 let mut weights_to_merge = 0_f64;
448
449 for centroid in iter_centroids {
450 weight_so_far += centroid.weight();
451
452 if weight_so_far <= q_limit_times_count {
453 sums_to_merge += centroid.mean() * centroid.weight();
454 weights_to_merge += centroid.weight();
455 } else {
456 result.sum += curr.add(sums_to_merge, weights_to_merge);
457 sums_to_merge = 0_f64;
458 weights_to_merge = 0_f64;
459 compressed.push(curr.clone());
460 q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
461 k_limit += 1;
462 curr = centroid;
463 }
464 }
465
466 result.sum += curr.add(sums_to_merge, weights_to_merge);
467 compressed.push(curr.clone());
468 compressed.shrink_to_fit();
469 compressed.sort();
470
471 result.count = count;
472 result.min = min;
473 result.max = max;
474 result.centroids = compressed;
475 result
476 }
477
478 pub fn estimate_quantile(&self, q: f64) -> f64 {
480 if self.centroids.is_empty() {
481 return 0.0;
482 }
483
484 let rank = q * self.count as f64;
485
486 let mut pos: usize;
487 let mut t;
488 if q > 0.5 {
489 if q >= 1.0 {
490 return self.max();
491 }
492
493 pos = 0;
494 t = self.count as f64;
495
496 for (k, centroid) in self.centroids.iter().enumerate().rev() {
497 t -= centroid.weight();
498
499 if rank >= t {
500 pos = k;
501 break;
502 }
503 }
504 } else {
505 if q <= 0.0 {
506 return self.min();
507 }
508
509 pos = self.centroids.len() - 1;
510 t = 0_f64;
511
512 for (k, centroid) in self.centroids.iter().enumerate() {
513 if rank < t + centroid.weight() {
514 pos = k;
515 break;
516 }
517
518 t += centroid.weight();
519 }
520 }
521
522 let mut delta = 0_f64;
523 let mut min = self.min;
524 let mut max = self.max;
525
526 if self.centroids.len() > 1 {
527 if pos == 0 {
528 delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean();
529 max = self.centroids[pos + 1].mean();
530 } else if pos == (self.centroids.len() - 1) {
531 delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean();
532 min = self.centroids[pos - 1].mean();
533 } else {
534 delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean())
535 / 2.0;
536 min = self.centroids[pos - 1].mean();
537 max = self.centroids[pos + 1].mean();
538 }
539 }
540
541 let value = self.centroids[pos].mean()
542 + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
543
544 if !min.is_finite() && min.is_sign_positive() {
549 min = f64::NEG_INFINITY;
550 }
551
552 if !max.is_finite() && max.is_sign_negative() {
553 max = f64::INFINITY;
554 }
555
556 Self::clamp(value, min, max)
557 }
558
559 pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
595 let centroids: Vec<ScalarValue> = self
597 .centroids
598 .iter()
599 .flat_map(|c| [c.mean(), c.weight()])
600 .map(|v| ScalarValue::Float64(Some(v)))
601 .collect();
602
603 let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64);
604
605 vec![
606 ScalarValue::UInt64(Some(self.max_size as u64)),
607 ScalarValue::Float64(Some(self.sum)),
608 ScalarValue::UInt64(Some(self.count)),
609 ScalarValue::Float64(Some(self.max)),
610 ScalarValue::Float64(Some(self.min)),
611 ScalarValue::List(arr),
612 ]
613 }
614
615 pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
624 assert_eq!(state.len(), 6, "invalid TDigest state");
625
626 let max_size = match &state[0] {
627 ScalarValue::UInt64(Some(v)) => *v as usize,
628 v => panic!("invalid max_size type {v:?}"),
629 };
630
631 let centroids: Vec<_> = match &state[5] {
632 ScalarValue::List(arr) => {
633 let array = arr.values();
634
635 let f64arr =
636 as_primitive_array::<Float64Type>(array).expect("expected f64 array");
637 f64arr
638 .values()
639 .chunks(2)
640 .map(|v| Centroid::new(v[0], v[1]))
641 .collect()
642 }
643 v => panic!("invalid centroids type {v:?}"),
644 };
645
646 let max = cast_scalar_f64!(&state[3]);
647 let min = cast_scalar_f64!(&state[4]);
648
649 if min.is_finite() && max.is_finite() {
650 assert!(max.total_cmp(&min).is_ge());
651 }
652
653 Self {
654 max_size,
655 sum: cast_scalar_f64!(state[1]),
656 count: cast_scalar_u64!(&state[2]),
657 max,
658 min,
659 centroids,
660 }
661 }
662}
663
664#[cfg(debug_assertions)]
665fn is_sorted(values: &[f64]) -> bool {
666 values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le())
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 macro_rules! assert_error_bounds {
676 ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
677 assert_error_bounds!(
678 $t,
679 quantile = $quantile,
680 want = $want,
681 allowable_error = 0.01
682 )
683 };
684 ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => {
685 let ans = $t.estimate_quantile($quantile);
686 let expected: f64 = $want;
687 let percentage: f64 = (expected - ans).abs() / expected;
688 assert!(
689 percentage < $re,
690 "relative error {} is more than {}% (got quantile {}, want {})",
691 percentage,
692 $re,
693 ans,
694 expected
695 );
696 };
697 }
698
699 macro_rules! assert_state_roundtrip {
700 ($t:ident) => {
701 let state = $t.to_scalar_state();
702 let other = TDigest::from_scalar_state(&state);
703 assert_eq!($t, other);
704 };
705 }
706
707 #[test]
708 fn test_int64_uniform() {
709 let values = (1i64..=1000).map(|v| v as f64).collect();
710
711 let t = TDigest::new(100);
712 let t = t.merge_unsorted_f64(values);
713
714 assert_error_bounds!(t, quantile = 0.1, want = 100.0);
715 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
716 assert_error_bounds!(t, quantile = 0.9, want = 900.0);
717 assert_state_roundtrip!(t);
718 }
719
720 #[test]
721 fn test_centroid_addition_regression() {
722 let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
725 let mut t = TDigest::new(10);
726
727 for v in vals {
728 t = t.merge_unsorted_f64(vec![v]);
729 }
730
731 assert_error_bounds!(t, quantile = 0.5, want = 1.0);
732 assert_error_bounds!(t, quantile = 0.95, want = 2.0);
733 assert_state_roundtrip!(t);
734 }
735
736 #[test]
737 fn test_merge_unsorted_against_uniform_distro() {
738 let t = TDigest::new(100);
739 let values: Vec<_> = (1..=1_000_000).map(f64::from).collect();
740
741 let t = t.merge_unsorted_f64(values);
742
743 assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
744 assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
745 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
746 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
747 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
748 assert_state_roundtrip!(t);
749 }
750
751 #[test]
752 fn test_merge_unsorted_against_skewed_distro() {
753 let t = TDigest::new(100);
754 let mut values: Vec<_> = (1..=600_000).map(f64::from).collect();
755 values.resize(1_000_000, 1_000_000_f64);
756
757 let t = t.merge_unsorted_f64(values);
758
759 assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
760 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
761 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
762 assert_state_roundtrip!(t);
763 }
764
765 #[test]
766 fn test_merge_digests() {
767 let mut digests: Vec<TDigest> = Vec::new();
768
769 for _ in 1..=100 {
770 let t = TDigest::new(100);
771 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
772 let t = t.merge_unsorted_f64(values);
773 digests.push(t)
774 }
775
776 let t = TDigest::merge_digests(&digests);
777
778 assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
779 assert_error_bounds!(t, quantile = 0.99, want = 990.0);
780 assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2);
781 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
782 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
783 assert_state_roundtrip!(t);
784 }
785
786 #[test]
787 fn test_size() {
788 let t = TDigest::new(10);
789 let t = t.merge_unsorted_f64(vec![0.0, 1.0]);
790
791 assert_eq!(t.size(), 96);
792 }
793}