1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::Array;
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27 array::{
28 ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30 },
31 datatypes::{DataType, Field},
32};
33use datafusion_common::{
34 downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result,
35 ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
43 Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46 TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51use crate::utils::{get_scalar_value, validate_percentile_expr};
52
53create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
54
55pub fn approx_percentile_cont(
57 order_by: Sort,
58 percentile: Expr,
59 centroids: Option<Expr>,
60) -> Expr {
61 let expr = order_by.expr.clone();
62
63 let args = if let Some(centroids) = centroids {
64 vec![expr, percentile, centroids]
65 } else {
66 vec![expr, percentile]
67 };
68
69 Expr::AggregateFunction(AggregateFunction::new_udf(
70 approx_percentile_cont_udaf(),
71 args,
72 false,
73 None,
74 vec![order_by],
75 None,
76 ))
77}
78
79#[user_doc(
80 doc_section(label = "Approximate Functions"),
81 description = "Returns the approximate percentile of input values using the t-digest algorithm.",
82 syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
83 sql_example = r#"```sql
84> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
85+------------------------------------------------------------------+
86| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
87+------------------------------------------------------------------+
88| 65.0 |
89+------------------------------------------------------------------+
90> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
91+-----------------------------------------------------------------------+
92| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
93+-----------------------------------------------------------------------+
94| 65.0 |
95+-----------------------------------------------------------------------+
96```
97An alternate syntax is also supported:
98```sql
99> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
100+-----------------------------------------------+
101| approx_percentile_cont(column_name, 0.75) |
102+-----------------------------------------------+
103| 65.0 |
104+-----------------------------------------------+
105
106> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
107+----------------------------------------------------------+
108| approx_percentile_cont(column_name, 0.75, 100) |
109+----------------------------------------------------------+
110| 65.0 |
111+----------------------------------------------------------+
112```
113"#,
114 standard_argument(name = "expression",),
115 argument(
116 name = "percentile",
117 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
118 ),
119 argument(
120 name = "centroids",
121 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
122 )
123)]
124#[derive(PartialEq, Eq, Hash)]
125pub struct ApproxPercentileCont {
126 signature: Signature,
127}
128
129impl Debug for ApproxPercentileCont {
130 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
131 f.debug_struct("ApproxPercentileCont")
132 .field("name", &self.name())
133 .field("signature", &self.signature)
134 .finish()
135 }
136}
137
138impl Default for ApproxPercentileCont {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl ApproxPercentileCont {
145 pub fn new() -> Self {
147 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
148 for num in NUMERICS {
150 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
151 for int in INTEGERS {
153 variants.push(TypeSignature::Exact(vec![
154 num.clone(),
155 DataType::Float64,
156 int.clone(),
157 ]))
158 }
159 }
160 Self {
161 signature: Signature::one_of(variants, Volatility::Immutable),
162 }
163 }
164
165 pub(crate) fn create_accumulator(
166 &self,
167 args: AccumulatorArgs,
168 ) -> Result<ApproxPercentileAccumulator> {
169 let percentile =
170 validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;
171
172 let is_descending = args
173 .order_bys
174 .first()
175 .map(|sort_expr| sort_expr.options.descending)
176 .unwrap_or(false);
177
178 let percentile = if is_descending {
179 1.0 - percentile
180 } else {
181 percentile
182 };
183
184 let tdigest_max_size = if args.exprs.len() == 3 {
185 Some(validate_input_max_size_expr(&args.exprs[2])?)
186 } else {
187 None
188 };
189
190 let data_type = args.expr_fields[0].data_type();
191 let accumulator: ApproxPercentileAccumulator = match data_type {
192 DataType::UInt8
193 | DataType::UInt16
194 | DataType::UInt32
195 | DataType::UInt64
196 | DataType::Int8
197 | DataType::Int16
198 | DataType::Int32
199 | DataType::Int64
200 | DataType::Float32
201 | DataType::Float64 => {
202 if let Some(max_size) = tdigest_max_size {
203 ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
204 } else {
205 ApproxPercentileAccumulator::new(percentile, data_type.clone())
206 }
207 }
208 other => {
209 return not_impl_err!(
210 "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
211 )
212 }
213 };
214
215 Ok(accumulator)
216 }
217}
218
219fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
220 let scalar_value = get_scalar_value(expr).map_err(|_e| {
221 DataFusionError::Plan(
222 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
223 .to_string(),
224 )
225 })?;
226
227 let max_size = match scalar_value {
228 ScalarValue::UInt8(Some(q)) => q as usize,
229 ScalarValue::UInt16(Some(q)) => q as usize,
230 ScalarValue::UInt32(Some(q)) => q as usize,
231 ScalarValue::UInt64(Some(q)) => q as usize,
232 ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
233 ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
234 ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
235 ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
236 sv => {
237 return plan_err!(
238 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
239 sv.data_type()
240 )
241 },
242 };
243
244 Ok(max_size)
245}
246
247impl AggregateUDFImpl for ApproxPercentileCont {
248 fn as_any(&self) -> &dyn Any {
249 self
250 }
251
252 #[allow(rustdoc::private_intra_doc_links)]
253 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
256 Ok(vec![
257 Field::new(
258 format_state_name(args.name, "max_size"),
259 DataType::UInt64,
260 false,
261 ),
262 Field::new(
263 format_state_name(args.name, "sum"),
264 DataType::Float64,
265 false,
266 ),
267 Field::new(
268 format_state_name(args.name, "count"),
269 DataType::UInt64,
270 false,
271 ),
272 Field::new(
273 format_state_name(args.name, "max"),
274 DataType::Float64,
275 false,
276 ),
277 Field::new(
278 format_state_name(args.name, "min"),
279 DataType::Float64,
280 false,
281 ),
282 Field::new_list(
283 format_state_name(args.name, "centroids"),
284 Field::new_list_field(DataType::Float64, true),
285 false,
286 ),
287 ]
288 .into_iter()
289 .map(Arc::new)
290 .collect())
291 }
292
293 fn name(&self) -> &str {
294 "approx_percentile_cont"
295 }
296
297 fn signature(&self) -> &Signature {
298 &self.signature
299 }
300
301 #[inline]
302 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
303 Ok(Box::new(self.create_accumulator(acc_args)?))
304 }
305
306 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
307 if !arg_types[0].is_numeric() {
308 return plan_err!("approx_percentile_cont requires numeric input types");
309 }
310 if arg_types.len() == 3 && !arg_types[2].is_integer() {
311 return plan_err!(
312 "approx_percentile_cont requires integer centroids input types"
313 );
314 }
315 Ok(arg_types[0].clone())
316 }
317
318 fn supports_null_handling_clause(&self) -> bool {
319 false
320 }
321
322 fn supports_within_group_clause(&self) -> bool {
323 true
324 }
325
326 fn documentation(&self) -> Option<&Documentation> {
327 self.doc()
328 }
329}
330
331#[derive(Debug)]
332pub struct ApproxPercentileAccumulator {
333 digest: TDigest,
334 percentile: f64,
335 return_type: DataType,
336}
337
338impl ApproxPercentileAccumulator {
339 pub fn new(percentile: f64, return_type: DataType) -> Self {
340 Self {
341 digest: TDigest::new(DEFAULT_MAX_SIZE),
342 percentile,
343 return_type,
344 }
345 }
346
347 pub fn new_with_max_size(
348 percentile: f64,
349 return_type: DataType,
350 max_size: usize,
351 ) -> Self {
352 Self {
353 digest: TDigest::new(max_size),
354 percentile,
355 return_type,
356 }
357 }
358
359 pub(crate) fn max_size(&self) -> usize {
361 self.digest.max_size()
362 }
363
364 pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
366 let digests = digests.iter().chain(std::iter::once(&self.digest));
367 self.digest = TDigest::merge_digests(digests)
368 }
369
370 pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
372 debug_assert!(
373 values.null_count() == 0,
374 "convert_to_float assumes nulls have already been filtered out"
375 );
376 match values.data_type() {
377 DataType::Float64 => {
378 let array = downcast_value!(values, Float64Array);
379 Ok(array
380 .values()
381 .iter()
382 .filter_map(|v| v.try_as_f64().transpose())
383 .collect::<Result<Vec<_>>>()?)
384 }
385 DataType::Float32 => {
386 let array = downcast_value!(values, Float32Array);
387 Ok(array
388 .values()
389 .iter()
390 .filter_map(|v| v.try_as_f64().transpose())
391 .collect::<Result<Vec<_>>>()?)
392 }
393 DataType::Int64 => {
394 let array = downcast_value!(values, Int64Array);
395 Ok(array
396 .values()
397 .iter()
398 .filter_map(|v| v.try_as_f64().transpose())
399 .collect::<Result<Vec<_>>>()?)
400 }
401 DataType::Int32 => {
402 let array = downcast_value!(values, Int32Array);
403 Ok(array
404 .values()
405 .iter()
406 .filter_map(|v| v.try_as_f64().transpose())
407 .collect::<Result<Vec<_>>>()?)
408 }
409 DataType::Int16 => {
410 let array = downcast_value!(values, Int16Array);
411 Ok(array
412 .values()
413 .iter()
414 .filter_map(|v| v.try_as_f64().transpose())
415 .collect::<Result<Vec<_>>>()?)
416 }
417 DataType::Int8 => {
418 let array = downcast_value!(values, Int8Array);
419 Ok(array
420 .values()
421 .iter()
422 .filter_map(|v| v.try_as_f64().transpose())
423 .collect::<Result<Vec<_>>>()?)
424 }
425 DataType::UInt64 => {
426 let array = downcast_value!(values, UInt64Array);
427 Ok(array
428 .values()
429 .iter()
430 .filter_map(|v| v.try_as_f64().transpose())
431 .collect::<Result<Vec<_>>>()?)
432 }
433 DataType::UInt32 => {
434 let array = downcast_value!(values, UInt32Array);
435 Ok(array
436 .values()
437 .iter()
438 .filter_map(|v| v.try_as_f64().transpose())
439 .collect::<Result<Vec<_>>>()?)
440 }
441 DataType::UInt16 => {
442 let array = downcast_value!(values, UInt16Array);
443 Ok(array
444 .values()
445 .iter()
446 .filter_map(|v| v.try_as_f64().transpose())
447 .collect::<Result<Vec<_>>>()?)
448 }
449 DataType::UInt8 => {
450 let array = downcast_value!(values, UInt8Array);
451 Ok(array
452 .values()
453 .iter()
454 .filter_map(|v| v.try_as_f64().transpose())
455 .collect::<Result<Vec<_>>>()?)
456 }
457 e => internal_err!(
458 "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
459 ),
460 }
461 }
462}
463
464impl Accumulator for ApproxPercentileAccumulator {
465 fn state(&mut self) -> Result<Vec<ScalarValue>> {
466 Ok(self.digest.to_scalar_state().into_iter().collect())
467 }
468
469 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
470 let mut values = Arc::clone(&values[0]);
472 if values.null_count() > 0 {
473 values = filter(&values, &is_not_null(&values)?)?;
474 }
475 let sorted_values = &arrow::compute::sort(&values, None)?;
476 let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
477 self.digest = self.digest.merge_sorted_f64(&sorted_values);
478 Ok(())
479 }
480
481 fn evaluate(&mut self) -> Result<ScalarValue> {
482 if self.digest.count() == 0 {
483 return ScalarValue::try_from(self.return_type.clone());
484 }
485 let q = self.digest.estimate_quantile(self.percentile);
486
487 Ok(match &self.return_type {
490 DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
491 DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
492 DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
493 DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
494 DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
495 DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
496 DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
497 DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
498 DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
499 DataType::Float64 => ScalarValue::Float64(Some(q)),
500 v => unreachable!("unexpected return type {}", v),
501 })
502 }
503
504 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
505 if states.is_empty() {
506 return Ok(());
507 }
508
509 let states = (0..states[0].len())
510 .map(|index| {
511 states
512 .iter()
513 .map(|array| ScalarValue::try_from_array(array, index))
514 .collect::<Result<Vec<_>>>()
515 .map(|state| TDigest::from_scalar_state(&state))
516 })
517 .collect::<Result<Vec<_>>>()?;
518
519 self.merge_digests(&states);
520
521 Ok(())
522 }
523
524 fn size(&self) -> usize {
525 size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
526 + self.return_type.size()
527 - size_of_val(&self.return_type)
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use arrow::datatypes::DataType;
534
535 use datafusion_functions_aggregate_common::tdigest::TDigest;
536
537 use crate::approx_percentile_cont::ApproxPercentileAccumulator;
538
539 #[test]
540 fn test_combine_approx_percentile_accumulator() {
541 let mut digests: Vec<TDigest> = Vec::new();
542
543 for _ in 1..=50 {
545 let t = TDigest::new(100);
546 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
547 let t = t.merge_unsorted_f64(values);
548 digests.push(t)
549 }
550
551 let t1 = TDigest::merge_digests(&digests);
552 let t2 = TDigest::merge_digests(&digests);
553
554 let mut accumulator =
555 ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
556
557 accumulator.merge_digests(&[t1]);
558 assert_eq!(accumulator.digest.count(), 50_000);
559 accumulator.merge_digests(&[t2]);
560 assert_eq!(accumulator.digest.count(), 100_000);
561 }
562}