Skip to main content

datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Debug;
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{Result, not_impl_err, plan_err};
29use datafusion_expr::Volatility::Immutable;
30use datafusion_expr::expr::{AggregateFunction, Sort};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42    ApproxPercentileContWithWeight,
43    approx_percentile_cont_with_weight_udaf
44);
45
46/// Computes the approximate percentile continuous with weight of a set of numbers
47pub fn approx_percentile_cont_with_weight(
48    order_by: Sort,
49    weight: Expr,
50    percentile: Expr,
51    centroids: Option<Expr>,
52) -> Expr {
53    let expr = order_by.expr.clone();
54
55    let args = if let Some(centroids) = centroids {
56        vec![expr, weight, percentile, centroids]
57    } else {
58        vec![expr, weight, percentile]
59    };
60
61    Expr::AggregateFunction(AggregateFunction::new_udf(
62        approx_percentile_cont_with_weight_udaf(),
63        args,
64        false,
65        None,
66        vec![order_by],
67        None,
68    ))
69}
70
71/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
72#[user_doc(
73    doc_section(label = "Approximate Functions"),
74    description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75    syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76    sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5                                                                                        |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5                                                                                             |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5                                             |
98+--------------------------------------------------+
99```"#,
100    standard_argument(name = "expression", prefix = "The"),
101    argument(
102        name = "weight",
103        description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104    ),
105    argument(
106        name = "percentile",
107        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108    ),
109    argument(
110        name = "centroids",
111        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112    )
113)]
114#[derive(PartialEq, Eq, Hash, Debug)]
115pub struct ApproxPercentileContWithWeight {
116    signature: Signature,
117    approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Default for ApproxPercentileContWithWeight {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl ApproxPercentileContWithWeight {
127    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
128    pub fn new() -> Self {
129        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
130        // Accept any numeric value paired with weight and float64 percentile
131        for num in NUMERICS {
132            variants.push(TypeSignature::Exact(vec![
133                num.clone(),
134                num.clone(),
135                DataType::Float64,
136            ]));
137            // Additionally accept an integer number of centroids for T-Digest
138            for int in INTEGERS {
139                variants.push(TypeSignature::Exact(vec![
140                    num.clone(),
141                    num.clone(),
142                    DataType::Float64,
143                    int.clone(),
144                ]));
145            }
146        }
147        Self {
148            signature: Signature::one_of(variants, Immutable),
149            approx_percentile_cont: ApproxPercentileCont::new(),
150        }
151    }
152}
153
154impl AggregateUDFImpl for ApproxPercentileContWithWeight {
155    fn as_any(&self) -> &dyn Any {
156        self
157    }
158
159    fn name(&self) -> &str {
160        "approx_percentile_cont_with_weight"
161    }
162
163    fn signature(&self) -> &Signature {
164        &self.signature
165    }
166
167    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
168        if !arg_types[0].is_numeric() {
169            return plan_err!(
170                "approx_percentile_cont_with_weight requires numeric input types"
171            );
172        }
173        if !arg_types[1].is_numeric() {
174            return plan_err!(
175                "approx_percentile_cont_with_weight requires numeric weight input types"
176            );
177        }
178        if arg_types[2] != DataType::Float64 {
179            return plan_err!(
180                "approx_percentile_cont_with_weight requires float64 percentile input types"
181            );
182        }
183        if arg_types.len() == 4 && !arg_types[3].is_integer() {
184            return plan_err!(
185                "approx_percentile_cont_with_weight requires integer centroids input types"
186            );
187        }
188        Ok(arg_types[0].clone())
189    }
190
191    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
192        if acc_args.is_distinct {
193            return not_impl_err!(
194                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
195            );
196        }
197
198        if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
199            return plan_err!(
200                "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
201            );
202        }
203
204        let sub_args = AccumulatorArgs {
205            exprs: if acc_args.exprs.len() == 4 {
206                &[
207                    Arc::clone(&acc_args.exprs[0]), // value
208                    Arc::clone(&acc_args.exprs[2]), // percentile
209                    Arc::clone(&acc_args.exprs[3]), // centroids
210                ]
211            } else {
212                &[
213                    Arc::clone(&acc_args.exprs[0]), // value
214                    Arc::clone(&acc_args.exprs[2]), // percentile
215                ]
216            },
217            expr_fields: if acc_args.exprs.len() == 4 {
218                &[
219                    Arc::clone(&acc_args.expr_fields[0]), // value
220                    Arc::clone(&acc_args.expr_fields[2]), // percentile
221                    Arc::clone(&acc_args.expr_fields[3]), // centroids
222                ]
223            } else {
224                &[
225                    Arc::clone(&acc_args.expr_fields[0]), // value
226                    Arc::clone(&acc_args.expr_fields[2]), // percentile
227                ]
228            },
229            // Unchanged below; we list each field explicitly in case we ever add more
230            // fields to AccumulatorArgs making it easier to see if changes are also
231            // needed here.
232            return_field: acc_args.return_field,
233            schema: acc_args.schema,
234            ignore_nulls: acc_args.ignore_nulls,
235            order_bys: acc_args.order_bys,
236            is_reversed: acc_args.is_reversed,
237            name: acc_args.name,
238            is_distinct: acc_args.is_distinct,
239        };
240        let approx_percentile_cont_accumulator =
241            self.approx_percentile_cont.create_accumulator(&sub_args)?;
242        let accumulator = ApproxPercentileWithWeightAccumulator::new(
243            approx_percentile_cont_accumulator,
244        );
245        Ok(Box::new(accumulator))
246    }
247
248    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
249    /// state.
250    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
251        self.approx_percentile_cont.state_fields(args)
252    }
253
254    fn supports_within_group_clause(&self) -> bool {
255        true
256    }
257
258    fn documentation(&self) -> Option<&Documentation> {
259        self.doc()
260    }
261}
262
263#[derive(Debug)]
264pub struct ApproxPercentileWithWeightAccumulator {
265    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
266}
267
268impl ApproxPercentileWithWeightAccumulator {
269    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
270        Self {
271            approx_percentile_cont_accumulator,
272        }
273    }
274}
275
276impl Accumulator for ApproxPercentileWithWeightAccumulator {
277    fn state(&mut self) -> Result<Vec<ScalarValue>> {
278        self.approx_percentile_cont_accumulator.state()
279    }
280
281    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
282        let mut means = Arc::clone(&values[0]);
283        let mut weights = Arc::clone(&values[1]);
284        // If nulls are present in either array, need to filter those rows out in both arrays
285        match (means.null_count() > 0, weights.null_count() > 0) {
286            // Both have nulls
287            (true, true) => {
288                let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
289                means = filter(&means, &predicate)?;
290                weights = filter(&weights, &predicate)?;
291            }
292            // Only one has nulls
293            (false, true) => {
294                let predicate = &is_not_null(&weights)?;
295                means = filter(&means, predicate)?;
296                weights = filter(&weights, predicate)?;
297            }
298            (true, false) => {
299                let predicate = &is_not_null(&means)?;
300                means = filter(&means, predicate)?;
301                weights = filter(&weights, predicate)?;
302            }
303            // No nulls
304            (false, false) => {}
305        }
306        debug_assert_eq!(
307            means.len(),
308            weights.len(),
309            "invalid number of values in means and weights"
310        );
311        let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
312        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
313        let mut digests: Vec<TDigest> = vec![];
314        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
315            digests.push(TDigest::new_with_centroid(
316                self.approx_percentile_cont_accumulator.max_size(),
317                Centroid::new(*mean, *weight),
318            ))
319        }
320        self.approx_percentile_cont_accumulator
321            .merge_digests(&digests);
322        Ok(())
323    }
324
325    fn evaluate(&mut self) -> Result<ScalarValue> {
326        self.approx_percentile_cont_accumulator.evaluate()
327    }
328
329    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
330        self.approx_percentile_cont_accumulator
331            .merge_batch(states)?;
332
333        Ok(())
334    }
335
336    fn size(&self) -> usize {
337        size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
338            + self.approx_percentile_cont_accumulator.size()
339    }
340}