datafusion_functions_aggregate/
approx_percentile_cont_with_weight.rs1use std::any::Any;
19use std::fmt::{Debug, Formatter};
20use std::hash::{DefaultHasher, Hash, Hasher};
21use std::mem::size_of_val;
22use std::sync::Arc;
23
24use arrow::datatypes::FieldRef;
25use arrow::{array::ArrayRef, datatypes::DataType};
26use datafusion_common::ScalarValue;
27use datafusion_common::{not_impl_err, plan_err, Result};
28use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29use datafusion_expr::type_coercion::aggregates::NUMERICS;
30use datafusion_expr::Volatility::Immutable;
31use datafusion_expr::{
32 Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
33};
34use datafusion_functions_aggregate_common::tdigest::{
35 Centroid, TDigest, DEFAULT_MAX_SIZE,
36};
37use datafusion_macros::user_doc;
38
39use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
40
41make_udaf_expr_and_func!(
42 ApproxPercentileContWithWeight,
43 approx_percentile_cont_with_weight,
44 expression weight percentile,
45 "Computes the approximate percentile continuous with weight of a set of numbers",
46 approx_percentile_cont_with_weight_udaf
47);
48
49#[user_doc(
51 doc_section(label = "Approximate Functions"),
52 description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
53 syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
54 sql_example = r#"```sql
55> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
56+---------------------------------------------------------------------------------------------+
57| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
58+---------------------------------------------------------------------------------------------+
59| 78.5 |
60+---------------------------------------------------------------------------------------------+
61```"#,
62 standard_argument(name = "expression", prefix = "The"),
63 argument(
64 name = "weight",
65 description = "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators."
66 ),
67 argument(
68 name = "percentile",
69 description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
70 )
71)]
72pub struct ApproxPercentileContWithWeight {
73 signature: Signature,
74 approx_percentile_cont: ApproxPercentileCont,
75}
76
77impl Debug for ApproxPercentileContWithWeight {
78 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("ApproxPercentileContWithWeight")
80 .field("signature", &self.signature)
81 .finish()
82 }
83}
84
85impl Default for ApproxPercentileContWithWeight {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl ApproxPercentileContWithWeight {
92 pub fn new() -> Self {
94 Self {
95 signature: Signature::one_of(
96 NUMERICS
98 .iter()
99 .map(|t| {
100 TypeSignature::Exact(vec![
101 t.clone(),
102 t.clone(),
103 DataType::Float64,
104 ])
105 })
106 .collect(),
107 Immutable,
108 ),
109 approx_percentile_cont: ApproxPercentileCont::new(),
110 }
111 }
112}
113
114impl AggregateUDFImpl for ApproxPercentileContWithWeight {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "approx_percentile_cont_with_weight"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128 if !arg_types[0].is_numeric() {
129 return plan_err!(
130 "approx_percentile_cont_with_weight requires numeric input types"
131 );
132 }
133 if !arg_types[1].is_numeric() {
134 return plan_err!(
135 "approx_percentile_cont_with_weight requires numeric weight input types"
136 );
137 }
138 if arg_types[2] != DataType::Float64 {
139 return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
140 }
141 Ok(arg_types[0].clone())
142 }
143
144 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
145 if acc_args.is_distinct {
146 return not_impl_err!(
147 "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
148 );
149 }
150
151 if acc_args.exprs.len() != 3 {
152 return plan_err!(
153 "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile"
154 );
155 }
156
157 let sub_args = AccumulatorArgs {
158 exprs: &[
159 Arc::clone(&acc_args.exprs[0]),
160 Arc::clone(&acc_args.exprs[2]),
161 ],
162 ..acc_args
163 };
164 let approx_percentile_cont_accumulator =
165 self.approx_percentile_cont.create_accumulator(sub_args)?;
166 let accumulator = ApproxPercentileWithWeightAccumulator::new(
167 approx_percentile_cont_accumulator,
168 );
169 Ok(Box::new(accumulator))
170 }
171
172 #[allow(rustdoc::private_intra_doc_links)]
173 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
176 self.approx_percentile_cont.state_fields(args)
177 }
178
179 fn supports_null_handling_clause(&self) -> bool {
180 false
181 }
182
183 fn is_ordered_set_aggregate(&self) -> bool {
184 true
185 }
186
187 fn documentation(&self) -> Option<&Documentation> {
188 self.doc()
189 }
190
191 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
192 let Some(other) = other.as_any().downcast_ref::<Self>() else {
193 return false;
194 };
195 let Self {
196 signature,
197 approx_percentile_cont,
198 } = self;
199 signature == &other.signature
200 && approx_percentile_cont.equals(&other.approx_percentile_cont)
201 }
202
203 fn hash_value(&self) -> u64 {
204 let Self {
205 signature,
206 approx_percentile_cont,
207 } = self;
208 let mut hasher = DefaultHasher::new();
209 std::any::type_name::<Self>().hash(&mut hasher);
210 signature.hash(&mut hasher);
211 hasher.write_u64(approx_percentile_cont.hash_value());
212 hasher.finish()
213 }
214}
215
216#[derive(Debug)]
217pub struct ApproxPercentileWithWeightAccumulator {
218 approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
219}
220
221impl ApproxPercentileWithWeightAccumulator {
222 pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
223 Self {
224 approx_percentile_cont_accumulator,
225 }
226 }
227}
228
229impl Accumulator for ApproxPercentileWithWeightAccumulator {
230 fn state(&mut self) -> Result<Vec<ScalarValue>> {
231 self.approx_percentile_cont_accumulator.state()
232 }
233
234 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
235 let means = &values[0];
236 let weights = &values[1];
237 debug_assert_eq!(
238 means.len(),
239 weights.len(),
240 "invalid number of values in means and weights"
241 );
242 let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
243 let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
244 let mut digests: Vec<TDigest> = vec![];
245 for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
246 digests.push(TDigest::new_with_centroid(
247 DEFAULT_MAX_SIZE,
248 Centroid::new(*mean, *weight),
249 ))
250 }
251 self.approx_percentile_cont_accumulator
252 .merge_digests(&digests);
253 Ok(())
254 }
255
256 fn evaluate(&mut self) -> Result<ScalarValue> {
257 self.approx_percentile_cont_accumulator.evaluate()
258 }
259
260 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
261 self.approx_percentile_cont_accumulator
262 .merge_batch(states)?;
263
264 Ok(())
265 }
266
267 fn size(&self) -> usize {
268 size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator)
269 + self.approx_percentile_cont_accumulator.size()
270 }
271}