1use arrow::datatypes::{
23 DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION,
24 DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
25 DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, FieldRef,
26};
27
28use datafusion_common::plan_err;
29use datafusion_common::{Result, exec_err, not_impl_err, utils::take_function_args};
30
31use crate::Volatility::Immutable;
32use crate::{
33 Accumulator, AggregateUDFImpl, Coercion, Expr, GroupsAccumulator, ReversedUDAF,
34 Signature, TypeSignature, TypeSignatureClass,
35 expr::AggregateFunction,
36 function::{AccumulatorArgs, StateFieldsArgs},
37 utils::AggregateOrderSensitivity,
38};
39use datafusion_common::types::{NativeType, logical_float64};
40
41macro_rules! create_func {
42 ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
43 #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
44 pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
45 static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
47 std::sync::LazyLock::new(|| {
48 std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
49 });
50 std::sync::Arc::clone(&INSTANCE)
51 }
52 }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58 Expr::AggregateFunction(AggregateFunction::new_udf(
59 sum_udaf(),
60 vec![expr],
61 false,
62 None,
63 vec![],
64 None,
65 ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71 Expr::AggregateFunction(AggregateFunction::new_udf(
72 count_udaf(),
73 vec![expr],
74 false,
75 None,
76 vec![],
77 None,
78 ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84 Expr::AggregateFunction(AggregateFunction::new_udf(
85 avg_udaf(),
86 vec![expr],
87 false,
88 None,
89 vec![],
90 None,
91 ))
92}
93
94#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Sum {
97 signature: Signature,
98}
99
100impl Sum {
101 pub fn new() -> Self {
102 Self {
103 signature: Signature::user_defined(Immutable),
104 }
105 }
106}
107
108impl Default for Sum {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl AggregateUDFImpl for Sum {
115 fn name(&self) -> &str {
116 "sum"
117 }
118
119 fn signature(&self) -> &Signature {
120 &self.signature
121 }
122
123 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
124 let [array] = take_function_args(self.name(), arg_types)?;
125
126 fn coerced_type(data_type: &DataType) -> Result<DataType> {
130 match data_type {
131 DataType::Dictionary(_, v) => coerced_type(v),
132 DataType::Decimal32(_, _)
135 | DataType::Decimal64(_, _)
136 | DataType::Decimal128(_, _)
137 | DataType::Decimal256(_, _) => Ok(data_type.clone()),
138 dt if dt.is_signed_integer() => Ok(DataType::Int64),
139 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
140 dt if dt.is_floating() => Ok(DataType::Float64),
141 _ => exec_err!("Sum not supported for {data_type}"),
142 }
143 }
144
145 Ok(vec![coerced_type(array)?])
146 }
147
148 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
149 match &arg_types[0] {
150 DataType::Int64 => Ok(DataType::Int64),
151 DataType::UInt64 => Ok(DataType::UInt64),
152 DataType::Float64 => Ok(DataType::Float64),
153 DataType::Decimal32(precision, scale) => {
154 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
157 Ok(DataType::Decimal32(new_precision, *scale))
158 }
159 DataType::Decimal64(precision, scale) => {
160 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
163 Ok(DataType::Decimal64(new_precision, *scale))
164 }
165 DataType::Decimal128(precision, scale) => {
166 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
169 Ok(DataType::Decimal128(new_precision, *scale))
170 }
171 DataType::Decimal256(precision, scale) => {
172 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
175 Ok(DataType::Decimal256(new_precision, *scale))
176 }
177 other => {
178 exec_err!("[return_type] SUM not supported for {}", other)
179 }
180 }
181 }
182
183 fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
184 unreachable!("stub should not have accumulate()")
185 }
186
187 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
188 unreachable!("stub should not have state_fields()")
189 }
190
191 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
192 false
193 }
194
195 fn create_groups_accumulator(
196 &self,
197 _args: AccumulatorArgs,
198 ) -> Result<Box<dyn GroupsAccumulator>> {
199 unreachable!("stub should not have accumulate()")
200 }
201
202 fn reverse_expr(&self) -> ReversedUDAF {
203 ReversedUDAF::Identical
204 }
205
206 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
207 AggregateOrderSensitivity::Insensitive
208 }
209}
210
211#[derive(PartialEq, Eq, Hash)]
213pub struct Count {
214 signature: Signature,
215 aliases: Vec<String>,
216}
217
218impl std::fmt::Debug for Count {
219 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
220 f.debug_struct("Count")
221 .field("name", &self.name())
222 .field("signature", &self.signature)
223 .finish()
224 }
225}
226
227impl Default for Count {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl Count {
234 pub fn new() -> Self {
235 Self {
236 aliases: vec!["count".to_string()],
237 signature: Signature::variadic_any(Immutable),
238 }
239 }
240}
241
242impl AggregateUDFImpl for Count {
243 fn name(&self) -> &str {
244 "COUNT"
245 }
246
247 fn signature(&self) -> &Signature {
248 &self.signature
249 }
250
251 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
252 Ok(DataType::Int64)
253 }
254
255 fn is_nullable(&self) -> bool {
256 false
257 }
258
259 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260 not_impl_err!("no impl for stub")
261 }
262
263 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264 not_impl_err!("no impl for stub")
265 }
266
267 fn aliases(&self) -> &[String] {
268 &self.aliases
269 }
270
271 fn create_groups_accumulator(
272 &self,
273 _args: AccumulatorArgs,
274 ) -> Result<Box<dyn GroupsAccumulator>> {
275 not_impl_err!("no impl for stub")
276 }
277
278 fn reverse_expr(&self) -> ReversedUDAF {
279 ReversedUDAF::Identical
280 }
281}
282
283create_func!(Min, min_udaf);
284
285pub fn min(expr: Expr) -> Expr {
286 Expr::AggregateFunction(AggregateFunction::new_udf(
287 min_udaf(),
288 vec![expr],
289 false,
290 None,
291 vec![],
292 None,
293 ))
294}
295
296#[derive(PartialEq, Eq, Hash)]
298pub struct Min {
299 signature: Signature,
300}
301
302impl std::fmt::Debug for Min {
303 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
304 f.debug_struct("Min")
305 .field("name", &self.name())
306 .field("signature", &self.signature)
307 .finish()
308 }
309}
310
311impl Default for Min {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317impl Min {
318 pub fn new() -> Self {
319 Self {
320 signature: Signature::variadic_any(Immutable),
321 }
322 }
323}
324
325impl AggregateUDFImpl for Min {
326 fn name(&self) -> &str {
327 "min"
328 }
329
330 fn signature(&self) -> &Signature {
331 &self.signature
332 }
333
334 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
335 Ok(DataType::Int64)
336 }
337
338 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
339 not_impl_err!("no impl for stub")
340 }
341
342 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
343 not_impl_err!("no impl for stub")
344 }
345
346 fn create_groups_accumulator(
347 &self,
348 _args: AccumulatorArgs,
349 ) -> Result<Box<dyn GroupsAccumulator>> {
350 not_impl_err!("no impl for stub")
351 }
352
353 fn reverse_expr(&self) -> ReversedUDAF {
354 ReversedUDAF::Identical
355 }
356 fn is_descending(&self) -> Option<bool> {
357 Some(false)
358 }
359}
360
361create_func!(Max, max_udaf);
362
363pub fn max(expr: Expr) -> Expr {
364 Expr::AggregateFunction(AggregateFunction::new_udf(
365 max_udaf(),
366 vec![expr],
367 false,
368 None,
369 vec![],
370 None,
371 ))
372}
373
374#[derive(PartialEq, Eq, Hash)]
376pub struct Max {
377 signature: Signature,
378}
379
380impl std::fmt::Debug for Max {
381 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
382 f.debug_struct("Max")
383 .field("name", &self.name())
384 .field("signature", &self.signature)
385 .finish()
386 }
387}
388
389impl Default for Max {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395impl Max {
396 pub fn new() -> Self {
397 Self {
398 signature: Signature::variadic_any(Immutable),
399 }
400 }
401}
402
403impl AggregateUDFImpl for Max {
404 fn name(&self) -> &str {
405 "max"
406 }
407
408 fn signature(&self) -> &Signature {
409 &self.signature
410 }
411
412 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
413 Ok(DataType::Int64)
414 }
415
416 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
417 not_impl_err!("no impl for stub")
418 }
419
420 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
421 not_impl_err!("no impl for stub")
422 }
423
424 fn create_groups_accumulator(
425 &self,
426 _args: AccumulatorArgs,
427 ) -> Result<Box<dyn GroupsAccumulator>> {
428 not_impl_err!("no impl for stub")
429 }
430
431 fn reverse_expr(&self) -> ReversedUDAF {
432 ReversedUDAF::Identical
433 }
434 fn is_descending(&self) -> Option<bool> {
435 Some(true)
436 }
437}
438
439#[derive(Debug, PartialEq, Eq, Hash)]
441pub struct Avg {
442 signature: Signature,
443 aliases: Vec<String>,
444}
445
446impl Avg {
447 pub fn new() -> Self {
448 let signature = Signature::one_of(
449 vec![
450 TypeSignature::Coercible(vec![Coercion::new_exact(
451 TypeSignatureClass::Decimal,
452 )]),
453 TypeSignature::Coercible(vec![Coercion::new_implicit(
454 TypeSignatureClass::Native(logical_float64()),
455 vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
456 NativeType::Float64,
457 )]),
458 ],
459 Immutable,
460 );
461 Self {
462 aliases: vec![String::from("mean")],
463 signature,
464 }
465 }
466}
467
468impl Default for Avg {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474impl AggregateUDFImpl for Avg {
475 fn name(&self) -> &str {
476 "avg"
477 }
478
479 fn signature(&self) -> &Signature {
480 &self.signature
481 }
482
483 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
484 let [args] = take_function_args(self.name(), arg_types)?;
485
486 fn coerced_type(data_type: &DataType) -> Result<DataType> {
489 match &data_type {
490 DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
491 DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
492 DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
493 DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
494 d if d.is_numeric() => Ok(DataType::Float64),
495 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
496 DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
497 _ => {
498 plan_err!("Avg does not support inputs of type {data_type}.")
499 }
500 }
501 }
502 Ok(vec![coerced_type(args)?])
503 }
504
505 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
506 match &arg_types[0] {
507 DataType::Decimal32(precision, scale) => {
508 let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
511 let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
512 Ok(DataType::Decimal32(new_precision, new_scale))
513 }
514 DataType::Decimal64(precision, scale) => {
515 let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
518 let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
519 Ok(DataType::Decimal64(new_precision, new_scale))
520 }
521 DataType::Decimal128(precision, scale) => {
522 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
525 let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
526 Ok(DataType::Decimal128(new_precision, new_scale))
527 }
528 DataType::Decimal256(precision, scale) => {
529 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
532 let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
533 Ok(DataType::Decimal256(new_precision, new_scale))
534 }
535 DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
536 _ => Ok(DataType::Float64),
537 }
538 }
539
540 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
541 not_impl_err!("no impl for stub")
542 }
543
544 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
545 not_impl_err!("no impl for stub")
546 }
547
548 fn aliases(&self) -> &[String] {
549 &self.aliases
550 }
551}