datafusion_expr/test/
function_stub.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate function stubs for test in expr / optimizer.
19//!
20//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22use std::any::Any;
23
24use arrow::datatypes::{
25    DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
26    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
27    DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
28};
29
30use datafusion_common::plan_err;
31use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
32
33use crate::type_coercion::aggregates::NUMERICS;
34use crate::Volatility::Immutable;
35use crate::{
36    expr::AggregateFunction,
37    function::{AccumulatorArgs, StateFieldsArgs},
38    utils::AggregateOrderSensitivity,
39    Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
40};
41
42macro_rules! create_func {
43    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
44        paste::paste! {
45            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
46            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
47                // Singleton instance of [$UDAF], ensures the UDAF is only created once
48                static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
49                    std::sync::LazyLock::new(|| {
50                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
51                    });
52                std::sync::Arc::clone(&INSTANCE)
53            }
54        }
55    }
56}
57
58create_func!(Sum, sum_udaf);
59
60pub fn sum(expr: Expr) -> Expr {
61    Expr::AggregateFunction(AggregateFunction::new_udf(
62        sum_udaf(),
63        vec![expr],
64        false,
65        None,
66        vec![],
67        None,
68    ))
69}
70
71create_func!(Count, count_udaf);
72
73pub fn count(expr: Expr) -> Expr {
74    Expr::AggregateFunction(AggregateFunction::new_udf(
75        count_udaf(),
76        vec![expr],
77        false,
78        None,
79        vec![],
80        None,
81    ))
82}
83
84create_func!(Avg, avg_udaf);
85
86pub fn avg(expr: Expr) -> Expr {
87    Expr::AggregateFunction(AggregateFunction::new_udf(
88        avg_udaf(),
89        vec![expr],
90        false,
91        None,
92        vec![],
93        None,
94    ))
95}
96
97/// Stub `sum` used for optimizer testing
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub struct Sum {
100    signature: Signature,
101}
102
103impl Sum {
104    pub fn new() -> Self {
105        Self {
106            signature: Signature::user_defined(Immutable),
107        }
108    }
109}
110
111impl Default for Sum {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl AggregateUDFImpl for Sum {
118    fn as_any(&self) -> &dyn Any {
119        self
120    }
121
122    fn name(&self) -> &str {
123        "sum"
124    }
125
126    fn signature(&self) -> &Signature {
127        &self.signature
128    }
129
130    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131        let [array] = take_function_args(self.name(), arg_types)?;
132
133        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
134        // smallint, int, bigint, real, double precision, decimal, or interval.
135
136        fn coerced_type(data_type: &DataType) -> Result<DataType> {
137            match data_type {
138                DataType::Dictionary(_, v) => coerced_type(v),
139                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
140                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
141                DataType::Decimal32(_, _)
142                | DataType::Decimal64(_, _)
143                | DataType::Decimal128(_, _)
144                | DataType::Decimal256(_, _) => Ok(data_type.clone()),
145                dt if dt.is_signed_integer() => Ok(DataType::Int64),
146                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
147                dt if dt.is_floating() => Ok(DataType::Float64),
148                _ => exec_err!("Sum not supported for {data_type}"),
149            }
150        }
151
152        Ok(vec![coerced_type(array)?])
153    }
154
155    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
156        match &arg_types[0] {
157            DataType::Int64 => Ok(DataType::Int64),
158            DataType::UInt64 => Ok(DataType::UInt64),
159            DataType::Float64 => Ok(DataType::Float64),
160            DataType::Decimal32(precision, scale) => {
161                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
162                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
164                Ok(DataType::Decimal32(new_precision, *scale))
165            }
166            DataType::Decimal64(precision, scale) => {
167                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
168                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
169                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
170                Ok(DataType::Decimal64(new_precision, *scale))
171            }
172            DataType::Decimal128(precision, scale) => {
173                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
174                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
175                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
176                Ok(DataType::Decimal128(new_precision, *scale))
177            }
178            DataType::Decimal256(precision, scale) => {
179                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
180                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
181                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
182                Ok(DataType::Decimal256(new_precision, *scale))
183            }
184            other => {
185                exec_err!("[return_type] SUM not supported for {}", other)
186            }
187        }
188    }
189
190    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
191        unreachable!("stub should not have accumulate()")
192    }
193
194    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
195        unreachable!("stub should not have state_fields()")
196    }
197
198    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
199        false
200    }
201
202    fn create_groups_accumulator(
203        &self,
204        _args: AccumulatorArgs,
205    ) -> Result<Box<dyn GroupsAccumulator>> {
206        unreachable!("stub should not have accumulate()")
207    }
208
209    fn reverse_expr(&self) -> ReversedUDAF {
210        ReversedUDAF::Identical
211    }
212
213    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
214        AggregateOrderSensitivity::Insensitive
215    }
216}
217
218/// Testing stub implementation of COUNT aggregate
219#[derive(PartialEq, Eq, Hash)]
220pub struct Count {
221    signature: Signature,
222    aliases: Vec<String>,
223}
224
225impl std::fmt::Debug for Count {
226    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
227        f.debug_struct("Count")
228            .field("name", &self.name())
229            .field("signature", &self.signature)
230            .finish()
231    }
232}
233
234impl Default for Count {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240impl Count {
241    pub fn new() -> Self {
242        Self {
243            aliases: vec!["count".to_string()],
244            signature: Signature::variadic_any(Immutable),
245        }
246    }
247}
248
249impl AggregateUDFImpl for Count {
250    fn as_any(&self) -> &dyn Any {
251        self
252    }
253
254    fn name(&self) -> &str {
255        "COUNT"
256    }
257
258    fn signature(&self) -> &Signature {
259        &self.signature
260    }
261
262    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
263        Ok(DataType::Int64)
264    }
265
266    fn is_nullable(&self) -> bool {
267        false
268    }
269
270    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
271        not_impl_err!("no impl for stub")
272    }
273
274    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
275        not_impl_err!("no impl for stub")
276    }
277
278    fn aliases(&self) -> &[String] {
279        &self.aliases
280    }
281
282    fn create_groups_accumulator(
283        &self,
284        _args: AccumulatorArgs,
285    ) -> Result<Box<dyn GroupsAccumulator>> {
286        not_impl_err!("no impl for stub")
287    }
288
289    fn reverse_expr(&self) -> ReversedUDAF {
290        ReversedUDAF::Identical
291    }
292}
293
294create_func!(Min, min_udaf);
295
296pub fn min(expr: Expr) -> Expr {
297    Expr::AggregateFunction(AggregateFunction::new_udf(
298        min_udaf(),
299        vec![expr],
300        false,
301        None,
302        vec![],
303        None,
304    ))
305}
306
307/// Testing stub implementation of Min aggregate
308#[derive(PartialEq, Eq, Hash)]
309pub struct Min {
310    signature: Signature,
311}
312
313impl std::fmt::Debug for Min {
314    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
315        f.debug_struct("Min")
316            .field("name", &self.name())
317            .field("signature", &self.signature)
318            .finish()
319    }
320}
321
322impl Default for Min {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl Min {
329    pub fn new() -> Self {
330        Self {
331            signature: Signature::variadic_any(Immutable),
332        }
333    }
334}
335
336impl AggregateUDFImpl for Min {
337    fn as_any(&self) -> &dyn Any {
338        self
339    }
340
341    fn name(&self) -> &str {
342        "min"
343    }
344
345    fn signature(&self) -> &Signature {
346        &self.signature
347    }
348
349    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
350        Ok(DataType::Int64)
351    }
352
353    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
354        not_impl_err!("no impl for stub")
355    }
356
357    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
358        not_impl_err!("no impl for stub")
359    }
360
361    fn create_groups_accumulator(
362        &self,
363        _args: AccumulatorArgs,
364    ) -> Result<Box<dyn GroupsAccumulator>> {
365        not_impl_err!("no impl for stub")
366    }
367
368    fn reverse_expr(&self) -> ReversedUDAF {
369        ReversedUDAF::Identical
370    }
371    fn is_descending(&self) -> Option<bool> {
372        Some(false)
373    }
374}
375
376create_func!(Max, max_udaf);
377
378pub fn max(expr: Expr) -> Expr {
379    Expr::AggregateFunction(AggregateFunction::new_udf(
380        max_udaf(),
381        vec![expr],
382        false,
383        None,
384        vec![],
385        None,
386    ))
387}
388
389/// Testing stub implementation of MAX aggregate
390#[derive(PartialEq, Eq, Hash)]
391pub struct Max {
392    signature: Signature,
393}
394
395impl std::fmt::Debug for Max {
396    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
397        f.debug_struct("Max")
398            .field("name", &self.name())
399            .field("signature", &self.signature)
400            .finish()
401    }
402}
403
404impl Default for Max {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410impl Max {
411    pub fn new() -> Self {
412        Self {
413            signature: Signature::variadic_any(Immutable),
414        }
415    }
416}
417
418impl AggregateUDFImpl for Max {
419    fn as_any(&self) -> &dyn Any {
420        self
421    }
422
423    fn name(&self) -> &str {
424        "max"
425    }
426
427    fn signature(&self) -> &Signature {
428        &self.signature
429    }
430
431    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
432        Ok(DataType::Int64)
433    }
434
435    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
436        not_impl_err!("no impl for stub")
437    }
438
439    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
440        not_impl_err!("no impl for stub")
441    }
442
443    fn create_groups_accumulator(
444        &self,
445        _args: AccumulatorArgs,
446    ) -> Result<Box<dyn GroupsAccumulator>> {
447        not_impl_err!("no impl for stub")
448    }
449
450    fn reverse_expr(&self) -> ReversedUDAF {
451        ReversedUDAF::Identical
452    }
453    fn is_descending(&self) -> Option<bool> {
454        Some(true)
455    }
456}
457
458/// Testing stub implementation of avg aggregate
459#[derive(Debug, PartialEq, Eq, Hash)]
460pub struct Avg {
461    signature: Signature,
462    aliases: Vec<String>,
463}
464
465impl Avg {
466    pub fn new() -> Self {
467        Self {
468            aliases: vec![String::from("mean")],
469            signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
470        }
471    }
472}
473
474impl Default for Avg {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480impl AggregateUDFImpl for Avg {
481    fn as_any(&self) -> &dyn Any {
482        self
483    }
484
485    fn name(&self) -> &str {
486        "avg"
487    }
488
489    fn signature(&self) -> &Signature {
490        &self.signature
491    }
492
493    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
494        let [args] = take_function_args(self.name(), arg_types)?;
495
496        // Supported types smallint, int, bigint, real, double precision, decimal, or interval
497        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
498        fn coerced_type(data_type: &DataType) -> Result<DataType> {
499            match &data_type {
500                DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
501                DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
502                DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
503                DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
504                d if d.is_numeric() => Ok(DataType::Float64),
505                DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
506                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
507                _ => {
508                    plan_err!("Avg does not support inputs of type {data_type}.")
509                }
510            }
511        }
512        Ok(vec![coerced_type(args)?])
513    }
514
515    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
516        match &arg_types[0] {
517            DataType::Decimal32(precision, scale) => {
518                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
519                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
520                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
521                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
522                Ok(DataType::Decimal32(new_precision, new_scale))
523            }
524            DataType::Decimal64(precision, scale) => {
525                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
526                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
527                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
528                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
529                Ok(DataType::Decimal64(new_precision, new_scale))
530            }
531            DataType::Decimal128(precision, scale) => {
532                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
533                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
534                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
535                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
536                Ok(DataType::Decimal128(new_precision, new_scale))
537            }
538            DataType::Decimal256(precision, scale) => {
539                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
540                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
541                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
542                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
543                Ok(DataType::Decimal256(new_precision, new_scale))
544            }
545            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
546            _ => Ok(DataType::Float64),
547        }
548    }
549
550    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
551        not_impl_err!("no impl for stub")
552    }
553
554    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
555        not_impl_err!("no impl for stub")
556    }
557
558    fn aliases(&self) -> &[String] {
559        &self.aliases
560    }
561}