1use std::collections::HashMap;
19use std::fmt::Debug;
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25};
26use arrow::buffer::{OffsetBuffer, ScalarBuffer};
27use arrow::{
28 array::{Array, ArrayRef, AsArray},
29 datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type},
30};
31
32use num_traits::AsPrimitive;
33
34use arrow::array::ArrowNativeTypeOp;
35use datafusion_common::internal_err;
36use datafusion_common::types::{NativeType, logical_float64};
37use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
38
39use crate::min_max::{max_udaf, min_udaf};
40use datafusion_common::{
41 Result, ScalarValue, internal_datafusion_err, utils::take_function_args,
42};
43use datafusion_expr::utils::format_state_name;
44use datafusion_expr::{
45 Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature,
46 TypeSignatureClass, Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_expr::{
50 expr::{AggregateFunction, Sort},
51 function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
52 simplify::SimplifyContext,
53};
54use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
55use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
56use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
57use datafusion_macros::user_doc;
58
59use crate::utils::validate_percentile_expr;
60
61const INTERPOLATION_PRECISION: f64 = 1_000_000.0;
75
76create_func!(PercentileCont, percentile_cont_udaf);
77
78pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
80 let expr = order_by.expr.clone();
81 let args = vec![expr, percentile];
82
83 Expr::AggregateFunction(AggregateFunction::new_udf(
84 percentile_cont_udaf(),
85 args,
86 false,
87 None,
88 vec![order_by],
89 None,
90 ))
91}
92
93#[user_doc(
94 doc_section(label = "General Functions"),
95 description = "Returns the exact percentile of input values, interpolating between values if needed.",
96 syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
97 sql_example = r#"```sql
98> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
99+----------------------------------------------------------+
100| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
101+----------------------------------------------------------+
102| 45.5 |
103+----------------------------------------------------------+
104```
105
106An alternate syntax is also supported:
107```sql
108> SELECT percentile_cont(column_name, 0.75) FROM table_name;
109+---------------------------------------+
110| percentile_cont(column_name, 0.75) |
111+---------------------------------------+
112| 45.5 |
113+---------------------------------------+
114```"#,
115 standard_argument(name = "expression", prefix = "The"),
116 argument(
117 name = "percentile",
118 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
119 )
120)]
121#[derive(PartialEq, Eq, Hash, Debug)]
129pub struct PercentileCont {
130 signature: Signature,
131 aliases: Vec<String>,
132}
133
134impl Default for PercentileCont {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl PercentileCont {
141 pub fn new() -> Self {
142 Self {
143 signature: Signature::coercible(
144 vec![
145 Coercion::new_implicit(
146 TypeSignatureClass::Float,
147 vec![TypeSignatureClass::Numeric],
148 NativeType::Float64,
149 ),
150 Coercion::new_implicit(
151 TypeSignatureClass::Native(logical_float64()),
152 vec![TypeSignatureClass::Numeric],
153 NativeType::Float64,
154 ),
155 ],
156 Volatility::Immutable,
157 )
158 .with_parameter_names(vec!["expr", "percentile"])
159 .unwrap(),
160 aliases: vec![String::from("quantile_cont")],
161 }
162 }
163}
164
165impl AggregateUDFImpl for PercentileCont {
166 fn as_any(&self) -> &dyn std::any::Any {
167 self
168 }
169
170 fn name(&self) -> &str {
171 "percentile_cont"
172 }
173
174 fn aliases(&self) -> &[String] {
175 &self.aliases
176 }
177
178 fn signature(&self) -> &Signature {
179 &self.signature
180 }
181
182 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
183 match &arg_types[0] {
184 DataType::Null => Ok(DataType::Float64),
185 dt => Ok(dt.clone()),
186 }
187 }
188
189 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
190 let input_type = args.input_fields[0].data_type().clone();
191 if input_type.is_null() {
192 return Ok(vec![
193 Field::new(
194 format_state_name(args.name, self.name()),
195 DataType::Null,
196 true,
197 )
198 .into(),
199 ]);
200 }
201
202 let field = Field::new_list_field(input_type, true);
203 let state_name = if args.is_distinct {
204 "distinct_percentile_cont"
205 } else {
206 "percentile_cont"
207 };
208
209 Ok(vec![
210 Field::new(
211 format_state_name(args.name, state_name),
212 DataType::List(Arc::new(field)),
213 true,
214 )
215 .into(),
216 ])
217 }
218
219 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
220 let percentile = get_percentile(&args)?;
221
222 let input_dt = args.expr_fields[0].data_type();
223 if input_dt.is_null() {
224 return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
225 }
226
227 if args.is_distinct {
228 match input_dt {
229 DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::<
230 Float16Type,
231 >::new(percentile))),
232 DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::<
233 Float32Type,
234 >::new(percentile))),
235 DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::<
236 Float64Type,
237 >::new(percentile))),
238 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
239 }
240 } else {
241 match input_dt {
242 DataType::Float16 => Ok(Box::new(
243 PercentileContAccumulator::<Float16Type>::new(percentile),
244 )),
245 DataType::Float32 => Ok(Box::new(
246 PercentileContAccumulator::<Float32Type>::new(percentile),
247 )),
248 DataType::Float64 => Ok(Box::new(
249 PercentileContAccumulator::<Float64Type>::new(percentile),
250 )),
251 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
252 }
253 }
254 }
255
256 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
257 !args.is_distinct && !args.expr_fields[0].data_type().is_null()
258 }
259
260 fn create_groups_accumulator(
261 &self,
262 args: AccumulatorArgs,
263 ) -> Result<Box<dyn GroupsAccumulator>> {
264 let percentile = get_percentile(&args)?;
265
266 let input_dt = args.expr_fields[0].data_type();
267 match input_dt {
268 DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::<
269 Float16Type,
270 >::new(percentile))),
271 DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::<
272 Float32Type,
273 >::new(percentile))),
274 DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::<
275 Float64Type,
276 >::new(percentile))),
277 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
278 }
279 }
280
281 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
282 Some(Box::new(|aggregate_function, info| {
283 simplify_percentile_cont_aggregate(aggregate_function, info)
284 }))
285 }
286
287 fn supports_within_group_clause(&self) -> bool {
288 true
289 }
290
291 fn documentation(&self) -> Option<&Documentation> {
292 self.doc()
293 }
294}
295
296fn get_percentile(args: &AccumulatorArgs) -> Result<f64> {
297 let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
298
299 let is_descending = args
300 .order_bys
301 .first()
302 .map(|sort_expr| sort_expr.options.descending)
303 .unwrap_or(false);
304
305 let percentile = if is_descending {
306 1.0 - percentile
307 } else {
308 percentile
309 };
310
311 Ok(percentile)
312}
313
314fn simplify_percentile_cont_aggregate(
315 aggregate_function: AggregateFunction,
316 info: &SimplifyContext,
317) -> Result<Expr> {
318 enum PercentileRewriteTarget {
319 Min,
320 Max,
321 }
322
323 let params = &aggregate_function.params;
324 let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?;
325 let input_type = info.get_data_type(value)?;
329 if input_type.is_null() {
330 return Ok(Expr::AggregateFunction(aggregate_function));
331 }
332
333 let is_descending = params
334 .order_by
335 .first()
336 .map(|sort| !sort.asc)
337 .unwrap_or(false);
338
339 let rewrite_target = match percentile {
340 Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => {
341 if is_descending {
342 PercentileRewriteTarget::Max
343 } else {
344 PercentileRewriteTarget::Min
345 }
346 }
347 Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => {
348 if is_descending {
349 PercentileRewriteTarget::Min
350 } else {
351 PercentileRewriteTarget::Max
352 }
353 }
354 _ => return Ok(Expr::AggregateFunction(aggregate_function)),
355 };
356
357 let udaf = match rewrite_target {
358 PercentileRewriteTarget::Min => min_udaf(),
359 PercentileRewriteTarget::Max => max_udaf(),
360 };
361
362 let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
363 udaf,
364 vec![value.clone()],
365 params.distinct,
366 params.filter.clone(),
367 vec![],
368 params.null_treatment,
369 ));
370 Ok(rewritten)
371}
372
373#[derive(Debug)]
381struct PercentileContAccumulator<T: ArrowNumericType + Debug> {
382 all_values: Vec<T::Native>,
383 percentile: f64,
384}
385
386impl<T: ArrowNumericType + Debug> PercentileContAccumulator<T> {
387 fn new(percentile: f64) -> Self {
388 Self {
389 all_values: vec![],
390 percentile,
391 }
392 }
393}
394
395impl<T> Accumulator for PercentileContAccumulator<T>
396where
397 T: ArrowNumericType + Debug,
398 T::Native: Copy + AsPrimitive<f64>,
399 f64: AsPrimitive<T::Native>,
400{
401 fn state(&mut self) -> Result<Vec<ScalarValue>> {
402 let offsets =
406 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
407
408 let values_array = PrimitiveArray::<T>::new(
410 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
411 None,
412 );
413
414 let list_array = ListArray::new(
416 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
417 offsets,
418 Arc::new(values_array),
419 None,
420 );
421
422 Ok(vec![ScalarValue::List(Arc::new(list_array))])
423 }
424
425 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
426 let values = values[0].as_primitive::<T>();
427 self.all_values.reserve(values.len() - values.null_count());
428 self.all_values.extend(values.iter().flatten());
429 Ok(())
430 }
431
432 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
433 let array = states[0].as_list::<i32>();
434 self.update_batch(&[array.value(0)])?;
435 Ok(())
436 }
437
438 fn evaluate(&mut self) -> Result<ScalarValue> {
439 let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
440 ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
441 }
442
443 fn size(&self) -> usize {
444 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
445 }
446
447 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
448 let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
449 for i in 0..values[0].len() {
450 let v = ScalarValue::try_from_array(&values[0], i)?;
451 if !v.is_null() {
452 *to_remove.entry(v).or_default() += 1;
453 }
454 }
455
456 let mut i = 0;
457 while i < self.all_values.len() {
458 let k =
459 ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
460 if let Some(count) = to_remove.get_mut(&k)
461 && *count > 0
462 {
463 self.all_values.swap_remove(i);
464 *count -= 1;
465 if *count == 0 {
466 to_remove.remove(&k);
467 if to_remove.is_empty() {
468 break;
469 }
470 }
471 } else {
472 i += 1;
473 }
474 }
475 Ok(())
476 }
477
478 fn supports_retract_batch(&self) -> bool {
479 true
480 }
481}
482
483#[derive(Debug)]
490struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
491 group_values: Vec<Vec<T::Native>>,
492 percentile: f64,
493}
494
495impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
496 fn new(percentile: f64) -> Self {
497 Self {
498 group_values: vec![],
499 percentile,
500 }
501 }
502}
503
504impl<T> GroupsAccumulator for PercentileContGroupsAccumulator<T>
505where
506 T: ArrowNumericType + Send,
507 T::Native: Copy + AsPrimitive<f64>,
508 f64: AsPrimitive<T::Native>,
509{
510 fn update_batch(
511 &mut self,
512 values: &[ArrayRef],
513 group_indices: &[usize],
514 opt_filter: Option<&BooleanArray>,
515 total_num_groups: usize,
516 ) -> Result<()> {
517 let values = values[0].as_primitive::<T>();
521
522 self.group_values.resize(total_num_groups, Vec::new());
524 accumulate(
525 group_indices,
526 values,
527 opt_filter,
528 |group_index, new_value| {
529 self.group_values[group_index].push(new_value);
530 },
531 );
532
533 Ok(())
534 }
535
536 fn merge_batch(
537 &mut self,
538 values: &[ArrayRef],
539 group_indices: &[usize],
540 _opt_filter: Option<&BooleanArray>,
542 total_num_groups: usize,
543 ) -> Result<()> {
544 assert_eq!(values.len(), 1, "one argument to merge_batch");
545
546 let input_group_values = values[0].as_list::<i32>();
547
548 self.group_values.resize(total_num_groups, Vec::new());
550
551 group_indices
553 .iter()
554 .zip(input_group_values.iter())
555 .for_each(|(&group_index, values_opt)| {
556 if let Some(values) = values_opt {
557 let values = values.as_primitive::<T>();
558 self.group_values[group_index].extend(values.values().iter());
559 }
560 });
561
562 Ok(())
563 }
564
565 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
566 let emit_group_values = emit_to.take_needed(&mut self.group_values);
568
569 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
571 offsets.push(0);
572 let mut cur_len = 0_i32;
573 for group_value in &emit_group_values {
574 cur_len += group_value.len() as i32;
575 offsets.push(cur_len);
576 }
577 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
578
579 let flatten_group_values =
581 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
582 let group_values_array =
583 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None);
584
585 let result_list_array = ListArray::new(
587 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
588 offsets,
589 Arc::new(group_values_array),
590 None,
591 );
592
593 Ok(vec![Arc::new(result_list_array)])
594 }
595
596 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
597 let mut emit_group_values = emit_to.take_needed(&mut self.group_values);
599
600 let mut evaluate_result_builder =
602 PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
603 for values in &mut emit_group_values {
604 let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
605 evaluate_result_builder.append_option(value);
606 }
607
608 Ok(Arc::new(evaluate_result_builder.finish()))
609 }
610
611 fn convert_to_state(
612 &self,
613 values: &[ArrayRef],
614 opt_filter: Option<&BooleanArray>,
615 ) -> Result<Vec<ArrayRef>> {
616 assert_eq!(values.len(), 1, "one argument to merge_batch");
617
618 let input_array = values[0].as_primitive::<T>();
619
620 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None);
629
630 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
632 internal_datafusion_err!(
633 "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
634 )
635 })?;
636 let offsets = (0..=offset_end).collect::<Vec<_>>();
637 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
644
645 let nulls = filtered_null_mask(opt_filter, input_array);
647
648 let converted_list_array = ListArray::new(
649 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
650 offsets,
651 Arc::new(values),
652 nulls,
653 );
654
655 Ok(vec![Arc::new(converted_list_array)])
656 }
657
658 fn supports_convert_to_state(&self) -> bool {
659 true
660 }
661
662 fn size(&self) -> usize {
663 self.group_values
664 .iter()
665 .map(|values| values.capacity() * size_of::<T::Native>())
666 .sum::<usize>()
667 + self.group_values.capacity() * size_of::<Vec<T::Native>>()
669 }
670}
671
672#[derive(Debug)]
673struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
674 distinct_values: GenericDistinctBuffer<T>,
675 percentile: f64,
676}
677
678impl<T: ArrowNumericType + Debug> DistinctPercentileContAccumulator<T> {
679 fn new(percentile: f64) -> Self {
680 Self {
681 distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE),
682 percentile,
683 }
684 }
685}
686
687impl<T> Accumulator for DistinctPercentileContAccumulator<T>
688where
689 T: ArrowNumericType + Debug,
690 T::Native: Copy + AsPrimitive<f64>,
691 f64: AsPrimitive<T::Native>,
692{
693 fn state(&mut self) -> Result<Vec<ScalarValue>> {
694 self.distinct_values.state()
695 }
696
697 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
698 self.distinct_values.update_batch(values)
699 }
700
701 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
702 self.distinct_values.merge_batch(states)
703 }
704
705 fn evaluate(&mut self) -> Result<ScalarValue> {
706 let mut values: Vec<T::Native> =
707 self.distinct_values.values.iter().map(|v| v.0).collect();
708 let value = calculate_percentile::<T>(&mut values, self.percentile);
709 ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
710 }
711
712 fn size(&self) -> usize {
713 size_of_val(self) + self.distinct_values.size()
714 }
715
716 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
717 if values.is_empty() {
718 return Ok(());
719 }
720
721 let arr = values[0].as_primitive::<T>();
722 for value in arr.iter().flatten() {
723 self.distinct_values.values.remove(&Hashable(value));
724 }
725 Ok(())
726 }
727
728 fn supports_retract_batch(&self) -> bool {
729 true
730 }
731}
732
733fn calculate_percentile<T: ArrowNumericType>(
745 values: &mut [T::Native],
746 percentile: f64,
747) -> Option<T::Native>
748where
749 T::Native: Copy + AsPrimitive<f64>,
750 f64: AsPrimitive<T::Native>,
751{
752 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
753
754 let len = values.len();
755 if len == 0 {
756 None
757 } else if len == 1 {
758 Some(values[0])
759 } else if percentile == 0.0 {
760 Some(
762 *values
763 .iter()
764 .min_by(|a, b| cmp(a, b))
765 .expect("we checked for len > 0 a few lines above"),
766 )
767 } else if percentile == 1.0 {
768 Some(
770 *values
771 .iter()
772 .max_by(|a, b| cmp(a, b))
773 .expect("we checked for len > 0 a few lines above"),
774 )
775 } else {
776 let index = percentile * ((len - 1) as f64);
778 let lower_index = index.floor() as usize;
779 let upper_index = index.ceil() as usize;
780
781 if lower_index == upper_index {
782 let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
784 Some(*value)
785 } else {
786 let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
789 let lower_value = *lower_value;
790
791 let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
793 let upper_value = *upper_value;
794
795 let fraction = index - (lower_index as f64);
802 let scaled = (fraction * INTERPOLATION_PRECISION) as usize;
803 let weight = scaled as f64 / INTERPOLATION_PRECISION;
804
805 let lower_f: f64 = lower_value.as_();
806 let upper_f: f64 = upper_value.as_();
807 let interpolated_f = lower_f + (upper_f - lower_f) * weight;
808 Some(interpolated_f.as_())
809 }
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::calculate_percentile;
816 use half::f16;
817
818 #[test]
819 fn f16_interpolation_does_not_overflow_to_nan() {
820 let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)];
824 let result =
825 calculate_percentile::<arrow::datatypes::Float16Type>(&mut values, 0.5)
826 .expect("non-empty input");
827 let result_f = result.to_f32();
828 assert!(
829 !result_f.is_nan(),
830 "expected non-NaN result, got {result_f}"
831 );
832 assert!(
834 (result_f - 32752.0).abs() < 1.0,
835 "unexpected result {result_f}"
836 );
837 }
838}