datafusion_functions_aggregate/
approx_percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Debug;
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::{Array, Float16Array};
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27    array::{
28        ArrayRef, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
29        Int64Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array,
30    },
31    datatypes::{DataType, Field},
32};
33use datafusion_common::{
34    DataFusionError, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
35    plan_err,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
43    Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{DEFAULT_MAX_SIZE, TDigest};
46use datafusion_macros::user_doc;
47use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
48
49use crate::utils::{get_scalar_value, validate_percentile_expr};
50
51create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
52
53/// Computes the approximate percentile continuous of a set of numbers
54pub fn approx_percentile_cont(
55    order_by: Sort,
56    percentile: Expr,
57    centroids: Option<Expr>,
58) -> Expr {
59    let expr = order_by.expr.clone();
60
61    let args = if let Some(centroids) = centroids {
62        vec![expr, percentile, centroids]
63    } else {
64        vec![expr, percentile]
65    };
66
67    Expr::AggregateFunction(AggregateFunction::new_udf(
68        approx_percentile_cont_udaf(),
69        args,
70        false,
71        None,
72        vec![order_by],
73        None,
74    ))
75}
76
77#[user_doc(
78    doc_section(label = "Approximate Functions"),
79    description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80    syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
81    sql_example = r#"```sql
82> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+------------------------------------------------------------------+
84| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
85+------------------------------------------------------------------+
86| 65.0                                                             |
87+------------------------------------------------------------------+
88> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
89+-----------------------------------------------------------------------+
90| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
91+-----------------------------------------------------------------------+
92| 65.0                                                                  |
93+-----------------------------------------------------------------------+
94```
95An alternate syntax is also supported:
96```sql
97> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
98+-----------------------------------------------+
99| approx_percentile_cont(column_name, 0.75)     |
100+-----------------------------------------------+
101| 65.0                                          |
102+-----------------------------------------------+
103
104> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
105+----------------------------------------------------------+
106| approx_percentile_cont(column_name, 0.75, 100)           |
107+----------------------------------------------------------+
108| 65.0                                                     |
109+----------------------------------------------------------+
110```
111"#,
112    standard_argument(name = "expression",),
113    argument(
114        name = "percentile",
115        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
116    ),
117    argument(
118        name = "centroids",
119        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
120    )
121)]
122#[derive(Debug, PartialEq, Eq, Hash)]
123pub struct ApproxPercentileCont {
124    signature: Signature,
125}
126
127impl Default for ApproxPercentileCont {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl ApproxPercentileCont {
134    /// Create a new [`ApproxPercentileCont`] aggregate function.
135    pub fn new() -> Self {
136        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
137        // Accept any numeric value paired with a float64 percentile
138        for num in NUMERICS {
139            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
140            // Additionally accept an integer number of centroids for T-Digest
141            for int in INTEGERS {
142                variants.push(TypeSignature::Exact(vec![
143                    num.clone(),
144                    DataType::Float64,
145                    int.clone(),
146                ]))
147            }
148        }
149        Self {
150            signature: Signature::one_of(variants, Volatility::Immutable),
151        }
152    }
153
154    pub(crate) fn create_accumulator(
155        &self,
156        args: &AccumulatorArgs,
157    ) -> Result<ApproxPercentileAccumulator> {
158        let percentile =
159            validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;
160
161        let is_descending = args
162            .order_bys
163            .first()
164            .map(|sort_expr| sort_expr.options.descending)
165            .unwrap_or(false);
166
167        let percentile = if is_descending {
168            1.0 - percentile
169        } else {
170            percentile
171        };
172
173        let tdigest_max_size = if args.exprs.len() == 3 {
174            Some(validate_input_max_size_expr(&args.exprs[2])?)
175        } else {
176            None
177        };
178
179        let data_type = args.expr_fields[0].data_type();
180        let accumulator: ApproxPercentileAccumulator = match data_type {
181            DataType::UInt8
182            | DataType::UInt16
183            | DataType::UInt32
184            | DataType::UInt64
185            | DataType::Int8
186            | DataType::Int16
187            | DataType::Int32
188            | DataType::Int64
189            | DataType::Float16
190            | DataType::Float32
191            | DataType::Float64 => {
192                if let Some(max_size) = tdigest_max_size {
193                    ApproxPercentileAccumulator::new_with_max_size(
194                        percentile,
195                        data_type.clone(),
196                        max_size,
197                    )
198                } else {
199                    ApproxPercentileAccumulator::new(percentile, data_type.clone())
200                }
201            }
202            other => {
203                return not_impl_err!(
204                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
205                );
206            }
207        };
208
209        Ok(accumulator)
210    }
211}
212
213fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
214    let scalar_value = get_scalar_value(expr).map_err(|_e| {
215        DataFusionError::Plan(
216            "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
217                .to_string(),
218        )
219    })?;
220
221    let max_size = match scalar_value {
222        ScalarValue::UInt8(Some(q)) => q as usize,
223        ScalarValue::UInt16(Some(q)) => q as usize,
224        ScalarValue::UInt32(Some(q)) => q as usize,
225        ScalarValue::UInt64(Some(q)) => q as usize,
226        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
227        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
228        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
229        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
230        sv => {
231            return plan_err!(
232                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
233                sv.data_type()
234            );
235        }
236    };
237
238    Ok(max_size)
239}
240
241impl AggregateUDFImpl for ApproxPercentileCont {
242    fn as_any(&self) -> &dyn Any {
243        self
244    }
245
246    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
247    /// state.
248    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
249        Ok(vec![
250            Field::new(
251                format_state_name(args.name, "max_size"),
252                DataType::UInt64,
253                false,
254            ),
255            Field::new(
256                format_state_name(args.name, "sum"),
257                DataType::Float64,
258                false,
259            ),
260            Field::new(
261                format_state_name(args.name, "count"),
262                DataType::UInt64,
263                false,
264            ),
265            Field::new(
266                format_state_name(args.name, "max"),
267                DataType::Float64,
268                false,
269            ),
270            Field::new(
271                format_state_name(args.name, "min"),
272                DataType::Float64,
273                false,
274            ),
275            Field::new_list(
276                format_state_name(args.name, "centroids"),
277                Field::new_list_field(DataType::Float64, true),
278                false,
279            ),
280        ]
281        .into_iter()
282        .map(Arc::new)
283        .collect())
284    }
285
286    fn name(&self) -> &str {
287        "approx_percentile_cont"
288    }
289
290    fn signature(&self) -> &Signature {
291        &self.signature
292    }
293
294    #[inline]
295    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
296        Ok(Box::new(self.create_accumulator(&acc_args)?))
297    }
298
299    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
300        if !arg_types[0].is_numeric() {
301            return plan_err!("approx_percentile_cont requires numeric input types");
302        }
303        if arg_types.len() == 3 && !arg_types[2].is_integer() {
304            return plan_err!(
305                "approx_percentile_cont requires integer centroids input types"
306            );
307        }
308        Ok(arg_types[0].clone())
309    }
310
311    fn supports_within_group_clause(&self) -> bool {
312        true
313    }
314
315    fn documentation(&self) -> Option<&Documentation> {
316        self.doc()
317    }
318}
319
320#[derive(Debug)]
321pub struct ApproxPercentileAccumulator {
322    digest: TDigest,
323    percentile: f64,
324    return_type: DataType,
325}
326
327impl ApproxPercentileAccumulator {
328    pub fn new(percentile: f64, return_type: DataType) -> Self {
329        Self {
330            digest: TDigest::new(DEFAULT_MAX_SIZE),
331            percentile,
332            return_type,
333        }
334    }
335
336    pub fn new_with_max_size(
337        percentile: f64,
338        return_type: DataType,
339        max_size: usize,
340    ) -> Self {
341        Self {
342            digest: TDigest::new(max_size),
343            percentile,
344            return_type,
345        }
346    }
347
348    // pub(crate) for approx_percentile_cont_with_weight
349    pub(crate) fn max_size(&self) -> usize {
350        self.digest.max_size()
351    }
352
353    // pub(crate) for approx_percentile_cont_with_weight
354    pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
355        let digests = digests.iter().chain(std::iter::once(&self.digest));
356        self.digest = TDigest::merge_digests(digests)
357    }
358
359    // pub(crate) for approx_percentile_cont_with_weight
360    pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
361        debug_assert!(
362            values.null_count() == 0,
363            "convert_to_float assumes nulls have already been filtered out"
364        );
365        match values.data_type() {
366            DataType::Float64 => {
367                let array = downcast_value!(values, Float64Array);
368                Ok(array.values().iter().copied().collect::<Vec<_>>())
369            }
370            DataType::Float32 => {
371                let array = downcast_value!(values, Float32Array);
372                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
373            }
374            DataType::Float16 => {
375                let array = downcast_value!(values, Float16Array);
376                Ok(array
377                    .values()
378                    .iter()
379                    .map(|v| v.to_f64())
380                    .collect::<Vec<_>>())
381            }
382            DataType::Int64 => {
383                let array = downcast_value!(values, Int64Array);
384                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
385            }
386            DataType::Int32 => {
387                let array = downcast_value!(values, Int32Array);
388                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
389            }
390            DataType::Int16 => {
391                let array = downcast_value!(values, Int16Array);
392                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
393            }
394            DataType::Int8 => {
395                let array = downcast_value!(values, Int8Array);
396                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
397            }
398            DataType::UInt64 => {
399                let array = downcast_value!(values, UInt64Array);
400                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
401            }
402            DataType::UInt32 => {
403                let array = downcast_value!(values, UInt32Array);
404                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
405            }
406            DataType::UInt16 => {
407                let array = downcast_value!(values, UInt16Array);
408                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
409            }
410            DataType::UInt8 => {
411                let array = downcast_value!(values, UInt8Array);
412                Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
413            }
414            e => internal_err!(
415                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
416            ),
417        }
418    }
419}
420
421impl Accumulator for ApproxPercentileAccumulator {
422    fn state(&mut self) -> Result<Vec<ScalarValue>> {
423        Ok(self.digest.to_scalar_state().into_iter().collect())
424    }
425
426    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
427        // Remove any nulls before computing the percentile
428        let mut values = Arc::clone(&values[0]);
429        if values.null_count() > 0 {
430            values = filter(&values, &is_not_null(&values)?)?;
431        }
432        let sorted_values = &arrow::compute::sort(&values, None)?;
433        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
434        self.digest = self.digest.merge_sorted_f64(&sorted_values);
435        Ok(())
436    }
437
438    fn evaluate(&mut self) -> Result<ScalarValue> {
439        if self.digest.count() == 0 {
440            return ScalarValue::try_from(self.return_type.clone());
441        }
442        let q = self.digest.estimate_quantile(self.percentile);
443
444        // These acceptable return types MUST match the validation in
445        // ApproxPercentile::create_accumulator.
446        Ok(match &self.return_type {
447            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
448            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
449            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
450            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
451            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
452            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
453            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
454            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
455            DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
456            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
457            DataType::Float64 => ScalarValue::Float64(Some(q)),
458            v => unreachable!("unexpected return type {}", v),
459        })
460    }
461
462    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
463        if states.is_empty() {
464            return Ok(());
465        }
466
467        let states = (0..states[0].len())
468            .map(|index| {
469                states
470                    .iter()
471                    .map(|array| ScalarValue::try_from_array(array, index))
472                    .collect::<Result<Vec<_>>>()
473                    .map(|state| TDigest::from_scalar_state(&state))
474            })
475            .collect::<Result<Vec<_>>>()?;
476
477        self.merge_digests(&states);
478
479        Ok(())
480    }
481
482    fn size(&self) -> usize {
483        size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
484            + self.return_type.size()
485            - size_of_val(&self.return_type)
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use arrow::datatypes::DataType;
492
493    use datafusion_functions_aggregate_common::tdigest::TDigest;
494
495    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
496
497    #[test]
498    fn test_combine_approx_percentile_accumulator() {
499        let mut digests: Vec<TDigest> = Vec::new();
500
501        // one TDigest with 50_000 values from 1 to 1_000
502        for _ in 1..=50 {
503            let t = TDigest::new(100);
504            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
505            let t = t.merge_unsorted_f64(values);
506            digests.push(t)
507        }
508
509        let t1 = TDigest::merge_digests(&digests);
510        let t2 = TDigest::merge_digests(&digests);
511
512        let mut accumulator =
513            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
514
515        accumulator.merge_digests(&[t1]);
516        assert_eq!(accumulator.digest.count(), 50_000);
517        accumulator.merge_digests(&[t2]);
518        assert_eq!(accumulator.digest.count(), 100_000);
519    }
520}