datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::compute::{and, filter, is_not_null};
25use arrow::datatypes::FieldRef;
26use arrow::{array::ArrayRef, datatypes::DataType};
27use datafusion_common::ScalarValue;
28use datafusion_common::{Result, not_impl_err, plan_err};
29use datafusion_expr::Volatility::Immutable;
30use datafusion_expr::expr::{AggregateFunction, Sort};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
33use datafusion_expr::{
34 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
35};
36use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41create_func!(
42 ApproxPercentileContWithWeight,
43 approx_percentile_cont_with_weight_udaf
44);
45
46pub fn approx_percentile_cont_with_weight(
48 order_by: Sort,
49 weight: Expr,
50 percentile: Expr,
51 centroids: Option<Expr>,
52) -> Expr {
53 let expr = order_by.expr.clone();
54
55 let args = if let Some(centroids) = centroids {
56 vec![expr, weight, percentile, centroids]
57 } else {
58 vec![expr, weight, percentile]
59 };
60
61 Expr::AggregateFunction(AggregateFunction::new_udf(
62 approx_percentile_cont_with_weight_udaf(),
63 args,
64 false,
65 None,
66 vec![order_by],
67 None,
68 ))
69}
70
71#[user_doc(
73 doc_section(label = "Approximate Functions"),
74 description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
75 syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
76 sql_example = r#"```sql
77> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
78+---------------------------------------------------------------------------------------------+
79| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
80+---------------------------------------------------------------------------------------------+
81| 78.5 |
82+---------------------------------------------------------------------------------------------+
83> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
84+--------------------------------------------------------------------------------------------------+
85| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
86+--------------------------------------------------------------------------------------------------+
87| 78.5 |
88+--------------------------------------------------------------------------------------------------+
89```
90An alternative syntax is also supported:
91
92```sql
93> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
94+--------------------------------------------------+
95| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
96+--------------------------------------------------+
97| 78.5 |
98+--------------------------------------------------+
99```"#,
100 standard_argument(name = "expression", prefix = "The"),
101 argument(
102 name = "weight",
103 description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
104 ),
105 argument(
106 name = "percentile",
107 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
108 ),
109 argument(
110 name = "centroids",
111 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
112 )
113)]
114#[derive(PartialEq, Eq, Hash)]
115pub struct ApproxPercentileContWithWeight {
116 signature: Signature,
117 approx_percentile_cont: ApproxPercentileCont,
118}
119
120impl Debug for ApproxPercentileContWithWeight {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("ApproxPercentileContWithWeight")
123 .field("signature", &self.signature)
124 .finish()
125 }
126}
127
128impl Default for ApproxPercentileContWithWeight {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl ApproxPercentileContWithWeight {
135 pub fn new() -> Self {
137 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
138 for num in NUMERICS {
140 variants.push(TypeSignature::Exact(vec![
141 num.clone(),
142 num.clone(),
143 DataType::Float64,
144 ]));
145 for int in INTEGERS {
147 variants.push(TypeSignature::Exact(vec![
148 num.clone(),
149 num.clone(),
150 DataType::Float64,
151 int.clone(),
152 ]));
153 }
154 }
155 Self {
156 signature: Signature::one_of(variants, Immutable),
157 approx_percentile_cont: ApproxPercentileCont::new(),
158 }
159 }
160}
161
162impl AggregateUDFImpl for ApproxPercentileContWithWeight {
163 fn as_any(&self) -> &dyn Any {
164 self
165 }
166
167 fn name(&self) -> &str {
168 "approx_percentile_cont_with_weight"
169 }
170
171 fn signature(&self) -> &Signature {
172 &self.signature
173 }
174
175 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
176 if !arg_types[0].is_numeric() {
177 return plan_err!(
178 "approx_percentile_cont_with_weight requires numeric input types"
179 );
180 }
181 if !arg_types[1].is_numeric() {
182 return plan_err!(
183 "approx_percentile_cont_with_weight requires numeric weight input types"
184 );
185 }
186 if arg_types[2] != DataType::Float64 {
187 return plan_err!(
188 "approx_percentile_cont_with_weight requires float64 percentile input types"
189 );
190 }
191 if arg_types.len() == 4 && !arg_types[3].is_integer() {
192 return plan_err!(
193 "approx_percentile_cont_with_weight requires integer centroids input types"
194 );
195 }
196 Ok(arg_types[0].clone())
197 }
198
199 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
200 if acc_args.is_distinct {
201 return not_impl_err!(
202 "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
203 );
204 }
205
206 if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
207 return plan_err!(
208 "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
209 );
210 }
211
212 let sub_args = AccumulatorArgs {
213 exprs: if acc_args.exprs.len() == 4 {
214 &[
215 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), Arc::clone(&acc_args.exprs[3]), ]
219 } else {
220 &[
221 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), ]
224 },
225 expr_fields: if acc_args.exprs.len() == 4 {
226 &[
227 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), Arc::clone(&acc_args.expr_fields[3]), ]
231 } else {
232 &[
233 Arc::clone(&acc_args.expr_fields[0]), Arc::clone(&acc_args.expr_fields[2]), ]
236 },
237 return_field: acc_args.return_field,
241 schema: acc_args.schema,
242 ignore_nulls: acc_args.ignore_nulls,
243 order_bys: acc_args.order_bys,
244 is_reversed: acc_args.is_reversed,
245 name: acc_args.name,
246 is_distinct: acc_args.is_distinct,
247 };
248 let approx_percentile_cont_accumulator =
249 self.approx_percentile_cont.create_accumulator(&sub_args)?;
250 let accumulator = ApproxPercentileWithWeightAccumulator::new(
251 approx_percentile_cont_accumulator,
252 );
253 Ok(Box::new(accumulator))
254 }
255
256 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
259 self.approx_percentile_cont.state_fields(args)
260 }
261
262 fn supports_within_group_clause(&self) -> bool {
263 true
264 }
265
266 fn documentation(&self) -> Option<&Documentation> {
267 self.doc()
268 }
269}
270
271#[derive(Debug)]
272pub struct ApproxPercentileWithWeightAccumulator {
273 approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
274}
275
276impl ApproxPercentileWithWeightAccumulator {
277 pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
278 Self {
279 approx_percentile_cont_accumulator,
280 }
281 }
282}
283
284impl Accumulator for ApproxPercentileWithWeightAccumulator {
285 fn state(&mut self) -> Result<Vec<ScalarValue>> {
286 self.approx_percentile_cont_accumulator.state()
287 }
288
289 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
290 let mut means = Arc::clone(&values[0]);
291 let mut weights = Arc::clone(&values[1]);
292 match (means.null_count() > 0, weights.null_count() > 0) {
294 (true, true) => {
296 let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
297 means = filter(&means, &predicate)?;
298 weights = filter(&weights, &predicate)?;
299 }
300 (false, true) => {
302 let predicate = &is_not_null(&weights)?;
303 means = filter(&means, predicate)?;
304 weights = filter(&weights, predicate)?;
305 }
306 (true, false) => {
307 let predicate = &is_not_null(&means)?;
308 means = filter(&means, predicate)?;
309 weights = filter(&weights, predicate)?;
310 }
311 (false, false) => {}
313 }
314 debug_assert_eq!(
315 means.len(),
316 weights.len(),
317 "invalid number of values in means and weights"
318 );
319 let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
320 let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
321 let mut digests: Vec<TDigest> = vec![];
322 for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
323 digests.push(TDigest::new_with_centroid(
324 self.approx_percentile_cont_accumulator.max_size(),
325 Centroid::new(*mean, *weight),
326 ))
327 }
328 self.approx_percentile_cont_accumulator
329 .merge_digests(&digests);
330 Ok(())
331 }
332
333 fn evaluate(&mut self) -> Result<ScalarValue> {
334 self.approx_percentile_cont_accumulator.evaluate()
335 }
336
337 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
338 self.approx_percentile_cont_accumulator
339 .merge_batch(states)?;
340
341 Ok(())
342 }
343
344 fn size(&self) -> usize {
345 size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
346 + self.approx_percentile_cont_accumulator.size()
347 }
348}