datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{not_impl_err, plan_err, Result};
29use datafusion_expr::expr::{AggregateFunction, Sort};
30use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
32use datafusion_expr::Volatility::Immutable;
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42 ApproxPercentileContWithWeight,
43 approx_percentile_cont_with_weight_udaf
44);
45
46pub fn approx_percentile_cont_with_weight(
48 order_by: Sort,
49 weight: Expr,
50 percentile: Expr,
51 centroids: Option<Expr>,
52) -> Expr {
53 let expr = order_by.expr.clone();
54
55 let args = if let Some(centroids) = centroids {
56 vec![expr, weight, percentile, centroids]
57 } else {
58 vec![expr, weight, percentile]
59 };
60
61 Expr::AggregateFunction(AggregateFunction::new_udf(
62 approx_percentile_cont_with_weight_udaf(),
63 args,
64 false,
65 None,
66 vec![order_by],
67 None,
68 ))
69}
70
71#[user_doc(
73 doc_section(label = "Approximate Functions"),
74 description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75 syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76 sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5 |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5 |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5 |
98+--------------------------------------------------+
99```"#,
100 standard_argument(name = "expression", prefix = "The"),
101 argument(
102 name = "weight",
103 description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104 ),
105 argument(
106 name = "percentile",
107 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108 ),
109 argument(
110 name = "centroids",
111 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112 )
113)]
114#[derive(PartialEq, Eq, Hash)]
115pub struct ApproxPercentileContWithWeight {
116 signature: Signature,
117 approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Debug for ApproxPercentileContWithWeight {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("ApproxPercentileContWithWeight")
123 .field("signature", &self.signature)
124 .finish()
125 }
126}
127
128impl Default for ApproxPercentileContWithWeight {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl ApproxPercentileContWithWeight {
135 pub fn new() -> Self {
137 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
138 for num in NUMERICS {
140 variants.push(TypeSignature::Exact(vec![
141 num.clone(),
142 num.clone(),
143 DataType::Float64,
144 ]));
145 for int in INTEGERS {
147 variants.push(TypeSignature::Exact(vec![
148 num.clone(),
149 num.clone(),
150 DataType::Float64,
151 int.clone(),
152 ]));
153 }
154 }
155 Self {
156 signature: Signature::one_of(variants, Immutable),
157 approx_percentile_cont: ApproxPercentileCont::new(),
158 }
159 }
160}
161
162impl AggregateUDFImpl for ApproxPercentileContWithWeight {
163 fn as_any(&self) -> &dyn Any {
164 self
165 }
166
167 fn name(&self) -> &str {
168 "approx_percentile_cont_with_weight"
169 }
170
171 fn signature(&self) -> &Signature {
172 &self.signature
173 }
174
175 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
176 if !arg_types[0].is_numeric() {
177 return plan_err!(
178 "approx_percentile_cont_with_weight requires numeric input types"
179 );
180 }
181 if !arg_types[1].is_numeric() {
182 return plan_err!(
183 "approx_percentile_cont_with_weight requires numeric weight input types"
184 );
185 }
186 if arg_types[2] != DataType::Float64 {
187 return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
188 }
189 if arg_types.len() == 4 && !arg_types[3].is_integer() {
190 return plan_err!(
191 "approx_percentile_cont_with_weight requires integer centroids input types"
192 );
193 }
194 Ok(arg_types[0].clone())
195 }
196
197 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
198 if acc_args.is_distinct {
199 return not_impl_err!(
200 "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
201 );
202 }
203
204 if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
205 return plan_err!(
206 "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
207 );
208 }
209
210 let sub_args = AccumulatorArgs {
211 exprs: if acc_args.exprs.len() == 4 {
212 &[
213 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), Arc::clone(&acc_args.exprs[3]), ]
217 } else {
218 &[
219 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), ]
222 },
223 expr_fields: if acc_args.exprs.len() == 4 {
224 &[
225 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), Arc::clone(&acc_args.expr_fields[3]), ]
229 } else {
230 &[
231 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), ]
234 },
235 return_field: acc_args.return_field,
239 schema: acc_args.schema,
240 ignore_nulls: acc_args.ignore_nulls,
241 order_bys: acc_args.order_bys,
242 is_reversed: acc_args.is_reversed,
243 name: acc_args.name,
244 is_distinct: acc_args.is_distinct,
245 };
246 let approx_percentile_cont_accumulator =
247 self.approx_percentile_cont.create_accumulator(sub_args)?;
248 let accumulator = ApproxPercentileWithWeightAccumulator::new(
249 approx_percentile_cont_accumulator,
250 );
251 Ok(Box::new(accumulator))
252 }
253
254 #[allow(rustdoc::private_intra_doc_links)]
255 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
258 self.approx_percentile_cont.state_fields(args)
259 }
260
261 fn supports_null_handling_clause(&self) -> bool {
262 false
263 }
264
265 fn supports_within_group_clause(&self) -> bool {
266 true
267 }
268
269 fn documentation(&self) -> Option<&Documentation> {
270 self.doc()
271 }
272}
273
274#[derive(Debug)]
275pub struct ApproxPercentileWithWeightAccumulator {
276 approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
277}
278
279impl ApproxPercentileWithWeightAccumulator {
280 pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
281 Self {
282 approx_percentile_cont_accumulator,
283 }
284 }
285}
286
287impl Accumulator for ApproxPercentileWithWeightAccumulator {
288 fn state(&mut self) -> Result<Vec<ScalarValue>> {
289 self.approx_percentile_cont_accumulator.state()
290 }
291
292 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
293 let mut means = Arc::clone(&values[0]);
294 let mut weights = Arc::clone(&values[1]);
295 match (means.null_count() > 0, weights.null_count() > 0) {
297 (true, true) => {
299 let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
300 means = filter(&means, &predicate)?;
301 weights = filter(&weights, &predicate)?;
302 }
303 (false, true) => {
305 let predicate = &is_not_null(&weights)?;
306 means = filter(&means, predicate)?;
307 weights = filter(&weights, predicate)?;
308 }
309 (true, false) => {
310 let predicate = &is_not_null(&means)?;
311 means = filter(&means, predicate)?;
312 weights = filter(&weights, predicate)?;
313 }
314 (false, false) => {}
316 }
317 debug_assert_eq!(
318 means.len(),
319 weights.len(),
320 "invalid number of values in means and weights"
321 );
322 let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
323 let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
324 let mut digests: Vec<TDigest> = vec![];
325 for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
326 digests.push(TDigest::new_with_centroid(
327 self.approx_percentile_cont_accumulator.max_size(),
328 Centroid::new(*mean, *weight),
329 ))
330 }
331 self.approx_percentile_cont_accumulator
332 .merge_digests(&digests);
333 Ok(())
334 }
335
336 fn evaluate(&mut self) -> Result<ScalarValue> {
337 self.approx_percentile_cont_accumulator.evaluate()
338 }
339
340 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
341 self.approx_percentile_cont_accumulator
342 .merge_batch(states)?;
343
344 Ok(())
345 }
346
347 fn size(&self) -> usize {
348 size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
349 + self.approx_percentile_cont_accumulator.size()
350 }
351}