datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs
1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::Hash;
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::datatypes::FieldRef;
25use arrow::{array::ArrayRef, datatypes::DataType};
26use datafusion_common::ScalarValue;
27use datafusion_common::{not_impl_err, plan_err, Result};
28use datafusion_expr::expr::{AggregateFunction, Sort};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
31use datafusion_expr::Volatility::Immutable;
32use datafusion_expr::{
33 Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
34};
35use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
36use datafusion_macros::user_doc;
37
38use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
39
40create_func!(
41 ApproxPercentileContWithWeight,
42 approx_percentile_cont_with_weight_udaf
43);
44
45pub fn approx_percentile_cont_with_weight(
47 order_by: Sort,
48 weight: Expr,
49 percentile: Expr,
50 centroids: Option<Expr>,
51) -> Expr {
52 let expr = order_by.expr.clone();
53
54 let args = if let Some(centroids) = centroids {
55 vec![expr, weight, percentile, centroids]
56 } else {
57 vec![expr, weight, percentile]
58 };
59
60 Expr::AggregateFunction(AggregateFunction::new_udf(
61 approx_percentile_cont_with_weight_udaf(),
62 args,
63 false,
64 None,
65 vec![order_by],
66 None,
67 ))
68}
69
70#[user_doc(
72 doc_section(label = "Approximate Functions"),
73 description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
74 syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
75 sql_example = r#"```sql
76> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
77+---------------------------------------------------------------------------------------------+
78| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
79+---------------------------------------------------------------------------------------------+
80| 78.5 |
81+---------------------------------------------------------------------------------------------+
82> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+--------------------------------------------------------------------------------------------------+
84| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
85+--------------------------------------------------------------------------------------------------+
86| 78.5 |
87+--------------------------------------------------------------------------------------------------+
88```
89An alternative syntax is also supported:
90
91```sql
92> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
93+--------------------------------------------------+
94| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
95+--------------------------------------------------+
96| 78.5 |
97+--------------------------------------------------+
98```"#,
99 standard_argument(name = "expression", prefix = "The"),
100 argument(
101 name = "weight",
102 description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
103 ),
104 argument(
105 name = "percentile",
106 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
107 ),
108 argument(
109 name = "centroids",
110 description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
111 )
112)]
113#[derive(PartialEq, Eq, Hash)]
114pub struct ApproxPercentileContWithWeight {
115 signature: Signature,
116 approx_percentile_cont: ApproxPercentileCont,
117}
118
119impl Debug for ApproxPercentileContWithWeight {
120 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("ApproxPercentileContWithWeight")
122 .field("signature", &self.signature)
123 .finish()
124 }
125}
126
127impl Default for ApproxPercentileContWithWeight {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl ApproxPercentileContWithWeight {
134 pub fn new() -> Self {
136 let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
137 for num in NUMERICS {
139 variants.push(TypeSignature::Exact(vec![
140 num.clone(),
141 num.clone(),
142 DataType::Float64,
143 ]));
144 for int in INTEGERS {
146 variants.push(TypeSignature::Exact(vec![
147 num.clone(),
148 num.clone(),
149 DataType::Float64,
150 int.clone(),
151 ]));
152 }
153 }
154 Self {
155 signature: Signature::one_of(variants, Immutable),
156 approx_percentile_cont: ApproxPercentileCont::new(),
157 }
158 }
159}
160
161impl AggregateUDFImpl for ApproxPercentileContWithWeight {
162 fn as_any(&self) -> &dyn Any {
163 self
164 }
165
166 fn name(&self) -> &str {
167 "approx_percentile_cont_with_weight"
168 }
169
170 fn signature(&self) -> &Signature {
171 &self.signature
172 }
173
174 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
175 if !arg_types[0].is_numeric() {
176 return plan_err!(
177 "approx_percentile_cont_with_weight requires numeric input types"
178 );
179 }
180 if !arg_types[1].is_numeric() {
181 return plan_err!(
182 "approx_percentile_cont_with_weight requires numeric weight input types"
183 );
184 }
185 if arg_types[2] != DataType::Float64 {
186 return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
187 }
188 if arg_types.len() == 4 && !arg_types[3].is_integer() {
189 return plan_err!(
190 "approx_percentile_cont_with_weight requires integer centroids input types"
191 );
192 }
193 Ok(arg_types[0].clone())
194 }
195
196 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
197 if acc_args.is_distinct {
198 return not_impl_err!(
199 "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
200 );
201 }
202
203 if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
204 return plan_err!(
205 "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
206 );
207 }
208
209 let sub_args = AccumulatorArgs {
210 exprs: if acc_args.exprs.len() == 4 {
211 &[
212 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), Arc::clone(&acc_args.exprs[3]), ]
216 } else {
217 &[
218 Arc::clone(&acc_args.exprs[0]), Arc::clone(&acc_args.exprs[2]), ]
221 },
222 ..acc_args
223 };
224 let approx_percentile_cont_accumulator =
225 self.approx_percentile_cont.create_accumulator(sub_args)?;
226 let accumulator = ApproxPercentileWithWeightAccumulator::new(
227 approx_percentile_cont_accumulator,
228 );
229 Ok(Box::new(accumulator))
230 }
231
232 #[allow(rustdoc::private_intra_doc_links)]
233 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
236 self.approx_percentile_cont.state_fields(args)
237 }
238
239 fn supports_null_handling_clause(&self) -> bool {
240 false
241 }
242
243 fn is_ordered_set_aggregate(&self) -> bool {
244 true
245 }
246
247 fn documentation(&self) -> Option<&Documentation> {
248 self.doc()
249 }
250}
251
252#[derive(Debug)]
253pub struct ApproxPercentileWithWeightAccumulator {
254 approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
255}
256
257impl ApproxPercentileWithWeightAccumulator {
258 pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
259 Self {
260 approx_percentile_cont_accumulator,
261 }
262 }
263}
264
265impl Accumulator for ApproxPercentileWithWeightAccumulator {
266 fn state(&mut self) -> Result<Vec<ScalarValue>> {
267 self.approx_percentile_cont_accumulator.state()
268 }
269
270 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
271 let means = &values[0];
272 let weights = &values[1];
273 debug_assert_eq!(
274 means.len(),
275 weights.len(),
276 "invalid number of values in means and weights"
277 );
278 let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
279 let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
280 let mut digests: Vec<TDigest> = vec![];
281 for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
282 digests.push(TDigest::new_with_centroid(
283 self.approx_percentile_cont_accumulator.max_size(),
284 Centroid::new(*mean, *weight),
285 ))
286 }
287 self.approx_percentile_cont_accumulator
288 .merge_digests(&digests);
289 Ok(())
290 }
291
292 fn evaluate(&mut self) -> Result<ScalarValue> {
293 self.approx_percentile_cont_accumulator.evaluate()
294 }
295
296 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297 self.approx_percentile_cont_accumulator
298 .merge_batch(states)?;
299
300 Ok(())
301 }
302
303 fn size(&self) -> usize {
304 size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
305 + self.approx_percentile_cont_accumulator.size()
306 }
307}