1use std::any::Any;
23
24use arrow::datatypes::{
25 DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
26 DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
27 DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
28};
29
30use datafusion_common::plan_err;
31use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
32
33use crate::type_coercion::aggregates::NUMERICS;
34use crate::Volatility::Immutable;
35use crate::{
36 expr::AggregateFunction,
37 function::{AccumulatorArgs, StateFieldsArgs},
38 utils::AggregateOrderSensitivity,
39 Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
40};
41
42macro_rules! create_func {
43 ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
44 paste::paste! {
45 #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
46 pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
47 static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
49 std::sync::LazyLock::new(|| {
50 std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
51 });
52 std::sync::Arc::clone(&INSTANCE)
53 }
54 }
55 }
56}
57
58create_func!(Sum, sum_udaf);
59
60pub fn sum(expr: Expr) -> Expr {
61 Expr::AggregateFunction(AggregateFunction::new_udf(
62 sum_udaf(),
63 vec![expr],
64 false,
65 None,
66 vec![],
67 None,
68 ))
69}
70
71create_func!(Count, count_udaf);
72
73pub fn count(expr: Expr) -> Expr {
74 Expr::AggregateFunction(AggregateFunction::new_udf(
75 count_udaf(),
76 vec![expr],
77 false,
78 None,
79 vec![],
80 None,
81 ))
82}
83
84create_func!(Avg, avg_udaf);
85
86pub fn avg(expr: Expr) -> Expr {
87 Expr::AggregateFunction(AggregateFunction::new_udf(
88 avg_udaf(),
89 vec![expr],
90 false,
91 None,
92 vec![],
93 None,
94 ))
95}
96
97#[derive(Debug, PartialEq, Eq, Hash)]
99pub struct Sum {
100 signature: Signature,
101}
102
103impl Sum {
104 pub fn new() -> Self {
105 Self {
106 signature: Signature::user_defined(Immutable),
107 }
108 }
109}
110
111impl Default for Sum {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl AggregateUDFImpl for Sum {
118 fn as_any(&self) -> &dyn Any {
119 self
120 }
121
122 fn name(&self) -> &str {
123 "sum"
124 }
125
126 fn signature(&self) -> &Signature {
127 &self.signature
128 }
129
130 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131 let [array] = take_function_args(self.name(), arg_types)?;
132
133 fn coerced_type(data_type: &DataType) -> Result<DataType> {
137 match data_type {
138 DataType::Dictionary(_, v) => coerced_type(v),
139 DataType::Decimal32(_, _)
142 | DataType::Decimal64(_, _)
143 | DataType::Decimal128(_, _)
144 | DataType::Decimal256(_, _) => Ok(data_type.clone()),
145 dt if dt.is_signed_integer() => Ok(DataType::Int64),
146 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
147 dt if dt.is_floating() => Ok(DataType::Float64),
148 _ => exec_err!("Sum not supported for {data_type}"),
149 }
150 }
151
152 Ok(vec![coerced_type(array)?])
153 }
154
155 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
156 match &arg_types[0] {
157 DataType::Int64 => Ok(DataType::Int64),
158 DataType::UInt64 => Ok(DataType::UInt64),
159 DataType::Float64 => Ok(DataType::Float64),
160 DataType::Decimal32(precision, scale) => {
161 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
164 Ok(DataType::Decimal32(new_precision, *scale))
165 }
166 DataType::Decimal64(precision, scale) => {
167 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
170 Ok(DataType::Decimal64(new_precision, *scale))
171 }
172 DataType::Decimal128(precision, scale) => {
173 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
176 Ok(DataType::Decimal128(new_precision, *scale))
177 }
178 DataType::Decimal256(precision, scale) => {
179 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
182 Ok(DataType::Decimal256(new_precision, *scale))
183 }
184 other => {
185 exec_err!("[return_type] SUM not supported for {}", other)
186 }
187 }
188 }
189
190 fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
191 unreachable!("stub should not have accumulate()")
192 }
193
194 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
195 unreachable!("stub should not have state_fields()")
196 }
197
198 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
199 false
200 }
201
202 fn create_groups_accumulator(
203 &self,
204 _args: AccumulatorArgs,
205 ) -> Result<Box<dyn GroupsAccumulator>> {
206 unreachable!("stub should not have accumulate()")
207 }
208
209 fn reverse_expr(&self) -> ReversedUDAF {
210 ReversedUDAF::Identical
211 }
212
213 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
214 AggregateOrderSensitivity::Insensitive
215 }
216}
217
218#[derive(PartialEq, Eq, Hash)]
220pub struct Count {
221 signature: Signature,
222 aliases: Vec<String>,
223}
224
225impl std::fmt::Debug for Count {
226 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
227 f.debug_struct("Count")
228 .field("name", &self.name())
229 .field("signature", &self.signature)
230 .finish()
231 }
232}
233
234impl Default for Count {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240impl Count {
241 pub fn new() -> Self {
242 Self {
243 aliases: vec!["count".to_string()],
244 signature: Signature::variadic_any(Immutable),
245 }
246 }
247}
248
249impl AggregateUDFImpl for Count {
250 fn as_any(&self) -> &dyn Any {
251 self
252 }
253
254 fn name(&self) -> &str {
255 "COUNT"
256 }
257
258 fn signature(&self) -> &Signature {
259 &self.signature
260 }
261
262 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
263 Ok(DataType::Int64)
264 }
265
266 fn is_nullable(&self) -> bool {
267 false
268 }
269
270 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
271 not_impl_err!("no impl for stub")
272 }
273
274 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
275 not_impl_err!("no impl for stub")
276 }
277
278 fn aliases(&self) -> &[String] {
279 &self.aliases
280 }
281
282 fn create_groups_accumulator(
283 &self,
284 _args: AccumulatorArgs,
285 ) -> Result<Box<dyn GroupsAccumulator>> {
286 not_impl_err!("no impl for stub")
287 }
288
289 fn reverse_expr(&self) -> ReversedUDAF {
290 ReversedUDAF::Identical
291 }
292}
293
294create_func!(Min, min_udaf);
295
296pub fn min(expr: Expr) -> Expr {
297 Expr::AggregateFunction(AggregateFunction::new_udf(
298 min_udaf(),
299 vec![expr],
300 false,
301 None,
302 vec![],
303 None,
304 ))
305}
306
307#[derive(PartialEq, Eq, Hash)]
309pub struct Min {
310 signature: Signature,
311}
312
313impl std::fmt::Debug for Min {
314 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
315 f.debug_struct("Min")
316 .field("name", &self.name())
317 .field("signature", &self.signature)
318 .finish()
319 }
320}
321
322impl Default for Min {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl Min {
329 pub fn new() -> Self {
330 Self {
331 signature: Signature::variadic_any(Immutable),
332 }
333 }
334}
335
336impl AggregateUDFImpl for Min {
337 fn as_any(&self) -> &dyn Any {
338 self
339 }
340
341 fn name(&self) -> &str {
342 "min"
343 }
344
345 fn signature(&self) -> &Signature {
346 &self.signature
347 }
348
349 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
350 Ok(DataType::Int64)
351 }
352
353 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
354 not_impl_err!("no impl for stub")
355 }
356
357 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
358 not_impl_err!("no impl for stub")
359 }
360
361 fn create_groups_accumulator(
362 &self,
363 _args: AccumulatorArgs,
364 ) -> Result<Box<dyn GroupsAccumulator>> {
365 not_impl_err!("no impl for stub")
366 }
367
368 fn reverse_expr(&self) -> ReversedUDAF {
369 ReversedUDAF::Identical
370 }
371 fn is_descending(&self) -> Option<bool> {
372 Some(false)
373 }
374}
375
376create_func!(Max, max_udaf);
377
378pub fn max(expr: Expr) -> Expr {
379 Expr::AggregateFunction(AggregateFunction::new_udf(
380 max_udaf(),
381 vec![expr],
382 false,
383 None,
384 vec![],
385 None,
386 ))
387}
388
389#[derive(PartialEq, Eq, Hash)]
391pub struct Max {
392 signature: Signature,
393}
394
395impl std::fmt::Debug for Max {
396 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
397 f.debug_struct("Max")
398 .field("name", &self.name())
399 .field("signature", &self.signature)
400 .finish()
401 }
402}
403
404impl Default for Max {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl Max {
411 pub fn new() -> Self {
412 Self {
413 signature: Signature::variadic_any(Immutable),
414 }
415 }
416}
417
418impl AggregateUDFImpl for Max {
419 fn as_any(&self) -> &dyn Any {
420 self
421 }
422
423 fn name(&self) -> &str {
424 "max"
425 }
426
427 fn signature(&self) -> &Signature {
428 &self.signature
429 }
430
431 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
432 Ok(DataType::Int64)
433 }
434
435 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
436 not_impl_err!("no impl for stub")
437 }
438
439 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
440 not_impl_err!("no impl for stub")
441 }
442
443 fn create_groups_accumulator(
444 &self,
445 _args: AccumulatorArgs,
446 ) -> Result<Box<dyn GroupsAccumulator>> {
447 not_impl_err!("no impl for stub")
448 }
449
450 fn reverse_expr(&self) -> ReversedUDAF {
451 ReversedUDAF::Identical
452 }
453 fn is_descending(&self) -> Option<bool> {
454 Some(true)
455 }
456}
457
458#[derive(Debug, PartialEq, Eq, Hash)]
460pub struct Avg {
461 signature: Signature,
462 aliases: Vec<String>,
463}
464
465impl Avg {
466 pub fn new() -> Self {
467 Self {
468 aliases: vec![String::from("mean")],
469 signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
470 }
471 }
472}
473
474impl Default for Avg {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl AggregateUDFImpl for Avg {
481 fn as_any(&self) -> &dyn Any {
482 self
483 }
484
485 fn name(&self) -> &str {
486 "avg"
487 }
488
489 fn signature(&self) -> &Signature {
490 &self.signature
491 }
492
493 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
494 let [args] = take_function_args(self.name(), arg_types)?;
495
496 fn coerced_type(data_type: &DataType) -> Result<DataType> {
499 match &data_type {
500 DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
501 DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
502 DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
503 DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
504 d if d.is_numeric() => Ok(DataType::Float64),
505 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
506 DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
507 _ => {
508 plan_err!("Avg does not support inputs of type {data_type}.")
509 }
510 }
511 }
512 Ok(vec![coerced_type(args)?])
513 }
514
515 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
516 match &arg_types[0] {
517 DataType::Decimal32(precision, scale) => {
518 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
521 let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
522 Ok(DataType::Decimal32(new_precision, new_scale))
523 }
524 DataType::Decimal64(precision, scale) => {
525 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
528 let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
529 Ok(DataType::Decimal64(new_precision, new_scale))
530 }
531 DataType::Decimal128(precision, scale) => {
532 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
535 let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
536 Ok(DataType::Decimal128(new_precision, new_scale))
537 }
538 DataType::Decimal256(precision, scale) => {
539 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
542 let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
543 Ok(DataType::Decimal256(new_precision, new_scale))
544 }
545 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
546 _ => Ok(DataType::Float64),
547 }
548 }
549
550 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
551 not_impl_err!("no impl for stub")
552 }
553
554 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
555 not_impl_err!("no impl for stub")
556 }
557
558 fn aliases(&self) -> &[String] {
559 &self.aliases
560 }
561}