1use std::fmt::Debug;
19use std::mem::{size_of, size_of_val};
20use std::sync::Arc;
21
22use arrow::array::{
23 ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
24};
25use arrow::buffer::{OffsetBuffer, ScalarBuffer};
26use arrow::{
27 array::{Array, ArrayRef, AsArray},
28 datatypes::{
29 ArrowNativeType, DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type,
30 },
31};
32
33use arrow::array::ArrowNativeTypeOp;
34use datafusion_common::internal_err;
35use datafusion_common::types::{NativeType, logical_float64};
36use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
37
38use crate::min_max::{max_udaf, min_udaf};
39use datafusion_common::{
40 Result, ScalarValue, internal_datafusion_err, utils::take_function_args,
41};
42use datafusion_expr::utils::format_state_name;
43use datafusion_expr::{
44 Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature,
45 TypeSignatureClass, Volatility,
46};
47use datafusion_expr::{EmitTo, GroupsAccumulator};
48use datafusion_expr::{
49 expr::{AggregateFunction, Sort},
50 function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
51 simplify::SimplifyInfo,
52};
53use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
54use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
55use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
56use datafusion_macros::user_doc;
57
58use crate::utils::validate_percentile_expr;
59
60const INTERPOLATION_PRECISION: usize = 1_000_000;
71
72create_func!(PercentileCont, percentile_cont_udaf);
73
74pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
76 let expr = order_by.expr.clone();
77 let args = vec![expr, percentile];
78
79 Expr::AggregateFunction(AggregateFunction::new_udf(
80 percentile_cont_udaf(),
81 args,
82 false,
83 None,
84 vec![order_by],
85 None,
86 ))
87}
88
89#[user_doc(
90 doc_section(label = "General Functions"),
91 description = "Returns the exact percentile of input values, interpolating between values if needed.",
92 syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
93 sql_example = r#"```sql
94> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
95+----------------------------------------------------------+
96| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
97+----------------------------------------------------------+
98| 45.5 |
99+----------------------------------------------------------+
100```
101
102An alternate syntax is also supported:
103```sql
104> SELECT percentile_cont(column_name, 0.75) FROM table_name;
105+---------------------------------------+
106| percentile_cont(column_name, 0.75) |
107+---------------------------------------+
108| 45.5 |
109+---------------------------------------+
110```"#,
111 standard_argument(name = "expression", prefix = "The"),
112 argument(
113 name = "percentile",
114 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
115 )
116)]
117#[derive(PartialEq, Eq, Hash, Debug)]
125pub struct PercentileCont {
126 signature: Signature,
127 aliases: Vec<String>,
128}
129
130impl Default for PercentileCont {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl PercentileCont {
137 pub fn new() -> Self {
138 Self {
139 signature: Signature::coercible(
140 vec![
141 Coercion::new_implicit(
142 TypeSignatureClass::Float,
143 vec![TypeSignatureClass::Numeric],
144 NativeType::Float64,
145 ),
146 Coercion::new_implicit(
147 TypeSignatureClass::Native(logical_float64()),
148 vec![TypeSignatureClass::Numeric],
149 NativeType::Float64,
150 ),
151 ],
152 Volatility::Immutable,
153 )
154 .with_parameter_names(vec!["expr", "percentile"])
155 .unwrap(),
156 aliases: vec![String::from("quantile_cont")],
157 }
158 }
159}
160
161impl AggregateUDFImpl for PercentileCont {
162 fn as_any(&self) -> &dyn std::any::Any {
163 self
164 }
165
166 fn name(&self) -> &str {
167 "percentile_cont"
168 }
169
170 fn aliases(&self) -> &[String] {
171 &self.aliases
172 }
173
174 fn signature(&self) -> &Signature {
175 &self.signature
176 }
177
178 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
179 match &arg_types[0] {
180 DataType::Null => Ok(DataType::Float64),
181 dt => Ok(dt.clone()),
182 }
183 }
184
185 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
186 let input_type = args.input_fields[0].data_type().clone();
187 if input_type.is_null() {
188 return Ok(vec![
189 Field::new(
190 format_state_name(args.name, self.name()),
191 DataType::Null,
192 true,
193 )
194 .into(),
195 ]);
196 }
197
198 let field = Field::new_list_field(input_type, true);
199 let state_name = if args.is_distinct {
200 "distinct_percentile_cont"
201 } else {
202 "percentile_cont"
203 };
204
205 Ok(vec![
206 Field::new(
207 format_state_name(args.name, state_name),
208 DataType::List(Arc::new(field)),
209 true,
210 )
211 .into(),
212 ])
213 }
214
215 fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
216 let percentile = get_percentile(&args)?;
217
218 let input_dt = args.expr_fields[0].data_type();
219 if input_dt.is_null() {
220 return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
221 }
222
223 if args.is_distinct {
224 match input_dt {
225 DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::<
226 Float16Type,
227 >::new(percentile))),
228 DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::<
229 Float32Type,
230 >::new(percentile))),
231 DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::<
232 Float64Type,
233 >::new(percentile))),
234 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
235 }
236 } else {
237 match input_dt {
238 DataType::Float16 => Ok(Box::new(
239 PercentileContAccumulator::<Float16Type>::new(percentile),
240 )),
241 DataType::Float32 => Ok(Box::new(
242 PercentileContAccumulator::<Float32Type>::new(percentile),
243 )),
244 DataType::Float64 => Ok(Box::new(
245 PercentileContAccumulator::<Float64Type>::new(percentile),
246 )),
247 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
248 }
249 }
250 }
251
252 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
253 !args.is_distinct && !args.expr_fields[0].data_type().is_null()
254 }
255
256 fn create_groups_accumulator(
257 &self,
258 args: AccumulatorArgs,
259 ) -> Result<Box<dyn GroupsAccumulator>> {
260 let percentile = get_percentile(&args)?;
261
262 let input_dt = args.expr_fields[0].data_type();
263 match input_dt {
264 DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::<
265 Float16Type,
266 >::new(percentile))),
267 DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::<
268 Float32Type,
269 >::new(percentile))),
270 DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::<
271 Float64Type,
272 >::new(percentile))),
273 dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
274 }
275 }
276
277 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
278 Some(Box::new(|aggregate_function, info| {
279 simplify_percentile_cont_aggregate(aggregate_function, info)
280 }))
281 }
282
283 fn supports_within_group_clause(&self) -> bool {
284 true
285 }
286
287 fn documentation(&self) -> Option<&Documentation> {
288 self.doc()
289 }
290}
291
292fn get_percentile(args: &AccumulatorArgs) -> Result<f64> {
293 let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
294
295 let is_descending = args
296 .order_bys
297 .first()
298 .map(|sort_expr| sort_expr.options.descending)
299 .unwrap_or(false);
300
301 let percentile = if is_descending {
302 1.0 - percentile
303 } else {
304 percentile
305 };
306
307 Ok(percentile)
308}
309
310fn simplify_percentile_cont_aggregate(
311 aggregate_function: AggregateFunction,
312 info: &dyn SimplifyInfo,
313) -> Result<Expr> {
314 enum PercentileRewriteTarget {
315 Min,
316 Max,
317 }
318
319 let params = &aggregate_function.params;
320 let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?;
321 let input_type = info.get_data_type(value)?;
325 if input_type.is_null() {
326 return Ok(Expr::AggregateFunction(aggregate_function));
327 }
328
329 let is_descending = params
330 .order_by
331 .first()
332 .map(|sort| !sort.asc)
333 .unwrap_or(false);
334
335 let rewrite_target = match percentile {
336 Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => {
337 if is_descending {
338 PercentileRewriteTarget::Max
339 } else {
340 PercentileRewriteTarget::Min
341 }
342 }
343 Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => {
344 if is_descending {
345 PercentileRewriteTarget::Min
346 } else {
347 PercentileRewriteTarget::Max
348 }
349 }
350 _ => return Ok(Expr::AggregateFunction(aggregate_function)),
351 };
352
353 let udaf = match rewrite_target {
354 PercentileRewriteTarget::Min => min_udaf(),
355 PercentileRewriteTarget::Max => max_udaf(),
356 };
357
358 let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
359 udaf,
360 vec![value.clone()],
361 params.distinct,
362 params.filter.clone(),
363 vec![],
364 params.null_treatment,
365 ));
366 Ok(rewritten)
367}
368
369#[derive(Debug)]
377struct PercentileContAccumulator<T: ArrowNumericType + Debug> {
378 all_values: Vec<T::Native>,
379 percentile: f64,
380}
381
382impl<T: ArrowNumericType + Debug> PercentileContAccumulator<T> {
383 fn new(percentile: f64) -> Self {
384 Self {
385 all_values: vec![],
386 percentile,
387 }
388 }
389}
390
391impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> {
392 fn state(&mut self) -> Result<Vec<ScalarValue>> {
393 let offsets =
397 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
398
399 let values_array = PrimitiveArray::<T>::new(
401 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
402 None,
403 );
404
405 let list_array = ListArray::new(
407 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
408 offsets,
409 Arc::new(values_array),
410 None,
411 );
412
413 Ok(vec![ScalarValue::List(Arc::new(list_array))])
414 }
415
416 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
417 let values = values[0].as_primitive::<T>();
418 self.all_values.reserve(values.len() - values.null_count());
419 self.all_values.extend(values.iter().flatten());
420 Ok(())
421 }
422
423 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
424 let array = states[0].as_list::<i32>();
425 self.update_batch(&[array.value(0)])?;
426 Ok(())
427 }
428
429 fn evaluate(&mut self) -> Result<ScalarValue> {
430 let d = std::mem::take(&mut self.all_values);
431 let value = calculate_percentile::<T>(d, self.percentile);
432 ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
433 }
434
435 fn size(&self) -> usize {
436 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
437 }
438}
439
440#[derive(Debug)]
447struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
448 group_values: Vec<Vec<T::Native>>,
449 percentile: f64,
450}
451
452impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
453 fn new(percentile: f64) -> Self {
454 Self {
455 group_values: vec![],
456 percentile,
457 }
458 }
459}
460
461impl<T: ArrowNumericType + Send> GroupsAccumulator
462 for PercentileContGroupsAccumulator<T>
463{
464 fn update_batch(
465 &mut self,
466 values: &[ArrayRef],
467 group_indices: &[usize],
468 opt_filter: Option<&BooleanArray>,
469 total_num_groups: usize,
470 ) -> Result<()> {
471 let values = values[0].as_primitive::<T>();
475
476 self.group_values.resize(total_num_groups, Vec::new());
478 accumulate(
479 group_indices,
480 values,
481 opt_filter,
482 |group_index, new_value| {
483 self.group_values[group_index].push(new_value);
484 },
485 );
486
487 Ok(())
488 }
489
490 fn merge_batch(
491 &mut self,
492 values: &[ArrayRef],
493 group_indices: &[usize],
494 _opt_filter: Option<&BooleanArray>,
496 total_num_groups: usize,
497 ) -> Result<()> {
498 assert_eq!(values.len(), 1, "one argument to merge_batch");
499
500 let input_group_values = values[0].as_list::<i32>();
501
502 self.group_values.resize(total_num_groups, Vec::new());
504
505 group_indices
507 .iter()
508 .zip(input_group_values.iter())
509 .for_each(|(&group_index, values_opt)| {
510 if let Some(values) = values_opt {
511 let values = values.as_primitive::<T>();
512 self.group_values[group_index].extend(values.values().iter());
513 }
514 });
515
516 Ok(())
517 }
518
519 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
520 let emit_group_values = emit_to.take_needed(&mut self.group_values);
522
523 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
525 offsets.push(0);
526 let mut cur_len = 0_i32;
527 for group_value in &emit_group_values {
528 cur_len += group_value.len() as i32;
529 offsets.push(cur_len);
530 }
531 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
532
533 let flatten_group_values =
535 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
536 let group_values_array =
537 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None);
538
539 let result_list_array = ListArray::new(
541 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
542 offsets,
543 Arc::new(group_values_array),
544 None,
545 );
546
547 Ok(vec![Arc::new(result_list_array)])
548 }
549
550 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
551 let emit_group_values = emit_to.take_needed(&mut self.group_values);
553
554 let mut evaluate_result_builder =
556 PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
557 for values in emit_group_values {
558 let value = calculate_percentile::<T>(values, self.percentile);
559 evaluate_result_builder.append_option(value);
560 }
561
562 Ok(Arc::new(evaluate_result_builder.finish()))
563 }
564
565 fn convert_to_state(
566 &self,
567 values: &[ArrayRef],
568 opt_filter: Option<&BooleanArray>,
569 ) -> Result<Vec<ArrayRef>> {
570 assert_eq!(values.len(), 1, "one argument to merge_batch");
571
572 let input_array = values[0].as_primitive::<T>();
573
574 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None);
583
584 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
586 internal_datafusion_err!(
587 "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
588 )
589 })?;
590 let offsets = (0..=offset_end).collect::<Vec<_>>();
591 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
598
599 let nulls = filtered_null_mask(opt_filter, input_array);
601
602 let converted_list_array = ListArray::new(
603 Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
604 offsets,
605 Arc::new(values),
606 nulls,
607 );
608
609 Ok(vec![Arc::new(converted_list_array)])
610 }
611
612 fn supports_convert_to_state(&self) -> bool {
613 true
614 }
615
616 fn size(&self) -> usize {
617 self.group_values
618 .iter()
619 .map(|values| values.capacity() * size_of::<T::Native>())
620 .sum::<usize>()
621 + self.group_values.capacity() * size_of::<Vec<T::Native>>()
623 }
624}
625
626#[derive(Debug)]
627struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
628 distinct_values: GenericDistinctBuffer<T>,
629 percentile: f64,
630}
631
632impl<T: ArrowNumericType + Debug> DistinctPercentileContAccumulator<T> {
633 fn new(percentile: f64) -> Self {
634 Self {
635 distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE),
636 percentile,
637 }
638 }
639}
640
641impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumulator<T> {
642 fn state(&mut self) -> Result<Vec<ScalarValue>> {
643 self.distinct_values.state()
644 }
645
646 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
647 self.distinct_values.update_batch(values)
648 }
649
650 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
651 self.distinct_values.merge_batch(states)
652 }
653
654 fn evaluate(&mut self) -> Result<ScalarValue> {
655 let d = std::mem::take(&mut self.distinct_values.values)
656 .into_iter()
657 .map(|v| v.0)
658 .collect::<Vec<_>>();
659 let value = calculate_percentile::<T>(d, self.percentile);
660 ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
661 }
662
663 fn size(&self) -> usize {
664 size_of_val(self) + self.distinct_values.size()
665 }
666}
667
668fn calculate_percentile<T: ArrowNumericType>(
676 mut values: Vec<T::Native>,
677 percentile: f64,
678) -> Option<T::Native> {
679 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
680
681 let len = values.len();
682 if len == 0 {
683 None
684 } else if len == 1 {
685 Some(values[0])
686 } else if percentile == 0.0 {
687 Some(
689 *values
690 .iter()
691 .min_by(|a, b| cmp(a, b))
692 .expect("we checked for len > 0 a few lines above"),
693 )
694 } else if percentile == 1.0 {
695 Some(
697 *values
698 .iter()
699 .max_by(|a, b| cmp(a, b))
700 .expect("we checked for len > 0 a few lines above"),
701 )
702 } else {
703 let index = percentile * ((len - 1) as f64);
705 let lower_index = index.floor() as usize;
706 let upper_index = index.ceil() as usize;
707
708 if lower_index == upper_index {
709 let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
711 Some(*value)
712 } else {
713 let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
716 let lower_value = *lower_value;
717
718 let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
720 let upper_value = *upper_value;
721
722 let fraction = index - (lower_index as f64);
730 let diff = upper_value.sub_wrapping(lower_value);
731 let interpolated = lower_value.add_wrapping(
732 diff.mul_wrapping(T::Native::usize_as(
733 (fraction * INTERPOLATION_PRECISION as f64) as usize,
734 ))
735 .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)),
736 );
737 Some(interpolated)
738 }
739 }
740}