datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{not_impl_err, plan_err, Result};
29use datafusion_expr::expr::{AggregateFunction, Sort};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
32use datafusion_expr::Volatility::Immutable;
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42    ApproxPercentileContWithWeight,
43    approx_percentile_cont_with_weight_udaf
44);
45
46/// Computes the approximate percentile continuous with weight of a set of numbers
47pub fn approx_percentile_cont_with_weight(
48    order_by: Sort,
49    weight: Expr,
50    percentile: Expr,
51    centroids: Option<Expr>,
52) -> Expr {
53    let expr = order_by.expr.clone();
54
55    let args = if let Some(centroids) = centroids {
56        vec![expr, weight, percentile, centroids]
57    } else {
58        vec![expr, weight, percentile]
59    };
60
61    Expr::AggregateFunction(AggregateFunction::new_udf(
62        approx_percentile_cont_with_weight_udaf(),
63        args,
64        false,
65        None,
66        vec![order_by],
67        None,
68    ))
69}
70
71/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
72#[user_doc(
73    doc_section(label = "Approximate Functions"),
74    description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75    syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76    sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5                                                                                        |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5                                                                                             |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5                                             |
98+--------------------------------------------------+
99```"#,
100    standard_argument(name = "expression", prefix = "The"),
101    argument(
102        name = "weight",
103        description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104    ),
105    argument(
106        name = "percentile",
107        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108    ),
109    argument(
110        name = "centroids",
111        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112    )
113)]
114#[derive(PartialEq, Eq, Hash)]
115pub struct ApproxPercentileContWithWeight {
116    signature: Signature,
117    approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Debug for ApproxPercentileContWithWeight {
121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("ApproxPercentileContWithWeight")
123            .field("signature", &self.signature)
124            .finish()
125    }
126}
127
128impl Default for ApproxPercentileContWithWeight {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl ApproxPercentileContWithWeight {
135    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
136    pub fn new() -> Self {
137        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
138        // Accept any numeric value paired with weight and float64 percentile
139        for num in NUMERICS {
140            variants.push(TypeSignature::Exact(vec![
141                num.clone(),
142                num.clone(),
143                DataType::Float64,
144            ]));
145            // Additionally accept an integer number of centroids for T-Digest
146            for int in INTEGERS {
147                variants.push(TypeSignature::Exact(vec![
148                    num.clone(),
149                    num.clone(),
150                    DataType::Float64,
151                    int.clone(),
152                ]));
153            }
154        }
155        Self {
156            signature: Signature::one_of(variants, Immutable),
157            approx_percentile_cont: ApproxPercentileCont::new(),
158        }
159    }
160}
161
162impl AggregateUDFImpl for ApproxPercentileContWithWeight {
163    fn as_any(&self) -> &dyn Any {
164        self
165    }
166
167    fn name(&self) -> &str {
168        "approx_percentile_cont_with_weight"
169    }
170
171    fn signature(&self) -> &Signature {
172        &self.signature
173    }
174
175    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
176        if !arg_types[0].is_numeric() {
177            return plan_err!(
178                "approx_percentile_cont_with_weight requires numeric input types"
179            );
180        }
181        if !arg_types[1].is_numeric() {
182            return plan_err!(
183                "approx_percentile_cont_with_weight requires numeric weight input types"
184            );
185        }
186        if arg_types[2] != DataType::Float64 {
187            return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
188        }
189        if arg_types.len() == 4 && !arg_types[3].is_integer() {
190            return plan_err!(
191                "approx_percentile_cont_with_weight requires integer centroids input types"
192            );
193        }
194        Ok(arg_types[0].clone())
195    }
196
197    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
198        if acc_args.is_distinct {
199            return not_impl_err!(
200                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
201            );
202        }
203
204        if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
205            return plan_err!(
206                "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
207            );
208        }
209
210        let sub_args = AccumulatorArgs {
211            exprs: if acc_args.exprs.len() == 4 {
212                &[
213                    Arc::clone(&acc_args.exprs[0]), // value
214                    Arc::clone(&acc_args.exprs[2]), // percentile
215                    Arc::clone(&acc_args.exprs[3]), // centroids
216                ]
217            } else {
218                &[
219                    Arc::clone(&acc_args.exprs[0]), // value
220                    Arc::clone(&acc_args.exprs[2]), // percentile
221                ]
222            },
223            expr_fields: if acc_args.exprs.len() == 4 {
224                &[
225                    Arc::clone(&acc_args.expr_fields[0]), // value
226                    Arc::clone(&acc_args.expr_fields[2]), // percentile
227                    Arc::clone(&acc_args.expr_fields[3]), // centroids
228                ]
229            } else {
230                &[
231                    Arc::clone(&acc_args.expr_fields[0]), // value
232                    Arc::clone(&acc_args.expr_fields[2]), // percentile
233                ]
234            },
235            // Unchanged below; we list each field explicitly in case we ever add more
236            // fields to AccumulatorArgs making it easier to see if changes are also
237            // needed here.
238            return_field: acc_args.return_field,
239            schema: acc_args.schema,
240            ignore_nulls: acc_args.ignore_nulls,
241            order_bys: acc_args.order_bys,
242            is_reversed: acc_args.is_reversed,
243            name: acc_args.name,
244            is_distinct: acc_args.is_distinct,
245        };
246        let approx_percentile_cont_accumulator =
247            self.approx_percentile_cont.create_accumulator(sub_args)?;
248        let accumulator = ApproxPercentileWithWeightAccumulator::new(
249            approx_percentile_cont_accumulator,
250        );
251        Ok(Box::new(accumulator))
252    }
253
254    #[allow(rustdoc::private_intra_doc_links)]
255    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
256    /// state.
257    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
258        self.approx_percentile_cont.state_fields(args)
259    }
260
261    fn supports_null_handling_clause(&self) -> bool {
262        false
263    }
264
265    fn supports_within_group_clause(&self) -> bool {
266        true
267    }
268
269    fn documentation(&self) -> Option<&Documentation> {
270        self.doc()
271    }
272}
273
274#[derive(Debug)]
275pub struct ApproxPercentileWithWeightAccumulator {
276    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
277}
278
279impl ApproxPercentileWithWeightAccumulator {
280    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
281        Self {
282            approx_percentile_cont_accumulator,
283        }
284    }
285}
286
287impl Accumulator for ApproxPercentileWithWeightAccumulator {
288    fn state(&mut self) -> Result<Vec<ScalarValue>> {
289        self.approx_percentile_cont_accumulator.state()
290    }
291
292    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
293        let mut means = Arc::clone(&values[0]);
294        let mut weights = Arc::clone(&values[1]);
295        // If nulls are present in either array, need to filter those rows out in both arrays
296        match (means.null_count() > 0, weights.null_count() > 0) {
297            // Both have nulls
298            (true, true) => {
299                let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
300                means = filter(&means, &predicate)?;
301                weights = filter(&weights, &predicate)?;
302            }
303            // Only one has nulls
304            (false, true) => {
305                let predicate = &is_not_null(&weights)?;
306                means = filter(&means, predicate)?;
307                weights = filter(&weights, predicate)?;
308            }
309            (true, false) => {
310                let predicate = &is_not_null(&means)?;
311                means = filter(&means, predicate)?;
312                weights = filter(&weights, predicate)?;
313            }
314            // No nulls
315            (false, false) => {}
316        }
317        debug_assert_eq!(
318            means.len(),
319            weights.len(),
320            "invalid number of values in means and weights"
321        );
322        let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
323        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
324        let mut digests: Vec<TDigest> = vec![];
325        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
326            digests.push(TDigest::new_with_centroid(
327                self.approx_percentile_cont_accumulator.max_size(),
328                Centroid::new(*mean, *weight),
329            ))
330        }
331        self.approx_percentile_cont_accumulator
332            .merge_digests(&digests);
333        Ok(())
334    }
335
336    fn evaluate(&mut self) -> Result<ScalarValue> {
337        self.approx_percentile_cont_accumulator.evaluate()
338    }
339
340    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
341        self.approx_percentile_cont_accumulator
342            .merge_batch(states)?;
343
344        Ok(())
345    }
346
347    fn size(&self) -> usize {
348        size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
349            + self.approx_percentile_cont_accumulator.size()
350    }
351}