1use std::any::Any;
23
24use arrow::datatypes::{
25 DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33 expr::AggregateFunction,
34 function::{AccumulatorArgs, StateFieldsArgs},
35 utils::AggregateOrderSensitivity,
36 Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40 ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41 paste::paste! {
42 #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43 pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44 static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46 std::sync::LazyLock::new(|| {
47 std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48 });
49 std::sync::Arc::clone(&INSTANCE)
50 }
51 }
52 }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58 Expr::AggregateFunction(AggregateFunction::new_udf(
59 sum_udaf(),
60 vec![expr],
61 false,
62 None,
63 vec![],
64 None,
65 ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71 Expr::AggregateFunction(AggregateFunction::new_udf(
72 count_udaf(),
73 vec![expr],
74 false,
75 None,
76 vec![],
77 None,
78 ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84 Expr::AggregateFunction(AggregateFunction::new_udf(
85 avg_udaf(),
86 vec![expr],
87 false,
88 None,
89 vec![],
90 None,
91 ))
92}
93
94#[derive(Debug)]
96pub struct Sum {
97 signature: Signature,
98}
99
100impl Sum {
101 pub fn new() -> Self {
102 Self {
103 signature: Signature::user_defined(Immutable),
104 }
105 }
106}
107
108impl Default for Sum {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl AggregateUDFImpl for Sum {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "sum"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128 let [array] = take_function_args(self.name(), arg_types)?;
129
130 fn coerced_type(data_type: &DataType) -> Result<DataType> {
134 match data_type {
135 DataType::Dictionary(_, v) => coerced_type(v),
136 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139 Ok(data_type.clone())
140 }
141 dt if dt.is_signed_integer() => Ok(DataType::Int64),
142 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143 dt if dt.is_floating() => Ok(DataType::Float64),
144 _ => exec_err!("Sum not supported for {}", data_type),
145 }
146 }
147
148 Ok(vec![coerced_type(array)?])
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 match &arg_types[0] {
153 DataType::Int64 => Ok(DataType::Int64),
154 DataType::UInt64 => Ok(DataType::UInt64),
155 DataType::Float64 => Ok(DataType::Float64),
156 DataType::Decimal128(precision, scale) => {
157 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160 Ok(DataType::Decimal128(new_precision, *scale))
161 }
162 DataType::Decimal256(precision, scale) => {
163 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166 Ok(DataType::Decimal256(new_precision, *scale))
167 }
168 other => {
169 exec_err!("[return_type] SUM not supported for {}", other)
170 }
171 }
172 }
173
174 fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175 unreachable!("stub should not have accumulate()")
176 }
177
178 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
179 unreachable!("stub should not have state_fields()")
180 }
181
182 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
183 false
184 }
185
186 fn create_groups_accumulator(
187 &self,
188 _args: AccumulatorArgs,
189 ) -> Result<Box<dyn GroupsAccumulator>> {
190 unreachable!("stub should not have accumulate()")
191 }
192
193 fn reverse_expr(&self) -> ReversedUDAF {
194 ReversedUDAF::Identical
195 }
196
197 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
198 AggregateOrderSensitivity::Insensitive
199 }
200}
201
202pub struct Count {
204 signature: Signature,
205 aliases: Vec<String>,
206}
207
208impl std::fmt::Debug for Count {
209 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
210 f.debug_struct("Count")
211 .field("name", &self.name())
212 .field("signature", &self.signature)
213 .finish()
214 }
215}
216
217impl Default for Count {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223impl Count {
224 pub fn new() -> Self {
225 Self {
226 aliases: vec!["count".to_string()],
227 signature: Signature::variadic_any(Immutable),
228 }
229 }
230}
231
232impl AggregateUDFImpl for Count {
233 fn as_any(&self) -> &dyn Any {
234 self
235 }
236
237 fn name(&self) -> &str {
238 "COUNT"
239 }
240
241 fn signature(&self) -> &Signature {
242 &self.signature
243 }
244
245 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
246 Ok(DataType::Int64)
247 }
248
249 fn is_nullable(&self) -> bool {
250 false
251 }
252
253 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
254 not_impl_err!("no impl for stub")
255 }
256
257 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
258 not_impl_err!("no impl for stub")
259 }
260
261 fn aliases(&self) -> &[String] {
262 &self.aliases
263 }
264
265 fn create_groups_accumulator(
266 &self,
267 _args: AccumulatorArgs,
268 ) -> Result<Box<dyn GroupsAccumulator>> {
269 not_impl_err!("no impl for stub")
270 }
271
272 fn reverse_expr(&self) -> ReversedUDAF {
273 ReversedUDAF::Identical
274 }
275}
276
277create_func!(Min, min_udaf);
278
279pub fn min(expr: Expr) -> Expr {
280 Expr::AggregateFunction(AggregateFunction::new_udf(
281 min_udaf(),
282 vec![expr],
283 false,
284 None,
285 vec![],
286 None,
287 ))
288}
289
290pub struct Min {
292 signature: Signature,
293}
294
295impl std::fmt::Debug for Min {
296 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
297 f.debug_struct("Min")
298 .field("name", &self.name())
299 .field("signature", &self.signature)
300 .finish()
301 }
302}
303
304impl Default for Min {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310impl Min {
311 pub fn new() -> Self {
312 Self {
313 signature: Signature::variadic_any(Immutable),
314 }
315 }
316}
317
318impl AggregateUDFImpl for Min {
319 fn as_any(&self) -> &dyn Any {
320 self
321 }
322
323 fn name(&self) -> &str {
324 "min"
325 }
326
327 fn signature(&self) -> &Signature {
328 &self.signature
329 }
330
331 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
332 Ok(DataType::Int64)
333 }
334
335 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
336 not_impl_err!("no impl for stub")
337 }
338
339 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
340 not_impl_err!("no impl for stub")
341 }
342
343 fn create_groups_accumulator(
344 &self,
345 _args: AccumulatorArgs,
346 ) -> Result<Box<dyn GroupsAccumulator>> {
347 not_impl_err!("no impl for stub")
348 }
349
350 fn reverse_expr(&self) -> ReversedUDAF {
351 ReversedUDAF::Identical
352 }
353 fn is_descending(&self) -> Option<bool> {
354 Some(false)
355 }
356}
357
358create_func!(Max, max_udaf);
359
360pub fn max(expr: Expr) -> Expr {
361 Expr::AggregateFunction(AggregateFunction::new_udf(
362 max_udaf(),
363 vec![expr],
364 false,
365 None,
366 vec![],
367 None,
368 ))
369}
370
371pub struct Max {
373 signature: Signature,
374}
375
376impl std::fmt::Debug for Max {
377 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
378 f.debug_struct("Max")
379 .field("name", &self.name())
380 .field("signature", &self.signature)
381 .finish()
382 }
383}
384
385impl Default for Max {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl Max {
392 pub fn new() -> Self {
393 Self {
394 signature: Signature::variadic_any(Immutable),
395 }
396 }
397}
398
399impl AggregateUDFImpl for Max {
400 fn as_any(&self) -> &dyn Any {
401 self
402 }
403
404 fn name(&self) -> &str {
405 "max"
406 }
407
408 fn signature(&self) -> &Signature {
409 &self.signature
410 }
411
412 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
413 Ok(DataType::Int64)
414 }
415
416 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
417 not_impl_err!("no impl for stub")
418 }
419
420 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
421 not_impl_err!("no impl for stub")
422 }
423
424 fn create_groups_accumulator(
425 &self,
426 _args: AccumulatorArgs,
427 ) -> Result<Box<dyn GroupsAccumulator>> {
428 not_impl_err!("no impl for stub")
429 }
430
431 fn reverse_expr(&self) -> ReversedUDAF {
432 ReversedUDAF::Identical
433 }
434 fn is_descending(&self) -> Option<bool> {
435 Some(true)
436 }
437}
438
439#[derive(Debug)]
441pub struct Avg {
442 signature: Signature,
443 aliases: Vec<String>,
444}
445
446impl Avg {
447 pub fn new() -> Self {
448 Self {
449 aliases: vec![String::from("mean")],
450 signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
451 }
452 }
453}
454
455impl Default for Avg {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461impl AggregateUDFImpl for Avg {
462 fn as_any(&self) -> &dyn Any {
463 self
464 }
465
466 fn name(&self) -> &str {
467 "avg"
468 }
469
470 fn signature(&self) -> &Signature {
471 &self.signature
472 }
473
474 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
475 avg_return_type(self.name(), &arg_types[0])
476 }
477
478 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
479 not_impl_err!("no impl for stub")
480 }
481
482 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
483 not_impl_err!("no impl for stub")
484 }
485
486 fn aliases(&self) -> &[String] {
487 &self.aliases
488 }
489
490 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
491 coerce_avg_type(self.name(), arg_types)
492 }
493}