Skip to main content

datafusion_expr/test/
function_stub.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate function stubs for test in expr / optimizer.
19//!
20//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22use arrow::datatypes::{
23    DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION,
24    DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
25    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, FieldRef,
26};
27
28use datafusion_common::plan_err;
29use datafusion_common::{Result, exec_err, not_impl_err, utils::take_function_args};
30
31use crate::Volatility::Immutable;
32use crate::{
33    Accumulator, AggregateUDFImpl, Coercion, Expr, GroupsAccumulator, ReversedUDAF,
34    Signature, TypeSignature, TypeSignatureClass,
35    expr::AggregateFunction,
36    function::{AccumulatorArgs, StateFieldsArgs},
37    utils::AggregateOrderSensitivity,
38};
39use datafusion_common::types::{NativeType, logical_float64};
40
41macro_rules! create_func {
42    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
43            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
44            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
45                // Singleton instance of [$UDAF], ensures the UDAF is only created once
46                static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
47                    std::sync::LazyLock::new(|| {
48                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
49                    });
50                std::sync::Arc::clone(&INSTANCE)
51            }
52    }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58    Expr::AggregateFunction(AggregateFunction::new_udf(
59        sum_udaf(),
60        vec![expr],
61        false,
62        None,
63        vec![],
64        None,
65    ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71    Expr::AggregateFunction(AggregateFunction::new_udf(
72        count_udaf(),
73        vec![expr],
74        false,
75        None,
76        vec![],
77        None,
78    ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84    Expr::AggregateFunction(AggregateFunction::new_udf(
85        avg_udaf(),
86        vec![expr],
87        false,
88        None,
89        vec![],
90        None,
91    ))
92}
93
94/// Stub `sum` used for optimizer testing
95#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Sum {
97    signature: Signature,
98}
99
100impl Sum {
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::user_defined(Immutable),
104        }
105    }
106}
107
108impl Default for Sum {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl AggregateUDFImpl for Sum {
115    fn name(&self) -> &str {
116        "sum"
117    }
118
119    fn signature(&self) -> &Signature {
120        &self.signature
121    }
122
123    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
124        let [array] = take_function_args(self.name(), arg_types)?;
125
126        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
127        // smallint, int, bigint, real, double precision, decimal, or interval.
128
129        fn coerced_type(data_type: &DataType) -> Result<DataType> {
130            match data_type {
131                DataType::Dictionary(_, v) => coerced_type(v),
132                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
133                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
134                DataType::Decimal32(_, _)
135                | DataType::Decimal64(_, _)
136                | DataType::Decimal128(_, _)
137                | DataType::Decimal256(_, _) => Ok(data_type.clone()),
138                dt if dt.is_signed_integer() => Ok(DataType::Int64),
139                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
140                dt if dt.is_floating() => Ok(DataType::Float64),
141                _ => exec_err!("Sum not supported for {data_type}"),
142            }
143        }
144
145        Ok(vec![coerced_type(array)?])
146    }
147
148    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
149        match &arg_types[0] {
150            DataType::Int64 => Ok(DataType::Int64),
151            DataType::UInt64 => Ok(DataType::UInt64),
152            DataType::Float64 => Ok(DataType::Float64),
153            DataType::Decimal32(precision, scale) => {
154                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
155                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
156                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
157                Ok(DataType::Decimal32(new_precision, *scale))
158            }
159            DataType::Decimal64(precision, scale) => {
160                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
161                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
162                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
163                Ok(DataType::Decimal64(new_precision, *scale))
164            }
165            DataType::Decimal128(precision, scale) => {
166                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
167                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
168                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
169                Ok(DataType::Decimal128(new_precision, *scale))
170            }
171            DataType::Decimal256(precision, scale) => {
172                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
173                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
174                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
175                Ok(DataType::Decimal256(new_precision, *scale))
176            }
177            other => {
178                exec_err!("[return_type] SUM not supported for {}", other)
179            }
180        }
181    }
182
183    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
184        unreachable!("stub should not have accumulate()")
185    }
186
187    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
188        unreachable!("stub should not have state_fields()")
189    }
190
191    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
192        false
193    }
194
195    fn create_groups_accumulator(
196        &self,
197        _args: AccumulatorArgs,
198    ) -> Result<Box<dyn GroupsAccumulator>> {
199        unreachable!("stub should not have accumulate()")
200    }
201
202    fn reverse_expr(&self) -> ReversedUDAF {
203        ReversedUDAF::Identical
204    }
205
206    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
207        AggregateOrderSensitivity::Insensitive
208    }
209}
210
211/// Testing stub implementation of COUNT aggregate
212#[derive(PartialEq, Eq, Hash)]
213pub struct Count {
214    signature: Signature,
215    aliases: Vec<String>,
216}
217
218impl std::fmt::Debug for Count {
219    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
220        f.debug_struct("Count")
221            .field("name", &self.name())
222            .field("signature", &self.signature)
223            .finish()
224    }
225}
226
227impl Default for Count {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl Count {
234    pub fn new() -> Self {
235        Self {
236            aliases: vec!["count".to_string()],
237            signature: Signature::variadic_any(Immutable),
238        }
239    }
240}
241
242impl AggregateUDFImpl for Count {
243    fn name(&self) -> &str {
244        "COUNT"
245    }
246
247    fn signature(&self) -> &Signature {
248        &self.signature
249    }
250
251    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
252        Ok(DataType::Int64)
253    }
254
255    fn is_nullable(&self) -> bool {
256        false
257    }
258
259    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
260        not_impl_err!("no impl for stub")
261    }
262
263    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264        not_impl_err!("no impl for stub")
265    }
266
267    fn aliases(&self) -> &[String] {
268        &self.aliases
269    }
270
271    fn create_groups_accumulator(
272        &self,
273        _args: AccumulatorArgs,
274    ) -> Result<Box<dyn GroupsAccumulator>> {
275        not_impl_err!("no impl for stub")
276    }
277
278    fn reverse_expr(&self) -> ReversedUDAF {
279        ReversedUDAF::Identical
280    }
281}
282
283create_func!(Min, min_udaf);
284
285pub fn min(expr: Expr) -> Expr {
286    Expr::AggregateFunction(AggregateFunction::new_udf(
287        min_udaf(),
288        vec![expr],
289        false,
290        None,
291        vec![],
292        None,
293    ))
294}
295
296/// Testing stub implementation of Min aggregate
297#[derive(PartialEq, Eq, Hash)]
298pub struct Min {
299    signature: Signature,
300}
301
302impl std::fmt::Debug for Min {
303    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
304        f.debug_struct("Min")
305            .field("name", &self.name())
306            .field("signature", &self.signature)
307            .finish()
308    }
309}
310
311impl Default for Min {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317impl Min {
318    pub fn new() -> Self {
319        Self {
320            signature: Signature::variadic_any(Immutable),
321        }
322    }
323}
324
325impl AggregateUDFImpl for Min {
326    fn name(&self) -> &str {
327        "min"
328    }
329
330    fn signature(&self) -> &Signature {
331        &self.signature
332    }
333
334    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
335        Ok(DataType::Int64)
336    }
337
338    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
339        not_impl_err!("no impl for stub")
340    }
341
342    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
343        not_impl_err!("no impl for stub")
344    }
345
346    fn create_groups_accumulator(
347        &self,
348        _args: AccumulatorArgs,
349    ) -> Result<Box<dyn GroupsAccumulator>> {
350        not_impl_err!("no impl for stub")
351    }
352
353    fn reverse_expr(&self) -> ReversedUDAF {
354        ReversedUDAF::Identical
355    }
356    fn is_descending(&self) -> Option<bool> {
357        Some(false)
358    }
359}
360
361create_func!(Max, max_udaf);
362
363pub fn max(expr: Expr) -> Expr {
364    Expr::AggregateFunction(AggregateFunction::new_udf(
365        max_udaf(),
366        vec![expr],
367        false,
368        None,
369        vec![],
370        None,
371    ))
372}
373
374/// Testing stub implementation of MAX aggregate
375#[derive(PartialEq, Eq, Hash)]
376pub struct Max {
377    signature: Signature,
378}
379
380impl std::fmt::Debug for Max {
381    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
382        f.debug_struct("Max")
383            .field("name", &self.name())
384            .field("signature", &self.signature)
385            .finish()
386    }
387}
388
389impl Default for Max {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395impl Max {
396    pub fn new() -> Self {
397        Self {
398            signature: Signature::variadic_any(Immutable),
399        }
400    }
401}
402
403impl AggregateUDFImpl for Max {
404    fn name(&self) -> &str {
405        "max"
406    }
407
408    fn signature(&self) -> &Signature {
409        &self.signature
410    }
411
412    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
413        Ok(DataType::Int64)
414    }
415
416    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
417        not_impl_err!("no impl for stub")
418    }
419
420    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
421        not_impl_err!("no impl for stub")
422    }
423
424    fn create_groups_accumulator(
425        &self,
426        _args: AccumulatorArgs,
427    ) -> Result<Box<dyn GroupsAccumulator>> {
428        not_impl_err!("no impl for stub")
429    }
430
431    fn reverse_expr(&self) -> ReversedUDAF {
432        ReversedUDAF::Identical
433    }
434    fn is_descending(&self) -> Option<bool> {
435        Some(true)
436    }
437}
438
439/// Testing stub implementation of avg aggregate
440#[derive(Debug, PartialEq, Eq, Hash)]
441pub struct Avg {
442    signature: Signature,
443    aliases: Vec<String>,
444}
445
446impl Avg {
447    pub fn new() -> Self {
448        let signature = Signature::one_of(
449            vec![
450                TypeSignature::Coercible(vec![Coercion::new_exact(
451                    TypeSignatureClass::Decimal,
452                )]),
453                TypeSignature::Coercible(vec![Coercion::new_implicit(
454                    TypeSignatureClass::Native(logical_float64()),
455                    vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
456                    NativeType::Float64,
457                )]),
458            ],
459            Immutable,
460        );
461        Self {
462            aliases: vec![String::from("mean")],
463            signature,
464        }
465    }
466}
467
468impl Default for Avg {
469    fn default() -> Self {
470        Self::new()
471    }
472}
473
474impl AggregateUDFImpl for Avg {
475    fn name(&self) -> &str {
476        "avg"
477    }
478
479    fn signature(&self) -> &Signature {
480        &self.signature
481    }
482
483    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
484        let [args] = take_function_args(self.name(), arg_types)?;
485
486        // Supported types smallint, int, bigint, real, double precision, decimal, or interval
487        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
488        fn coerced_type(data_type: &DataType) -> Result<DataType> {
489            match &data_type {
490                DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
491                DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
492                DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
493                DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
494                d if d.is_numeric() => Ok(DataType::Float64),
495                DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
496                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
497                _ => {
498                    plan_err!("Avg does not support inputs of type {data_type}.")
499                }
500            }
501        }
502        Ok(vec![coerced_type(args)?])
503    }
504
505    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
506        match &arg_types[0] {
507            DataType::Decimal32(precision, scale) => {
508                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
509                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
510                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
511                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
512                Ok(DataType::Decimal32(new_precision, new_scale))
513            }
514            DataType::Decimal64(precision, scale) => {
515                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
516                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
517                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
518                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
519                Ok(DataType::Decimal64(new_precision, new_scale))
520            }
521            DataType::Decimal128(precision, scale) => {
522                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
523                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
524                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
525                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
526                Ok(DataType::Decimal128(new_precision, new_scale))
527            }
528            DataType::Decimal256(precision, scale) => {
529                // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
530                // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
531                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
532                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
533                Ok(DataType::Decimal256(new_precision, new_scale))
534            }
535            DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
536            _ => Ok(DataType::Float64),
537        }
538    }
539
540    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
541        not_impl_err!("no impl for stub")
542    }
543
544    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
545        not_impl_err!("no impl for stub")
546    }
547
548    fn aliases(&self) -> &[String] {
549        &self.aliases
550    }
551}