datafusion_expr/test/
function_stub.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate function stubs for test in expr / optimizer.
19//!
20//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22use std::any::Any;
23
24use arrow::datatypes::{
25    DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33    expr::AggregateFunction,
34    function::{AccumulatorArgs, StateFieldsArgs},
35    utils::AggregateOrderSensitivity,
36    Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41        paste::paste! {
42            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44                // Singleton instance of [$UDAF], ensures the UDAF is only created once
45                static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46                    std::sync::LazyLock::new(|| {
47                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48                    });
49                std::sync::Arc::clone(&INSTANCE)
50            }
51        }
52    }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58    Expr::AggregateFunction(AggregateFunction::new_udf(
59        sum_udaf(),
60        vec![expr],
61        false,
62        None,
63        vec![],
64        None,
65    ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71    Expr::AggregateFunction(AggregateFunction::new_udf(
72        count_udaf(),
73        vec![expr],
74        false,
75        None,
76        vec![],
77        None,
78    ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84    Expr::AggregateFunction(AggregateFunction::new_udf(
85        avg_udaf(),
86        vec![expr],
87        false,
88        None,
89        vec![],
90        None,
91    ))
92}
93
94/// Stub `sum` used for optimizer testing
95#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct Sum {
97    signature: Signature,
98}
99
100impl Sum {
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::user_defined(Immutable),
104        }
105    }
106}
107
108impl Default for Sum {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl AggregateUDFImpl for Sum {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "sum"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128        let [array] = take_function_args(self.name(), arg_types)?;
129
130        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
131        // smallint, int, bigint, real, double precision, decimal, or interval.
132
133        fn coerced_type(data_type: &DataType) -> Result<DataType> {
134            match data_type {
135                DataType::Dictionary(_, v) => coerced_type(v),
136                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
137                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
138                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139                    Ok(data_type.clone())
140                }
141                dt if dt.is_signed_integer() => Ok(DataType::Int64),
142                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143                dt if dt.is_floating() => Ok(DataType::Float64),
144                _ => exec_err!("Sum not supported for {}", data_type),
145            }
146        }
147
148        Ok(vec![coerced_type(array)?])
149    }
150
151    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152        match &arg_types[0] {
153            DataType::Int64 => Ok(DataType::Int64),
154            DataType::UInt64 => Ok(DataType::UInt64),
155            DataType::Float64 => Ok(DataType::Float64),
156            DataType::Decimal128(precision, scale) => {
157                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
158                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
159                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160                Ok(DataType::Decimal128(new_precision, *scale))
161            }
162            DataType::Decimal256(precision, scale) => {
163                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
164                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
165                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166                Ok(DataType::Decimal256(new_precision, *scale))
167            }
168            other => {
169                exec_err!("[return_type] SUM not supported for {}", other)
170            }
171        }
172    }
173
174    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175        unreachable!("stub should not have accumulate()")
176    }
177
178    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
179        unreachable!("stub should not have state_fields()")
180    }
181
182    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
183        false
184    }
185
186    fn create_groups_accumulator(
187        &self,
188        _args: AccumulatorArgs,
189    ) -> Result<Box<dyn GroupsAccumulator>> {
190        unreachable!("stub should not have accumulate()")
191    }
192
193    fn reverse_expr(&self) -> ReversedUDAF {
194        ReversedUDAF::Identical
195    }
196
197    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
198        AggregateOrderSensitivity::Insensitive
199    }
200}
201
202/// Testing stub implementation of COUNT aggregate
203#[derive(PartialEq, Eq, Hash)]
204pub struct Count {
205    signature: Signature,
206    aliases: Vec<String>,
207}
208
209impl std::fmt::Debug for Count {
210    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
211        f.debug_struct("Count")
212            .field("name", &self.name())
213            .field("signature", &self.signature)
214            .finish()
215    }
216}
217
218impl Default for Count {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224impl Count {
225    pub fn new() -> Self {
226        Self {
227            aliases: vec!["count".to_string()],
228            signature: Signature::variadic_any(Immutable),
229        }
230    }
231}
232
233impl AggregateUDFImpl for Count {
234    fn as_any(&self) -> &dyn Any {
235        self
236    }
237
238    fn name(&self) -> &str {
239        "COUNT"
240    }
241
242    fn signature(&self) -> &Signature {
243        &self.signature
244    }
245
246    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
247        Ok(DataType::Int64)
248    }
249
250    fn is_nullable(&self) -> bool {
251        false
252    }
253
254    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
255        not_impl_err!("no impl for stub")
256    }
257
258    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
259        not_impl_err!("no impl for stub")
260    }
261
262    fn aliases(&self) -> &[String] {
263        &self.aliases
264    }
265
266    fn create_groups_accumulator(
267        &self,
268        _args: AccumulatorArgs,
269    ) -> Result<Box<dyn GroupsAccumulator>> {
270        not_impl_err!("no impl for stub")
271    }
272
273    fn reverse_expr(&self) -> ReversedUDAF {
274        ReversedUDAF::Identical
275    }
276}
277
278create_func!(Min, min_udaf);
279
280pub fn min(expr: Expr) -> Expr {
281    Expr::AggregateFunction(AggregateFunction::new_udf(
282        min_udaf(),
283        vec![expr],
284        false,
285        None,
286        vec![],
287        None,
288    ))
289}
290
291/// Testing stub implementation of Min aggregate
292#[derive(PartialEq, Eq, Hash)]
293pub struct Min {
294    signature: Signature,
295}
296
297impl std::fmt::Debug for Min {
298    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
299        f.debug_struct("Min")
300            .field("name", &self.name())
301            .field("signature", &self.signature)
302            .finish()
303    }
304}
305
306impl Default for Min {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312impl Min {
313    pub fn new() -> Self {
314        Self {
315            signature: Signature::variadic_any(Immutable),
316        }
317    }
318}
319
320impl AggregateUDFImpl for Min {
321    fn as_any(&self) -> &dyn Any {
322        self
323    }
324
325    fn name(&self) -> &str {
326        "min"
327    }
328
329    fn signature(&self) -> &Signature {
330        &self.signature
331    }
332
333    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
334        Ok(DataType::Int64)
335    }
336
337    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
338        not_impl_err!("no impl for stub")
339    }
340
341    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
342        not_impl_err!("no impl for stub")
343    }
344
345    fn create_groups_accumulator(
346        &self,
347        _args: AccumulatorArgs,
348    ) -> Result<Box<dyn GroupsAccumulator>> {
349        not_impl_err!("no impl for stub")
350    }
351
352    fn reverse_expr(&self) -> ReversedUDAF {
353        ReversedUDAF::Identical
354    }
355    fn is_descending(&self) -> Option<bool> {
356        Some(false)
357    }
358}
359
360create_func!(Max, max_udaf);
361
362pub fn max(expr: Expr) -> Expr {
363    Expr::AggregateFunction(AggregateFunction::new_udf(
364        max_udaf(),
365        vec![expr],
366        false,
367        None,
368        vec![],
369        None,
370    ))
371}
372
373/// Testing stub implementation of MAX aggregate
374#[derive(PartialEq, Eq, Hash)]
375pub struct Max {
376    signature: Signature,
377}
378
379impl std::fmt::Debug for Max {
380    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
381        f.debug_struct("Max")
382            .field("name", &self.name())
383            .field("signature", &self.signature)
384            .finish()
385    }
386}
387
388impl Default for Max {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394impl Max {
395    pub fn new() -> Self {
396        Self {
397            signature: Signature::variadic_any(Immutable),
398        }
399    }
400}
401
402impl AggregateUDFImpl for Max {
403    fn as_any(&self) -> &dyn Any {
404        self
405    }
406
407    fn name(&self) -> &str {
408        "max"
409    }
410
411    fn signature(&self) -> &Signature {
412        &self.signature
413    }
414
415    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
416        Ok(DataType::Int64)
417    }
418
419    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
420        not_impl_err!("no impl for stub")
421    }
422
423    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
424        not_impl_err!("no impl for stub")
425    }
426
427    fn create_groups_accumulator(
428        &self,
429        _args: AccumulatorArgs,
430    ) -> Result<Box<dyn GroupsAccumulator>> {
431        not_impl_err!("no impl for stub")
432    }
433
434    fn reverse_expr(&self) -> ReversedUDAF {
435        ReversedUDAF::Identical
436    }
437    fn is_descending(&self) -> Option<bool> {
438        Some(true)
439    }
440}
441
442/// Testing stub implementation of avg aggregate
443#[derive(Debug, PartialEq, Eq, Hash)]
444pub struct Avg {
445    signature: Signature,
446    aliases: Vec<String>,
447}
448
449impl Avg {
450    pub fn new() -> Self {
451        Self {
452            aliases: vec![String::from("mean")],
453            signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
454        }
455    }
456}
457
458impl Default for Avg {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464impl AggregateUDFImpl for Avg {
465    fn as_any(&self) -> &dyn Any {
466        self
467    }
468
469    fn name(&self) -> &str {
470        "avg"
471    }
472
473    fn signature(&self) -> &Signature {
474        &self.signature
475    }
476
477    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
478        avg_return_type(self.name(), &arg_types[0])
479    }
480
481    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
482        not_impl_err!("no impl for stub")
483    }
484
485    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
486        not_impl_err!("no impl for stub")
487    }
488
489    fn aliases(&self) -> &[String] {
490        &self.aliases
491    }
492
493    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
494        coerce_avg_type(self.name(), arg_types)
495    }
496}