datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs1use std::any::Any;
19use std::fmt::Debug;
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{Result, not_impl_err, plan_err};
29use datafusion_expr::Volatility::Immutable;
30use datafusion_expr::expr::{AggregateFunction, Sort};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42 ApproxPercentileContWithWeight,
43 approx_percentile_cont_with_weight_udaf
44);
45
46pub fn approx_percentile_cont_with_weight(
48 order_by: Sort,
49 weight: Expr,
50 percentile: Expr,
51 centroids: Option<Expr>,
52) -> Expr {
53 let expr = order_by.expr.clone();
54
55 let args = if let Some(centroids) = centroids {
56 vec![expr, weight, percentile, centroids]
57 } else {
58 vec![expr, weight, percentile]
59 };
60
61 Expr::AggregateFunction(AggregateFunction::new_udf(
62 approx_percentile_cont_with_weight_udaf(),
63 args,
64 false,
65 None,
66 vec![order_by],
67 None,
68 ))
69}
70
71#[user_doc(
73 doc_section(label = "Approximate Functions"),
74 description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75 syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76 sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5 |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5 |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5 |
98+--------------------------------------------------+
99```"#,
100 standard_argument(name = "expression", prefix = "The"),
101 argument(
102 name = "weight",
103 description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104 ),
105 argument(
106 name = "percentile",
107 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108 ),
109 argument(
110 name = "centroids",
111 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112 )
113)]
114#[derive(PartialEq, Eq, Hash, Debug)]
115pub struct ApproxPercentileContWithWeight {
116 signature: Signature,
117 approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Default for ApproxPercentileContWithWeight {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl ApproxPercentileContWithWeight {
127 pub fn new() -> Self {
129 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
130 for num in NUMERICS {
132 variants.push(TypeSignature::Exact(vec![
133 num.clone(),
134 num.clone(),
135 DataType::Float64,
136 ]));
137 for int in INTEGERS {
139 variants.push(TypeSignature::Exact(vec![
140 num.clone(),
141 num.clone(),
142 DataType::Float64,
143 int.clone(),
144 ]));
145 }
146 }
147 Self {
148 signature: Signature::one_of(variants, Immutable),
149 approx_percentile_cont: ApproxPercentileCont::new(),
150 }
151 }
152}
153
154impl AggregateUDFImpl for ApproxPercentileContWithWeight {
155 fn as_any(&self) -> &dyn Any {
156 self
157 }
158
159 fn name(&self) -> &str {
160 "approx_percentile_cont_with_weight"
161 }
162
163 fn signature(&self) -> &Signature {
164 &self.signature
165 }
166
167 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
168 if !arg_types[0].is_numeric() {
169 return plan_err!(
170 "approx_percentile_cont_with_weight requires numeric input types"
171 );
172 }
173 if !arg_types[1].is_numeric() {
174 return plan_err!(
175 "approx_percentile_cont_with_weight requires numeric weight input types"
176 );
177 }
178 if arg_types[2] != DataType::Float64 {
179 return plan_err!(
180 "approx_percentile_cont_with_weight requires float64 percentile input types"
181 );
182 }
183 if arg_types.len() == 4 && !arg_types[3].is_integer() {
184 return plan_err!(
185 "approx_percentile_cont_with_weight requires integer centroids input types"
186 );
187 }
188 Ok(arg_types[0].clone())
189 }
190
191 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
192 if acc_args.is_distinct {
193 return not_impl_err!(
194 "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
195 );
196 }
197
198 if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
199 return plan_err!(
200 "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
201 );
202 }
203
204 let sub_args = AccumulatorArgs {
205 exprs: if acc_args.exprs.len() == 4 {
206 &[
207 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), Arc::clone(&acc_args.exprs[3]), ]
211 } else {
212 &[
213 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), ]
216 },
217 expr_fields: if acc_args.exprs.len() == 4 {
218 &[
219 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), Arc::clone(&acc_args.expr_fields[3]), ]
223 } else {
224 &[
225 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), ]
228 },
229 return_field: acc_args.return_field,
233 schema: acc_args.schema,
234 ignore_nulls: acc_args.ignore_nulls,
235 order_bys: acc_args.order_bys,
236 is_reversed: acc_args.is_reversed,
237 name: acc_args.name,
238 is_distinct: acc_args.is_distinct,
239 };
240 let approx_percentile_cont_accumulator =
241 self.approx_percentile_cont.create_accumulator(&sub_args)?;
242 let accumulator = ApproxPercentileWithWeightAccumulator::new(
243 approx_percentile_cont_accumulator,
244 );
245 Ok(Box::new(accumulator))
246 }
247
248 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
251 self.approx_percentile_cont.state_fields(args)
252 }
253
254 fn supports_within_group_clause(&self) -> bool {
255 true
256 }
257
258 fn documentation(&self) -> Option<&Documentation> {
259 self.doc()
260 }
261}
262
263#[derive(Debug)]
264pub struct ApproxPercentileWithWeightAccumulator {
265 approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
266}
267
268impl ApproxPercentileWithWeightAccumulator {
269 pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
270 Self {
271 approx_percentile_cont_accumulator,
272 }
273 }
274}
275
276impl Accumulator for ApproxPercentileWithWeightAccumulator {
277 fn state(&mut self) -> Result<Vec<ScalarValue>> {
278 self.approx_percentile_cont_accumulator.state()
279 }
280
281 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
282 let mut means = Arc::clone(&values[0]);
283 let mut weights = Arc::clone(&values[1]);
284 match (means.null_count() > 0, weights.null_count() > 0) {
286 (true, true) => {
288 let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
289 means = filter(&means, &predicate)?;
290 weights = filter(&weights, &predicate)?;
291 }
292 (false, true) => {
294 let predicate = &is_not_null(&weights)?;
295 means = filter(&means, predicate)?;
296 weights = filter(&weights, predicate)?;
297 }
298 (true, false) => {
299 let predicate = &is_not_null(&means)?;
300 means = filter(&means, predicate)?;
301 weights = filter(&weights, predicate)?;
302 }
303 (false, false) => {}
305 }
306 debug_assert_eq!(
307 means.len(),
308 weights.len(),
309 "invalid number of values in means and weights"
310 );
311 let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
312 let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
313 let mut digests: Vec<TDigest> = vec![];
314 for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
315 digests.push(TDigest::new_with_centroid(
316 self.approx_percentile_cont_accumulator.max_size(),
317 Centroid::new(*mean, *weight),
318 ))
319 }
320 self.approx_percentile_cont_accumulator
321 .merge_digests(&digests);
322 Ok(())
323 }
324
325 fn evaluate(&mut self) -> Result<ScalarValue> {
326 self.approx_percentile_cont_accumulator.evaluate()
327 }
328
329 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
330 self.approx_percentile_cont_accumulator
331 .merge_batch(states)?;
332
333 Ok(())
334 }
335
336 fn size(&self) -> usize {
337 size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
338 + self.approx_percentile_cont_accumulator.size()
339 }
340}