datafusion_functions_aggregate/
approx_percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, RecordBatch};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27    array::{
28        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30    },
31    datatypes::{DataType, Field, Schema},
32};
33use datafusion_common::{
34    downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35    Result, ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
43    TypeSignature, Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53/// Computes the approximate percentile continuous of a set of numbers
54pub fn approx_percentile_cont(
55    order_by: Sort,
56    percentile: Expr,
57    centroids: Option<Expr>,
58) -> Expr {
59    let expr = order_by.expr.clone();
60
61    let args = if let Some(centroids) = centroids {
62        vec![expr, percentile, centroids]
63    } else {
64        vec![expr, percentile]
65    };
66
67    Expr::AggregateFunction(AggregateFunction::new_udf(
68        approx_percentile_cont_udaf(),
69        args,
70        false,
71        None,
72        vec![order_by],
73        None,
74    ))
75}
76
77#[user_doc(
78    doc_section(label = "Approximate Functions"),
79    description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80    syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
81    sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+------------------------------------------------------------------+
84| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
85+------------------------------------------------------------------+
86| 65.0                                                             |
87+------------------------------------------------------------------+
88> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
89+-----------------------------------------------------------------------+
90| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
91+-----------------------------------------------------------------------+
92| 65.0                                                                  |
93+-----------------------------------------------------------------------+
94```
95An alternate syntax is also supported:
96```sql
97> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
98+-----------------------------------------------+
99| approx_percentile_cont(column_name, 0.75)     |
100+-----------------------------------------------+
101| 65.0                                          |
102+-----------------------------------------------+
103
104> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
105+----------------------------------------------------------+
106| approx_percentile_cont(column_name, 0.75, 100)           |
107+----------------------------------------------------------+
108| 65.0                                                     |
109+----------------------------------------------------------+
110```
111"#,
112    standard_argument(name = "expression",),
113    argument(
114        name = "percentile",
115        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
116    ),
117    argument(
118        name = "centroids",
119        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
120    )
121)]
122#[derive(PartialEq, Eq, Hash)]
123pub struct ApproxPercentileCont {
124    signature: Signature,
125}
126
127impl Debug for ApproxPercentileCont {
128    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
129        f.debug_struct("ApproxPercentileCont")
130            .field("name", &self.name())
131            .field("signature", &self.signature)
132            .finish()
133    }
134}
135
136impl Default for ApproxPercentileCont {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl ApproxPercentileCont {
143    /// Create a new [`ApproxPercentileCont`] aggregate function.
144    pub fn new() -> Self {
145        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
146        // Accept any numeric value paired with a float64 percentile
147        for num in NUMERICS {
148            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
149            // Additionally accept an integer number of centroids for T-Digest
150            for int in INTEGERS {
151                variants.push(TypeSignature::Exact(vec![
152                    num.clone(),
153                    DataType::Float64,
154                    int.clone(),
155                ]))
156            }
157        }
158        Self {
159            signature: Signature::one_of(variants, Volatility::Immutable),
160        }
161    }
162
163    pub(crate) fn create_accumulator(
164        &self,
165        args: AccumulatorArgs,
166    ) -> Result<ApproxPercentileAccumulator> {
167        let percentile = validate_input_percentile_expr(&args.exprs[1])?;
168
169        let is_descending = args
170            .order_bys
171            .first()
172            .map(|sort_expr| sort_expr.options.descending)
173            .unwrap_or(false);
174
175        let percentile = if is_descending {
176            1.0 - percentile
177        } else {
178            percentile
179        };
180
181        let tdigest_max_size = if args.exprs.len() == 3 {
182            Some(validate_input_max_size_expr(&args.exprs[2])?)
183        } else {
184            None
185        };
186
187        let data_type = args.exprs[0].data_type(args.schema)?;
188        let accumulator: ApproxPercentileAccumulator = match data_type {
189            t @ (DataType::UInt8
190            | DataType::UInt16
191            | DataType::UInt32
192            | DataType::UInt64
193            | DataType::Int8
194            | DataType::Int16
195            | DataType::Int32
196            | DataType::Int64
197            | DataType::Float32
198            | DataType::Float64) => {
199                if let Some(max_size) = tdigest_max_size {
200                    ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
201                }else{
202                    ApproxPercentileAccumulator::new(percentile, t)
203
204                }
205            }
206            other => {
207                return not_impl_err!(
208                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
209                )
210            }
211        };
212
213        Ok(accumulator)
214    }
215}
216
217fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
218    let empty_schema = Arc::new(Schema::empty());
219    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
220    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
221        Ok(s)
222    } else {
223        internal_err!("Didn't expect ColumnarValue::Array")
224    }
225}
226
227fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
228    let percentile = match get_scalar_value(expr)
229        .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
230        ScalarValue::Float32(Some(value)) => {
231            value as f64
232        }
233        ScalarValue::Float64(Some(value)) => {
234            value
235        }
236        sv => {
237            return not_impl_err!(
238                "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
239                sv.data_type()
240            )
241        }
242    };
243
244    // Ensure the percentile is between 0 and 1.
245    if !(0.0..=1.0).contains(&percentile) {
246        return plan_err!(
247            "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
248        );
249    }
250    Ok(percentile)
251}
252
253fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
254    let max_size = match get_scalar_value(expr)
255        .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
256        ScalarValue::UInt8(Some(q)) => q as usize,
257        ScalarValue::UInt16(Some(q)) => q as usize,
258        ScalarValue::UInt32(Some(q)) => q as usize,
259        ScalarValue::UInt64(Some(q)) => q as usize,
260        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
261        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
262        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
263        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
264        sv => {
265            return not_impl_err!(
266                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
267                sv.data_type()
268            )
269        },
270    };
271
272    Ok(max_size)
273}
274
275impl AggregateUDFImpl for ApproxPercentileCont {
276    fn as_any(&self) -> &dyn Any {
277        self
278    }
279
280    #[allow(rustdoc::private_intra_doc_links)]
281    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
282    /// state.
283    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
284        Ok(vec![
285            Field::new(
286                format_state_name(args.name, "max_size"),
287                DataType::UInt64,
288                false,
289            ),
290            Field::new(
291                format_state_name(args.name, "sum"),
292                DataType::Float64,
293                false,
294            ),
295            Field::new(
296                format_state_name(args.name, "count"),
297                DataType::UInt64,
298                false,
299            ),
300            Field::new(
301                format_state_name(args.name, "max"),
302                DataType::Float64,
303                false,
304            ),
305            Field::new(
306                format_state_name(args.name, "min"),
307                DataType::Float64,
308                false,
309            ),
310            Field::new_list(
311                format_state_name(args.name, "centroids"),
312                Field::new_list_field(DataType::Float64, true),
313                false,
314            ),
315        ]
316        .into_iter()
317        .map(Arc::new)
318        .collect())
319    }
320
321    fn name(&self) -> &str {
322        "approx_percentile_cont"
323    }
324
325    fn signature(&self) -> &Signature {
326        &self.signature
327    }
328
329    #[inline]
330    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
331        Ok(Box::new(self.create_accumulator(acc_args)?))
332    }
333
334    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
335        if !arg_types[0].is_numeric() {
336            return plan_err!("approx_percentile_cont requires numeric input types");
337        }
338        if arg_types.len() == 3 && !arg_types[2].is_integer() {
339            return plan_err!(
340                "approx_percentile_cont requires integer centroids input types"
341            );
342        }
343        Ok(arg_types[0].clone())
344    }
345
346    fn supports_null_handling_clause(&self) -> bool {
347        false
348    }
349
350    fn is_ordered_set_aggregate(&self) -> bool {
351        true
352    }
353
354    fn documentation(&self) -> Option<&Documentation> {
355        self.doc()
356    }
357}
358
359#[derive(Debug)]
360pub struct ApproxPercentileAccumulator {
361    digest: TDigest,
362    percentile: f64,
363    return_type: DataType,
364}
365
366impl ApproxPercentileAccumulator {
367    pub fn new(percentile: f64, return_type: DataType) -> Self {
368        Self {
369            digest: TDigest::new(DEFAULT_MAX_SIZE),
370            percentile,
371            return_type,
372        }
373    }
374
375    pub fn new_with_max_size(
376        percentile: f64,
377        return_type: DataType,
378        max_size: usize,
379    ) -> Self {
380        Self {
381            digest: TDigest::new(max_size),
382            percentile,
383            return_type,
384        }
385    }
386
387    // public for approx_percentile_cont_with_weight
388    pub(crate) fn max_size(&self) -> usize {
389        self.digest.max_size()
390    }
391
392    // public for approx_percentile_cont_with_weight
393    pub fn merge_digests(&mut self, digests: &[TDigest]) {
394        let digests = digests.iter().chain(std::iter::once(&self.digest));
395        self.digest = TDigest::merge_digests(digests)
396    }
397
398    // public for approx_percentile_cont_with_weight
399    pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
400        match values.data_type() {
401            DataType::Float64 => {
402                let array = downcast_value!(values, Float64Array);
403                Ok(array
404                    .values()
405                    .iter()
406                    .filter_map(|v| v.try_as_f64().transpose())
407                    .collect::<Result<Vec<_>>>()?)
408            }
409            DataType::Float32 => {
410                let array = downcast_value!(values, Float32Array);
411                Ok(array
412                    .values()
413                    .iter()
414                    .filter_map(|v| v.try_as_f64().transpose())
415                    .collect::<Result<Vec<_>>>()?)
416            }
417            DataType::Int64 => {
418                let array = downcast_value!(values, Int64Array);
419                Ok(array
420                    .values()
421                    .iter()
422                    .filter_map(|v| v.try_as_f64().transpose())
423                    .collect::<Result<Vec<_>>>()?)
424            }
425            DataType::Int32 => {
426                let array = downcast_value!(values, Int32Array);
427                Ok(array
428                    .values()
429                    .iter()
430                    .filter_map(|v| v.try_as_f64().transpose())
431                    .collect::<Result<Vec<_>>>()?)
432            }
433            DataType::Int16 => {
434                let array = downcast_value!(values, Int16Array);
435                Ok(array
436                    .values()
437                    .iter()
438                    .filter_map(|v| v.try_as_f64().transpose())
439                    .collect::<Result<Vec<_>>>()?)
440            }
441            DataType::Int8 => {
442                let array = downcast_value!(values, Int8Array);
443                Ok(array
444                    .values()
445                    .iter()
446                    .filter_map(|v| v.try_as_f64().transpose())
447                    .collect::<Result<Vec<_>>>()?)
448            }
449            DataType::UInt64 => {
450                let array = downcast_value!(values, UInt64Array);
451                Ok(array
452                    .values()
453                    .iter()
454                    .filter_map(|v| v.try_as_f64().transpose())
455                    .collect::<Result<Vec<_>>>()?)
456            }
457            DataType::UInt32 => {
458                let array = downcast_value!(values, UInt32Array);
459                Ok(array
460                    .values()
461                    .iter()
462                    .filter_map(|v| v.try_as_f64().transpose())
463                    .collect::<Result<Vec<_>>>()?)
464            }
465            DataType::UInt16 => {
466                let array = downcast_value!(values, UInt16Array);
467                Ok(array
468                    .values()
469                    .iter()
470                    .filter_map(|v| v.try_as_f64().transpose())
471                    .collect::<Result<Vec<_>>>()?)
472            }
473            DataType::UInt8 => {
474                let array = downcast_value!(values, UInt8Array);
475                Ok(array
476                    .values()
477                    .iter()
478                    .filter_map(|v| v.try_as_f64().transpose())
479                    .collect::<Result<Vec<_>>>()?)
480            }
481            e => internal_err!(
482                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
483            ),
484        }
485    }
486}
487
488impl Accumulator for ApproxPercentileAccumulator {
489    fn state(&mut self) -> Result<Vec<ScalarValue>> {
490        Ok(self.digest.to_scalar_state().into_iter().collect())
491    }
492
493    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494        // Remove any nulls before computing the percentile
495        let mut values = Arc::clone(&values[0]);
496        if values.nulls().is_some() {
497            values = filter(&values, &is_not_null(&values)?)?;
498        }
499        let sorted_values = &arrow::compute::sort(&values, None)?;
500        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
501        self.digest = self.digest.merge_sorted_f64(&sorted_values);
502        Ok(())
503    }
504
505    fn evaluate(&mut self) -> Result<ScalarValue> {
506        if self.digest.count() == 0 {
507            return ScalarValue::try_from(self.return_type.clone());
508        }
509        let q = self.digest.estimate_quantile(self.percentile);
510
511        // These acceptable return types MUST match the validation in
512        // ApproxPercentile::create_accumulator.
513        Ok(match &self.return_type {
514            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
515            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
516            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
517            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
518            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
519            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
520            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
521            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
522            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
523            DataType::Float64 => ScalarValue::Float64(Some(q)),
524            v => unreachable!("unexpected return type {:?}", v),
525        })
526    }
527
528    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
529        if states.is_empty() {
530            return Ok(());
531        }
532
533        let states = (0..states[0].len())
534            .map(|index| {
535                states
536                    .iter()
537                    .map(|array| ScalarValue::try_from_array(array, index))
538                    .collect::<Result<Vec<_>>>()
539                    .map(|state| TDigest::from_scalar_state(&state))
540            })
541            .collect::<Result<Vec<_>>>()?;
542
543        self.merge_digests(&states);
544
545        Ok(())
546    }
547
548    fn size(&self) -> usize {
549        size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
550            + self.return_type.size()
551            - size_of_val(&self.return_type)
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use arrow::datatypes::DataType;
558
559    use datafusion_functions_aggregate_common::tdigest::TDigest;
560
561    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
562
563    #[test]
564    fn test_combine_approx_percentile_accumulator() {
565        let mut digests: Vec<TDigest> = Vec::new();
566
567        // one TDigest with 50_000 values from 1 to 1_000
568        for _ in 1..=50 {
569            let t = TDigest::new(100);
570            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
571            let t = t.merge_unsorted_f64(values);
572            digests.push(t)
573        }
574
575        let t1 = TDigest::merge_digests(&digests);
576        let t2 = TDigest::merge_digests(&digests);
577
578        let mut accumulator =
579            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
580
581        accumulator.merge_digests(&[t1]);
582        assert_eq!(accumulator.digest.count(), 50_000);
583        accumulator.merge_digests(&[t2]);
584        assert_eq!(accumulator.digest.count(), 100_000);
585    }
586}