datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{Result, not_impl_err, plan_err};
29use datafusion_expr::Volatility::Immutable;
30use datafusion_expr::expr::{AggregateFunction, Sort};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
33use datafusion_expr::{
34    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42    ApproxPercentileContWithWeight,
43    approx_percentile_cont_with_weight_udaf
44);
45
46/// Computes the approximate percentile continuous with weight of a set of numbers
47pub fn approx_percentile_cont_with_weight(
48    order_by: Sort,
49    weight: Expr,
50    percentile: Expr,
51    centroids: Option<Expr>,
52) -> Expr {
53    let expr = order_by.expr.clone();
54
55    let args = if let Some(centroids) = centroids {
56        vec![expr, weight, percentile, centroids]
57    } else {
58        vec![expr, weight, percentile]
59    };
60
61    Expr::AggregateFunction(AggregateFunction::new_udf(
62        approx_percentile_cont_with_weight_udaf(),
63        args,
64        false,
65        None,
66        vec![order_by],
67        None,
68    ))
69}
70
71/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
72#[user_doc(
73    doc_section(label = "Approximate Functions"),
74    description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75    syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76    sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5                                                                                        |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5                                                                                             |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5                                             |
98+--------------------------------------------------+
99```"#,
100    standard_argument(name = "expression", prefix = "The"),
101    argument(
102        name = "weight",
103        description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104    ),
105    argument(
106        name = "percentile",
107        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108    ),
109    argument(
110        name = "centroids",
111        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112    )
113)]
114#[derive(PartialEq, Eq, Hash)]
115pub struct ApproxPercentileContWithWeight {
116    signature: Signature,
117    approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Debug for ApproxPercentileContWithWeight {
121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("ApproxPercentileContWithWeight")
123            .field("signature", &self.signature)
124            .finish()
125    }
126}
127
128impl Default for ApproxPercentileContWithWeight {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl ApproxPercentileContWithWeight {
135    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
136    pub fn new() -> Self {
137        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
138        // Accept any numeric value paired with weight and float64 percentile
139        for num in NUMERICS {
140            variants.push(TypeSignature::Exact(vec![
141                num.clone(),
142                num.clone(),
143                DataType::Float64,
144            ]));
145            // Additionally accept an integer number of centroids for T-Digest
146            for int in INTEGERS {
147                variants.push(TypeSignature::Exact(vec![
148                    num.clone(),
149                    num.clone(),
150                    DataType::Float64,
151                    int.clone(),
152                ]));
153            }
154        }
155        Self {
156            signature: Signature::one_of(variants, Immutable),
157            approx_percentile_cont: ApproxPercentileCont::new(),
158        }
159    }
160}
161
162impl AggregateUDFImpl for ApproxPercentileContWithWeight {
163    fn as_any(&self) -> &dyn Any {
164        self
165    }
166
167    fn name(&self) -> &str {
168        "approx_percentile_cont_with_weight"
169    }
170
171    fn signature(&self) -> &Signature {
172        &self.signature
173    }
174
175    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
176        if !arg_types[0].is_numeric() {
177            return plan_err!(
178                "approx_percentile_cont_with_weight requires numeric input types"
179            );
180        }
181        if !arg_types[1].is_numeric() {
182            return plan_err!(
183                "approx_percentile_cont_with_weight requires numeric weight input types"
184            );
185        }
186        if arg_types[2] != DataType::Float64 {
187            return plan_err!(
188                "approx_percentile_cont_with_weight requires float64 percentile input types"
189            );
190        }
191        if arg_types.len() == 4 && !arg_types[3].is_integer() {
192            return plan_err!(
193                "approx_percentile_cont_with_weight requires integer centroids input types"
194            );
195        }
196        Ok(arg_types[0].clone())
197    }
198
199    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
200        if acc_args.is_distinct {
201            return not_impl_err!(
202                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
203            );
204        }
205
206        if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
207            return plan_err!(
208                "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
209            );
210        }
211
212        let sub_args = AccumulatorArgs {
213            exprs: if acc_args.exprs.len() == 4 {
214                &[
215                    Arc::clone(&acc_args.exprs[0]), // value
216                    Arc::clone(&acc_args.exprs[2]), // percentile
217                    Arc::clone(&acc_args.exprs[3]), // centroids
218                ]
219            } else {
220                &[
221                    Arc::clone(&acc_args.exprs[0]), // value
222                    Arc::clone(&acc_args.exprs[2]), // percentile
223                ]
224            },
225            expr_fields: if acc_args.exprs.len() == 4 {
226                &[
227                    Arc::clone(&acc_args.expr_fields[0]), // value
228                    Arc::clone(&acc_args.expr_fields[2]), // percentile
229                    Arc::clone(&acc_args.expr_fields[3]), // centroids
230                ]
231            } else {
232                &[
233                    Arc::clone(&acc_args.expr_fields[0]), // value
234                    Arc::clone(&acc_args.expr_fields[2]), // percentile
235                ]
236            },
237            // Unchanged below; we list each field explicitly in case we ever add more
238            // fields to AccumulatorArgs making it easier to see if changes are also
239            // needed here.
240            return_field: acc_args.return_field,
241            schema: acc_args.schema,
242            ignore_nulls: acc_args.ignore_nulls,
243            order_bys: acc_args.order_bys,
244            is_reversed: acc_args.is_reversed,
245            name: acc_args.name,
246            is_distinct: acc_args.is_distinct,
247        };
248        let approx_percentile_cont_accumulator =
249            self.approx_percentile_cont.create_accumulator(&sub_args)?;
250        let accumulator = ApproxPercentileWithWeightAccumulator::new(
251            approx_percentile_cont_accumulator,
252        );
253        Ok(Box::new(accumulator))
254    }
255
256    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
257    /// state.
258    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
259        self.approx_percentile_cont.state_fields(args)
260    }
261
262    fn supports_within_group_clause(&self) -> bool {
263        true
264    }
265
266    fn documentation(&self) -> Option<&Documentation> {
267        self.doc()
268    }
269}
270
271#[derive(Debug)]
272pub struct ApproxPercentileWithWeightAccumulator {
273    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
274}
275
276impl ApproxPercentileWithWeightAccumulator {
277    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
278        Self {
279            approx_percentile_cont_accumulator,
280        }
281    }
282}
283
284impl Accumulator for ApproxPercentileWithWeightAccumulator {
285    fn state(&mut self) -> Result<Vec<ScalarValue>> {
286        self.approx_percentile_cont_accumulator.state()
287    }
288
289    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
290        let mut means = Arc::clone(&values[0]);
291        let mut weights = Arc::clone(&values[1]);
292        // If nulls are present in either array, need to filter those rows out in both arrays
293        match (means.null_count() > 0, weights.null_count() > 0) {
294            // Both have nulls
295            (true, true) => {
296                let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
297                means = filter(&means, &predicate)?;
298                weights = filter(&weights, &predicate)?;
299            }
300            // Only one has nulls
301            (false, true) => {
302                let predicate = &is_not_null(&weights)?;
303                means = filter(&means, predicate)?;
304                weights = filter(&weights, predicate)?;
305            }
306            (true, false) => {
307                let predicate = &is_not_null(&means)?;
308                means = filter(&means, predicate)?;
309                weights = filter(&weights, predicate)?;
310            }
311            // No nulls
312            (false, false) => {}
313        }
314        debug_assert_eq!(
315            means.len(),
316            weights.len(),
317            "invalid number of values in means and weights"
318        );
319        let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
320        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
321        let mut digests: Vec<TDigest> = vec![];
322        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
323            digests.push(TDigest::new_with_centroid(
324                self.approx_percentile_cont_accumulator.max_size(),
325                Centroid::new(*mean, *weight),
326            ))
327        }
328        self.approx_percentile_cont_accumulator
329            .merge_digests(&digests);
330        Ok(())
331    }
332
333    fn evaluate(&mut self) -> Result<ScalarValue> {
334        self.approx_percentile_cont_accumulator.evaluate()
335    }
336
337    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
338        self.approx_percentile_cont_accumulator
339            .merge_batch(states)?;
340
341        Ok(())
342    }
343
344    fn size(&self) -> usize {
345        size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
346            + self.approx_percentile_cont_accumulator.size()
347    }
348}