datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::datatypes::FieldRef;
25use arrow::{array::ArrayRef, datatypes::DataType};
26use datafusion_common::ScalarValue;
27use datafusion_common::{not_impl_err, plan_err, Result};
28use datafusion_expr::expr::{AggregateFunction, Sort};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
31use datafusion_expr::Volatility::Immutable;
32use datafusion_expr::{
33    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
34};
35use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
36use datafusion_macros::user_doc;
37
38use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
39
40create_func!(
41    ApproxPercentileContWithWeight,
42    approx_percentile_cont_with_weight_udaf
43);
44
45/// Computes the approximate percentile continuous with weight of a set of numbers
46pub fn approx_percentile_cont_with_weight(
47    order_by: Sort,
48    weight: Expr,
49    percentile: Expr,
50    centroids: Option<Expr>,
51) -> Expr {
52    let expr = order_by.expr.clone();
53
54    let args = if let Some(centroids) = centroids {
55        vec![expr, weight, percentile, centroids]
56    } else {
57        vec![expr, weight, percentile]
58    };
59
60    Expr::AggregateFunction(AggregateFunction::new_udf(
61        approx_percentile_cont_with_weight_udaf(),
62        args,
63        false,
64        None,
65        vec![order_by],
66        None,
67    ))
68}
69
70/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
71#[user_doc(
72    doc_section(label = "Approximate Functions"),
73    description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
74    syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
75    sql_example = r#"```sql
76> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
77+---------------------------------------------------------------------------------------------+
78| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
79+---------------------------------------------------------------------------------------------+
80| 78.5                                                                                        |
81+---------------------------------------------------------------------------------------------+
82> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+--------------------------------------------------------------------------------------------------+
84| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
85+--------------------------------------------------------------------------------------------------+
86| 78.5                                                                                             |
87+--------------------------------------------------------------------------------------------------+
88```
89An alternative syntax is also supported:
90
91```sql
92> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
93+--------------------------------------------------+
94| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
95+--------------------------------------------------+
96| 78.5                                             |
97+--------------------------------------------------+
98```"#,
99    standard_argument(name = "expression", prefix = "The"),
100    argument(
101        name = "weight",
102        description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
103    ),
104    argument(
105        name = "percentile",
106        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
107    ),
108    argument(
109        name = "centroids",
110        description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
111    )
112)]
113#[derive(PartialEq, Eq, Hash)]
114pub struct ApproxPercentileContWithWeight {
115    signature: Signature,
116    approx_percentile_cont: ApproxPercentileCont,
117}
118
119impl Debug for ApproxPercentileContWithWeight {
120    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("ApproxPercentileContWithWeight")
122            .field("signature", &self.signature)
123            .finish()
124    }
125}
126
127impl Default for ApproxPercentileContWithWeight {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl ApproxPercentileContWithWeight {
134    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
135    pub fn new() -> Self {
136        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
137        // Accept any numeric value paired with weight and float64 percentile
138        for num in NUMERICS {
139            variants.push(TypeSignature::Exact(vec![
140                num.clone(),
141                num.clone(),
142                DataType::Float64,
143            ]));
144            // Additionally accept an integer number of centroids for T-Digest
145            for int in INTEGERS {
146                variants.push(TypeSignature::Exact(vec![
147                    num.clone(),
148                    num.clone(),
149                    DataType::Float64,
150                    int.clone(),
151                ]));
152            }
153        }
154        Self {
155            signature: Signature::one_of(variants, Immutable),
156            approx_percentile_cont: ApproxPercentileCont::new(),
157        }
158    }
159}
160
161impl AggregateUDFImpl for ApproxPercentileContWithWeight {
162    fn as_any(&self) -> &dyn Any {
163        self
164    }
165
166    fn name(&self) -> &str {
167        "approx_percentile_cont_with_weight"
168    }
169
170    fn signature(&self) -> &Signature {
171        &self.signature
172    }
173
174    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
175        if !arg_types[0].is_numeric() {
176            return plan_err!(
177                "approx_percentile_cont_with_weight requires numeric input types"
178            );
179        }
180        if !arg_types[1].is_numeric() {
181            return plan_err!(
182                "approx_percentile_cont_with_weight requires numeric weight input types"
183            );
184        }
185        if arg_types[2] != DataType::Float64 {
186            return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
187        }
188        if arg_types.len() == 4 && !arg_types[3].is_integer() {
189            return plan_err!(
190                "approx_percentile_cont_with_weight requires integer centroids input types"
191            );
192        }
193        Ok(arg_types[0].clone())
194    }
195
196    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
197        if acc_args.is_distinct {
198            return not_impl_err!(
199                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
200            );
201        }
202
203        if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
204            return plan_err!(
205                "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
206            );
207        }
208
209        let sub_args = AccumulatorArgs {
210            exprs: if acc_args.exprs.len() == 4 {
211                &[
212                    Arc::clone(&acc_args.exprs[0]), // value
213                    Arc::clone(&acc_args.exprs[2]), // percentile
214                    Arc::clone(&acc_args.exprs[3]), // centroids
215                ]
216            } else {
217                &[
218                    Arc::clone(&acc_args.exprs[0]), // value
219                    Arc::clone(&acc_args.exprs[2]), // percentile
220                ]
221            },
222            ..acc_args
223        };
224        let approx_percentile_cont_accumulator =
225            self.approx_percentile_cont.create_accumulator(sub_args)?;
226        let accumulator = ApproxPercentileWithWeightAccumulator::new(
227            approx_percentile_cont_accumulator,
228        );
229        Ok(Box::new(accumulator))
230    }
231
232    #[allow(rustdoc::private_intra_doc_links)]
233    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
234    /// state.
235    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
236        self.approx_percentile_cont.state_fields(args)
237    }
238
239    fn supports_null_handling_clause(&self) -> bool {
240        false
241    }
242
243    fn is_ordered_set_aggregate(&self) -> bool {
244        true
245    }
246
247    fn documentation(&self) -> Option<&Documentation> {
248        self.doc()
249    }
250}
251
252#[derive(Debug)]
253pub struct ApproxPercentileWithWeightAccumulator {
254    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
255}
256
257impl ApproxPercentileWithWeightAccumulator {
258    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
259        Self {
260            approx_percentile_cont_accumulator,
261        }
262    }
263}
264
265impl Accumulator for ApproxPercentileWithWeightAccumulator {
266    fn state(&mut self) -> Result<Vec<ScalarValue>> {
267        self.approx_percentile_cont_accumulator.state()
268    }
269
270    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
271        let means = &values[0];
272        let weights = &values[1];
273        debug_assert_eq!(
274            means.len(),
275            weights.len(),
276            "invalid number of values in means and weights"
277        );
278        let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
279        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
280        let mut digests: Vec<TDigest> = vec![];
281        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
282            digests.push(TDigest::new_with_centroid(
283                self.approx_percentile_cont_accumulator.max_size(),
284                Centroid::new(*mean, *weight),
285            ))
286        }
287        self.approx_percentile_cont_accumulator
288            .merge_digests(&digests);
289        Ok(())
290    }
291
292    fn evaluate(&mut self) -> Result<ScalarValue> {
293        self.approx_percentile_cont_accumulator.evaluate()
294    }
295
296    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297        self.approx_percentile_cont_accumulator
298            .merge_batch(states)?;
299
300        Ok(())
301    }
302
303    fn size(&self) -> usize {
304        size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
305            + self.approx_percentile_cont_accumulator.size()
306    }
307}