datafusion_functions_aggregate/
approx_percentile_cont.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::mem::size_of_val;
21use std::sync::Arc;
22
23use arrow::array::Array;
24use arrow::compute::{filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{
27    array::{
28        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
29        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
30    },
31    datatypes::{DataType, Field},
32};
33use datafusion_common::{
34    downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result,
35    ScalarValue,
36};
37use datafusion_expr::expr::{AggregateFunction, Sort};
38use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
39use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
40use datafusion_expr::utils::format_state_name;
41use datafusion_expr::{
42    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
43    Volatility,
44};
45use datafusion_functions_aggregate_common::tdigest::{
46    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47};
48use datafusion_macros::user_doc;
49use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
50
51use crate::utils::{get_scalar_value, validate_percentile_expr};
52
53create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
54
55/// Computes the approximate percentile continuous of a set of numbers
56pub fn approx_percentile_cont(
57    order_by: Sort,
58    percentile: Expr,
59    centroids: Option<Expr>,
60) -> Expr {
61    let expr = order_by.expr.clone();
62
63    let args = if let Some(centroids) = centroids {
64        vec![expr, percentile, centroids]
65    } else {
66        vec![expr, percentile]
67    };
68
69    Expr::AggregateFunction(AggregateFunction::new_udf(
70        approx_percentile_cont_udaf(),
71        args,
72        false,
73        None,
74        vec![order_by],
75        None,
76    ))
77}
78
79#[user_doc(
80    doc_section(label = "Approximate Functions"),
81    description = "Returns the approximate percentile of input values using the t-digest algorithm.",
82    syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
83    sql_example = r#"```sql
84> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
85+------------------------------------------------------------------+
86| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
87+------------------------------------------------------------------+
88| 65.0                                                             |
89+------------------------------------------------------------------+
90> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
91+-----------------------------------------------------------------------+
92| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
93+-----------------------------------------------------------------------+
94| 65.0                                                                  |
95+-----------------------------------------------------------------------+
96```
97An alternate syntax is also supported:
98```sql
99> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name;
100+-----------------------------------------------+
101| approx_percentile_cont(column_name, 0.75)     |
102+-----------------------------------------------+
103| 65.0                                          |
104+-----------------------------------------------+
105
106> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
107+----------------------------------------------------------+
108| approx_percentile_cont(column_name, 0.75, 100)           |
109+----------------------------------------------------------+
110| 65.0                                                     |
111+----------------------------------------------------------+
112```
113"#,
114    standard_argument(name = "expression",),
115    argument(
116        name = "percentile",
117        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
118    ),
119    argument(
120        name = "centroids",
121        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
122    )
123)]
124#[derive(PartialEq, Eq, Hash)]
125pub struct ApproxPercentileCont {
126    signature: Signature,
127}
128
129impl Debug for ApproxPercentileCont {
130    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
131        f.debug_struct("ApproxPercentileCont")
132            .field("name", &self.name())
133            .field("signature", &self.signature)
134            .finish()
135    }
136}
137
138impl Default for ApproxPercentileCont {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl ApproxPercentileCont {
145    /// Create a new [`ApproxPercentileCont`] aggregate function.
146    pub fn new() -> Self {
147        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
148        // Accept any numeric value paired with a float64 percentile
149        for num in NUMERICS {
150            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
151            // Additionally accept an integer number of centroids for T-Digest
152            for int in INTEGERS {
153                variants.push(TypeSignature::Exact(vec![
154                    num.clone(),
155                    DataType::Float64,
156                    int.clone(),
157                ]))
158            }
159        }
160        Self {
161            signature: Signature::one_of(variants, Volatility::Immutable),
162        }
163    }
164
165    pub(crate) fn create_accumulator(
166        &self,
167        args: AccumulatorArgs,
168    ) -> Result<ApproxPercentileAccumulator> {
169        let percentile =
170            validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;
171
172        let is_descending = args
173            .order_bys
174            .first()
175            .map(|sort_expr| sort_expr.options.descending)
176            .unwrap_or(false);
177
178        let percentile = if is_descending {
179            1.0 - percentile
180        } else {
181            percentile
182        };
183
184        let tdigest_max_size = if args.exprs.len() == 3 {
185            Some(validate_input_max_size_expr(&args.exprs[2])?)
186        } else {
187            None
188        };
189
190        let data_type = args.expr_fields[0].data_type();
191        let accumulator: ApproxPercentileAccumulator = match data_type {
192            DataType::UInt8
193            | DataType::UInt16
194            | DataType::UInt32
195            | DataType::UInt64
196            | DataType::Int8
197            | DataType::Int16
198            | DataType::Int32
199            | DataType::Int64
200            | DataType::Float32
201            | DataType::Float64 => {
202                if let Some(max_size) = tdigest_max_size {
203                    ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
204                } else {
205                    ApproxPercentileAccumulator::new(percentile, data_type.clone())
206                }
207            }
208            other => {
209                return not_impl_err!(
210                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
211                )
212            }
213        };
214
215        Ok(accumulator)
216    }
217}
218
219fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
220    let scalar_value = get_scalar_value(expr).map_err(|_e| {
221        DataFusionError::Plan(
222            "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
223                .to_string(),
224        )
225    })?;
226
227    let max_size = match scalar_value {
228        ScalarValue::UInt8(Some(q)) => q as usize,
229        ScalarValue::UInt16(Some(q)) => q as usize,
230        ScalarValue::UInt32(Some(q)) => q as usize,
231        ScalarValue::UInt64(Some(q)) => q as usize,
232        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
233        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
234        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
235        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
236        sv => {
237            return plan_err!(
238                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
239                sv.data_type()
240            )
241        },
242    };
243
244    Ok(max_size)
245}
246
247impl AggregateUDFImpl for ApproxPercentileCont {
248    fn as_any(&self) -> &dyn Any {
249        self
250    }
251
252    #[allow(rustdoc::private_intra_doc_links)]
253    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
254    /// state.
255    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
256        Ok(vec![
257            Field::new(
258                format_state_name(args.name, "max_size"),
259                DataType::UInt64,
260                false,
261            ),
262            Field::new(
263                format_state_name(args.name, "sum"),
264                DataType::Float64,
265                false,
266            ),
267            Field::new(
268                format_state_name(args.name, "count"),
269                DataType::UInt64,
270                false,
271            ),
272            Field::new(
273                format_state_name(args.name, "max"),
274                DataType::Float64,
275                false,
276            ),
277            Field::new(
278                format_state_name(args.name, "min"),
279                DataType::Float64,
280                false,
281            ),
282            Field::new_list(
283                format_state_name(args.name, "centroids"),
284                Field::new_list_field(DataType::Float64, true),
285                false,
286            ),
287        ]
288        .into_iter()
289        .map(Arc::new)
290        .collect())
291    }
292
293    fn name(&self) -> &str {
294        "approx_percentile_cont"
295    }
296
297    fn signature(&self) -> &Signature {
298        &self.signature
299    }
300
301    #[inline]
302    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
303        Ok(Box::new(self.create_accumulator(acc_args)?))
304    }
305
306    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
307        if !arg_types[0].is_numeric() {
308            return plan_err!("approx_percentile_cont requires numeric input types");
309        }
310        if arg_types.len() == 3 && !arg_types[2].is_integer() {
311            return plan_err!(
312                "approx_percentile_cont requires integer centroids input types"
313            );
314        }
315        Ok(arg_types[0].clone())
316    }
317
318    fn supports_null_handling_clause(&self) -> bool {
319        false
320    }
321
322    fn supports_within_group_clause(&self) -> bool {
323        true
324    }
325
326    fn documentation(&self) -> Option<&Documentation> {
327        self.doc()
328    }
329}
330
331#[derive(Debug)]
332pub struct ApproxPercentileAccumulator {
333    digest: TDigest,
334    percentile: f64,
335    return_type: DataType,
336}
337
338impl ApproxPercentileAccumulator {
339    pub fn new(percentile: f64, return_type: DataType) -> Self {
340        Self {
341            digest: TDigest::new(DEFAULT_MAX_SIZE),
342            percentile,
343            return_type,
344        }
345    }
346
347    pub fn new_with_max_size(
348        percentile: f64,
349        return_type: DataType,
350        max_size: usize,
351    ) -> Self {
352        Self {
353            digest: TDigest::new(max_size),
354            percentile,
355            return_type,
356        }
357    }
358
359    // pub(crate) for approx_percentile_cont_with_weight
360    pub(crate) fn max_size(&self) -> usize {
361        self.digest.max_size()
362    }
363
364    // pub(crate) for approx_percentile_cont_with_weight
365    pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
366        let digests = digests.iter().chain(std::iter::once(&self.digest));
367        self.digest = TDigest::merge_digests(digests)
368    }
369
370    // pub(crate) for approx_percentile_cont_with_weight
371    pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
372        debug_assert!(
373            values.null_count() == 0,
374            "convert_to_float assumes nulls have already been filtered out"
375        );
376        match values.data_type() {
377            DataType::Float64 => {
378                let array = downcast_value!(values, Float64Array);
379                Ok(array
380                    .values()
381                    .iter()
382                    .filter_map(|v| v.try_as_f64().transpose())
383                    .collect::<Result<Vec<_>>>()?)
384            }
385            DataType::Float32 => {
386                let array = downcast_value!(values, Float32Array);
387                Ok(array
388                    .values()
389                    .iter()
390                    .filter_map(|v| v.try_as_f64().transpose())
391                    .collect::<Result<Vec<_>>>()?)
392            }
393            DataType::Int64 => {
394                let array = downcast_value!(values, Int64Array);
395                Ok(array
396                    .values()
397                    .iter()
398                    .filter_map(|v| v.try_as_f64().transpose())
399                    .collect::<Result<Vec<_>>>()?)
400            }
401            DataType::Int32 => {
402                let array = downcast_value!(values, Int32Array);
403                Ok(array
404                    .values()
405                    .iter()
406                    .filter_map(|v| v.try_as_f64().transpose())
407                    .collect::<Result<Vec<_>>>()?)
408            }
409            DataType::Int16 => {
410                let array = downcast_value!(values, Int16Array);
411                Ok(array
412                    .values()
413                    .iter()
414                    .filter_map(|v| v.try_as_f64().transpose())
415                    .collect::<Result<Vec<_>>>()?)
416            }
417            DataType::Int8 => {
418                let array = downcast_value!(values, Int8Array);
419                Ok(array
420                    .values()
421                    .iter()
422                    .filter_map(|v| v.try_as_f64().transpose())
423                    .collect::<Result<Vec<_>>>()?)
424            }
425            DataType::UInt64 => {
426                let array = downcast_value!(values, UInt64Array);
427                Ok(array
428                    .values()
429                    .iter()
430                    .filter_map(|v| v.try_as_f64().transpose())
431                    .collect::<Result<Vec<_>>>()?)
432            }
433            DataType::UInt32 => {
434                let array = downcast_value!(values, UInt32Array);
435                Ok(array
436                    .values()
437                    .iter()
438                    .filter_map(|v| v.try_as_f64().transpose())
439                    .collect::<Result<Vec<_>>>()?)
440            }
441            DataType::UInt16 => {
442                let array = downcast_value!(values, UInt16Array);
443                Ok(array
444                    .values()
445                    .iter()
446                    .filter_map(|v| v.try_as_f64().transpose())
447                    .collect::<Result<Vec<_>>>()?)
448            }
449            DataType::UInt8 => {
450                let array = downcast_value!(values, UInt8Array);
451                Ok(array
452                    .values()
453                    .iter()
454                    .filter_map(|v| v.try_as_f64().transpose())
455                    .collect::<Result<Vec<_>>>()?)
456            }
457            e => internal_err!(
458                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
459            ),
460        }
461    }
462}
463
464impl Accumulator for ApproxPercentileAccumulator {
465    fn state(&mut self) -> Result<Vec<ScalarValue>> {
466        Ok(self.digest.to_scalar_state().into_iter().collect())
467    }
468
469    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
470        // Remove any nulls before computing the percentile
471        let mut values = Arc::clone(&values[0]);
472        if values.null_count() > 0 {
473            values = filter(&values, &is_not_null(&values)?)?;
474        }
475        let sorted_values = &arrow::compute::sort(&values, None)?;
476        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
477        self.digest = self.digest.merge_sorted_f64(&sorted_values);
478        Ok(())
479    }
480
481    fn evaluate(&mut self) -> Result<ScalarValue> {
482        if self.digest.count() == 0 {
483            return ScalarValue::try_from(self.return_type.clone());
484        }
485        let q = self.digest.estimate_quantile(self.percentile);
486
487        // These acceptable return types MUST match the validation in
488        // ApproxPercentile::create_accumulator.
489        Ok(match &self.return_type {
490            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
491            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
492            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
493            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
494            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
495            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
496            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
497            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
498            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
499            DataType::Float64 => ScalarValue::Float64(Some(q)),
500            v => unreachable!("unexpected return type {}", v),
501        })
502    }
503
504    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
505        if states.is_empty() {
506            return Ok(());
507        }
508
509        let states = (0..states[0].len())
510            .map(|index| {
511                states
512                    .iter()
513                    .map(|array| ScalarValue::try_from_array(array, index))
514                    .collect::<Result<Vec<_>>>()
515                    .map(|state| TDigest::from_scalar_state(&state))
516            })
517            .collect::<Result<Vec<_>>>()?;
518
519        self.merge_digests(&states);
520
521        Ok(())
522    }
523
524    fn size(&self) -> usize {
525        size_of_val(self) + self.digest.size() - size_of_val(&self.digest)
526            + self.return_type.size()
527            - size_of_val(&self.return_type)
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use arrow::datatypes::DataType;
534
535    use datafusion_functions_aggregate_common::tdigest::TDigest;
536
537    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
538
539    #[test]
540    fn test_combine_approx_percentile_accumulator() {
541        let mut digests: Vec<TDigest> = Vec::new();
542
543        // one TDigest with 50_000 values from 1 to 1_000
544        for _ in 1..=50 {
545            let t = TDigest::new(100);
546            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
547            let t = t.merge_unsorted_f64(values);
548            digests.push(t)
549        }
550
551        let t1 = TDigest::merge_digests(&digests);
552        let t2 = TDigest::merge_digests(&digests);
553
554        let mut accumulator =
555            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
556
557        accumulator.merge_digests(&[t1]);
558        assert_eq!(accumulator.digest.count(), 50_000);
559        accumulator.merge_digests(&[t2]);
560        assert_eq!(accumulator.digest.count(), 100_000);
561    }
562}