datafusion_functions_aggregate/
percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{Debug, Formatter};
19use std::mem::{size_of, size_of_val};
20use std::sync::Arc;
21
22use arrow::array::{
23    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
24};
25use arrow::buffer::{OffsetBuffer, ScalarBuffer};
26use arrow::{
27    array::{Array, ArrayRef, AsArray},
28    datatypes::{
29        ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
30        Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type,
31    },
32};
33
34use arrow::array::ArrowNativeTypeOp;
35
36use datafusion_common::{
37    internal_datafusion_err, internal_err, plan_err, DataFusionError, HashSet, Result,
38    ScalarValue,
39};
40use datafusion_expr::expr::{AggregateFunction, Sort};
41use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
42use datafusion_expr::type_coercion::aggregates::NUMERICS;
43use datafusion_expr::utils::format_state_name;
44use datafusion_expr::{
45    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
46    Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
51use datafusion_functions_aggregate_common::utils::Hashable;
52use datafusion_macros::user_doc;
53
54use crate::utils::validate_percentile_expr;
55
56/// Precision multiplier for linear interpolation calculations.
57///
58/// This value of 1,000,000 was chosen to balance precision with overflow safety:
59/// - Provides 6 decimal places of precision for the fractional component
60/// - Small enough to avoid overflow when multiplied with typical numeric values
61/// - Sufficient precision for most statistical applications
62///
63/// The interpolation formula: `lower + (upper - lower) * fraction`
64/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION`
65/// to avoid floating-point operations on integer types while maintaining precision.
66const INTERPOLATION_PRECISION: usize = 1_000_000;
67
68create_func!(PercentileCont, percentile_cont_udaf);
69
70/// Computes the exact percentile continuous of a set of numbers
71pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
72    let expr = order_by.expr.clone();
73    let args = vec![expr, percentile];
74
75    Expr::AggregateFunction(AggregateFunction::new_udf(
76        percentile_cont_udaf(),
77        args,
78        false,
79        None,
80        vec![order_by],
81        None,
82    ))
83}
84
85#[user_doc(
86    doc_section(label = "General Functions"),
87    description = "Returns the exact percentile of input values, interpolating between values if needed.",
88    syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY expression)",
89    sql_example = r#"```sql
90> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
91+----------------------------------------------------------+
92| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
93+----------------------------------------------------------+
94| 45.5                                                     |
95+----------------------------------------------------------+
96```
97
98An alternate syntax is also supported:
99```sql
100> SELECT percentile_cont(column_name, 0.75) FROM table_name;
101+---------------------------------------+
102| percentile_cont(column_name, 0.75)    |
103+---------------------------------------+
104| 45.5                                  |
105+---------------------------------------+
106```"#,
107    standard_argument(name = "expression", prefix = "The"),
108    argument(
109        name = "percentile",
110        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
111    )
112)]
113/// PERCENTILE_CONT aggregate expression. This uses an exact calculation and stores all values
114/// in memory before computing the result. If an approximation is sufficient then
115/// APPROX_PERCENTILE_CONT provides a much more efficient solution.
116///
117/// If using the distinct variation, the memory usage will be similarly high if the
118/// cardinality is high as it stores all distinct values in memory before computing the
119/// result, but if cardinality is low then memory usage will also be lower.
120#[derive(PartialEq, Eq, Hash)]
121pub struct PercentileCont {
122    signature: Signature,
123    aliases: Vec<String>,
124}
125
126impl Debug for PercentileCont {
127    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
128        f.debug_struct("PercentileCont")
129            .field("name", &self.name())
130            .field("signature", &self.signature)
131            .finish()
132    }
133}
134
135impl Default for PercentileCont {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl PercentileCont {
142    pub fn new() -> Self {
143        let mut variants = Vec::with_capacity(NUMERICS.len());
144        // Accept any numeric value paired with a float64 percentile
145        for num in NUMERICS {
146            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
147        }
148        Self {
149            signature: Signature::one_of(variants, Volatility::Immutable)
150                .with_parameter_names(vec!["expr".to_string(), "percentile".to_string()])
151                .expect("valid parameter names for percentile_cont"),
152            aliases: vec![String::from("quantile_cont")],
153        }
154    }
155
156    fn create_accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
157        let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
158
159        let is_descending = args
160            .order_bys
161            .first()
162            .map(|sort_expr| sort_expr.options.descending)
163            .unwrap_or(false);
164
165        let percentile = if is_descending {
166            1.0 - percentile
167        } else {
168            percentile
169        };
170
171        macro_rules! helper {
172            ($t:ty, $dt:expr) => {
173                if args.is_distinct {
174                    Ok(Box::new(DistinctPercentileContAccumulator::<$t> {
175                        data_type: $dt.clone(),
176                        distinct_values: HashSet::new(),
177                        percentile,
178                    }))
179                } else {
180                    Ok(Box::new(PercentileContAccumulator::<$t> {
181                        data_type: $dt.clone(),
182                        all_values: vec![],
183                        percentile,
184                    }))
185                }
186            };
187        }
188
189        let input_dt = args.exprs[0].data_type(args.schema)?;
190        match input_dt {
191            // For integer types, use Float64 internally since percentile_cont returns Float64
192            DataType::Int8
193            | DataType::Int16
194            | DataType::Int32
195            | DataType::Int64
196            | DataType::UInt8
197            | DataType::UInt16
198            | DataType::UInt32
199            | DataType::UInt64 => helper!(Float64Type, DataType::Float64),
200            DataType::Float16 => helper!(Float16Type, input_dt),
201            DataType::Float32 => helper!(Float32Type, input_dt),
202            DataType::Float64 => helper!(Float64Type, input_dt),
203            DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
204            DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
205            DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
206            DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
207            _ => Err(DataFusionError::NotImplemented(format!(
208                "PercentileContAccumulator not supported for {} with {}",
209                args.name, input_dt,
210            ))),
211        }
212    }
213}
214
215impl AggregateUDFImpl for PercentileCont {
216    fn as_any(&self) -> &dyn std::any::Any {
217        self
218    }
219
220    fn name(&self) -> &str {
221        "percentile_cont"
222    }
223
224    fn aliases(&self) -> &[String] {
225        &self.aliases
226    }
227
228    fn signature(&self) -> &Signature {
229        &self.signature
230    }
231
232    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
233        if !arg_types[0].is_numeric() {
234            return plan_err!("percentile_cont requires numeric input types");
235        }
236        // PERCENTILE_CONT performs linear interpolation and should return a float type
237        // For integer inputs, return Float64 (matching PostgreSQL/DuckDB behavior)
238        // For float inputs, preserve the float type
239        match &arg_types[0] {
240            DataType::Float16 | DataType::Float32 | DataType::Float64 => {
241                Ok(arg_types[0].clone())
242            }
243            DataType::Decimal32(_, _)
244            | DataType::Decimal64(_, _)
245            | DataType::Decimal128(_, _)
246            | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()),
247            DataType::UInt8
248            | DataType::UInt16
249            | DataType::UInt32
250            | DataType::UInt64
251            | DataType::Int8
252            | DataType::Int16
253            | DataType::Int32
254            | DataType::Int64 => Ok(DataType::Float64),
255            // Shouldn't happen due to signature check, but just in case
256            dt => plan_err!(
257                "percentile_cont does not support input type {}, must be numeric",
258                dt
259            ),
260        }
261    }
262
263    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
264        //Intermediate state is a list of the elements we have collected so far
265        let input_type = args.input_fields[0].data_type().clone();
266        // For integer types, we store as Float64 internally
267        let storage_type = match &input_type {
268            DataType::Int8
269            | DataType::Int16
270            | DataType::Int32
271            | DataType::Int64
272            | DataType::UInt8
273            | DataType::UInt16
274            | DataType::UInt32
275            | DataType::UInt64 => DataType::Float64,
276            _ => input_type,
277        };
278
279        let field = Field::new_list_field(storage_type, true);
280        let state_name = if args.is_distinct {
281            "distinct_percentile_cont"
282        } else {
283            "percentile_cont"
284        };
285
286        Ok(vec![Field::new(
287            format_state_name(args.name, state_name),
288            DataType::List(Arc::new(field)),
289            true,
290        )
291        .into()])
292    }
293
294    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
295        self.create_accumulator(acc_args)
296    }
297
298    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
299        !args.is_distinct
300    }
301
302    fn create_groups_accumulator(
303        &self,
304        args: AccumulatorArgs,
305    ) -> Result<Box<dyn GroupsAccumulator>> {
306        let num_args = args.exprs.len();
307        if num_args != 2 {
308            return internal_err!(
309                "percentile_cont should have 2 args, but found num args:{}",
310                args.exprs.len()
311            );
312        }
313
314        let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;
315
316        let is_descending = args
317            .order_bys
318            .first()
319            .map(|sort_expr| sort_expr.options.descending)
320            .unwrap_or(false);
321
322        let percentile = if is_descending {
323            1.0 - percentile
324        } else {
325            percentile
326        };
327
328        macro_rules! helper {
329            ($t:ty, $dt:expr) => {
330                Ok(Box::new(PercentileContGroupsAccumulator::<$t>::new(
331                    $dt, percentile,
332                )))
333            };
334        }
335
336        let input_dt = args.exprs[0].data_type(args.schema)?;
337        match input_dt {
338            // For integer types, use Float64 internally since percentile_cont returns Float64
339            DataType::Int8
340            | DataType::Int16
341            | DataType::Int32
342            | DataType::Int64
343            | DataType::UInt8
344            | DataType::UInt16
345            | DataType::UInt32
346            | DataType::UInt64 => helper!(Float64Type, DataType::Float64),
347            DataType::Float16 => helper!(Float16Type, input_dt),
348            DataType::Float32 => helper!(Float32Type, input_dt),
349            DataType::Float64 => helper!(Float64Type, input_dt),
350            DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
351            DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
352            DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
353            DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
354            _ => Err(DataFusionError::NotImplemented(format!(
355                "PercentileContGroupsAccumulator not supported for {} with {}",
356                args.name, input_dt,
357            ))),
358        }
359    }
360
361    fn supports_null_handling_clause(&self) -> bool {
362        false
363    }
364
365    fn supports_within_group_clause(&self) -> bool {
366        true
367    }
368
369    fn documentation(&self) -> Option<&Documentation> {
370        self.doc()
371    }
372}
373
374/// The percentile_cont accumulator accumulates the raw input values
375/// as native types.
376///
377/// The intermediate state is represented as a List of scalar values updated by
378/// `merge_batch` and a `Vec` of native values that are converted to scalar values
379/// in the final evaluation step so that we avoid expensive conversions and
380/// allocations during `update_batch`.
381struct PercentileContAccumulator<T: ArrowNumericType> {
382    data_type: DataType,
383    all_values: Vec<T::Native>,
384    percentile: f64,
385}
386
387impl<T: ArrowNumericType> Debug for PercentileContAccumulator<T> {
388    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
389        write!(
390            f,
391            "PercentileContAccumulator({}, percentile={})",
392            self.data_type, self.percentile
393        )
394    }
395}
396
397impl<T: ArrowNumericType> Accumulator for PercentileContAccumulator<T> {
398    fn state(&mut self) -> Result<Vec<ScalarValue>> {
399        // Convert `all_values` to `ListArray` and return a single List ScalarValue
400
401        // Build offsets
402        let offsets =
403            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
404
405        // Build inner array
406        let values_array = PrimitiveArray::<T>::new(
407            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
408            None,
409        )
410        .with_data_type(self.data_type.clone());
411
412        // Build the result list array
413        let list_array = ListArray::new(
414            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
415            offsets,
416            Arc::new(values_array),
417            None,
418        );
419
420        Ok(vec![ScalarValue::List(Arc::new(list_array))])
421    }
422
423    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
424        // Cast to target type if needed (e.g., integer to Float64)
425        let values = if values[0].data_type() != &self.data_type {
426            arrow::compute::cast(&values[0], &self.data_type)?
427        } else {
428            Arc::clone(&values[0])
429        };
430
431        let values = values.as_primitive::<T>();
432        self.all_values.reserve(values.len() - values.null_count());
433        self.all_values.extend(values.iter().flatten());
434        Ok(())
435    }
436
437    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
438        let array = states[0].as_list::<i32>();
439        for v in array.iter().flatten() {
440            self.update_batch(&[v])?
441        }
442        Ok(())
443    }
444
445    fn evaluate(&mut self) -> Result<ScalarValue> {
446        let d = std::mem::take(&mut self.all_values);
447        let value = calculate_percentile::<T>(d, self.percentile);
448        ScalarValue::new_primitive::<T>(value, &self.data_type)
449    }
450
451    fn size(&self) -> usize {
452        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
453    }
454}
455
456/// The percentile_cont groups accumulator accumulates the raw input values
457///
458/// For calculating the exact percentile of groups, we need to store all values
459/// of groups before final evaluation.
460/// So values in each group will be stored in a `Vec<T>`, and the total group values
461/// will be actually organized as a `Vec<Vec<T>>`.
462#[derive(Debug)]
463struct PercentileContGroupsAccumulator<T: ArrowNumericType + Send> {
464    data_type: DataType,
465    group_values: Vec<Vec<T::Native>>,
466    percentile: f64,
467}
468
469impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
470    pub fn new(data_type: DataType, percentile: f64) -> Self {
471        Self {
472            data_type,
473            group_values: Vec::new(),
474            percentile,
475        }
476    }
477}
478
479impl<T: ArrowNumericType + Send> GroupsAccumulator
480    for PercentileContGroupsAccumulator<T>
481{
482    fn update_batch(
483        &mut self,
484        values: &[ArrayRef],
485        group_indices: &[usize],
486        opt_filter: Option<&BooleanArray>,
487        total_num_groups: usize,
488    ) -> Result<()> {
489        // For ordered-set aggregates, we only care about the ORDER BY column (first element)
490        // The percentile parameter is already stored in self.percentile
491
492        // Cast to target type if needed (e.g., integer to Float64)
493        let values_array = if values[0].data_type() != &self.data_type {
494            arrow::compute::cast(&values[0], &self.data_type)?
495        } else {
496            Arc::clone(&values[0])
497        };
498
499        let values = values_array.as_primitive::<T>();
500
501        // Push the `not nulls + not filtered` row into its group
502        self.group_values.resize(total_num_groups, Vec::new());
503        accumulate(
504            group_indices,
505            values,
506            opt_filter,
507            |group_index, new_value| {
508                self.group_values[group_index].push(new_value);
509            },
510        );
511
512        Ok(())
513    }
514
515    fn merge_batch(
516        &mut self,
517        values: &[ArrayRef],
518        group_indices: &[usize],
519        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
520        _opt_filter: Option<&BooleanArray>,
521        total_num_groups: usize,
522    ) -> Result<()> {
523        assert_eq!(values.len(), 1, "one argument to merge_batch");
524
525        let input_group_values = values[0].as_list::<i32>();
526
527        // Ensure group values big enough
528        self.group_values.resize(total_num_groups, Vec::new());
529
530        // Extend values to related groups
531        group_indices
532            .iter()
533            .zip(input_group_values.iter())
534            .for_each(|(&group_index, values_opt)| {
535                if let Some(values) = values_opt {
536                    let values = values.as_primitive::<T>();
537                    self.group_values[group_index].extend(values.values().iter());
538                }
539            });
540
541        Ok(())
542    }
543
544    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
545        // Emit values
546        let emit_group_values = emit_to.take_needed(&mut self.group_values);
547
548        // Build offsets
549        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
550        offsets.push(0);
551        let mut cur_len = 0_i32;
552        for group_value in &emit_group_values {
553            cur_len += group_value.len() as i32;
554            offsets.push(cur_len);
555        }
556        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
557
558        // Build inner array
559        let flatten_group_values =
560            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
561        let group_values_array =
562            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
563                .with_data_type(self.data_type.clone());
564
565        // Build the result list array
566        let result_list_array = ListArray::new(
567            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
568            offsets,
569            Arc::new(group_values_array),
570            None,
571        );
572
573        Ok(vec![Arc::new(result_list_array)])
574    }
575
576    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
577        // Emit values
578        let emit_group_values = emit_to.take_needed(&mut self.group_values);
579
580        // Calculate percentile for each group
581        let mut evaluate_result_builder =
582            PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
583        for values in emit_group_values {
584            let value = calculate_percentile::<T>(values, self.percentile);
585            evaluate_result_builder.append_option(value);
586        }
587
588        Ok(Arc::new(evaluate_result_builder.finish()))
589    }
590
591    fn convert_to_state(
592        &self,
593        values: &[ArrayRef],
594        opt_filter: Option<&BooleanArray>,
595    ) -> Result<Vec<ArrayRef>> {
596        assert_eq!(values.len(), 1, "one argument to merge_batch");
597
598        // Cast to target type if needed (e.g., integer to Float64)
599        let values_array = if values[0].data_type() != &self.data_type {
600            arrow::compute::cast(&values[0], &self.data_type)?
601        } else {
602            Arc::clone(&values[0])
603        };
604
605        let input_array = values_array.as_primitive::<T>();
606
607        // Directly convert the input array to states, each row will be
608        // seen as a respective group.
609        // For detail, the `input_array` will be converted to a `ListArray`.
610        // And if row is `not null + not filtered`, it will be converted to a list
611        // with only one element; otherwise, this row in `ListArray` will be set
612        // to null.
613
614        // Reuse values buffer in `input_array` to build `values` in `ListArray`
615        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
616            .with_data_type(self.data_type.clone());
617
618        // `offsets` in `ListArray`, each row as a list element
619        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
620            internal_datafusion_err!(
621                "cast array_len to i32 failed in convert_to_state of group percentile_cont, err:{e:?}"
622            )
623        })?;
624        let offsets = (0..=offset_end).collect::<Vec<_>>();
625        // Safety: The offsets vector is constructed as a sequential range from 0 to input_array.len(),
626        // which guarantees all OffsetBuffer invariants:
627        // 1. Offsets are monotonically increasing (each element is prev + 1)
628        // 2. No offset exceeds the values array length (max offset = input_array.len())
629        // 3. First offset is 0 and last offset equals the total length
630        // Therefore new_unchecked is safe to use here.
631        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
632
633        // `nulls` for converted `ListArray`
634        let nulls = filtered_null_mask(opt_filter, input_array);
635
636        let converted_list_array = ListArray::new(
637            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
638            offsets,
639            Arc::new(values),
640            nulls,
641        );
642
643        Ok(vec![Arc::new(converted_list_array)])
644    }
645
646    fn supports_convert_to_state(&self) -> bool {
647        true
648    }
649
650    fn size(&self) -> usize {
651        self.group_values
652            .iter()
653            .map(|values| values.capacity() * size_of::<T::Native>())
654            .sum::<usize>()
655            // account for size of self.group_values too
656            + self.group_values.capacity() * size_of::<Vec<T::Native>>()
657    }
658}
659
660/// The distinct percentile_cont accumulator accumulates the raw input values
661/// using a HashSet to eliminate duplicates.
662///
663/// The intermediate state is represented as a List of scalar values updated by
664/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
665/// in the final evaluation step so that we avoid expensive conversions and
666/// allocations during `update_batch`.
667struct DistinctPercentileContAccumulator<T: ArrowNumericType> {
668    data_type: DataType,
669    distinct_values: HashSet<Hashable<T::Native>>,
670    percentile: f64,
671}
672
673impl<T: ArrowNumericType> Debug for DistinctPercentileContAccumulator<T> {
674    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
675        write!(
676            f,
677            "DistinctPercentileContAccumulator({}, percentile={})",
678            self.data_type, self.percentile
679        )
680    }
681}
682
683impl<T: ArrowNumericType> Accumulator for DistinctPercentileContAccumulator<T> {
684    fn state(&mut self) -> Result<Vec<ScalarValue>> {
685        let all_values = self
686            .distinct_values
687            .iter()
688            .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
689            .collect::<Result<Vec<_>>>()?;
690
691        let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
692        Ok(vec![ScalarValue::List(arr)])
693    }
694
695    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
696        if values.is_empty() {
697            return Ok(());
698        }
699
700        // Cast to target type if needed (e.g., integer to Float64)
701        let values = if values[0].data_type() != &self.data_type {
702            arrow::compute::cast(&values[0], &self.data_type)?
703        } else {
704            Arc::clone(&values[0])
705        };
706
707        let array = values.as_primitive::<T>();
708        match array.nulls().filter(|x| x.null_count() > 0) {
709            Some(n) => {
710                for idx in n.valid_indices() {
711                    self.distinct_values.insert(Hashable(array.value(idx)));
712                }
713            }
714            None => array.values().iter().for_each(|x| {
715                self.distinct_values.insert(Hashable(*x));
716            }),
717        }
718        Ok(())
719    }
720
721    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
722        let array = states[0].as_list::<i32>();
723        for v in array.iter().flatten() {
724            self.update_batch(&[v])?
725        }
726        Ok(())
727    }
728
729    fn evaluate(&mut self) -> Result<ScalarValue> {
730        let d = std::mem::take(&mut self.distinct_values)
731            .into_iter()
732            .map(|v| v.0)
733            .collect::<Vec<_>>();
734        let value = calculate_percentile::<T>(d, self.percentile);
735        ScalarValue::new_primitive::<T>(value, &self.data_type)
736    }
737
738    fn size(&self) -> usize {
739        size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
740    }
741}
742
743/// Calculate the percentile value for a given set of values.
744/// This function performs an exact calculation by sorting all values.
745///
746/// The percentile is calculated using linear interpolation between closest ranks.
747/// For percentile p and n values:
748/// - If p * (n-1) is an integer, return the value at that position
749/// - Otherwise, interpolate between the two closest values
750fn calculate_percentile<T: ArrowNumericType>(
751    mut values: Vec<T::Native>,
752    percentile: f64,
753) -> Option<T::Native> {
754    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
755
756    let len = values.len();
757    if len == 0 {
758        None
759    } else if len == 1 {
760        Some(values[0])
761    } else if percentile == 0.0 {
762        // Get minimum value
763        Some(
764            *values
765                .iter()
766                .min_by(|a, b| cmp(a, b))
767                .expect("we checked for len > 0 a few lines above"),
768        )
769    } else if percentile == 1.0 {
770        // Get maximum value
771        Some(
772            *values
773                .iter()
774                .max_by(|a, b| cmp(a, b))
775                .expect("we checked for len > 0 a few lines above"),
776        )
777    } else {
778        // Calculate the index using the formula: p * (n - 1)
779        let index = percentile * ((len - 1) as f64);
780        let lower_index = index.floor() as usize;
781        let upper_index = index.ceil() as usize;
782
783        if lower_index == upper_index {
784            // Exact index, return the value at that position
785            let (_, value, _) = values.select_nth_unstable_by(lower_index, cmp);
786            Some(*value)
787        } else {
788            // Need to interpolate between two values
789            // First, partition at lower_index to get the lower value
790            let (_, lower_value, _) = values.select_nth_unstable_by(lower_index, cmp);
791            let lower_value = *lower_value;
792
793            // Then partition at upper_index to get the upper value
794            let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
795            let upper_value = *upper_value;
796
797            // Linear interpolation using wrapping arithmetic
798            // We use wrapping operations here (matching the approach in median.rs) because:
799            // 1. Both values come from the input data, so diff is bounded by the value range
800            // 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough
801            //    to prevent overflow when combined with typical numeric ranges
802            // 3. The result is guaranteed to be between lower_value and upper_value
803            // 4. For floating-point types, wrapping ops behave the same as standard ops
804            let fraction = index - (lower_index as f64);
805            let diff = upper_value.sub_wrapping(lower_value);
806            let interpolated = lower_value.add_wrapping(
807                diff.mul_wrapping(T::Native::usize_as(
808                    (fraction * INTERPOLATION_PRECISION as f64) as usize,
809                ))
810                .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)),
811            );
812            Some(interpolated)
813        }
814    }
815}