1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27 array::{
28 ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30 },
31 datatypes::{DataType, Field, Schema},
32};
33use datafusion_common::{
34 downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35 Result, ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
43 TypeSignature, Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46 TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53pub fn approx_percentile_cont(
55 order_by: Sort,
56 percentile: Expr,
57 centroids: Option<Expr>,
58) -> Expr {
59 let expr = order_by.expr.clone();
60
61 let args = if let Some(centroids) = centroids {
62 vec![expr, percentile, centroids]
63 } else {
64 vec![expr, percentile]
65 };
66
67 Expr::AggregateFunction(AggregateFunction::new_udf(
68 approx_percentile_cont_udaf(),
69 args,
70 false,
71 None,
72 vec![order_by],
73 None,
74 ))
75}
76
77#[user_doc(
78 doc_section(label = "Approximate Functions"),
79 description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80 syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
81 sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+------------------------------------------------------------------+
84| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
85+------------------------------------------------------------------+
86| 65.0 |
87+------------------------------------------------------------------+
88> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
89+-----------------------------------------------------------------------+
90| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
91+-----------------------------------------------------------------------+
92| 65.0 |
93+-----------------------------------------------------------------------+
94```
95An alternate syntax is also supported:
96```sql
97> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
98+-----------------------------------------------+
99| approx_percentile_cont(column_name, 0.75) |
100+-----------------------------------------------+
101| 65.0 |
102+-----------------------------------------------+
103
104> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
105+----------------------------------------------------------+
106| approx_percentile_cont(column_name, 0.75, 100) |
107+----------------------------------------------------------+
108| 65.0 |
109+----------------------------------------------------------+
110```
111"#,
112 standard_argument(name = "expression",),
113 argument(
114 name = "percentile",
115 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
116 ),
117 argument(
118 name = "centroids",
119 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
120 )
121)]
122#[derive(PartialEq, Eq, Hash)]
123pub struct ApproxPercentileCont {
124 signature: Signature,
125}
126
127impl Debug for ApproxPercentileCont {
128 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
129 f.debug_struct("ApproxPercentileCont")
130 .field("name", &self.name())
131 .field("signature", &self.signature)
132 .finish()
133 }
134}
135
136impl Default for ApproxPercentileCont {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl ApproxPercentileCont {
143 pub fn new() -> Self {
145 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
146 for num in NUMERICS {
148 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
149 for int in INTEGERS {
151 variants.push(TypeSignature::Exact(vec![
152 num.clone(),
153 DataType::Float64,
154 int.clone(),
155 ]))
156 }
157 }
158 Self {
159 signature: Signature::one_of(variants, Volatility::Immutable),
160 }
161 }
162
163 pub(crate) fn create_accumulator(
164 &self,
165 args: AccumulatorArgs,
166 ) -> Result<ApproxPercentileAccumulator> {
167 let percentile = validate_input_percentile_expr(&args.exprs[1])?;
168
169 let is_descending = args
170 .order_bys
171 .first()
172 .map(|sort_expr| sort_expr.options.descending)
173 .unwrap_or(false);
174
175 let percentile = if is_descending {
176 1.0 - percentile
177 } else {
178 percentile
179 };
180
181 let tdigest_max_size = if args.exprs.len() == 3 {
182 Some(validate_input_max_size_expr(&args.exprs[2])?)
183 } else {
184 None
185 };
186
187 let data_type = args.exprs[0].data_type(args.schema)?;
188 let accumulator: ApproxPercentileAccumulator = match data_type {
189 t @ (DataType::UInt8
190 | DataType::UInt16
191 | DataType::UInt32
192 | DataType::UInt64
193 | DataType::Int8
194 | DataType::Int16
195 | DataType::Int32
196 | DataType::Int64
197 | DataType::Float32
198 | DataType::Float64) => {
199 if let Some(max_size) = tdigest_max_size {
200 ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
201 }else{
202 ApproxPercentileAccumulator::new(percentile, t)
203
204 }
205 }
206 other => {
207 return not_impl_err!(
208 "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
209 )
210 }
211 };
212
213 Ok(accumulator)
214 }
215}
216
217fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
218 let empty_schema = Arc::new(Schema::empty());
219 let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
220 if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
221 Ok(s)
222 } else {
223 internal_err!("Didn't expect ColumnarValue::Array")
224 }
225}
226
227fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
228 let percentile = match get_scalar_value(expr)
229 .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
230 ScalarValue::Float32(Some(value)) => {
231 value as f64
232 }
233 ScalarValue::Float64(Some(value)) => {
234 value
235 }
236 sv => {
237 return not_impl_err!(
238 "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
239 sv.data_type()
240 )
241 }
242 };
243
244 if !(0.0..=1.0).contains(&percentile) {
246 return plan_err!(
247 "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
248 );
249 }
250 Ok(percentile)
251}
252
253fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
254 let max_size = match get_scalar_value(expr)
255 .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
256 ScalarValue::UInt8(Some(q)) => q as usize,
257 ScalarValue::UInt16(Some(q)) => q as usize,
258 ScalarValue::UInt32(Some(q)) => q as usize,
259 ScalarValue::UInt64(Some(q)) => q as usize,
260 ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
261 ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
262 ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
263 ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
264 sv => {
265 return not_impl_err!(
266 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
267 sv.data_type()
268 )
269 },
270 };
271
272 Ok(max_size)
273}
274
275impl AggregateUDFImpl for ApproxPercentileCont {
276 fn as_any(&self) -> &dyn Any {
277 self
278 }
279
280 #[allow(rustdoc::private_intra_doc_links)]
281 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
284 Ok(vec![
285 Field::new(
286 format_state_name(args.name, "max_size"),
287 DataType::UInt64,
288 false,
289 ),
290 Field::new(
291 format_state_name(args.name, "sum"),
292 DataType::Float64,
293 false,
294 ),
295 Field::new(
296 format_state_name(args.name, "count"),
297 DataType::UInt64,
298 false,
299 ),
300 Field::new(
301 format_state_name(args.name, "max"),
302 DataType::Float64,
303 false,
304 ),
305 Field::new(
306 format_state_name(args.name, "min"),
307 DataType::Float64,
308 false,
309 ),
310 Field::new_list(
311 format_state_name(args.name, "centroids"),
312 Field::new_list_field(DataType::Float64, true),
313 false,
314 ),
315 ]
316 .into_iter()
317 .map(Arc::new)
318 .collect())
319 }
320
321 fn name(&self) -> &str {
322 "approx_percentile_cont"
323 }
324
325 fn signature(&self) -> &Signature {
326 &self.signature
327 }
328
329 #[inline]
330 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
331 Ok(Box::new(self.create_accumulator(acc_args)?))
332 }
333
334 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
335 if !arg_types[0].is_numeric() {
336 return plan_err!("approx_percentile_cont requires numeric input types");
337 }
338 if arg_types.len() == 3 && !arg_types[2].is_integer() {
339 return plan_err!(
340 "approx_percentile_cont requires integer centroids input types"
341 );
342 }
343 Ok(arg_types[0].clone())
344 }
345
346 fn supports_null_handling_clause(&self) -> bool {
347 false
348 }
349
350 fn is_ordered_set_aggregate(&self) -> bool {
351 true
352 }
353
354 fn documentation(&self) -> Option<&Documentation> {
355 self.doc()
356 }
357}
358
359#[derive(Debug)]
360pub struct ApproxPercentileAccumulator {
361 digest: TDigest,
362 percentile: f64,
363 return_type: DataType,
364}
365
366impl ApproxPercentileAccumulator {
367 pub fn new(percentile: f64, return_type: DataType) -> Self {
368 Self {
369 digest: TDigest::new(DEFAULT_MAX_SIZE),
370 percentile,
371 return_type,
372 }
373 }
374
375 pub fn new_with_max_size(
376 percentile: f64,
377 return_type: DataType,
378 max_size: usize,
379 ) -> Self {
380 Self {
381 digest: TDigest::new(max_size),
382 percentile,
383 return_type,
384 }
385 }
386
387 pub(crate) fn max_size(&self) -> usize {
389 self.digest.max_size()
390 }
391
392 pub fn merge_digests(&mut self, digests: &[TDigest]) {
394 let digests = digests.iter().chain(std::iter::once(&self.digest));
395 self.digest = TDigest::merge_digests(digests)
396 }
397
398 pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
400 match values.data_type() {
401 DataType::Float64 => {
402 let array = downcast_value!(values, Float64Array);
403 Ok(array
404 .values()
405 .iter()
406 .filter_map(|v| v.try_as_f64().transpose())
407 .collect::<Result<Vec<_>>>()?)
408 }
409 DataType::Float32 => {
410 let array = downcast_value!(values, Float32Array);
411 Ok(array
412 .values()
413 .iter()
414 .filter_map(|v| v.try_as_f64().transpose())
415 .collect::<Result<Vec<_>>>()?)
416 }
417 DataType::Int64 => {
418 let array = downcast_value!(values, Int64Array);
419 Ok(array
420 .values()
421 .iter()
422 .filter_map(|v| v.try_as_f64().transpose())
423 .collect::<Result<Vec<_>>>()?)
424 }
425 DataType::Int32 => {
426 let array = downcast_value!(values, Int32Array);
427 Ok(array
428 .values()
429 .iter()
430 .filter_map(|v| v.try_as_f64().transpose())
431 .collect::<Result<Vec<_>>>()?)
432 }
433 DataType::Int16 => {
434 let array = downcast_value!(values, Int16Array);
435 Ok(array
436 .values()
437 .iter()
438 .filter_map(|v| v.try_as_f64().transpose())
439 .collect::<Result<Vec<_>>>()?)
440 }
441 DataType::Int8 => {
442 let array = downcast_value!(values, Int8Array);
443 Ok(array
444 .values()
445 .iter()
446 .filter_map(|v| v.try_as_f64().transpose())
447 .collect::<Result<Vec<_>>>()?)
448 }
449 DataType::UInt64 => {
450 let array = downcast_value!(values, UInt64Array);
451 Ok(array
452 .values()
453 .iter()
454 .filter_map(|v| v.try_as_f64().transpose())
455 .collect::<Result<Vec<_>>>()?)
456 }
457 DataType::UInt32 => {
458 let array = downcast_value!(values, UInt32Array);
459 Ok(array
460 .values()
461 .iter()
462 .filter_map(|v| v.try_as_f64().transpose())
463 .collect::<Result<Vec<_>>>()?)
464 }
465 DataType::UInt16 => {
466 let array = downcast_value!(values, UInt16Array);
467 Ok(array
468 .values()
469 .iter()
470 .filter_map(|v| v.try_as_f64().transpose())
471 .collect::<Result<Vec<_>>>()?)
472 }
473 DataType::UInt8 => {
474 let array = downcast_value!(values, UInt8Array);
475 Ok(array
476 .values()
477 .iter()
478 .filter_map(|v| v.try_as_f64().transpose())
479 .collect::<Result<Vec<_>>>()?)
480 }
481 e => internal_err!(
482 "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
483 ),
484 }
485 }
486}
487
488impl Accumulator for ApproxPercentileAccumulator {
489 fn state(&mut self) -> Result<Vec<ScalarValue>> {
490 Ok(self.digest.to_scalar_state().into_iter().collect())
491 }
492
493 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494 let mut values = Arc::clone(&values[0]);
496 if values.nulls().is_some() {
497 values = filter(&values, &is_not_null(&values)?)?;
498 }
499 let sorted_values = &arrow::compute::sort(&values, None)?;
500 let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
501 self.digest = self.digest.merge_sorted_f64(&sorted_values);
502 Ok(())
503 }
504
505 fn evaluate(&mut self) -> Result<ScalarValue> {
506 if self.digest.count() == 0 {
507 return ScalarValue::try_from(self.return_type.clone());
508 }
509 let q = self.digest.estimate_quantile(self.percentile);
510
511 Ok(match &self.return_type {
514 DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
515 DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
516 DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
517 DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
518 DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
519 DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
520 DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
521 DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
522 DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
523 DataType::Float64 => ScalarValue::Float64(Some(q)),
524 v => unreachable!("unexpected return type {:?}", v),
525 })
526 }
527
528 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
529 if states.is_empty() {
530 return Ok(());
531 }
532
533 let states = (0..states[0].len())
534 .map(|index| {
535 states
536 .iter()
537 .map(|array| ScalarValue::try_from_array(array, index))
538 .collect::<Result<Vec<_>>>()
539 .map(|state| TDigest::from_scalar_state(&state))
540 })
541 .collect::<Result<Vec<_>>>()?;
542
543 self.merge_digests(&states);
544
545 Ok(())
546 }
547
548 fn size(&self) -> usize {
549 size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
550 + self.return_type.size()
551 - size_of_val(&self.return_type)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use arrow::datatypes::DataType;
558
559 use datafusion_functions_aggregate_common::tdigest::TDigest;
560
561 use crate::approx_percentile_cont::ApproxPercentileAccumulator;
562
563 #[test]
564 fn test_combine_approx_percentile_accumulator() {
565 let mut digests: Vec<TDigest> = Vec::new();
566
567 for _ in 1..=50 {
569 let t = TDigest::new(100);
570 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
571 let t = t.merge_unsorted_f64(values);
572 digests.push(t)
573 }
574
575 let t1 = TDigest::merge_digests(&digests);
576 let t2 = TDigest::merge_digests(&digests);
577
578 let mut accumulator =
579 ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
580
581 accumulator.merge_digests(&[t1]);
582 assert_eq!(accumulator.digest.count(), 50_000);
583 accumulator.merge_digests(&[t2]);
584 assert_eq!(accumulator.digest.count(), 100_000);
585 }
586}