1use std::fmt::{Debug, Formatter};
19use std::mem::{size_of, size_of_val};
20use std::sync::Arc;
21
22use arrow::array::{
23 ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
24};
25use arrow::buffer::{OffsetBuffer, ScalarBuffer};
26use arrow::{
27 array::{Array, ArrayRef, AsArray},
28 datatypes::{
29 ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
30 Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type,
31 },
32};
33
34use arrow::array::ArrowNativeTypeOp;
35
36use datafusion_common::{
37 internal_datafusion_err, internal_err, plan_err, DataFusionError, HashSet, Result,
38 ScalarValue,
39};
40use datafusion_expr::expr::{AggregateFunction, Sort};
41use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
42use datafusion_expr::type_coercion::aggregates::NUMERICS;
43use datafusion_expr::utils::format_state_name;
44use datafusion_expr::{
45 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
46 Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
51use datafusion_functions_aggregate_common::utils::Hashable;
52use datafusion_macros::user_doc;
53
54use crate::utils::validate_percentile_expr;
55
56const INTERPOLATION_PRECISION: usize = 1_000_000;
67
68create_func!(PercentileCont, percentile_cont_udaf);
69
70pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
72 let expr = order_by.expr.clone();
73 let args = vec![expr, percentile];
74
75 Expr::AggregateFunction(AggregateFunction::new_udf(
76 percentile_cont_udaf(),
77 args,
78 false,
79 None,
80 vec![order_by],
81 None,
82 ))
83}
84
85#[user_doc(
86 doc_section(label = "General Functions"),
87 description = "Returns the exact percentile of input values, interpolating between values if needed.",
88 syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
89 sql_example = r#"```sql
90> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
91+----------------------------------------------------------+
92| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
93+----------------------------------------------------------+
94| 45.5 |
95+----------------------------------------------------------+
96```
97
98An alternate syntax is also supported:
99```sql
100> SELECT percentile_cont(column_name, 0.75) FROM table_name;
101+---------------------------------------+
102| percentile_cont(column_name, 0.75) |
103+---------------------------------------+
104| 45.5 |
105+---------------------------------------+
106```"#,
107 standard_argument(name = "expression", prefix = "The"),
108 argument(
109 name = "percentile",
110 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
111 )
112)]
113#[derive(PartialEq, Eq, Hash)]
121pub struct PercentileCont {
122 signature: Signature,
123 aliases: Vec<String>,
124}
125
126impl Debug for PercentileCont {
127 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
128 f.debug_struct("PercentileCont")
129 .field("name", &self.name())
130 .field("signature", &self.signature)
131 .finish()
132 }
133}
134
135impl Default for PercentileCont {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl PercentileCont {
142 pub fn new() -> Self {
143 let mut variants = Vec::with_capacity(NUMERICS.len());
144 for num in NUMERICS {
146 variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
147 }
148 Self {
149 signature: Signature::one_of(variants, Volatility::Immutable)
150 .with_parameter_names(vec!["expr".to_string(), "percentile".to_string()])
151 .expect("valid parameter names for percentile_cont"),
152 aliases: vec![String::from("quantile_cont")],
153 }
154 }
155
156 fn create_accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
157 let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
158
159 let is_descending = args
160 .order_bys
161 .first()
162 .map(|sort_expr| sort_expr.options.descending)
163 .unwrap_or(false);
164
165 let percentile = if is_descending {
166 1.0 - percentile
167 } else {
168 percentile
169 };
170
171 macro_rules! helper {
172 ($t:ty, $dt:expr) => {
173 if args.is_distinct {
174 Ok(Box::new(DistinctPercentileContAccumulator::<$t> {
175 data_type: $dt.clone(),
176 distinct_values: HashSet::new(),
177 percentile,
178 }))
179 } else {
180 Ok(Box::new(PercentileContAccumulator::<$t> {
181 data_type: $dt.clone(),
182 all_values: vec![],
183 percentile,
184 }))
185 }
186 };
187 }
188
189 let input_dt = args.exprs[0].data_type(args.schema)?;
190 match input_dt {
191 DataType::Int8
193 | DataType::Int16
194 | DataType::Int32
195 | DataType::Int64
196 | DataType::UInt8
197 | DataType::UInt16
198 | DataType::UInt32
199 | DataType::UInt64 => helper!(Float64Type, DataType::Float64),
200 DataType::Float16 => helper!(Float16Type, input_dt),
201 DataType::Float32 => helper!(Float32Type, input_dt),
202 DataType::Float64 => helper!(Float64Type, input_dt),
203 DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
204 DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
205 DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
206 DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
207 _ => Err(DataFusionError::NotImplemented(format!(
208 "PercentileContAccumulator not supported for {} with {}",
209 args.name, input_dt,
210 ))),
211 }
212 }
213}
214
215impl AggregateUDFImpl for PercentileCont {
216 fn as_any(&self) -> &dyn std::any::Any {
217 self
218 }
219
220 fn name(&self) -> &str {
221 "percentile_cont"
222 }
223
224 fn aliases(&self) -> &[String] {
225 &self.aliases
226 }
227
228 fn signature(&self) -> &Signature {
229 &self.signature
230 }
231
232 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
233 if !arg_types[0].is_numeric() {
234 return plan_err!("percentile_cont requires numeric input types");
235 }
236 match &arg_types[0] {
240 DataType::Float16 | DataType::Float32 | DataType::Float64 => {
241 Ok(arg_types[0].clone())
242 }
243 DataType::Decimal32(_, _)
244 | DataType::Decimal64(_, _)
245 | DataType::Decimal128(_, _)
246 | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()),
247 DataType::UInt8
248 | DataType::UInt16
249 | DataType::UInt32
250 | DataType::UInt64
251 | DataType::Int8
252 | DataType::Int16
253 | DataType::Int32
254 | DataType::Int64 => Ok(DataType::Float64),
255 dt => plan_err!(
257 "percentile_cont does not support input type {}, must be numeric",
258 dt
259 ),
260 }
261 }
262
263 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
264 let input_type = args.input_fields[0].data_type().clone();
266 let storage_type = match &input_type {
268 DataType::Int8
269 | DataType::Int16
270 | DataType::Int32
271 | DataType::Int64
272 | DataType::UInt8
273 | DataType::UInt16
274 | DataType::UInt32
275 | DataType::UInt64 => DataType::Float64,
276 _ => input_type,
277 };
278
279 let field = Field::new_list_field(storage_type, true);
280 let state_name = if args.is_distinct {
281 "distinct_percentile_cont"
282 } else {
283 "percentile_cont"
284 };
285
286 Ok(vec![Field::new(
287 format_state_name(args.name, state_name),
288 DataType::List(Arc::new(field)),
289 true,
290 )
291 .into()])
292 }
293
294 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
295 self.create_accumulator(acc_args)
296 }
297
298 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
299 !args.is_distinct
300 }
301
302 fn create_groups_accumulator(
303 &self,
304 args: AccumulatorArgs,
305 ) -> Result<Box<dyn GroupsAccumulator>> {
306 let num_args = args.exprs.len();
307 if num_args != 2 {
308 return internal_err!(
309 "percentile_cont should have 2 args, but found num args:{}",
310 args.exprs.len()
311 );
312 }
313
314 let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
315
316 let is_descending = args
317 .order_bys
318 .first()
319 .map(|sort_expr| sort_expr.options.descending)
320 .unwrap_or(false);
321
322 let percentile = if is_descending {
323 1.0 - percentile
324 } else {
325 percentile
326 };
327
328 macro_rules! helper {
329 ($t:ty, $dt:expr) => {
330 Ok(Box::new(PercentileContGroupsAccumulator::<$t>::new(
331 $dt, percentile,
332 )))
333 };
334 }
335
336 let input_dt = args.exprs[0].data_type(args.schema)?;
337 match input_dt {
338 DataType::Int8
340 | DataType::Int16
341 | DataType::Int32
342 | DataType::Int64
343 | DataType::UInt8
344 | DataType::UInt16
345 | DataType::UInt32
346 | DataType::UInt64 => helper!(Float64Type, DataType::Float64),
347 DataType::Float16 => helper!(Float16Type, input_dt),
348 DataType::Float32 => helper!(Float32Type, input_dt),
349 DataType::Float64 => helper!(Float64Type, input_dt),
350 DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
351 DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
352 DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
353 DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
354 _ => Err(DataFusionError::NotImplemented(format!(
355 "PercentileContGroupsAccumulator not supported for {} with {}",
356 args.name, input_dt,
357 ))),
358 }
359 }
360
361 fn supports_null_handling_clause(&self) -> bool {
362 false
363 }
364
365 fn supports_within_group_clause(&self) -> bool {
366 true
367 }
368
369 fn documentation(&self) -> Option<&Documentation> {
370 self.doc()
371 }
372}
373
374struct PercentileContAccumulator<T: ArrowNumericType> {
382 data_type: DataType,
383 all_values: Vec<T::Native>,
384 percentile: f64,
385}
386
387impl<T: ArrowNumericType> Debug for PercentileContAccumulator<T> {
388 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
389 write!(
390 f,
391 "PercentileContAccumulator({}, percentile={})",
392 self.data_type, self.percentile
393 )
394 }
395}
396
397impl<T: ArrowNumericType> Accumulator for PercentileContAccumulator<T> {
398 fn state(&mut self) -> Result<Vec<ScalarValue>> {
399 let offsets =
403 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
404
405 let values_array = PrimitiveArray::<T>::new(
407 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
408 None,
409 )
410 .with_data_type(self.data_type.clone());
411
412 let list_array = ListArray::new(
414 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
415 offsets,
416 Arc::new(values_array),
417 None,
418 );
419
420 Ok(vec![ScalarValue::List(Arc::new(list_array))])
421 }
422
423 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
424 let values = if values[0].data_type() != &self.data_type {
426 arrow::compute::cast(&values[0], &self.data_type)?
427 } else {
428 Arc::clone(&values[0])
429 };
430
431 let values = values.as_primitive::<T>();
432 self.all_values.reserve(values.len() - values.null_count());
433 self.all_values.extend(values.iter().flatten());
434 Ok(())
435 }
436
437 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
438 let array = states[0].as_list::<i32>();
439 for v in array.iter().flatten() {
440 self.update_batch(&[v])?
441 }
442 Ok(())
443 }
444
445 fn evaluate(&mut self) -> Result<ScalarValue> {
446 let d = std::mem::take(&mut self.all_values);
447 let value = calculate_percentile::<T>(d, self.percentile);
448 ScalarValue::new_primitive::<T>(value, &self.data_type)
449 }
450
451 fn size(&self) -> usize {
452 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
453 }
454}
455
456#[derive(Debug)]
463struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
464 data_type: DataType,
465 group_values: Vec<Vec<T::Native>>,
466 percentile: f64,
467}
468
469impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
470 pub fn new(data_type: DataType, percentile: f64) -> Self {
471 Self {
472 data_type,
473 group_values: Vec::new(),
474 percentile,
475 }
476 }
477}
478
479impl<T: ArrowNumericType + Send> GroupsAccumulator
480 for PercentileContGroupsAccumulator<T>
481{
482 fn update_batch(
483 &mut self,
484 values: &[ArrayRef],
485 group_indices: &[usize],
486 opt_filter: Option<&BooleanArray>,
487 total_num_groups: usize,
488 ) -> Result<()> {
489 let values_array = if values[0].data_type() != &self.data_type {
494 arrow::compute::cast(&values[0], &self.data_type)?
495 } else {
496 Arc::clone(&values[0])
497 };
498
499 let values = values_array.as_primitive::<T>();
500
501 self.group_values.resize(total_num_groups, Vec::new());
503 accumulate(
504 group_indices,
505 values,
506 opt_filter,
507 |group_index, new_value| {
508 self.group_values[group_index].push(new_value);
509 },
510 );
511
512 Ok(())
513 }
514
515 fn merge_batch(
516 &mut self,
517 values: &[ArrayRef],
518 group_indices: &[usize],
519 _opt_filter: Option<&BooleanArray>,
521 total_num_groups: usize,
522 ) -> Result<()> {
523 assert_eq!(values.len(), 1, "one argument to merge_batch");
524
525 let input_group_values = values[0].as_list::<i32>();
526
527 self.group_values.resize(total_num_groups, Vec::new());
529
530 group_indices
532 .iter()
533 .zip(input_group_values.iter())
534 .for_each(|(&group_index, values_opt)| {
535 if let Some(values) = values_opt {
536 let values = values.as_primitive::<T>();
537 self.group_values[group_index].extend(values.values().iter());
538 }
539 });
540
541 Ok(())
542 }
543
544 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
545 let emit_group_values = emit_to.take_needed(&mut self.group_values);
547
548 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
550 offsets.push(0);
551 let mut cur_len = 0_i32;
552 for group_value in &emit_group_values {
553 cur_len += group_value.len() as i32;
554 offsets.push(cur_len);
555 }
556 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
557
558 let flatten_group_values =
560 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
561 let group_values_array =
562 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
563 .with_data_type(self.data_type.clone());
564
565 let result_list_array = ListArray::new(
567 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
568 offsets,
569 Arc::new(group_values_array),
570 None,
571 );
572
573 Ok(vec![Arc::new(result_list_array)])
574 }
575
576 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
577 let emit_group_values = emit_to.take_needed(&mut self.group_values);
579
580 let mut evaluate_result_builder =
582 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
583 for values in emit_group_values {
584 let value = calculate_percentile::<T>(values, self.percentile);
585 evaluate_result_builder.append_option(value);
586 }
587
588 Ok(Arc::new(evaluate_result_builder.finish()))
589 }
590
591 fn convert_to_state(
592 &self,
593 values: &[ArrayRef],
594 opt_filter: Option<&BooleanArray>,
595 ) -> Result<Vec<ArrayRef>> {
596 assert_eq!(values.len(), 1, "one argument to merge_batch");
597
598 let values_array = if values[0].data_type() != &self.data_type {
600 arrow::compute::cast(&values[0], &self.data_type)?
601 } else {
602 Arc::clone(&values[0])
603 };
604
605 let input_array = values_array.as_primitive::<T>();
606
607 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
616 .with_data_type(self.data_type.clone());
617
618 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
620 internal_datafusion_err!(
621 "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
622 )
623 })?;
624 let offsets = (0..=offset_end).collect::<Vec<_>>();
625 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
632
633 let nulls = filtered_null_mask(opt_filter, input_array);
635
636 let converted_list_array = ListArray::new(
637 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
638 offsets,
639 Arc::new(values),
640 nulls,
641 );
642
643 Ok(vec![Arc::new(converted_list_array)])
644 }
645
646 fn supports_convert_to_state(&self) -> bool {
647 true
648 }
649
650 fn size(&self) -> usize {
651 self.group_values
652 .iter()
653 .map(|values| values.capacity() * size_of::<T::Native>())
654 .sum::<usize>()
655 + self.group_values.capacity() * size_of::<Vec<T::Native>>()
657 }
658}
659
660struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
668 data_type: DataType,
669 distinct_values: HashSet<Hashable<T::Native>>,
670 percentile: f64,
671}
672
673impl<T: ArrowNumericType> Debug for DistinctPercentileContAccumulator<T> {
674 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
675 write!(
676 f,
677 "DistinctPercentileContAccumulator({}, percentile={})",
678 self.data_type, self.percentile
679 )
680 }
681}
682
683impl<T: ArrowNumericType> Accumulator for DistinctPercentileContAccumulator<T> {
684 fn state(&mut self) -> Result<Vec<ScalarValue>> {
685 let all_values = self
686 .distinct_values
687 .iter()
688 .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
689 .collect::<Result<Vec<_>>>()?;
690
691 let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
692 Ok(vec![ScalarValue::List(arr)])
693 }
694
695 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
696 if values.is_empty() {
697 return Ok(());
698 }
699
700 let values = if values[0].data_type() != &self.data_type {
702 arrow::compute::cast(&values[0], &self.data_type)?
703 } else {
704 Arc::clone(&values[0])
705 };
706
707 let array = values.as_primitive::<T>();
708 match array.nulls().filter(|x| x.null_count() > 0) {
709 Some(n) => {
710 for idx in n.valid_indices() {
711 self.distinct_values.insert(Hashable(array.value(idx)));
712 }
713 }
714 None => array.values().iter().for_each(|x| {
715 self.distinct_values.insert(Hashable(*x));
716 }),
717 }
718 Ok(())
719 }
720
721 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
722 let array = states[0].as_list::<i32>();
723 for v in array.iter().flatten() {
724 self.update_batch(&[v])?
725 }
726 Ok(())
727 }
728
729 fn evaluate(&mut self) -> Result<ScalarValue> {
730 let d = std::mem::take(&mut self.distinct_values)
731 .into_iter()
732 .map(|v| v.0)
733 .collect::<Vec<_>>();
734 let value = calculate_percentile::<T>(d, self.percentile);
735 ScalarValue::new_primitive::<T>(value, &self.data_type)
736 }
737
738 fn size(&self) -> usize {
739 size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
740 }
741}
742
743fn calculate_percentile<T: ArrowNumericType>(
751 mut values: Vec<T::Native>,
752 percentile: f64,
753) -> Option<T::Native> {
754 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
755
756 let len = values.len();
757 if len == 0 {
758 None
759 } else if len == 1 {
760 Some(values[0])
761 } else if percentile == 0.0 {
762 Some(
764 *values
765 .iter()
766 .min_by(|a, b| cmp(a, b))
767 .expect("we checked for len > 0 a few lines above"),
768 )
769 } else if percentile == 1.0 {
770 Some(
772 *values
773 .iter()
774 .max_by(|a, b| cmp(a, b))
775 .expect("we checked for len > 0 a few lines above"),
776 )
777 } else {
778 let index = percentile * ((len - 1) as f64);
780 let lower_index = index.floor() as usize;
781 let upper_index = index.ceil() as usize;
782
783 if lower_index == upper_index {
784 let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
786 Some(*value)
787 } else {
788 let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
791 let lower_value = *lower_value;
792
793 let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
795 let upper_value = *upper_value;
796
797 let fraction = index - (lower_index as f64);
805 let diff = upper_value.sub_wrapping(lower_value);
806 let interpolated = lower_value.add_wrapping(
807 diff.mul_wrapping(T::Native::usize_as(
808 (fraction * INTERPOLATION_PRECISION as f64) as usize,
809 ))
810 .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)),
811 );
812 Some(interpolated)
813 }
814 }
815}