datafusion_expr/test/
function_stub.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate function stubs for test in expr / optimizer.
19//!
20//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22use std::any::Any;
23
24use arrow::datatypes::{
25    DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33    expr::AggregateFunction,
34    function::{AccumulatorArgs, StateFieldsArgs},
35    utils::AggregateOrderSensitivity,
36    Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41        paste::paste! {
42            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44                // Singleton instance of [$UDAF], ensures the UDAF is only created once
45                static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46                    std::sync::LazyLock::new(|| {
47                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48                    });
49                std::sync::Arc::clone(&INSTANCE)
50            }
51        }
52    }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58    Expr::AggregateFunction(AggregateFunction::new_udf(
59        sum_udaf(),
60        vec![expr],
61        false,
62        None,
63        vec![],
64        None,
65    ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71    Expr::AggregateFunction(AggregateFunction::new_udf(
72        count_udaf(),
73        vec![expr],
74        false,
75        None,
76        vec![],
77        None,
78    ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84    Expr::AggregateFunction(AggregateFunction::new_udf(
85        avg_udaf(),
86        vec![expr],
87        false,
88        None,
89        vec![],
90        None,
91    ))
92}
93
94/// Stub `sum` used for optimizer testing
95#[derive(Debug)]
96pub struct Sum {
97    signature: Signature,
98}
99
100impl Sum {
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::user_defined(Immutable),
104        }
105    }
106}
107
108impl Default for Sum {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl AggregateUDFImpl for Sum {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "sum"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128        let [array] = take_function_args(self.name(), arg_types)?;
129
130        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
131        // smallint, int, bigint, real, double precision, decimal, or interval.
132
133        fn coerced_type(data_type: &DataType) -> Result<DataType> {
134            match data_type {
135                DataType::Dictionary(_, v) => coerced_type(v),
136                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
137                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
138                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139                    Ok(data_type.clone())
140                }
141                dt if dt.is_signed_integer() => Ok(DataType::Int64),
142                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143                dt if dt.is_floating() => Ok(DataType::Float64),
144                _ => exec_err!("Sum not supported for {}", data_type),
145            }
146        }
147
148        Ok(vec![coerced_type(array)?])
149    }
150
151    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152        match &arg_types[0] {
153            DataType::Int64 => Ok(DataType::Int64),
154            DataType::UInt64 => Ok(DataType::UInt64),
155            DataType::Float64 => Ok(DataType::Float64),
156            DataType::Decimal128(precision, scale) => {
157                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
158                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
159                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160                Ok(DataType::Decimal128(new_precision, *scale))
161            }
162            DataType::Decimal256(precision, scale) => {
163                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
164                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
165                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166                Ok(DataType::Decimal256(new_precision, *scale))
167            }
168            other => {
169                exec_err!("[return_type] SUM not supported for {}", other)
170            }
171        }
172    }
173
174    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175        unreachable!("stub should not have accumulate()")
176    }
177
178    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
179        unreachable!("stub should not have state_fields()")
180    }
181
182    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
183        false
184    }
185
186    fn create_groups_accumulator(
187        &self,
188        _args: AccumulatorArgs,
189    ) -> Result<Box<dyn GroupsAccumulator>> {
190        unreachable!("stub should not have accumulate()")
191    }
192
193    fn reverse_expr(&self) -> ReversedUDAF {
194        ReversedUDAF::Identical
195    }
196
197    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
198        AggregateOrderSensitivity::Insensitive
199    }
200}
201
202/// Testing stub implementation of COUNT aggregate
203pub struct Count {
204    signature: Signature,
205    aliases: Vec<String>,
206}
207
208impl std::fmt::Debug for Count {
209    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
210        f.debug_struct("Count")
211            .field("name", &self.name())
212            .field("signature", &self.signature)
213            .finish()
214    }
215}
216
217impl Default for Count {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223impl Count {
224    pub fn new() -> Self {
225        Self {
226            aliases: vec!["count".to_string()],
227            signature: Signature::variadic_any(Immutable),
228        }
229    }
230}
231
232impl AggregateUDFImpl for Count {
233    fn as_any(&self) -> &dyn Any {
234        self
235    }
236
237    fn name(&self) -> &str {
238        "COUNT"
239    }
240
241    fn signature(&self) -> &Signature {
242        &self.signature
243    }
244
245    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
246        Ok(DataType::Int64)
247    }
248
249    fn is_nullable(&self) -> bool {
250        false
251    }
252
253    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
254        not_impl_err!("no impl for stub")
255    }
256
257    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
258        not_impl_err!("no impl for stub")
259    }
260
261    fn aliases(&self) -> &[String] {
262        &self.aliases
263    }
264
265    fn create_groups_accumulator(
266        &self,
267        _args: AccumulatorArgs,
268    ) -> Result<Box<dyn GroupsAccumulator>> {
269        not_impl_err!("no impl for stub")
270    }
271
272    fn reverse_expr(&self) -> ReversedUDAF {
273        ReversedUDAF::Identical
274    }
275}
276
277create_func!(Min, min_udaf);
278
279pub fn min(expr: Expr) -> Expr {
280    Expr::AggregateFunction(AggregateFunction::new_udf(
281        min_udaf(),
282        vec![expr],
283        false,
284        None,
285        vec![],
286        None,
287    ))
288}
289
290/// Testing stub implementation of Min aggregate
291pub struct Min {
292    signature: Signature,
293}
294
295impl std::fmt::Debug for Min {
296    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
297        f.debug_struct("Min")
298            .field("name", &self.name())
299            .field("signature", &self.signature)
300            .finish()
301    }
302}
303
304impl Default for Min {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310impl Min {
311    pub fn new() -> Self {
312        Self {
313            signature: Signature::variadic_any(Immutable),
314        }
315    }
316}
317
318impl AggregateUDFImpl for Min {
319    fn as_any(&self) -> &dyn Any {
320        self
321    }
322
323    fn name(&self) -> &str {
324        "min"
325    }
326
327    fn signature(&self) -> &Signature {
328        &self.signature
329    }
330
331    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
332        Ok(DataType::Int64)
333    }
334
335    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
336        not_impl_err!("no impl for stub")
337    }
338
339    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
340        not_impl_err!("no impl for stub")
341    }
342
343    fn create_groups_accumulator(
344        &self,
345        _args: AccumulatorArgs,
346    ) -> Result<Box<dyn GroupsAccumulator>> {
347        not_impl_err!("no impl for stub")
348    }
349
350    fn reverse_expr(&self) -> ReversedUDAF {
351        ReversedUDAF::Identical
352    }
353    fn is_descending(&self) -> Option<bool> {
354        Some(false)
355    }
356}
357
358create_func!(Max, max_udaf);
359
360pub fn max(expr: Expr) -> Expr {
361    Expr::AggregateFunction(AggregateFunction::new_udf(
362        max_udaf(),
363        vec![expr],
364        false,
365        None,
366        vec![],
367        None,
368    ))
369}
370
371/// Testing stub implementation of MAX aggregate
372pub struct Max {
373    signature: Signature,
374}
375
376impl std::fmt::Debug for Max {
377    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
378        f.debug_struct("Max")
379            .field("name", &self.name())
380            .field("signature", &self.signature)
381            .finish()
382    }
383}
384
385impl Default for Max {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl Max {
392    pub fn new() -> Self {
393        Self {
394            signature: Signature::variadic_any(Immutable),
395        }
396    }
397}
398
399impl AggregateUDFImpl for Max {
400    fn as_any(&self) -> &dyn Any {
401        self
402    }
403
404    fn name(&self) -> &str {
405        "max"
406    }
407
408    fn signature(&self) -> &Signature {
409        &self.signature
410    }
411
412    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
413        Ok(DataType::Int64)
414    }
415
416    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
417        not_impl_err!("no impl for stub")
418    }
419
420    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
421        not_impl_err!("no impl for stub")
422    }
423
424    fn create_groups_accumulator(
425        &self,
426        _args: AccumulatorArgs,
427    ) -> Result<Box<dyn GroupsAccumulator>> {
428        not_impl_err!("no impl for stub")
429    }
430
431    fn reverse_expr(&self) -> ReversedUDAF {
432        ReversedUDAF::Identical
433    }
434    fn is_descending(&self) -> Option<bool> {
435        Some(true)
436    }
437}
438
439/// Testing stub implementation of avg aggregate
440#[derive(Debug)]
441pub struct Avg {
442    signature: Signature,
443    aliases: Vec<String>,
444}
445
446impl Avg {
447    pub fn new() -> Self {
448        Self {
449            aliases: vec![String::from("mean")],
450            signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
451        }
452    }
453}
454
455impl Default for Avg {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461impl AggregateUDFImpl for Avg {
462    fn as_any(&self) -> &dyn Any {
463        self
464    }
465
466    fn name(&self) -> &str {
467        "avg"
468    }
469
470    fn signature(&self) -> &Signature {
471        &self.signature
472    }
473
474    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
475        avg_return_type(self.name(), &arg_types[0])
476    }
477
478    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
479        not_impl_err!("no impl for stub")
480    }
481
482    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
483        not_impl_err!("no impl for stub")
484    }
485
486    fn aliases(&self) -> &[String] {
487        &self.aliases
488    }
489
490    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
491        coerce_avg_type(self.name(), arg_types)
492    }
493}