1use std::any::Any;
23
24use arrow::datatypes::{
25 DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33 expr::AggregateFunction,
34 function::{AccumulatorArgs, StateFieldsArgs},
35 utils::AggregateOrderSensitivity,
36 Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40 ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41 paste::paste! {
42 #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43 pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44 static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46 std::sync::LazyLock::new(|| {
47 std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48 });
49 std::sync::Arc::clone(&INSTANCE)
50 }
51 }
52 }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58 Expr::AggregateFunction(AggregateFunction::new_udf(
59 sum_udaf(),
60 vec![expr],
61 false,
62 None,
63 vec![],
64 None,
65 ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71 Expr::AggregateFunction(AggregateFunction::new_udf(
72 count_udaf(),
73 vec![expr],
74 false,
75 None,
76 vec![],
77 None,
78 ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84 Expr::AggregateFunction(AggregateFunction::new_udf(
85 avg_udaf(),
86 vec![expr],
87 false,
88 None,
89 vec![],
90 None,
91 ))
92}
93
94#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Sum {
97 signature: Signature,
98}
99
100impl Sum {
101 pub fn new() -> Self {
102 Self {
103 signature: Signature::user_defined(Immutable),
104 }
105 }
106}
107
108impl Default for Sum {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl AggregateUDFImpl for Sum {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "sum"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128 let [array] = take_function_args(self.name(), arg_types)?;
129
130 fn coerced_type(data_type: &DataType) -> Result<DataType> {
134 match data_type {
135 DataType::Dictionary(_, v) => coerced_type(v),
136 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139 Ok(data_type.clone())
140 }
141 dt if dt.is_signed_integer() => Ok(DataType::Int64),
142 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143 dt if dt.is_floating() => Ok(DataType::Float64),
144 _ => exec_err!("Sum not supported for {}", data_type),
145 }
146 }
147
148 Ok(vec![coerced_type(array)?])
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 match &arg_types[0] {
153 DataType::Int64 => Ok(DataType::Int64),
154 DataType::UInt64 => Ok(DataType::UInt64),
155 DataType::Float64 => Ok(DataType::Float64),
156 DataType::Decimal128(precision, scale) => {
157 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160 Ok(DataType::Decimal128(new_precision, *scale))
161 }
162 DataType::Decimal256(precision, scale) => {
163 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166 Ok(DataType::Decimal256(new_precision, *scale))
167 }
168 other => {
169 exec_err!("[return_type] SUM not supported for {}", other)
170 }
171 }
172 }
173
174 fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175 unreachable!("stub should not have accumulate()")
176 }
177
178 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
179 unreachable!("stub should not have state_fields()")
180 }
181
182 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
183 false
184 }
185
186 fn create_groups_accumulator(
187 &self,
188 _args: AccumulatorArgs,
189 ) -> Result<Box<dyn GroupsAccumulator>> {
190 unreachable!("stub should not have accumulate()")
191 }
192
193 fn reverse_expr(&self) -> ReversedUDAF {
194 ReversedUDAF::Identical
195 }
196
197 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
198 AggregateOrderSensitivity::Insensitive
199 }
200}
201
202#[derive(PartialEq, Eq, Hash)]
204pub struct Count {
205 signature: Signature,
206 aliases: Vec<String>,
207}
208
209impl std::fmt::Debug for Count {
210 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
211 f.debug_struct("Count")
212 .field("name", &self.name())
213 .field("signature", &self.signature)
214 .finish()
215 }
216}
217
218impl Default for Count {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224impl Count {
225 pub fn new() -> Self {
226 Self {
227 aliases: vec!["count".to_string()],
228 signature: Signature::variadic_any(Immutable),
229 }
230 }
231}
232
233impl AggregateUDFImpl for Count {
234 fn as_any(&self) -> &dyn Any {
235 self
236 }
237
238 fn name(&self) -> &str {
239 "COUNT"
240 }
241
242 fn signature(&self) -> &Signature {
243 &self.signature
244 }
245
246 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
247 Ok(DataType::Int64)
248 }
249
250 fn is_nullable(&self) -> bool {
251 false
252 }
253
254 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
255 not_impl_err!("no impl for stub")
256 }
257
258 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
259 not_impl_err!("no impl for stub")
260 }
261
262 fn aliases(&self) -> &[String] {
263 &self.aliases
264 }
265
266 fn create_groups_accumulator(
267 &self,
268 _args: AccumulatorArgs,
269 ) -> Result<Box<dyn GroupsAccumulator>> {
270 not_impl_err!("no impl for stub")
271 }
272
273 fn reverse_expr(&self) -> ReversedUDAF {
274 ReversedUDAF::Identical
275 }
276}
277
278create_func!(Min, min_udaf);
279
280pub fn min(expr: Expr) -> Expr {
281 Expr::AggregateFunction(AggregateFunction::new_udf(
282 min_udaf(),
283 vec![expr],
284 false,
285 None,
286 vec![],
287 None,
288 ))
289}
290
291#[derive(PartialEq, Eq, Hash)]
293pub struct Min {
294 signature: Signature,
295}
296
297impl std::fmt::Debug for Min {
298 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
299 f.debug_struct("Min")
300 .field("name", &self.name())
301 .field("signature", &self.signature)
302 .finish()
303 }
304}
305
306impl Default for Min {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312impl Min {
313 pub fn new() -> Self {
314 Self {
315 signature: Signature::variadic_any(Immutable),
316 }
317 }
318}
319
320impl AggregateUDFImpl for Min {
321 fn as_any(&self) -> &dyn Any {
322 self
323 }
324
325 fn name(&self) -> &str {
326 "min"
327 }
328
329 fn signature(&self) -> &Signature {
330 &self.signature
331 }
332
333 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
334 Ok(DataType::Int64)
335 }
336
337 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
338 not_impl_err!("no impl for stub")
339 }
340
341 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
342 not_impl_err!("no impl for stub")
343 }
344
345 fn create_groups_accumulator(
346 &self,
347 _args: AccumulatorArgs,
348 ) -> Result<Box<dyn GroupsAccumulator>> {
349 not_impl_err!("no impl for stub")
350 }
351
352 fn reverse_expr(&self) -> ReversedUDAF {
353 ReversedUDAF::Identical
354 }
355 fn is_descending(&self) -> Option<bool> {
356 Some(false)
357 }
358}
359
360create_func!(Max, max_udaf);
361
362pub fn max(expr: Expr) -> Expr {
363 Expr::AggregateFunction(AggregateFunction::new_udf(
364 max_udaf(),
365 vec![expr],
366 false,
367 None,
368 vec![],
369 None,
370 ))
371}
372
373#[derive(PartialEq, Eq, Hash)]
375pub struct Max {
376 signature: Signature,
377}
378
379impl std::fmt::Debug for Max {
380 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
381 f.debug_struct("Max")
382 .field("name", &self.name())
383 .field("signature", &self.signature)
384 .finish()
385 }
386}
387
388impl Default for Max {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394impl Max {
395 pub fn new() -> Self {
396 Self {
397 signature: Signature::variadic_any(Immutable),
398 }
399 }
400}
401
402impl AggregateUDFImpl for Max {
403 fn as_any(&self) -> &dyn Any {
404 self
405 }
406
407 fn name(&self) -> &str {
408 "max"
409 }
410
411 fn signature(&self) -> &Signature {
412 &self.signature
413 }
414
415 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
416 Ok(DataType::Int64)
417 }
418
419 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
420 not_impl_err!("no impl for stub")
421 }
422
423 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
424 not_impl_err!("no impl for stub")
425 }
426
427 fn create_groups_accumulator(
428 &self,
429 _args: AccumulatorArgs,
430 ) -> Result<Box<dyn GroupsAccumulator>> {
431 not_impl_err!("no impl for stub")
432 }
433
434 fn reverse_expr(&self) -> ReversedUDAF {
435 ReversedUDAF::Identical
436 }
437 fn is_descending(&self) -> Option<bool> {
438 Some(true)
439 }
440}
441
442#[derive(Debug, PartialEq, Eq, Hash)]
444pub struct Avg {
445 signature: Signature,
446 aliases: Vec<String>,
447}
448
449impl Avg {
450 pub fn new() -> Self {
451 Self {
452 aliases: vec![String::from("mean")],
453 signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
454 }
455 }
456}
457
458impl Default for Avg {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464impl AggregateUDFImpl for Avg {
465 fn as_any(&self) -> &dyn Any {
466 self
467 }
468
469 fn name(&self) -> &str {
470 "avg"
471 }
472
473 fn signature(&self) -> &Signature {
474 &self.signature
475 }
476
477 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
478 avg_return_type(self.name(), &arg_types[0])
479 }
480
481 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
482 not_impl_err!("no impl for stub")
483 }
484
485 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
486 not_impl_err!("no impl for stub")
487 }
488
489 fn aliases(&self) -> &[String] {
490 &self.aliases
491 }
492
493 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
494 coerce_avg_type(self.name(), arg_types)
495 }
496}