1use std::any::Any;
19use std::fmt::Debug;
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, Float16Array};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27 array::{
28 ArrayRef, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
29 Int64Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array,
30 },
31 datatypes::{DataType, Field},
32};
33use datafusion_common::{
34 DataFusionError, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
35 plan_err,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
43 Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{DEFAULT_MAX_SIZE, TDigest};
46use datafusion_macros::user_doc;
47use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
48
49use crate::utils::{get_scalar_value, validate_percentile_expr};
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53pub fn approx_percentile_cont(
55 order_by: Sort,
56 percentile: Expr,
57 centroids: Option<Expr>,
58) -> Expr {
59 let expr = order_by.expr.clone();
60
61 let args = if let Some(centroids) = centroids {
62 vec![expr, percentile, centroids]
63 } else {
64 vec![expr, percentile]
65 };
66
67 Expr::AggregateFunction(AggregateFunction::new_udf(
68 approx_percentile_cont_udaf(),
69 args,
70 false,
71 None,
72 vec![order_by],
73 None,
74 ))
75}
76
77#[user_doc(
78 doc_section(label = "Approximate Functions"),
79 description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80 syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
81 sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+------------------------------------------------------------------+
84| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
85+------------------------------------------------------------------+
86| 65.0 |
87+------------------------------------------------------------------+
88> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
89+-----------------------------------------------------------------------+
90| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
91+-----------------------------------------------------------------------+
92| 65.0 |
93+-----------------------------------------------------------------------+
94```
95An alternate syntax is also supported:
96```sql
97> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
98+-----------------------------------------------+
99| approx_percentile_cont(column_name, 0.75) |
100+-----------------------------------------------+
101| 65.0 |
102+-----------------------------------------------+
103
104> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
105+----------------------------------------------------------+
106| approx_percentile_cont(column_name, 0.75, 100) |
107+----------------------------------------------------------+
108| 65.0 |
109+----------------------------------------------------------+
110```
111"#,
112 standard_argument(name = "expression",),
113 argument(
114 name = "percentile",
115 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
116 ),
117 argument(
118 name = "centroids",
119 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
120 )
121)]
122#[derive(Debug, PartialEq, Eq, Hash)]
123pub struct ApproxPercentileCont {
124 signature: Signature,
125}
126
127impl Default for ApproxPercentileCont {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl ApproxPercentileCont {
134 pub fn new() -> Self {
136 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
137 for num in NUMERICS {
139 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
140 for int in INTEGERS {
142 variants.push(TypeSignature::Exact(vec![
143 num.clone(),
144 DataType::Float64,
145 int.clone(),
146 ]))
147 }
148 }
149 Self {
150 signature: Signature::one_of(variants, Volatility::Immutable),
151 }
152 }
153
154 pub(crate) fn create_accumulator(
155 &self,
156 args: &AccumulatorArgs,
157 ) -> Result<ApproxPercentileAccumulator> {
158 let percentile =
159 validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;
160
161 let is_descending = args
162 .order_bys
163 .first()
164 .map(|sort_expr| sort_expr.options.descending)
165 .unwrap_or(false);
166
167 let percentile = if is_descending {
168 1.0 - percentile
169 } else {
170 percentile
171 };
172
173 let tdigest_max_size = if args.exprs.len() == 3 {
174 Some(validate_input_max_size_expr(&args.exprs[2])?)
175 } else {
176 None
177 };
178
179 let data_type = args.expr_fields[0].data_type();
180 let accumulator: ApproxPercentileAccumulator = match data_type {
181 DataType::UInt8
182 | DataType::UInt16
183 | DataType::UInt32
184 | DataType::UInt64
185 | DataType::Int8
186 | DataType::Int16
187 | DataType::Int32
188 | DataType::Int64
189 | DataType::Float16
190 | DataType::Float32
191 | DataType::Float64 => {
192 if let Some(max_size) = tdigest_max_size {
193 ApproxPercentileAccumulator::new_with_max_size(
194 percentile,
195 data_type.clone(),
196 max_size,
197 )
198 } else {
199 ApproxPercentileAccumulator::new(percentile, data_type.clone())
200 }
201 }
202 other => {
203 return not_impl_err!(
204 "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
205 );
206 }
207 };
208
209 Ok(accumulator)
210 }
211}
212
213fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
214 let scalar_value = get_scalar_value(expr).map_err(|_e| {
215 DataFusionError::Plan(
216 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
217 .to_string(),
218 )
219 })?;
220
221 let max_size = match scalar_value {
222 ScalarValue::UInt8(Some(q)) => q as usize,
223 ScalarValue::UInt16(Some(q)) => q as usize,
224 ScalarValue::UInt32(Some(q)) => q as usize,
225 ScalarValue::UInt64(Some(q)) => q as usize,
226 ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
227 ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
228 ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
229 ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
230 sv => {
231 return plan_err!(
232 "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
233 sv.data_type()
234 );
235 }
236 };
237
238 Ok(max_size)
239}
240
241impl AggregateUDFImpl for ApproxPercentileCont {
242 fn as_any(&self) -> &dyn Any {
243 self
244 }
245
246 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
249 Ok(vec![
250 Field::new(
251 format_state_name(args.name, "max_size"),
252 DataType::UInt64,
253 false,
254 ),
255 Field::new(
256 format_state_name(args.name, "sum"),
257 DataType::Float64,
258 false,
259 ),
260 Field::new(
261 format_state_name(args.name, "count"),
262 DataType::UInt64,
263 false,
264 ),
265 Field::new(
266 format_state_name(args.name, "max"),
267 DataType::Float64,
268 false,
269 ),
270 Field::new(
271 format_state_name(args.name, "min"),
272 DataType::Float64,
273 false,
274 ),
275 Field::new_list(
276 format_state_name(args.name, "centroids"),
277 Field::new_list_field(DataType::Float64, true),
278 false,
279 ),
280 ]
281 .into_iter()
282 .map(Arc::new)
283 .collect())
284 }
285
286 fn name(&self) -> &str {
287 "approx_percentile_cont"
288 }
289
290 fn signature(&self) -> &Signature {
291 &self.signature
292 }
293
294 #[inline]
295 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
296 Ok(Box::new(self.create_accumulator(&acc_args)?))
297 }
298
299 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
300 if !arg_types[0].is_numeric() {
301 return plan_err!("approx_percentile_cont requires numeric input types");
302 }
303 if arg_types.len() == 3 && !arg_types[2].is_integer() {
304 return plan_err!(
305 "approx_percentile_cont requires integer centroids input types"
306 );
307 }
308 Ok(arg_types[0].clone())
309 }
310
311 fn supports_within_group_clause(&self) -> bool {
312 true
313 }
314
315 fn documentation(&self) -> Option<&Documentation> {
316 self.doc()
317 }
318}
319
320#[derive(Debug)]
321pub struct ApproxPercentileAccumulator {
322 digest: TDigest,
323 percentile: f64,
324 return_type: DataType,
325}
326
327impl ApproxPercentileAccumulator {
328 pub fn new(percentile: f64, return_type: DataType) -> Self {
329 Self {
330 digest: TDigest::new(DEFAULT_MAX_SIZE),
331 percentile,
332 return_type,
333 }
334 }
335
336 pub fn new_with_max_size(
337 percentile: f64,
338 return_type: DataType,
339 max_size: usize,
340 ) -> Self {
341 Self {
342 digest: TDigest::new(max_size),
343 percentile,
344 return_type,
345 }
346 }
347
348 pub(crate) fn max_size(&self) -> usize {
350 self.digest.max_size()
351 }
352
353 pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
355 let digests = digests.iter().chain(std::iter::once(&self.digest));
356 self.digest = TDigest::merge_digests(digests)
357 }
358
359 pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
361 debug_assert!(
362 values.null_count() == 0,
363 "convert_to_float assumes nulls have already been filtered out"
364 );
365 match values.data_type() {
366 DataType::Float64 => {
367 let array = downcast_value!(values, Float64Array);
368 Ok(array.values().iter().copied().collect::<Vec<_>>())
369 }
370 DataType::Float32 => {
371 let array = downcast_value!(values, Float32Array);
372 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
373 }
374 DataType::Float16 => {
375 let array = downcast_value!(values, Float16Array);
376 Ok(array
377 .values()
378 .iter()
379 .map(|v| v.to_f64())
380 .collect::<Vec<_>>())
381 }
382 DataType::Int64 => {
383 let array = downcast_value!(values, Int64Array);
384 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
385 }
386 DataType::Int32 => {
387 let array = downcast_value!(values, Int32Array);
388 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
389 }
390 DataType::Int16 => {
391 let array = downcast_value!(values, Int16Array);
392 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
393 }
394 DataType::Int8 => {
395 let array = downcast_value!(values, Int8Array);
396 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
397 }
398 DataType::UInt64 => {
399 let array = downcast_value!(values, UInt64Array);
400 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
401 }
402 DataType::UInt32 => {
403 let array = downcast_value!(values, UInt32Array);
404 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
405 }
406 DataType::UInt16 => {
407 let array = downcast_value!(values, UInt16Array);
408 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
409 }
410 DataType::UInt8 => {
411 let array = downcast_value!(values, UInt8Array);
412 Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
413 }
414 e => internal_err!(
415 "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
416 ),
417 }
418 }
419}
420
421impl Accumulator for ApproxPercentileAccumulator {
422 fn state(&mut self) -> Result<Vec<ScalarValue>> {
423 Ok(self.digest.to_scalar_state().into_iter().collect())
424 }
425
426 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
427 let mut values = Arc::clone(&values[0]);
429 if values.null_count() > 0 {
430 values = filter(&values, &is_not_null(&values)?)?;
431 }
432 let sorted_values = &arrow::compute::sort(&values, None)?;
433 let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
434 self.digest = self.digest.merge_sorted_f64(&sorted_values);
435 Ok(())
436 }
437
438 fn evaluate(&mut self) -> Result<ScalarValue> {
439 if self.digest.count() == 0 {
440 return ScalarValue::try_from(self.return_type.clone());
441 }
442 let q = self.digest.estimate_quantile(self.percentile);
443
444 Ok(match &self.return_type {
447 DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
448 DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
449 DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
450 DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
451 DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
452 DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
453 DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
454 DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
455 DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
456 DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
457 DataType::Float64 => ScalarValue::Float64(Some(q)),
458 v => unreachable!("unexpected return type {}", v),
459 })
460 }
461
462 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
463 if states.is_empty() {
464 return Ok(());
465 }
466
467 let states = (0..states[0].len())
468 .map(|index| {
469 states
470 .iter()
471 .map(|array| ScalarValue::try_from_array(array, index))
472 .collect::<Result<Vec<_>>>()
473 .map(|state| TDigest::from_scalar_state(&state))
474 })
475 .collect::<Result<Vec<_>>>()?;
476
477 self.merge_digests(&states);
478
479 Ok(())
480 }
481
482 fn size(&self) -> usize {
483 size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
484 + self.return_type.size()
485 - size_of_val(&self.return_type)
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use arrow::datatypes::DataType;
492
493 use datafusion_functions_aggregate_common::tdigest::TDigest;
494
495 use crate::approx_percentile_cont::ApproxPercentileAccumulator;
496
497 #[test]
498 fn test_combine_approx_percentile_accumulator() {
499 let mut digests: Vec<TDigest> = Vec::new();
500
501 for _ in 1..=50 {
503 let t = TDigest::new(100);
504 let values: Vec<_> = (1..=1_000).map(f64::from).collect();
505 let t = t.merge_unsorted_f64(values);
506 digests.push(t)
507 }
508
509 let t1 = TDigest::merge_digests(&digests);
510 let t2 = TDigest::merge_digests(&digests);
511
512 let mut accumulator =
513 ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
514
515 accumulator.merge_digests(&[t1]);
516 assert_eq!(accumulator.digest.count(), 50_000);
517 accumulator.merge_digests(&[t2]);
518 assert_eq!(accumulator.digest.count(), 100_000);
519 }
520}