datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::{DefaultHasher, Hash, Hasher};
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::datatypes::FieldRef;
25use arrow::{array::ArrayRef, datatypes::DataType};
26use datafusion_common::ScalarValue;
27use datafusion_common::{not_impl_err, plan_err, Result};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::type_coercion::aggregates::NUMERICS;
30use datafusion_expr::Volatility::Immutable;
31use datafusion_expr::{
32    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
33};
34use datafusion_functions_aggregate_common::tdigest::{
35    Centroid, TDigest, DEFAULT_MAX_SIZE,
36};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41make_udaf_expr_and_func!(
42    ApproxPercentileContWithWeight,
43    approx_percentile_cont_with_weight,
44    expression weight percentile,
45    "Computes the approximate percentile continuous with weight of a set of numbers",
46    approx_percentile_cont_with_weight_udaf
47);
48
49/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
50#[user_doc(
51    doc_section(label = "Approximate Functions"),
52    description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
53    syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
54    sql_example = r#"```sql
55> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
56+---------------------------------------------------------------------------------------------+
57| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
58+---------------------------------------------------------------------------------------------+
59| 78.5                                                                                        |
60+---------------------------------------------------------------------------------------------+
61```"#,
62    standard_argument(name = "expression", prefix = "The"),
63    argument(
64        name = "weight",
65        description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
66    ),
67    argument(
68        name = "percentile",
69        description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
70    )
71)]
72pub struct ApproxPercentileContWithWeight {
73    signature: Signature,
74    approx_percentile_cont: ApproxPercentileCont,
75}
76
77impl Debug for ApproxPercentileContWithWeight {
78    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("ApproxPercentileContWithWeight")
80            .field("signature", &self.signature)
81            .finish()
82    }
83}
84
85impl Default for ApproxPercentileContWithWeight {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl ApproxPercentileContWithWeight {
92    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
93    pub fn new() -> Self {
94        Self {
95            signature: Signature::one_of(
96                // Accept any numeric value paired with a float64 percentile
97                NUMERICS
98                    .iter()
99                    .map(|t| {
100                        TypeSignature::Exact(vec![
101                            t.clone(),
102                            t.clone(),
103                            DataType::Float64,
104                        ])
105                    })
106                    .collect(),
107                Immutable,
108            ),
109            approx_percentile_cont: ApproxPercentileCont::new(),
110        }
111    }
112}
113
114impl AggregateUDFImpl for ApproxPercentileContWithWeight {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "approx_percentile_cont_with_weight"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128        if !arg_types[0].is_numeric() {
129            return plan_err!(
130                "approx_percentile_cont_with_weight requires numeric input types"
131            );
132        }
133        if !arg_types[1].is_numeric() {
134            return plan_err!(
135                "approx_percentile_cont_with_weight requires numeric weight input types"
136            );
137        }
138        if arg_types[2] != DataType::Float64 {
139            return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
140        }
141        Ok(arg_types[0].clone())
142    }
143
144    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145        if acc_args.is_distinct {
146            return not_impl_err!(
147                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
148            );
149        }
150
151        if acc_args.exprs.len() != 3 {
152            return plan_err!(
153                "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile"
154            );
155        }
156
157        let sub_args = AccumulatorArgs {
158            exprs: &[
159                Arc::clone(&acc_args.exprs[0]),
160                Arc::clone(&acc_args.exprs[2]),
161            ],
162            ..acc_args
163        };
164        let approx_percentile_cont_accumulator =
165            self.approx_percentile_cont.create_accumulator(sub_args)?;
166        let accumulator = ApproxPercentileWithWeightAccumulator::new(
167            approx_percentile_cont_accumulator,
168        );
169        Ok(Box::new(accumulator))
170    }
171
172    #[allow(rustdoc::private_intra_doc_links)]
173    /// See [`TDigest::to_scalar_state()`] for a description of the serialized
174    /// state.
175    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
176        self.approx_percentile_cont.state_fields(args)
177    }
178
179    fn supports_null_handling_clause(&self) -> bool {
180        false
181    }
182
183    fn is_ordered_set_aggregate(&self) -> bool {
184        true
185    }
186
187    fn documentation(&self) -> Option<&Documentation> {
188        self.doc()
189    }
190
191    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
192        let Some(other) = other.as_any().downcast_ref::<Self>() else {
193            return false;
194        };
195        let Self {
196            signature,
197            approx_percentile_cont,
198        } = self;
199        signature == &other.signature
200            && approx_percentile_cont.equals(&other.approx_percentile_cont)
201    }
202
203    fn hash_value(&self) -> u64 {
204        let Self {
205            signature,
206            approx_percentile_cont,
207        } = self;
208        let mut hasher = DefaultHasher::new();
209        std::any::type_name::<Self>().hash(&mut hasher);
210        signature.hash(&mut hasher);
211        hasher.write_u64(approx_percentile_cont.hash_value());
212        hasher.finish()
213    }
214}
215
216#[derive(Debug)]
217pub struct ApproxPercentileWithWeightAccumulator {
218    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
219}
220
221impl ApproxPercentileWithWeightAccumulator {
222    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
223        Self {
224            approx_percentile_cont_accumulator,
225        }
226    }
227}
228
229impl Accumulator for ApproxPercentileWithWeightAccumulator {
230    fn state(&mut self) -> Result<Vec<ScalarValue>> {
231        self.approx_percentile_cont_accumulator.state()
232    }
233
234    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
235        let means = &values[0];
236        let weights = &values[1];
237        debug_assert_eq!(
238            means.len(),
239            weights.len(),
240            "invalid number of values in means and weights"
241        );
242        let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
243        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
244        let mut digests: Vec<TDigest> = vec![];
245        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
246            digests.push(TDigest::new_with_centroid(
247                DEFAULT_MAX_SIZE,
248                Centroid::new(*mean, *weight),
249            ))
250        }
251        self.approx_percentile_cont_accumulator
252            .merge_digests(&digests);
253        Ok(())
254    }
255
256    fn evaluate(&mut self) -> Result<ScalarValue> {
257        self.approx_percentile_cont_accumulator.evaluate()
258    }
259
260    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
261        self.approx_percentile_cont_accumulator
262            .merge_batch(states)?;
263
264        Ok(())
265    }
266
267    fn size(&self) -> usize {
268        size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
269            + self.approx_percentile_cont_accumulator.size()
270    }
271}