Skip to main content

datafusion_functions_aggregate/
percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25};
26use arrow::buffer::{OffsetBuffer, ScalarBuffer};
27use arrow::{
28    array::{Array, ArrayRef, AsArray},
29    datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type},
30};
31
32use num_traits::AsPrimitive;
33
34use arrow::array::ArrowNativeTypeOp;
35use datafusion_common::internal_err;
36use datafusion_common::types::{NativeType, logical_float64};
37use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
38
39use crate::min_max::{max_udaf, min_udaf};
40use datafusion_common::{
41    Result, ScalarValue, internal_datafusion_err, utils::take_function_args,
42};
43use datafusion_expr::utils::format_state_name;
44use datafusion_expr::{
45    Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature,
46    TypeSignatureClass, Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_expr::{
50    expr::{AggregateFunction, Sort},
51    function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
52    simplify::SimplifyContext,
53};
54use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
55use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
56use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
57use datafusion_macros::user_doc;
58
59use crate::utils::validate_percentile_expr;
60
61/// Precision multiplier for linear interpolation calculations.
62///
63/// This value of 1,000,000 was chosen to balance precision with overflow safety:
64/// - Provides 6 decimal places of precision for the fractional component
65/// - Small enough to avoid overflow when multiplied with typical numeric values
66/// - Sufficient precision for most statistical applications
67///
68/// The interpolation formula: `lower + (upper - lower) * fraction`
69/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION`
70/// to avoid floating-point operations on integer types while maintaining precision.
71///
72/// The interpolation arithmetic is performed in f64 and then cast back to the
73/// native type to avoid overflowing Float16 intermediates.
74const INTERPOLATION_PRECISION: f64 = 1_000_000.0;
75
76create_func!(PercentileCont, percentile_cont_udaf);
77
78/// Computes the exact percentile continuous of a set of numbers
79pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
80    let expr = order_by.expr.clone();
81    let args = vec![expr, percentile];
82
83    Expr::AggregateFunction(AggregateFunction::new_udf(
84        percentile_cont_udaf(),
85        args,
86        false,
87        None,
88        vec![order_by],
89        None,
90    ))
91}
92
93#[user_doc(
94    doc_section(label = "General Functions"),
95    description = "Returns the exact percentile of input values, interpolating between values if needed.",
96    syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
97    sql_example = r#"```sql
98> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
99+----------------------------------------------------------+
100| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
101+----------------------------------------------------------+
102| 45.5                                                     |
103+----------------------------------------------------------+
104```
105
106An alternate syntax is also supported:
107```sql
108> SELECT percentile_cont(column_name, 0.75) FROM table_name;
109+---------------------------------------+
110| percentile_cont(column_name, 0.75)    |
111+---------------------------------------+
112| 45.5                                  |
113+---------------------------------------+
114```"#,
115    standard_argument(name = "expression", prefix = "The"),
116    argument(
117        name = "percentile",
118        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
119    )
120)]
121/// PERCENTILE_CONT aggregate expression. This uses an exact calculation and stores all values
122/// in memory before computing the result. If an approximation is sufficient then
123/// APPROX_PERCENTILE_CONT provides a much more efficient solution.
124///
125/// If using the distinct variation, the memory usage will be similarly high if the
126/// cardinality is high as it stores all distinct values in memory before computing the
127/// result, but if cardinality is low then memory usage will also be lower.
128#[derive(PartialEq, Eq, Hash, Debug)]
129pub struct PercentileCont {
130    signature: Signature,
131    aliases: Vec<String>,
132}
133
134impl Default for PercentileCont {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl PercentileCont {
141    pub fn new() -> Self {
142        Self {
143            signature: Signature::coercible(
144                vec![
145                    Coercion::new_implicit(
146                        TypeSignatureClass::Float,
147                        vec![TypeSignatureClass::Numeric],
148                        NativeType::Float64,
149                    ),
150                    Coercion::new_implicit(
151                        TypeSignatureClass::Native(logical_float64()),
152                        vec![TypeSignatureClass::Numeric],
153                        NativeType::Float64,
154                    ),
155                ],
156                Volatility::Immutable,
157            )
158            .with_parameter_names(vec!["expr", "percentile"])
159            .unwrap(),
160            aliases: vec![String::from("quantile_cont")],
161        }
162    }
163}
164
165impl AggregateUDFImpl for PercentileCont {
166    fn as_any(&self) -> &dyn std::any::Any {
167        self
168    }
169
170    fn name(&self) -> &str {
171        "percentile_cont"
172    }
173
174    fn aliases(&self) -> &[String] {
175        &self.aliases
176    }
177
178    fn signature(&self) -> &Signature {
179        &self.signature
180    }
181
182    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
183        match &arg_types[0] {
184            DataType::Null => Ok(DataType::Float64),
185            dt => Ok(dt.clone()),
186        }
187    }
188
189    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
190        let input_type = args.input_fields[0].data_type().clone();
191        if input_type.is_null() {
192            return Ok(vec![
193                Field::new(
194                    format_state_name(args.name, self.name()),
195                    DataType::Null,
196                    true,
197                )
198                .into(),
199            ]);
200        }
201
202        let field = Field::new_list_field(input_type, true);
203        let state_name = if args.is_distinct {
204            "distinct_percentile_cont"
205        } else {
206            "percentile_cont"
207        };
208
209        Ok(vec![
210            Field::new(
211                format_state_name(args.name, state_name),
212                DataType::List(Arc::new(field)),
213                true,
214            )
215            .into(),
216        ])
217    }
218
219    fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
220        let percentile = get_percentile(&args)?;
221
222        let input_dt = args.expr_fields[0].data_type();
223        if input_dt.is_null() {
224            return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
225        }
226
227        if args.is_distinct {
228            match input_dt {
229                DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::<
230                    Float16Type,
231                >::new(percentile))),
232                DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::<
233                    Float32Type,
234                >::new(percentile))),
235                DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::<
236                    Float64Type,
237                >::new(percentile))),
238                dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
239            }
240        } else {
241            match input_dt {
242                DataType::Float16 => Ok(Box::new(
243                    PercentileContAccumulator::<Float16Type>::new(percentile),
244                )),
245                DataType::Float32 => Ok(Box::new(
246                    PercentileContAccumulator::<Float32Type>::new(percentile),
247                )),
248                DataType::Float64 => Ok(Box::new(
249                    PercentileContAccumulator::<Float64Type>::new(percentile),
250                )),
251                dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
252            }
253        }
254    }
255
256    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
257        !args.is_distinct && !args.expr_fields[0].data_type().is_null()
258    }
259
260    fn create_groups_accumulator(
261        &self,
262        args: AccumulatorArgs,
263    ) -> Result<Box<dyn GroupsAccumulator>> {
264        let percentile = get_percentile(&args)?;
265
266        let input_dt = args.expr_fields[0].data_type();
267        match input_dt {
268            DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::<
269                Float16Type,
270            >::new(percentile))),
271            DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::<
272                Float32Type,
273            >::new(percentile))),
274            DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::<
275                Float64Type,
276            >::new(percentile))),
277            dt => internal_err!("Unsupported datatype for percentile cont: {dt}"),
278        }
279    }
280
281    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
282        Some(Box::new(|aggregate_function, info| {
283            simplify_percentile_cont_aggregate(aggregate_function, info)
284        }))
285    }
286
287    fn supports_within_group_clause(&self) -> bool {
288        true
289    }
290
291    fn documentation(&self) -> Option<&Documentation> {
292        self.doc()
293    }
294}
295
296fn get_percentile(args: &AccumulatorArgs) -> Result<f64> {
297    let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
298
299    let is_descending = args
300        .order_bys
301        .first()
302        .map(|sort_expr| sort_expr.options.descending)
303        .unwrap_or(false);
304
305    let percentile = if is_descending {
306        1.0 - percentile
307    } else {
308        percentile
309    };
310
311    Ok(percentile)
312}
313
314fn simplify_percentile_cont_aggregate(
315    aggregate_function: AggregateFunction,
316    info: &SimplifyContext,
317) -> Result<Expr> {
318    enum PercentileRewriteTarget {
319        Min,
320        Max,
321    }
322
323    let params = &aggregate_function.params;
324    let [value, percentile] = take_function_args("percentile_cont", &params.args)?;
325    //
326    // For simplicity we don't bother with null types (otherwise we'd need to
327    // cast the return type)
328    let input_type = info.get_data_type(value)?;
329    if input_type.is_null() {
330        return Ok(Expr::AggregateFunction(aggregate_function));
331    }
332
333    let is_descending = params
334        .order_by
335        .first()
336        .map(|sort| !sort.asc)
337        .unwrap_or(false);
338
339    let rewrite_target = match percentile {
340        Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => {
341            if is_descending {
342                PercentileRewriteTarget::Max
343            } else {
344                PercentileRewriteTarget::Min
345            }
346        }
347        Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => {
348            if is_descending {
349                PercentileRewriteTarget::Min
350            } else {
351                PercentileRewriteTarget::Max
352            }
353        }
354        _ => return Ok(Expr::AggregateFunction(aggregate_function)),
355    };
356
357    let udaf = match rewrite_target {
358        PercentileRewriteTarget::Min => min_udaf(),
359        PercentileRewriteTarget::Max => max_udaf(),
360    };
361
362    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
363        udaf,
364        vec![value.clone()],
365        params.distinct,
366        params.filter.clone(),
367        vec![],
368        params.null_treatment,
369    ));
370    Ok(rewritten)
371}
372
373/// The percentile_cont accumulator accumulates the raw input values
374/// as native types.
375///
376/// The intermediate state is represented as a List of scalar values updated by
377/// `merge_batch` and a `Vec` of native values that are converted to scalar values
378/// in the final evaluation step so that we avoid expensive conversions and
379/// allocations during `update_batch`.
380#[derive(Debug)]
381struct PercentileContAccumulator<T: ArrowNumericType + Debug> {
382    all_values: Vec<T::Native>,
383    percentile: f64,
384}
385
386impl<T: ArrowNumericType + Debug> PercentileContAccumulator<T> {
387    fn new(percentile: f64) -> Self {
388        Self {
389            all_values: vec![],
390            percentile,
391        }
392    }
393}
394
395impl<T> Accumulator for PercentileContAccumulator<T>
396where
397    T: ArrowNumericType + Debug,
398    T::Native: Copy + AsPrimitive<f64>,
399    f64: AsPrimitive<T::Native>,
400{
401    fn state(&mut self) -> Result<Vec<ScalarValue>> {
402        // Convert `all_values` to `ListArray` and return a single List ScalarValue
403
404        // Build offsets
405        let offsets =
406            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
407
408        // Build inner array
409        let values_array = PrimitiveArray::<T>::new(
410            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
411            None,
412        );
413
414        // Build the result list array
415        let list_array = ListArray::new(
416            Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
417            offsets,
418            Arc::new(values_array),
419            None,
420        );
421
422        Ok(vec![ScalarValue::List(Arc::new(list_array))])
423    }
424
425    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
426        let values = values[0].as_primitive::<T>();
427        self.all_values.reserve(values.len() - values.null_count());
428        self.all_values.extend(values.iter().flatten());
429        Ok(())
430    }
431
432    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
433        let array = states[0].as_list::<i32>();
434        self.update_batch(&[array.value(0)])?;
435        Ok(())
436    }
437
438    fn evaluate(&mut self) -> Result<ScalarValue> {
439        let value = calculate_percentile::<T>(&mut self.all_values, self.percentile);
440        ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
441    }
442
443    fn size(&self) -> usize {
444        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
445    }
446
447    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
448        let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
449        for i in 0..values[0].len() {
450            let v = ScalarValue::try_from_array(&values[0], i)?;
451            if !v.is_null() {
452                *to_remove.entry(v).or_default() += 1;
453            }
454        }
455
456        let mut i = 0;
457        while i < self.all_values.len() {
458            let k =
459                ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
460            if let Some(count) = to_remove.get_mut(&k)
461                && *count > 0
462            {
463                self.all_values.swap_remove(i);
464                *count -= 1;
465                if *count == 0 {
466                    to_remove.remove(&k);
467                    if to_remove.is_empty() {
468                        break;
469                    }
470                }
471            } else {
472                i += 1;
473            }
474        }
475        Ok(())
476    }
477
478    fn supports_retract_batch(&self) -> bool {
479        true
480    }
481}
482
483/// The percentile_cont groups accumulator accumulates the raw input values
484///
485/// For calculating the exact percentile of groups, we need to store all values
486/// of groups before final evaluation.
487/// So values in each group will be stored in a `Vec<T>`, and the total group values
488/// will be actually organized as a `Vec<Vec<T>>`.
489#[derive(Debug)]
490struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
491    group_values: Vec<Vec<T::Native>>,
492    percentile: f64,
493}
494
495impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
496    fn new(percentile: f64) -> Self {
497        Self {
498            group_values: vec![],
499            percentile,
500        }
501    }
502}
503
504impl<T> GroupsAccumulator for PercentileContGroupsAccumulator<T>
505where
506    T: ArrowNumericType + Send,
507    T::Native: Copy + AsPrimitive<f64>,
508    f64: AsPrimitive<T::Native>,
509{
510    fn update_batch(
511        &mut self,
512        values: &[ArrayRef],
513        group_indices: &[usize],
514        opt_filter: Option<&BooleanArray>,
515        total_num_groups: usize,
516    ) -> Result<()> {
517        // For ordered-set aggregates, we only care about the ORDER BY column (first element)
518        // The percentile parameter is already stored in self.percentile
519
520        let values = values[0].as_primitive::<T>();
521
522        // Push the `not nulls + not filtered` row into its group
523        self.group_values.resize(total_num_groups, Vec::new());
524        accumulate(
525            group_indices,
526            values,
527            opt_filter,
528            |group_index, new_value| {
529                self.group_values[group_index].push(new_value);
530            },
531        );
532
533        Ok(())
534    }
535
536    fn merge_batch(
537        &mut self,
538        values: &[ArrayRef],
539        group_indices: &[usize],
540        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
541        _opt_filter: Option<&BooleanArray>,
542        total_num_groups: usize,
543    ) -> Result<()> {
544        assert_eq!(values.len(), 1, "one argument to merge_batch");
545
546        let input_group_values = values[0].as_list::<i32>();
547
548        // Ensure group values big enough
549        self.group_values.resize(total_num_groups, Vec::new());
550
551        // Extend values to related groups
552        group_indices
553            .iter()
554            .zip(input_group_values.iter())
555            .for_each(|(&group_index, values_opt)| {
556                if let Some(values) = values_opt {
557                    let values = values.as_primitive::<T>();
558                    self.group_values[group_index].extend(values.values().iter());
559                }
560            });
561
562        Ok(())
563    }
564
565    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
566        // Emit values
567        let emit_group_values = emit_to.take_needed(&mut self.group_values);
568
569        // Build offsets
570        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
571        offsets.push(0);
572        let mut cur_len = 0_i32;
573        for group_value in &emit_group_values {
574            cur_len += group_value.len() as i32;
575            offsets.push(cur_len);
576        }
577        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
578
579        // Build inner array
580        let flatten_group_values =
581            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
582        let group_values_array =
583            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None);
584
585        // Build the result list array
586        let result_list_array = ListArray::new(
587            Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
588            offsets,
589            Arc::new(group_values_array),
590            None,
591        );
592
593        Ok(vec![Arc::new(result_list_array)])
594    }
595
596    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
597        // Emit values
598        let mut emit_group_values = emit_to.take_needed(&mut self.group_values);
599
600        // Calculate percentile for each group
601        let mut evaluate_result_builder =
602            PrimitiveBuilder::<T>::with_capacity(emit_group_values.len());
603        for values in &mut emit_group_values {
604            let value = calculate_percentile::<T>(values.as_mut_slice(), self.percentile);
605            evaluate_result_builder.append_option(value);
606        }
607
608        Ok(Arc::new(evaluate_result_builder.finish()))
609    }
610
611    fn convert_to_state(
612        &self,
613        values: &[ArrayRef],
614        opt_filter: Option<&BooleanArray>,
615    ) -> Result<Vec<ArrayRef>> {
616        assert_eq!(values.len(), 1, "one argument to merge_batch");
617
618        let input_array = values[0].as_primitive::<T>();
619
620        // Directly convert the input array to states, each row will be
621        // seen as a respective group.
622        // For detail, the `input_array` will be converted to a `ListArray`.
623        // And if row is `not null + not filtered`, it will be converted to a list
624        // with only one element; otherwise, this row in `ListArray` will be set
625        // to null.
626
627        // Reuse values buffer in `input_array` to build `values` in `ListArray`
628        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None);
629
630        // `offsets` in `ListArray`, each row as a list element
631        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
632            internal_datafusion_err!(
633                "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
634            )
635        })?;
636        let offsets = (0..=offset_end).collect::<Vec<_>>();
637        // Safety: The offsets vector is constructed as a sequential range from 0 to input_array.len(),
638        // which guarantees all OffsetBuffer invariants:
639        // 1. Offsets are monotonically increasing (each element is prev + 1)
640        // 2. No offset exceeds the values array length (max offset = input_array.len())
641        // 3. First offset is 0 and last offset equals the total length
642        // Therefore new_unchecked is safe to use here.
643        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
644
645        // `nulls` for converted `ListArray`
646        let nulls = filtered_null_mask(opt_filter, input_array);
647
648        let converted_list_array = ListArray::new(
649            Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
650            offsets,
651            Arc::new(values),
652            nulls,
653        );
654
655        Ok(vec![Arc::new(converted_list_array)])
656    }
657
658    fn supports_convert_to_state(&self) -> bool {
659        true
660    }
661
662    fn size(&self) -> usize {
663        self.group_values
664            .iter()
665            .map(|values| values.capacity() * size_of::<T::Native>())
666            .sum::<usize>()
667            // account for size of self.group_values too
668            + self.group_values.capacity() * size_of::<Vec<T::Native>>()
669    }
670}
671
672#[derive(Debug)]
673struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
674    distinct_values: GenericDistinctBuffer<T>,
675    percentile: f64,
676}
677
678impl<T: ArrowNumericType + Debug> DistinctPercentileContAccumulator<T> {
679    fn new(percentile: f64) -> Self {
680        Self {
681            distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE),
682            percentile,
683        }
684    }
685}
686
687impl<T> Accumulator for DistinctPercentileContAccumulator<T>
688where
689    T: ArrowNumericType + Debug,
690    T::Native: Copy + AsPrimitive<f64>,
691    f64: AsPrimitive<T::Native>,
692{
693    fn state(&mut self) -> Result<Vec<ScalarValue>> {
694        self.distinct_values.state()
695    }
696
697    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
698        self.distinct_values.update_batch(values)
699    }
700
701    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
702        self.distinct_values.merge_batch(states)
703    }
704
705    fn evaluate(&mut self) -> Result<ScalarValue> {
706        let mut values: Vec<T::Native> =
707            self.distinct_values.values.iter().map(|v| v.0).collect();
708        let value = calculate_percentile::<T>(&mut values, self.percentile);
709        ScalarValue::new_primitive::<T>(value, &T::DATA_TYPE)
710    }
711
712    fn size(&self) -> usize {
713        size_of_val(self) + self.distinct_values.size()
714    }
715
716    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
717        if values.is_empty() {
718            return Ok(());
719        }
720
721        let arr = values[0].as_primitive::<T>();
722        for value in arr.iter().flatten() {
723            self.distinct_values.values.remove(&Hashable(value));
724        }
725        Ok(())
726    }
727
728    fn supports_retract_batch(&self) -> bool {
729        true
730    }
731}
732
733/// Calculate the percentile value for a given set of values.
734/// This function performs an exact calculation by sorting all values.
735///
736/// The percentile is calculated using linear interpolation between closest ranks.
737/// For percentile p and n values:
738/// - If p * (n-1) is an integer, return the value at that position
739/// - Otherwise, interpolate between the two closest values
740///
741/// Note: This function takes a mutable slice and sorts it in place, but does not
742/// consume the data. This is important for window frame queries where evaluate()
743/// may be called multiple times on the same accumulator state.
744fn calculate_percentile<T: ArrowNumericType>(
745    values: &mut [T::Native],
746    percentile: f64,
747) -> Option<T::Native>
748where
749    T::Native: Copy + AsPrimitive<f64>,
750    f64: AsPrimitive<T::Native>,
751{
752    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
753
754    let len = values.len();
755    if len == 0 {
756        None
757    } else if len == 1 {
758        Some(values[0])
759    } else if percentile == 0.0 {
760        // Get minimum value
761        Some(
762            *values
763                .iter()
764                .min_by(|a, b| cmp(a, b))
765                .expect("we checked for len > 0 a few lines above"),
766        )
767    } else if percentile == 1.0 {
768        // Get maximum value
769        Some(
770            *values
771                .iter()
772                .max_by(|a, b| cmp(a, b))
773                .expect("we checked for len > 0 a few lines above"),
774        )
775    } else {
776        // Calculate the index using the formula: p * (n - 1)
777        let index = percentile * ((len - 1) as f64);
778        let lower_index = index.floor() as usize;
779        let upper_index = index.ceil() as usize;
780
781        if lower_index == upper_index {
782            // Exact index, return the value at that position
783            let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
784            Some(*value)
785        } else {
786            // Need to interpolate between two values
787            // First, partition at lower_index to get the lower value
788            let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
789            let lower_value = *lower_value;
790
791            // Then partition at upper_index to get the upper value
792            let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
793            let upper_value = *upper_value;
794
795            // Linear interpolation.
796            // We compute a quantized interpolation weight using `INTERPOLATION_PRECISION` because:
797            // 1. Both values come from the input data, so (upper - lower) is bounded by the value range
798            // 2. fraction is between 0 and 1; quantizing it provides stable, predictable results
799            // 3. The result is guaranteed to be between lower_value and upper_value (modulo cast rounding)
800            // 4. Arithmetic is performed in f64 and cast back to avoid overflowing Float16 intermediates
801            let fraction = index - (lower_index as f64);
802            let scaled = (fraction * INTERPOLATION_PRECISION) as usize;
803            let weight = scaled as f64 / INTERPOLATION_PRECISION;
804
805            let lower_f: f64 = lower_value.as_();
806            let upper_f: f64 = upper_value.as_();
807            let interpolated_f = lower_f + (upper_f - lower_f) * weight;
808            Some(interpolated_f.as_())
809        }
810    }
811}
812
813#[cfg(test)]
814mod tests {
815    use super::calculate_percentile;
816    use half::f16;
817
818    #[test]
819    fn f16_interpolation_does_not_overflow_to_nan() {
820        // Regression test for https://github.com/apache/datafusion/issues/18945
821        // Interpolating between 0 and the max finite f16 value previously overflowed
822        // intermediate f16 computations and produced NaN.
823        let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)];
824        let result =
825            calculate_percentile::<arrow::datatypes::Float16Type>(&mut values, 0.5)
826                .expect("non-empty input");
827        let result_f = result.to_f32();
828        assert!(
829            !result_f.is_nan(),
830            "expected non-NaN result, got {result_f}"
831        );
832        // 0.5 percentile should be close to midpoint
833        assert!(
834            (result_f - 32752.0).abs() < 1.0,
835            "unexpected result {result_f}"
836        );
837    }
838}