datafusion_functions_aggregate_common/
tdigest.rs1use arrow::datatypes::DataType;
33use arrow::datatypes::Float64Type;
34use datafusion_common::cast::as_primitive_array;
35use datafusion_common::Result;
36use datafusion_common::ScalarValue;
37use std::cmp::Ordering;
38use std::mem::{size_of, size_of_val};
39
40pub const DEFAULT_MAX_SIZE: usize = 100;
41
42macro_rules! cast_scalar_f64 {
45 ($value:expr ) => {
46 match &$value {
47 ScalarValue::Float64(Some(v)) => *v,
48 v => panic!("invalid type {:?}", v),
49 }
50 };
51}
52
53macro_rules! cast_scalar_u64 {
56 ($value:expr ) => {
57 match &$value {
58 ScalarValue::UInt64(Some(v)) => *v,
59 v => panic!("invalid type {:?}", v),
60 }
61 };
62}
63
64pub trait TryIntoF64 {
68 fn try_as_f64(&self) -> Result<Option<f64>>;
75}
76
77macro_rules! impl_try_ordered_f64 {
79 ($type:ty) => {
80 impl TryIntoF64 for $type {
81 fn try_as_f64(&self) -> Result<Option<f64>> {
82 Ok(Some(*self as f64))
83 }
84 }
85 };
86}
87
88impl_try_ordered_f64!(f64);
89impl_try_ordered_f64!(f32);
90impl_try_ordered_f64!(i64);
91impl_try_ordered_f64!(i32);
92impl_try_ordered_f64!(i16);
93impl_try_ordered_f64!(i8);
94impl_try_ordered_f64!(u64);
95impl_try_ordered_f64!(u32);
96impl_try_ordered_f64!(u16);
97impl_try_ordered_f64!(u8);
98
99#[derive(Debug, PartialEq, Clone)]
101pub struct Centroid {
102 mean: f64,
103 weight: f64,
104}
105
106impl Centroid {
107 pub fn new(mean: f64, weight: f64) -> Self {
108 Centroid { mean, weight }
109 }
110
111 #[inline]
112 pub fn mean(&self) -> f64 {
113 self.mean
114 }
115
116 #[inline]
117 pub fn weight(&self) -> f64 {
118 self.weight
119 }
120
121 pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
122 let new_sum = sum + self.weight * self.mean;
123 let new_weight = self.weight + weight;
124 self.weight = new_weight;
125 self.mean = new_sum / new_weight;
126 new_sum
127 }
128
129 pub fn cmp_mean(&self, other: &Self) -> Ordering {
130 self.mean.total_cmp(&other.mean)
131 }
132}
133
134impl Default for Centroid {
135 fn default() -> Self {
136 Centroid {
137 mean: 0_f64,
138 weight: 1_f64,
139 }
140 }
141}
142
143#[derive(Debug, PartialEq, Clone)]
145pub struct TDigest {
146 centroids: Vec<Centroid>,
147 max_size: usize,
148 sum: f64,
149 count: u64,
150 max: f64,
151 min: f64,
152}
153
154impl TDigest {
155 pub fn new(max_size: usize) -> Self {
156 TDigest {
157 centroids: Vec::new(),
158 max_size,
159 sum: 0_f64,
160 count: 0,
161 max: f64::NAN,
162 min: f64::NAN,
163 }
164 }
165
166 pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
167 TDigest {
168 centroids: vec![centroid.clone()],
169 max_size,
170 sum: centroid.mean * centroid.weight,
171 count: 1,
172 max: centroid.mean,
173 min: centroid.mean,
174 }
175 }
176
177 #[inline]
178 pub fn count(&self) -> u64 {
179 self.count
180 }
181
182 #[inline]
183 pub fn max(&self) -> f64 {
184 self.max
185 }
186
187 #[inline]
188 pub fn min(&self) -> f64 {
189 self.min
190 }
191
192 #[inline]
193 pub fn max_size(&self) -> usize {
194 self.max_size
195 }
196
197 pub fn size(&self) -> usize {
199 size_of_val(self) + (size_of::<Centroid>() * self.centroids.capacity())
200 }
201}
202
203impl Default for TDigest {
204 fn default() -> Self {
205 TDigest {
206 centroids: Vec::new(),
207 max_size: 100,
208 sum: 0_f64,
209 count: 0,
210 max: f64::NAN,
211 min: f64::NAN,
212 }
213 }
214}
215
216impl TDigest {
217 fn k_to_q(k: u64, d: usize) -> f64 {
218 let k_div_d = k as f64 / d as f64;
219 if k_div_d >= 0.5 {
220 let base = 1.0 - k_div_d;
221 1.0 - 2.0 * base * base
222 } else {
223 2.0 * k_div_d * k_div_d
224 }
225 }
226
227 fn clamp(v: f64, lo: f64, hi: f64) -> f64 {
228 if lo.is_nan() || hi.is_nan() {
229 return v;
230 }
231 v.clamp(lo, hi)
232 }
233
234 pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
236 let mut values = unsorted_values;
237 values.sort_by(|a, b| a.total_cmp(b));
238 self.merge_sorted_f64(&values)
239 }
240
241 pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
242 #[cfg(debug_assertions)]
243 debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
244
245 if sorted_values.is_empty() {
246 return self.clone();
247 }
248
249 let mut result = TDigest::new(self.max_size());
250 result.count = self.count() + sorted_values.len() as u64;
251
252 let maybe_min = *sorted_values.first().unwrap();
253 let maybe_max = *sorted_values.last().unwrap();
254
255 if self.count() > 0 {
256 result.min = self.min.min(maybe_min);
257 result.max = self.max.max(maybe_max);
258 } else {
259 result.min = maybe_min;
260 result.max = maybe_max;
261 }
262
263 let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
264
265 let mut k_limit: u64 = 1;
266 let mut q_limit_times_count =
267 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
268 k_limit += 1;
269
270 let mut iter_centroids = self.centroids.iter().peekable();
271 let mut iter_sorted_values = sorted_values.iter().peekable();
272
273 let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
274 let curr = **iter_sorted_values.peek().unwrap();
275 if c.mean() < curr {
276 iter_centroids.next().unwrap().clone()
277 } else {
278 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
279 }
280 } else {
281 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
282 };
283
284 let mut weight_so_far = curr.weight();
285
286 let mut sums_to_merge = 0_f64;
287 let mut weights_to_merge = 0_f64;
288
289 while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() {
290 let next: Centroid = if let Some(c) = iter_centroids.peek() {
291 if iter_sorted_values.peek().is_none()
292 || c.mean() < **iter_sorted_values.peek().unwrap()
293 {
294 iter_centroids.next().unwrap().clone()
295 } else {
296 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
297 }
298 } else {
299 Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
300 };
301
302 let next_sum = next.mean() * next.weight();
303 weight_so_far += next.weight();
304
305 if weight_so_far <= q_limit_times_count {
306 sums_to_merge += next_sum;
307 weights_to_merge += next.weight();
308 } else {
309 result.sum += curr.add(sums_to_merge, weights_to_merge);
310 sums_to_merge = 0_f64;
311 weights_to_merge = 0_f64;
312
313 compressed.push(curr.clone());
314 q_limit_times_count =
315 Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
316 k_limit += 1;
317 curr = next;
318 }
319 }
320
321 result.sum += curr.add(sums_to_merge, weights_to_merge);
322 compressed.push(curr);
323 compressed.shrink_to_fit();
324 compressed.sort_by(|a, b| a.cmp_mean(b));
325
326 result.centroids = compressed;
327 result
328 }
329
330 fn external_merge(
331 centroids: &mut [Centroid],
332 first: usize,
333 middle: usize,
334 last: usize,
335 ) {
336 let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
337
338 let mut i = first;
339 let mut j = middle;
340
341 while i < middle && j < last {
342 match centroids[i].cmp_mean(¢roids[j]) {
343 Ordering::Less => {
344 result.push(centroids[i].clone());
345 i += 1;
346 }
347 Ordering::Greater => {
348 result.push(centroids[j].clone());
349 j += 1;
350 }
351 Ordering::Equal => {
352 result.push(centroids[i].clone());
353 i += 1;
354 }
355 }
356 }
357
358 while i < middle {
359 result.push(centroids[i].clone());
360 i += 1;
361 }
362
363 while j < last {
364 result.push(centroids[j].clone());
365 j += 1;
366 }
367
368 i = first;
369 for centroid in result.into_iter() {
370 centroids[i] = centroid;
371 i += 1;
372 }
373 }
374
375 pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) -> TDigest {
377 let digests = digests.into_iter().collect::<Vec<_>>();
378 let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum();
379 if n_centroids == 0 {
380 return TDigest::default();
381 }
382
383 let max_size = digests.first().unwrap().max_size;
384 let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
385 let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
386
387 let mut count = 0;
388 let mut min = f64::INFINITY;
389 let mut max = f64::NEG_INFINITY;
390
391 let mut start: usize = 0;
392 for digest in digests.iter() {
393 starts.push(start);
394
395 let curr_count = digest.count();
396 if curr_count > 0 {
397 min = min.min(digest.min);
398 max = max.max(digest.max);
399 count += curr_count;
400 for centroid in &digest.centroids {
401 centroids.push(centroid.clone());
402 start += 1;
403 }
404 }
405 }
406
407 let mut digests_per_block: usize = 1;
408 while digests_per_block < starts.len() {
409 for i in (0..starts.len()).step_by(digests_per_block * 2) {
410 if i + digests_per_block < starts.len() {
411 let first = starts[i];
412 let middle = starts[i + digests_per_block];
413 let last = if i + 2 * digests_per_block < starts.len() {
414 starts[i + 2 * digests_per_block]
415 } else {
416 centroids.len()
417 };
418
419 debug_assert!(first <= middle && middle <= last);
420 Self::external_merge(&mut centroids, first, middle, last);
421 }
422 }
423
424 digests_per_block *= 2;
425 }
426
427 let mut result = TDigest::new(max_size);
428 let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
429
430 let mut k_limit = 1;
431 let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
432
433 let mut iter_centroids = centroids.iter_mut();
434 let mut curr = iter_centroids.next().unwrap();
435 let mut weight_so_far = curr.weight();
436 let mut sums_to_merge = 0_f64;
437 let mut weights_to_merge = 0_f64;
438
439 for centroid in iter_centroids {
440 weight_so_far += centroid.weight();
441
442 if weight_so_far <= q_limit_times_count {
443 sums_to_merge += centroid.mean() * centroid.weight();
444 weights_to_merge += centroid.weight();
445 } else {
446 result.sum += curr.add(sums_to_merge, weights_to_merge);
447 sums_to_merge = 0_f64;
448 weights_to_merge = 0_f64;
449 compressed.push(curr.clone());
450 q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
451 k_limit += 1;
452 curr = centroid;
453 }
454 }
455
456 result.sum += curr.add(sums_to_merge, weights_to_merge);
457 compressed.push(curr.clone());
458 compressed.shrink_to_fit();
459 compressed.sort_by(|a, b| a.cmp_mean(b));
460
461 result.count = count;
462 result.min = min;
463 result.max = max;
464 result.centroids = compressed;
465 result
466 }
467
468 pub fn estimate_quantile(&self, q: f64) -> f64 {
470 if self.centroids.is_empty() {
471 return 0.0;
472 }
473
474 let rank = q * self.count as f64;
475
476 let mut pos: usize;
477 let mut t;
478 if q > 0.5 {
479 if q >= 1.0 {
480 return self.max();
481 }
482
483 pos = 0;
484 t = self.count as f64;
485
486 for (k, centroid) in self.centroids.iter().enumerate().rev() {
487 t -= centroid.weight();
488
489 if rank >= t {
490 pos = k;
491 break;
492 }
493 }
494 } else {
495 if q <= 0.0 {
496 return self.min();
497 }
498
499 pos = self.centroids.len() - 1;
500 t = 0_f64;
501
502 for (k, centroid) in self.centroids.iter().enumerate() {
503 if rank < t + centroid.weight() {
504 pos = k;
505 break;
506 }
507
508 t += centroid.weight();
509 }
510 }
511
512 let mut delta = 0_f64;
513 let mut min = self.min;
514 let mut max = self.max;
515
516 if self.centroids.len() > 1 {
517 if pos == 0 {
518 delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean();
519 max = self.centroids[pos + 1].mean();
520 } else if pos == (self.centroids.len() - 1) {
521 delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean();
522 min = self.centroids[pos - 1].mean();
523 } else {
524 delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean())
525 / 2.0;
526 min = self.centroids[pos - 1].mean();
527 max = self.centroids[pos + 1].mean();
528 }
529 }
530
531 let value = self.centroids[pos].mean()
532 + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
533
534 if !min.is_finite() && min.is_sign_positive() {
539 min = f64::NEG_INFINITY;
540 }
541
542 if !max.is_finite() && max.is_sign_negative() {
543 max = f64::INFINITY;
544 }
545
546 Self::clamp(value, min, max)
547 }
548
549 pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
585 let centroids: Vec<ScalarValue> = self
587 .centroids
588 .iter()
589 .flat_map(|c| [c.mean(), c.weight()])
590 .map(|v| ScalarValue::Float64(Some(v)))
591 .collect();
592
593 let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64);
594
595 vec![
596 ScalarValue::UInt64(Some(self.max_size as u64)),
597 ScalarValue::Float64(Some(self.sum)),
598 ScalarValue::UInt64(Some(self.count)),
599 ScalarValue::Float64(Some(self.max)),
600 ScalarValue::Float64(Some(self.min)),
601 ScalarValue::List(arr),
602 ]
603 }
604
605 pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
614 assert_eq!(state.len(), 6, "invalid TDigest state");
615
616 let max_size = match &state[0] {
617 ScalarValue::UInt64(Some(v)) => *v as usize,
618 v => panic!("invalid max_size type {v:?}"),
619 };
620
621 let centroids: Vec<_> = match &state[5] {
622 ScalarValue::List(arr) => {
623 let array = arr.values();
624
625 let f64arr =
626 as_primitive_array::<Float64Type>(array).expect("expected f64 array");
627 f64arr
628 .values()
629 .chunks(2)
630 .map(|v| Centroid::new(v[0], v[1]))
631 .collect()
632 }
633 v => panic!("invalid centroids type {v:?}"),
634 };
635
636 let max = cast_scalar_f64!(&state[3]);
637 let min = cast_scalar_f64!(&state[4]);
638
639 if min.is_finite() && max.is_finite() {
640 assert!(max.total_cmp(&min).is_ge());
641 }
642
643 Self {
644 max_size,
645 sum: cast_scalar_f64!(state[1]),
646 count: cast_scalar_u64!(&state[2]),
647 max,
648 min,
649 centroids,
650 }
651 }
652}
653
654#[cfg(debug_assertions)]
655fn is_sorted(values: &[f64]) -> bool {
656 values.windows(2).all(|w| w[0].total_cmp(&w[1]).is_le())
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 macro_rules! assert_error_bounds {
666 ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
667 assert_error_bounds!(
668 $t,
669 quantile = $quantile,
670 want = $want,
671 allowable_error = 0.01
672 )
673 };
674 ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => {
675 let ans = $t.estimate_quantile($quantile);
676 let expected: f64 = $want;
677 let percentage: f64 = (expected - ans).abs() / expected;
678 assert!(
679 percentage < $re,
680 "relative error {} is more than {}% (got quantile {}, want {})",
681 percentage,
682 $re,
683 ans,
684 expected
685 );
686 };
687 }
688
689 macro_rules! assert_state_roundtrip {
690 ($t:ident) => {
691 let state = $t.to_scalar_state();
692 let other = TDigest::from_scalar_state(&state);
693 assert_eq!($t, other);
694 };
695 }
696
697 #[test]
698 fn test_int64_uniform() {
699 let values = (1i64..=1000).map(|v| v as f64).collect();
700
701 let t = TDigest::new(100);
702 let t = t.merge_unsorted_f64(values);
703
704 assert_error_bounds!(t, quantile = 0.1, want = 100.0);
705 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
706 assert_error_bounds!(t, quantile = 0.9, want = 900.0);
707 assert_state_roundtrip!(t);
708 }
709
710 #[test]
711 fn test_centroid_addition_regression() {
712 let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
715 let mut t = TDigest::new(10);
716
717 for v in vals {
718 t = t.merge_unsorted_f64(vec![v]);
719 }
720
721 assert_error_bounds!(t, quantile = 0.5, want = 1.0);
722 assert_error_bounds!(t, quantile = 0.95, want = 2.0);
723 assert_state_roundtrip!(t);
724 }
725
726 #[test]
727 fn test_merge_unsorted_against_uniform_distro() {
728 let t = TDigest::new(100);
729 let values: Vec<_> = (1..=1_000_000).map(f64::from).collect();
730
731 let t = t.merge_unsorted_f64(values);
732
733 assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
734 assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
735 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
736 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
737 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
738 assert_state_roundtrip!(t);
739 }
740
741 #[test]
742 fn test_merge_unsorted_against_skewed_distro() {
743 let t = TDigest::new(100);
744 let mut values: Vec<_> = (1..=600_000).map(f64::from).collect();
745 values.resize(1_000_000, 1_000_000_f64);
746
747 let t = t.merge_unsorted_f64(values);
748
749 assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
750 assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
751 assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
752 assert_state_roundtrip!(t);
753 }
754
755 #[test]
756 fn test_merge_digests() {
757 let mut digests: Vec<TDigest> = Vec::new();
758
759 for _ in 1..=100 {
760 let t = TDigest::new(100);
761 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
762 let t = t.merge_unsorted_f64(values);
763 digests.push(t)
764 }
765
766 let t = TDigest::merge_digests(&digests);
767
768 assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
769 assert_error_bounds!(t, quantile = 0.99, want = 990.0);
770 assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2);
771 assert_error_bounds!(t, quantile = 0.0, want = 1.0);
772 assert_error_bounds!(t, quantile = 0.5, want = 500.0);
773 assert_state_roundtrip!(t);
774 }
775
776 #[test]
777 fn test_size() {
778 let t = TDigest::new(10);
779 let t = t.merge_unsorted_f64(vec![0.0, 1.0]);
780
781 assert_eq!(t.size(), 96);
782 }
783}