Skip to main content

datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err};
30use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
31use datafusion_expr_common::operator::Operator;
32use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
33
34use crate::expr::{
35    AggregateFunction, AggregateFunctionParams, ExprListDisplay, WindowFunctionParams,
36    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
37    schema_name_from_sorts,
38};
39use crate::function::{
40    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
41};
42use crate::groups_accumulator::GroupsAccumulator;
43use crate::udf_eq::UdfEq;
44use crate::utils::AggregateOrderSensitivity;
45use crate::utils::format_state_name;
46use crate::{Accumulator, Expr, expr_vec_fmt};
47use crate::{Documentation, Signature};
48
49/// Logical representation of a user-defined [aggregate function] (UDAF).
50///
51/// An aggregate function combines the values from multiple input rows
52/// into a single output "aggregate" (summary) row. It is different
53/// from a scalar function because it is stateful across batches. User
54/// defined aggregate functions can be used as normal SQL aggregate
55/// functions (`GROUP BY` clause) as well as window functions (`OVER`
56/// clause).
57///
58/// `AggregateUDF` provides DataFusion the information needed to plan and call
59/// aggregate functions, including name, type information, and a factory
60/// function to create an [`Accumulator`] instance, to perform the actual
61/// aggregation.
62///
63/// For more information, please see [the examples]:
64///
65/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
66///
67/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
68///    access (examples in [`advanced_udaf.rs`]).
69///
70/// # API Note
71/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
72/// compatibility with the older API.
73///
74/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
75/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
76/// [`Accumulator`]: Accumulator
77/// [`create_udaf`]: crate::expr_fn::create_udaf
78/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udaf.rs
79/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
80#[derive(Debug, Clone, PartialOrd)]
81pub struct AggregateUDF {
82    inner: Arc<dyn AggregateUDFImpl>,
83}
84
85impl PartialEq for AggregateUDF {
86    fn eq(&self, other: &Self) -> bool {
87        self.inner.dyn_eq(other.inner.as_ref() as &dyn Any)
88    }
89}
90
91impl Eq for AggregateUDF {}
92
93impl Hash for AggregateUDF {
94    fn hash<H: Hasher>(&self, state: &mut H) {
95        self.inner.dyn_hash(state)
96    }
97}
98
99impl fmt::Display for AggregateUDF {
100    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
101        write!(f, "{}", self.name())
102    }
103}
104
105/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
106#[derive(Debug)]
107pub struct StatisticsArgs<'a> {
108    /// The statistics of the aggregate input
109    pub statistics: &'a Statistics,
110    /// The resolved return type of the aggregate function
111    pub return_type: &'a DataType,
112    /// Whether the aggregate function is distinct.
113    ///
114    /// ```sql
115    /// SELECT COUNT(DISTINCT column1) FROM t;
116    /// ```
117    pub is_distinct: bool,
118    /// The physical expression of arguments the aggregate function takes.
119    pub exprs: &'a [Arc<dyn PhysicalExpr>],
120}
121
122impl AggregateUDF {
123    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
124    ///
125    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
126    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
127    where
128        F: AggregateUDFImpl + 'static,
129    {
130        Self::new_from_shared_impl(Arc::new(fun))
131    }
132
133    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
134    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
135        Self { inner: fun }
136    }
137
138    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
139    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
140        &self.inner
141    }
142
143    /// Adds additional names that can be used to invoke this function, in
144    /// addition to `name`
145    ///
146    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
147    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
148        Self::new_from_impl(AliasedAggregateUDFImpl::new(
149            Arc::clone(&self.inner),
150            aliases,
151        ))
152    }
153
154    /// Creates an [`Expr`] that calls the aggregate function.
155    ///
156    /// This utility allows using the UDAF without requiring access to
157    /// the registry, such as with the DataFrame API.
158    pub fn call(&self, args: Vec<Expr>) -> Expr {
159        Expr::AggregateFunction(AggregateFunction::new_udf(
160            Arc::new(self.clone()),
161            args,
162            false,
163            None,
164            vec![],
165            None,
166        ))
167    }
168
169    /// Returns this function's name
170    ///
171    /// See [`AggregateUDFImpl::name`] for more details.
172    pub fn name(&self) -> &str {
173        self.inner.name()
174    }
175
176    /// Returns the aliases for this function.
177    pub fn aliases(&self) -> &[String] {
178        self.inner.aliases()
179    }
180
181    /// See [`AggregateUDFImpl::schema_name`] for more details.
182    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
183        self.inner.schema_name(params)
184    }
185
186    /// Returns a human readable expression.
187    ///
188    /// See [`Expr::human_display`] for details.
189    pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
190        self.inner.human_display(params)
191    }
192
193    pub fn window_function_schema_name(
194        &self,
195        params: &WindowFunctionParams,
196    ) -> Result<String> {
197        self.inner.window_function_schema_name(params)
198    }
199
200    /// See [`AggregateUDFImpl::display_name`] for more details.
201    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
202        self.inner.display_name(params)
203    }
204
205    pub fn window_function_display_name(
206        &self,
207        params: &WindowFunctionParams,
208    ) -> Result<String> {
209        self.inner.window_function_display_name(params)
210    }
211
212    pub fn is_nullable(&self) -> bool {
213        self.inner.is_nullable()
214    }
215
216    /// Returns this function's signature (what input types are accepted)
217    ///
218    /// See [`AggregateUDFImpl::signature`] for more details.
219    pub fn signature(&self) -> &Signature {
220        self.inner.signature()
221    }
222
223    /// Return the type of the function given its input types
224    ///
225    /// See [`AggregateUDFImpl::return_type`] for more details.
226    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
227        self.inner.return_type(args)
228    }
229
230    /// Return the field of the function given its input fields
231    ///
232    /// See [`AggregateUDFImpl::return_field`] for more details.
233    pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
234        self.inner.return_field(args)
235    }
236
237    /// Return an accumulator the given aggregate, given its return datatype
238    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
239        self.inner.accumulator(acc_args)
240    }
241
242    /// Return the fields used to store the intermediate state for this aggregator, given
243    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
244    /// for more details.
245    ///
246    /// This is used to support multi-phase aggregations
247    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
248        self.inner.state_fields(args)
249    }
250
251    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
252    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
253        self.inner.groups_accumulator_supported(args)
254    }
255
256    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
257    pub fn create_groups_accumulator(
258        &self,
259        args: AccumulatorArgs,
260    ) -> Result<Box<dyn GroupsAccumulator>> {
261        self.inner.create_groups_accumulator(args)
262    }
263
264    pub fn create_sliding_accumulator(
265        &self,
266        args: AccumulatorArgs,
267    ) -> Result<Box<dyn Accumulator>> {
268        self.inner.create_sliding_accumulator(args)
269    }
270
271    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
272        self.inner.coerce_types(arg_types)
273    }
274
275    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
276    pub fn with_beneficial_ordering(
277        self,
278        beneficial_ordering: bool,
279    ) -> Result<Option<AggregateUDF>> {
280        self.inner
281            .with_beneficial_ordering(beneficial_ordering)
282            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
283    }
284
285    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
286    /// for possible options.
287    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
288        self.inner.order_sensitivity()
289    }
290
291    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
292    /// generate same result with this `AggregateUDF` when iterated in reverse
293    /// order, and `None` if there is no such `AggregateUDF`).
294    pub fn reverse_udf(&self) -> ReversedUDAF {
295        self.inner.reverse_expr()
296    }
297
298    /// Returns this aggregate function's simplification hook, if any.
299    ///
300    /// See [`AggregateUDFImpl::simplify`] for more details.
301    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
302        self.inner.simplify()
303    }
304
305    /// Rewrite aggregate to have simpler arguments
306    ///
307    /// See  [`AggregateUDFImpl::simplify_expr_op_literal`] for more details
308    pub fn simplify_expr_op_literal(
309        &self,
310        agg_function: &AggregateFunction,
311        arg: &Expr,
312        op: Operator,
313        lit: &Expr,
314        arg_is_left: bool,
315    ) -> Result<Option<Expr>> {
316        self.inner
317            .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
318    }
319
320    /// Returns true if the function is max, false if the function is min
321    /// None in all other cases, used in certain optimizations for
322    /// or aggregate
323    pub fn is_descending(&self) -> Option<bool> {
324        self.inner.is_descending()
325    }
326
327    /// Return the value of this aggregate function if it can be determined
328    /// entirely from statistics and arguments.
329    ///
330    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
331    pub fn value_from_stats(
332        &self,
333        statistics_args: &StatisticsArgs,
334    ) -> Option<ScalarValue> {
335        self.inner.value_from_stats(statistics_args)
336    }
337
338    /// See [`AggregateUDFImpl::default_value`] for more details.
339    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
340        self.inner.default_value(data_type)
341    }
342
343    /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
344    pub fn supports_null_handling_clause(&self) -> bool {
345        self.inner.supports_null_handling_clause()
346    }
347
348    /// See [`AggregateUDFImpl::supports_within_group_clause`] for more details.
349    pub fn supports_within_group_clause(&self) -> bool {
350        self.inner.supports_within_group_clause()
351    }
352
353    /// Returns the documentation for this Aggregate UDF.
354    ///
355    /// Documentation can be accessed programmatically as well as
356    /// generating publicly facing documentation.
357    pub fn documentation(&self) -> Option<&Documentation> {
358        self.inner.documentation()
359    }
360}
361
362impl<F> From<F> for AggregateUDF
363where
364    F: AggregateUDFImpl + Send + Sync + 'static,
365{
366    fn from(fun: F) -> Self {
367        Self::new_from_impl(fun)
368    }
369}
370
371/// Trait for implementing [`AggregateUDF`].
372///
373/// This trait exposes the full API for implementing user defined aggregate functions and
374/// can be used to implement any function.
375///
376/// See [`advanced_udaf.rs`] for a full example with complete implementation and
377/// [`AggregateUDF`] for other available options.
378///
379/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
380///
381/// # Basic Example
382/// ```
383/// # use std::any::Any;
384/// # use std::sync::{Arc, LazyLock};
385/// # use arrow::datatypes::{DataType, FieldRef};
386/// # use datafusion_common::{DataFusionError, plan_err, Result};
387/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
388/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
389/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
390/// # use arrow::datatypes::Schema;
391/// # use arrow::datatypes::Field;
392///
393/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
394/// struct GeoMeanUdf {
395///   signature: Signature,
396/// }
397///
398/// impl GeoMeanUdf {
399///   fn new() -> Self {
400///     Self {
401///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
402///      }
403///   }
404/// }
405///
406/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
407///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
408///             .with_argument("arg1", "The Float64 number for the geometric mean")
409///             .build()
410///     });
411///
412/// fn get_doc() -> &'static Documentation {
413///     &DOCUMENTATION
414/// }
415///
416/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
417/// impl AggregateUDFImpl for GeoMeanUdf {
418///    fn name(&self) -> &str { "geo_mean" }
419///    fn signature(&self) -> &Signature { &self.signature }
420///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
421///      if !matches!(args.get(0), Some(&DataType::Float64)) {
422///        return plan_err!("geo_mean only accepts Float64 arguments");
423///      }
424///      Ok(DataType::Float64)
425///    }
426///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
427///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
428///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
429///        Ok(vec![
430///             Arc::new(args.return_field.as_ref().clone().with_name("value")),
431///             Arc::new(Field::new("ordering", DataType::UInt32, true))
432///        ])
433///    }
434///    fn documentation(&self) -> Option<&Documentation> {
435///        Some(get_doc())
436///    }
437/// }
438///
439/// // Create a new AggregateUDF from the implementation
440/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
441///
442/// // Call the function `geo_mean(col)`
443/// let expr = geometric_mean.call(vec![col("a")]);
444/// ```
445pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
446    /// Returns this function's name
447    fn name(&self) -> &str;
448
449    /// Returns any aliases (alternate names) for this function.
450    ///
451    /// Note: `aliases` should only include names other than [`Self::name`].
452    /// Defaults to `[]` (no aliases)
453    fn aliases(&self) -> &[String] {
454        &[]
455    }
456
457    /// Returns the name of the column this expression would create
458    ///
459    /// See [`Expr::schema_name`] for details
460    ///
461    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
462    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
463        udaf_default_schema_name(self, params)
464    }
465
466    /// Returns a human readable expression.
467    ///
468    /// See [`Expr::human_display`] for details.
469    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
470        udaf_default_human_display(self, params)
471    }
472
473    /// Returns the name of the column this expression would create
474    ///
475    /// See [`Expr::schema_name`] for details
476    ///
477    /// Different from `schema_name` in that it is used for window aggregate function
478    ///
479    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
480    fn window_function_schema_name(
481        &self,
482        params: &WindowFunctionParams,
483    ) -> Result<String> {
484        udaf_default_window_function_schema_name(self, params)
485    }
486
487    /// Returns the user-defined display name of function, given the arguments
488    ///
489    /// This can be used to customize the output column name generated by this
490    /// function.
491    ///
492    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
493    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
494        udaf_default_display_name(self, params)
495    }
496
497    /// Returns the user-defined display name of function, given the arguments
498    ///
499    /// This can be used to customize the output column name generated by this
500    /// function.
501    ///
502    /// Different from `display_name` in that it is used for window aggregate function
503    ///
504    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
505    fn window_function_display_name(
506        &self,
507        params: &WindowFunctionParams,
508    ) -> Result<String> {
509        udaf_default_window_function_display_name(self, params)
510    }
511
512    /// Returns the function's [`Signature`] for information about what input
513    /// types are accepted and the function's Volatility.
514    fn signature(&self) -> &Signature;
515
516    /// What [`DataType`] will be returned by this function, given the types of
517    /// the arguments
518    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
519
520    /// What type will be returned by this function, given the arguments?
521    ///
522    /// By default, this function calls [`Self::return_type`] with the
523    /// types of each argument.
524    ///
525    /// # Notes
526    ///
527    /// Most UDFs should implement [`Self::return_type`] and not this
528    /// function as the output type for most functions only depends on the types
529    /// of their inputs (e.g. `sum(f64)` is always `f64`).
530    ///
531    /// This function can be used for more advanced cases such as:
532    ///
533    /// 1. specifying nullability
534    /// 2. return types based on the **values** of the arguments (rather than
535    ///    their **types**.
536    /// 3. return types based on metadata within the fields of the inputs
537    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
538        udaf_default_return_field(self, arg_fields)
539    }
540
541    /// Whether the aggregate function is nullable.
542    ///
543    /// Nullable means that the function could return `null` for any inputs.
544    /// For example, aggregate functions like `COUNT` always return a non null value
545    /// but others like `MIN` will return `NULL` if there is nullable input.
546    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
547    fn is_nullable(&self) -> bool {
548        true
549    }
550
551    /// Return a new [`Accumulator`] that aggregates values for a specific
552    /// group during query execution.
553    ///
554    /// acc_args: [`AccumulatorArgs`] contains information about how the
555    /// aggregate function was called.
556    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
557
558    /// Return the fields used to store the intermediate state of this accumulator.
559    ///
560    /// See [`Accumulator::state`] for background information.
561    ///
562    /// args:  [`StateFieldsArgs`] contains arguments passed to the
563    /// aggregate function's accumulator.
564    ///
565    /// # Notes:
566    ///
567    /// The default implementation returns a single state field named `name`
568    /// with the same type as `value_type`. This is suitable for aggregates such
569    /// as `SUM` or `MIN` where partial state can be combined by applying the
570    /// same aggregate.
571    ///
572    /// For aggregates such as `AVG` where the partial state is more complex
573    /// (e.g. a COUNT and a SUM), this method is used to define the additional
574    /// fields.
575    ///
576    /// The name of the fields must be unique within the query and thus should
577    /// be derived from `name`. See [`format_state_name`] for a utility function
578    /// to generate a unique name.
579    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
580        let fields = vec![
581            args.return_field
582                .as_ref()
583                .clone()
584                .with_name(format_state_name(args.name, "value")),
585        ];
586
587        Ok(fields
588            .into_iter()
589            .map(Arc::new)
590            .chain(args.ordering_fields.to_vec())
591            .collect())
592    }
593
594    /// If the aggregate expression has a specialized
595    /// [`GroupsAccumulator`] implementation. If this returns true,
596    /// `[Self::create_groups_accumulator]` will be called.
597    ///
598    /// # Notes
599    ///
600    /// Even if this function returns true, DataFusion will still use
601    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
602    /// used as a window function or when there no GROUP BY columns in the
603    /// query.
604    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
605        false
606    }
607
608    /// Return a specialized [`GroupsAccumulator`] that manages state
609    /// for all groups.
610    ///
611    /// For maximum performance, a [`GroupsAccumulator`] should be
612    /// implemented in addition to [`Accumulator`].
613    fn create_groups_accumulator(
614        &self,
615        _args: AccumulatorArgs,
616    ) -> Result<Box<dyn GroupsAccumulator>> {
617        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
618    }
619
620    /// Sliding accumulator is an alternative accumulator that can be used for
621    /// window functions. It has retract method to revert the previous update.
622    ///
623    /// See [retract_batch] for more details.
624    ///
625    /// [retract_batch]: Accumulator::retract_batch
626    fn create_sliding_accumulator(
627        &self,
628        args: AccumulatorArgs,
629    ) -> Result<Box<dyn Accumulator>> {
630        self.accumulator(args)
631    }
632
633    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
634    /// satisfied by its input. If this is not the case, UDFs with order
635    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
636    /// the correct result with possibly more work internally.
637    ///
638    /// # Returns
639    ///
640    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
641    /// If the expression can benefit from existing input ordering, but does
642    /// not implement the method, returns an error. Order insensitive and hard
643    /// requirement aggregators return `Ok(None)`.
644    fn with_beneficial_ordering(
645        self: Arc<Self>,
646        _beneficial_ordering: bool,
647    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
648        if self.order_sensitivity().is_beneficial() {
649            return exec_err!(
650                "Should implement with satisfied for aggregator :{:?}",
651                self.name()
652            );
653        }
654        Ok(None)
655    }
656
657    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
658    /// for possible options.
659    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
660        // We have hard ordering requirements by default, meaning that order
661        // sensitive UDFs need their input orderings to satisfy their ordering
662        // requirements to generate correct results.
663        AggregateOrderSensitivity::HardRequirement
664    }
665
666    /// Returns an optional hook for simplifying this user-defined aggregate.
667    ///
668    /// Use this hook to apply function-specific rewrites during optimization.
669    /// The default implementation returns `None`.
670    ///
671    /// For example, `percentile_cont(x, 0.0)` and `percentile_cont(x, 1.0)` can
672    /// be rewritten to `MIN(x)` or `MAX(x)` depending on the `ORDER BY`
673    /// direction.
674    ///
675    /// DataFusion already simplifies arguments and performs constant folding
676    /// (for example, `my_add(1, 2) -> 3`). For nested expressions, the optimizer
677    /// runs simplification in multiple passes, so arguments are typically
678    /// simplified before this hook is invoked. As a result, UDF implementations
679    /// usually do not need to handle argument simplification themselves.
680    ///
681    /// See configuration `datafusion.optimizer.max_passes` for details on how many
682    /// optimization passes may be applied.
683    ///
684    /// # Returns
685    ///
686    /// `None` if simplify is not defined.
687    ///
688    /// Or, a closure ([`AggregateFunctionSimplification`]) invoked with:
689    /// * `aggregate_function`: [AggregateFunction] with already simplified
690    ///   arguments
691    /// * `info`: [crate::simplify::SimplifyContext]
692    ///
693    /// The closure returns a simplified [Expr] or an error.
694    ///
695    /// # Notes
696    ///
697    /// The returned expression must have the same schema as the original
698    /// expression, including both the data type and nullability. For example,
699    /// if the original expression is nullable, the returned expression must
700    /// also be nullable, otherwise it may lead to schema verification errors
701    /// later in query planning.
702    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
703        None
704    }
705
706    /// Rewrite the aggregate to have simpler arguments
707    ///
708    /// This query pattern is not common in most real workloads, and most
709    /// aggregate implementations can safely ignore it. This API is included in
710    /// DataFusion because it is important for ClickBench Q29. See backstory
711    /// on <https://github.com/apache/datafusion/issues/15524>
712    ///
713    /// # Rewrite Overview
714    ///
715    /// The idea is to rewrite multiple aggregates with "complex arguments" into
716    /// ones with simpler arguments that can be optimized by common subexpression
717    /// elimination (CSE). At a high level the rewrite looks like
718    ///
719    /// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)`
720    ///
721    /// Into
722    ///
723    /// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)`
724    ///
725    /// While this rewrite may seem worse (slower) than the original as it
726    /// computes *more* aggregate expressions, the common subexpression
727    /// elimination (CSE) can then reduce the number of distinct aggregates the
728    /// query actually needs to compute with a rewrite like
729    ///
730    /// * `Projection(_A + 1*_B, _A + 2*_B)`
731    /// * `  Aggregate(_A = SUM(x), _B = COUNT(x))`
732    ///
733    /// This optimization is extremely important for ClickBench Q29, which has 90
734    /// such expressions for some reason, and so this optimization results in
735    /// only two aggregates being needed. The DataFusion optimizer will invoke
736    /// this method when it detects multiple aggregates in a query that share
737    /// arguments of the form `<arg> <op> <literal>`.
738    ///
739    /// # API
740    ///
741    /// If `agg_function` supports the rewrite, it should return a semantically
742    /// equivalent expression (likely with more aggregate expressions, but
743    /// simpler arguments)
744    ///
745    /// This is only called when:
746    /// 1. There are no "special" aggregate params (filters, null handling, etc)
747    /// 2. Aggregate functions with exactly one [`Expr`] argument
748    /// 3. There are no volatile expressions
749    ///
750    /// Arguments
751    /// * `agg_function`: the original aggregate function detected with complex
752    ///   arguments.
753    /// * `arg`: The common argument shared across multiple aggregates (e.g. `x`
754    ///   in the example above)
755    /// * `op`: the operator between the common argument and the literal (e.g.
756    ///   `+` in `x + 1` or `1 + x`)
757    /// * `lit`: the literal argument (e.g. `1` or `2` in the example above)
758    /// * `arg_is_left`: whether the common argument is on the left or right of
759    ///   the operator (e.g. `true` for `x + 1` and false for `1 + x`)
760    ///
761    /// The default implementation returns `None`, which is what most aggregates
762    /// should do.
763    fn simplify_expr_op_literal(
764        &self,
765        _agg_function: &AggregateFunction,
766        _arg: &Expr,
767        _op: Operator,
768        _lit: &Expr,
769        _arg_is_left: bool,
770    ) -> Result<Option<Expr>> {
771        Ok(None)
772    }
773
774    /// Returns the reverse expression of the aggregate function.
775    fn reverse_expr(&self) -> ReversedUDAF {
776        ReversedUDAF::NotSupported
777    }
778
779    /// Coerce arguments of a function call to types that the function can evaluate.
780    ///
781    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
782    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
783    /// cases
784    ///
785    /// See the [type coercion module](crate::type_coercion)
786    /// documentation for more details on type coercion
787    ///
788    /// For example, if your function requires a floating point arguments, but the user calls
789    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
790    /// to ensure the argument was cast to `1::double`
791    ///
792    /// # Parameters
793    /// * `arg_types`: The argument types of the arguments  this function with
794    ///
795    /// # Return value
796    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
797    /// arguments to these specific types.
798    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
799        not_impl_err!("Function {} does not implement coerce_types", self.name())
800    }
801
802    /// If this function is max, return true
803    /// If the function is min, return false
804    /// Otherwise return None (the default)
805    ///
806    ///
807    /// Note: this is used to use special aggregate implementations in certain conditions
808    fn is_descending(&self) -> Option<bool> {
809        None
810    }
811
812    /// Return the value of this aggregate function if it can be determined
813    /// entirely from statistics and arguments.
814    ///
815    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
816    /// improving query performance.
817    ///
818    /// For example, if the minimum value of column `x` is known to be `42` from
819    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
820    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
821        None
822    }
823
824    /// Returns default value of the function given the input is all `null`.
825    ///
826    /// Most of the aggregate function return Null if input is Null,
827    /// while `count` returns 0 if input is Null
828    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
829        ScalarValue::try_from(data_type)
830    }
831
832    /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` SQL clause,
833    /// return `true`. Otherwise, return `false` which will cause an error to be
834    /// raised during SQL parsing if these clauses are detected for this function.
835    ///
836    /// Functions which implement this as `true` are expected to handle the resulting
837    /// null handling config present in [`AccumulatorArgs`], `ignore_nulls`.
838    fn supports_null_handling_clause(&self) -> bool {
839        false
840    }
841
842    /// If this function supports the `WITHIN GROUP (ORDER BY column [ASC|DESC])`
843    /// SQL syntax, return `true`. Otherwise, return `false` (default) which will
844    /// cause an error when parsing SQL where this syntax is detected for this
845    /// function.
846    ///
847    /// This function should return `true` for ordered-set aggregate functions
848    /// only.
849    ///
850    /// # Ordered-set aggregate functions
851    ///
852    /// Ordered-set aggregate functions allow specifying a sort order that affects
853    /// how the function calculates its result, unlike other aggregate functions
854    /// like `sum` or `count`. For example, `percentile_cont` is an ordered-set
855    /// aggregate function that calculates the exact percentile value from a list
856    /// of values; the output of calculating the `0.75` percentile depends on if
857    /// you're calculating on an ascending or descending list of values.
858    ///
859    /// An example of how an ordered-set aggregate function is called with the
860    /// `WITHIN GROUP` SQL syntax:
861    ///
862    /// ```sql
863    /// -- Ascending
864    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
865    /// -- Default ordering is ascending if not explicitly specified
866    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1) FROM table;
867    /// -- Descending
868    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 DESC) FROM table;
869    /// ```
870    ///
871    /// This calculates the `0.75` percentile of the column `c1` from `table`,
872    /// according to the specific ordering. The column specified in the `WITHIN GROUP`
873    /// ordering clause is taken as the column to calculate values on; specifying
874    /// the `WITHIN GROUP` clause is optional so these queries are equivalent:
875    ///
876    /// ```sql
877    /// -- If no WITHIN GROUP is specified then default ordering is implementation
878    /// -- dependent; in this case ascending for percentile_cont
879    /// SELECT percentile_cont(c1, 0.75) FROM table;
880    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
881    /// ```
882    ///
883    /// Aggregate UDFs can define their default ordering if the function is called
884    /// without the `WITHIN GROUP` clause, though a default of ascending is the
885    /// standard practice.
886    ///
887    /// Ordered-set aggregate function implementations are responsible for handling
888    /// the input sort order themselves (e.g. `percentile_cont` must buffer and
889    /// sort the values internally). That is, DataFusion does not introduce any
890    /// kind of sort into the plan for these functions with this syntax.
891    fn supports_within_group_clause(&self) -> bool {
892        false
893    }
894
895    /// Returns the documentation for this Aggregate UDF.
896    ///
897    /// Documentation can be accessed programmatically as well as
898    /// generating publicly facing documentation.
899    fn documentation(&self) -> Option<&Documentation> {
900        None
901    }
902
903    /// Indicates whether the aggregation function is monotonic as a set
904    /// function. See [`SetMonotonicity`] for details.
905    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
906        SetMonotonicity::NotMonotonic
907    }
908}
909
910impl dyn AggregateUDFImpl {
911    /// Returns `true` if the implementation is of type `T`.
912    ///
913    /// Works correctly when called on `Arc<dyn AggregateUDFImpl>` via auto-deref.
914    pub fn is<T: AggregateUDFImpl>(&self) -> bool {
915        (self as &dyn Any).is::<T>()
916    }
917
918    /// Attempts to downcast to a concrete type `T`, returning `None` if the
919    /// implementation is not of that type.
920    ///
921    /// Works correctly when called on `Arc<dyn AggregateUDFImpl>` via auto-deref,
922    /// unlike `(&arc as &dyn Any).downcast_ref::<T>()` which would attempt to
923    /// downcast the `Arc` itself.
924    pub fn downcast_ref<T: AggregateUDFImpl>(&self) -> Option<&T> {
925        (self as &dyn Any).downcast_ref()
926    }
927}
928
929impl PartialEq for dyn AggregateUDFImpl {
930    fn eq(&self, other: &Self) -> bool {
931        self.dyn_eq(other)
932    }
933}
934
935impl PartialOrd for dyn AggregateUDFImpl {
936    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
937        match self.name().partial_cmp(other.name()) {
938            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
939            cmp => cmp,
940        }
941        // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
942        .filter(|cmp| *cmp != Ordering::Equal || self == other)
943    }
944}
945
946/// Encapsulates default implementation of [`AggregateUDFImpl::schema_name`].
947pub fn udaf_default_schema_name<F: AggregateUDFImpl + ?Sized>(
948    func: &F,
949    params: &AggregateFunctionParams,
950) -> Result<String> {
951    let AggregateFunctionParams {
952        args,
953        distinct,
954        filter,
955        order_by,
956        null_treatment,
957    } = params;
958
959    // exclude the first function argument(= column) in ordered set aggregate function,
960    // because it is duplicated with the WITHIN GROUP clause in schema name.
961    let args = if func.supports_within_group_clause() && !order_by.is_empty() {
962        &args[1..]
963    } else {
964        &args[..]
965    };
966
967    let mut schema_name = String::new();
968
969    schema_name.write_fmt(format_args!(
970        "{}({}{})",
971        func.name(),
972        if *distinct { "DISTINCT " } else { "" },
973        schema_name_from_exprs_comma_separated_without_space(args)?
974    ))?;
975
976    if let Some(null_treatment) = null_treatment {
977        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
978    }
979
980    if let Some(filter) = filter {
981        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
982    };
983
984    if !order_by.is_empty() {
985        let clause = match func.supports_within_group_clause() {
986            true => "WITHIN GROUP",
987            false => "ORDER BY",
988        };
989
990        schema_name.write_fmt(format_args!(
991            " {} [{}]",
992            clause,
993            schema_name_from_sorts(order_by)?
994        ))?;
995    };
996
997    Ok(schema_name)
998}
999
1000/// Encapsulates default implementation of [`AggregateUDFImpl::human_display`].
1001pub fn udaf_default_human_display<F: AggregateUDFImpl + ?Sized>(
1002    func: &F,
1003    params: &AggregateFunctionParams,
1004) -> Result<String> {
1005    let AggregateFunctionParams {
1006        args,
1007        distinct,
1008        filter,
1009        order_by,
1010        null_treatment,
1011    } = params;
1012
1013    let mut schema_name = String::new();
1014
1015    schema_name.write_fmt(format_args!(
1016        "{}({}{})",
1017        func.name(),
1018        if *distinct { "DISTINCT " } else { "" },
1019        ExprListDisplay::comma_separated(args.as_slice())
1020    ))?;
1021
1022    if let Some(null_treatment) = null_treatment {
1023        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
1024    }
1025
1026    if let Some(filter) = filter {
1027        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
1028    };
1029
1030    if !order_by.is_empty() {
1031        schema_name.write_fmt(format_args!(
1032            " ORDER BY [{}]",
1033            schema_name_from_sorts(order_by)?
1034        ))?;
1035    };
1036
1037    Ok(schema_name)
1038}
1039
1040/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_schema_name`].
1041pub fn udaf_default_window_function_schema_name<F: AggregateUDFImpl + ?Sized>(
1042    func: &F,
1043    params: &WindowFunctionParams,
1044) -> Result<String> {
1045    let WindowFunctionParams {
1046        args,
1047        partition_by,
1048        order_by,
1049        window_frame,
1050        filter,
1051        null_treatment,
1052        distinct,
1053    } = params;
1054
1055    let mut schema_name = String::new();
1056
1057    // Inject DISTINCT into the schema name when requested
1058    if *distinct {
1059        schema_name.write_fmt(format_args!(
1060            "{}(DISTINCT {})",
1061            func.name(),
1062            schema_name_from_exprs(args)?
1063        ))?;
1064    } else {
1065        schema_name.write_fmt(format_args!(
1066            "{}({})",
1067            func.name(),
1068            schema_name_from_exprs(args)?
1069        ))?;
1070    }
1071
1072    if let Some(null_treatment) = null_treatment {
1073        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
1074    }
1075
1076    if let Some(filter) = filter {
1077        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
1078    }
1079
1080    if !partition_by.is_empty() {
1081        schema_name.write_fmt(format_args!(
1082            " PARTITION BY [{}]",
1083            schema_name_from_exprs(partition_by)?
1084        ))?;
1085    }
1086
1087    if !order_by.is_empty() {
1088        schema_name.write_fmt(format_args!(
1089            " ORDER BY [{}]",
1090            schema_name_from_sorts(order_by)?
1091        ))?;
1092    }
1093
1094    schema_name.write_fmt(format_args!(" {window_frame}"))?;
1095
1096    Ok(schema_name)
1097}
1098
1099/// Encapsulates default implementation of [`AggregateUDFImpl::display_name`].
1100pub fn udaf_default_display_name<F: AggregateUDFImpl + ?Sized>(
1101    func: &F,
1102    params: &AggregateFunctionParams,
1103) -> Result<String> {
1104    let AggregateFunctionParams {
1105        args,
1106        distinct,
1107        filter,
1108        order_by,
1109        null_treatment,
1110    } = params;
1111
1112    let mut display_name = String::new();
1113
1114    display_name.write_fmt(format_args!(
1115        "{}({}{})",
1116        func.name(),
1117        if *distinct { "DISTINCT " } else { "" },
1118        expr_vec_fmt!(args)
1119    ))?;
1120
1121    if let Some(nt) = null_treatment {
1122        display_name.write_fmt(format_args!(" {nt}"))?;
1123    }
1124    if let Some(fe) = filter {
1125        display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1126    }
1127    if !order_by.is_empty() {
1128        display_name.write_fmt(format_args!(
1129            " ORDER BY [{}]",
1130            order_by
1131                .iter()
1132                .map(|o| format!("{o}"))
1133                .collect::<Vec<String>>()
1134                .join(", ")
1135        ))?;
1136    }
1137
1138    Ok(display_name)
1139}
1140
1141/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_display_name`].
1142pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>(
1143    func: &F,
1144    params: &WindowFunctionParams,
1145) -> Result<String> {
1146    let WindowFunctionParams {
1147        args,
1148        partition_by,
1149        order_by,
1150        window_frame,
1151        filter,
1152        null_treatment,
1153        distinct,
1154    } = params;
1155
1156    let mut display_name = String::new();
1157
1158    if *distinct {
1159        display_name.write_fmt(format_args!(
1160            "{}(DISTINCT {})",
1161            func.name(),
1162            expr_vec_fmt!(args)
1163        ))?;
1164    } else {
1165        display_name.write_fmt(format_args!(
1166            "{}({})",
1167            func.name(),
1168            expr_vec_fmt!(args)
1169        ))?;
1170    }
1171
1172    if let Some(null_treatment) = null_treatment {
1173        display_name.write_fmt(format_args!(" {null_treatment}"))?;
1174    }
1175
1176    if let Some(fe) = filter {
1177        display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1178    }
1179
1180    if !partition_by.is_empty() {
1181        display_name.write_fmt(format_args!(
1182            " PARTITION BY [{}]",
1183            expr_vec_fmt!(partition_by)
1184        ))?;
1185    }
1186
1187    if !order_by.is_empty() {
1188        display_name
1189            .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
1190    };
1191
1192    display_name.write_fmt(format_args!(
1193        " {} BETWEEN {} AND {}",
1194        window_frame.units, window_frame.start_bound, window_frame.end_bound
1195    ))?;
1196
1197    Ok(display_name)
1198}
1199
1200/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`].
1201pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>(
1202    func: &F,
1203    arg_fields: &[FieldRef],
1204) -> Result<FieldRef> {
1205    let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
1206    let data_type = func.return_type(&arg_types)?;
1207
1208    Ok(Arc::new(Field::new(
1209        func.name(),
1210        data_type,
1211        func.is_nullable(),
1212    )))
1213}
1214
1215pub enum ReversedUDAF {
1216    /// The expression is the same as the original expression, like SUM, COUNT
1217    Identical,
1218    /// The expression does not support reverse calculation
1219    NotSupported,
1220    /// The expression is different from the original expression
1221    Reversed(Arc<AggregateUDF>),
1222}
1223
1224/// AggregateUDF that adds an alias to the underlying function. It is better to
1225/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1226#[derive(Debug, PartialEq, Eq, Hash)]
1227struct AliasedAggregateUDFImpl {
1228    inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
1229    aliases: Vec<String>,
1230}
1231
1232impl AliasedAggregateUDFImpl {
1233    pub fn new(
1234        inner: Arc<dyn AggregateUDFImpl>,
1235        new_aliases: impl IntoIterator<Item = &'static str>,
1236    ) -> Self {
1237        let mut aliases = inner.aliases().to_vec();
1238        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1239
1240        Self {
1241            inner: inner.into(),
1242            aliases,
1243        }
1244    }
1245}
1246
1247#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1248impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1249    fn name(&self) -> &str {
1250        self.inner.name()
1251    }
1252
1253    fn signature(&self) -> &Signature {
1254        self.inner.signature()
1255    }
1256
1257    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1258        self.inner.return_type(arg_types)
1259    }
1260
1261    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1262        self.inner.accumulator(acc_args)
1263    }
1264
1265    fn aliases(&self) -> &[String] {
1266        &self.aliases
1267    }
1268
1269    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1270        self.inner.schema_name(params)
1271    }
1272
1273    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
1274        self.inner.human_display(params)
1275    }
1276
1277    fn window_function_schema_name(
1278        &self,
1279        params: &WindowFunctionParams,
1280    ) -> Result<String> {
1281        self.inner.window_function_schema_name(params)
1282    }
1283
1284    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1285        self.inner.display_name(params)
1286    }
1287
1288    fn window_function_display_name(
1289        &self,
1290        params: &WindowFunctionParams,
1291    ) -> Result<String> {
1292        self.inner.window_function_display_name(params)
1293    }
1294
1295    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1296        self.inner.state_fields(args)
1297    }
1298
1299    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1300        self.inner.groups_accumulator_supported(args)
1301    }
1302
1303    fn create_groups_accumulator(
1304        &self,
1305        args: AccumulatorArgs,
1306    ) -> Result<Box<dyn GroupsAccumulator>> {
1307        self.inner.create_groups_accumulator(args)
1308    }
1309
1310    fn create_sliding_accumulator(
1311        &self,
1312        args: AccumulatorArgs,
1313    ) -> Result<Box<dyn Accumulator>> {
1314        self.inner.accumulator(args)
1315    }
1316
1317    fn with_beneficial_ordering(
1318        self: Arc<Self>,
1319        beneficial_ordering: bool,
1320    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1321        Arc::clone(&self.inner)
1322            .with_beneficial_ordering(beneficial_ordering)
1323            .map(|udf| {
1324                udf.map(|udf| {
1325                    Arc::new(AliasedAggregateUDFImpl {
1326                        inner: udf.into(),
1327                        aliases: self.aliases.clone(),
1328                    }) as Arc<dyn AggregateUDFImpl>
1329                })
1330            })
1331    }
1332
1333    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1334        self.inner.order_sensitivity()
1335    }
1336
1337    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1338        self.inner.simplify()
1339    }
1340
1341    fn simplify_expr_op_literal(
1342        &self,
1343        agg_function: &AggregateFunction,
1344        arg: &Expr,
1345        op: Operator,
1346        lit: &Expr,
1347        arg_is_left: bool,
1348    ) -> Result<Option<Expr>> {
1349        self.inner
1350            .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
1351    }
1352
1353    fn reverse_expr(&self) -> ReversedUDAF {
1354        self.inner.reverse_expr()
1355    }
1356
1357    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1358        self.inner.coerce_types(arg_types)
1359    }
1360
1361    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1362        self.inner.return_field(arg_fields)
1363    }
1364
1365    fn is_nullable(&self) -> bool {
1366        self.inner.is_nullable()
1367    }
1368
1369    fn is_descending(&self) -> Option<bool> {
1370        self.inner.is_descending()
1371    }
1372
1373    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1374        self.inner.value_from_stats(statistics_args)
1375    }
1376
1377    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
1378        self.inner.default_value(data_type)
1379    }
1380
1381    fn supports_null_handling_clause(&self) -> bool {
1382        self.inner.supports_null_handling_clause()
1383    }
1384
1385    fn supports_within_group_clause(&self) -> bool {
1386        self.inner.supports_within_group_clause()
1387    }
1388
1389    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
1390        self.inner.set_monotonicity(data_type)
1391    }
1392
1393    fn documentation(&self) -> Option<&Documentation> {
1394        self.inner.documentation()
1395    }
1396}
1397
1398/// Indicates whether an aggregation function is monotonic as a set
1399/// function. A set function is monotonically increasing if its value
1400/// increases as its argument grows (as a set). Formally, `f` is a
1401/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1402/// is a superset of `T`.
1403///
1404/// For example `COUNT` and `MAX` are monotonically increasing as their
1405/// values always increase (or stay the same) as new values are seen. On
1406/// the other hand, `MIN` is monotonically decreasing as its value always
1407/// decreases or stays the same as new values are seen.
1408#[derive(Debug, Clone, PartialEq)]
1409pub enum SetMonotonicity {
1410    /// Aggregate value increases or stays the same as the input set grows.
1411    Increasing,
1412    /// Aggregate value decreases or stays the same as the input set grows.
1413    Decreasing,
1414    /// Aggregate value may increase, decrease, or stay the same as the input
1415    /// set grows.
1416    NotMonotonic,
1417}
1418
1419#[cfg(test)]
1420mod test {
1421    use crate::{AggregateUDF, AggregateUDFImpl};
1422    use arrow::datatypes::{DataType, FieldRef};
1423    use datafusion_common::Result;
1424    use datafusion_expr_common::accumulator::Accumulator;
1425    use datafusion_expr_common::signature::{Signature, Volatility};
1426    use datafusion_functions_aggregate_common::accumulator::{
1427        AccumulatorArgs, StateFieldsArgs,
1428    };
1429    use std::cmp::Ordering;
1430    use std::hash::{DefaultHasher, Hash, Hasher};
1431
1432    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1433    struct AMeanUdf {
1434        signature: Signature,
1435    }
1436
1437    impl AMeanUdf {
1438        fn new() -> Self {
1439            Self {
1440                signature: Signature::uniform(
1441                    1,
1442                    vec![DataType::Float64],
1443                    Volatility::Immutable,
1444                ),
1445            }
1446        }
1447    }
1448
1449    impl AggregateUDFImpl for AMeanUdf {
1450        fn name(&self) -> &str {
1451            "a"
1452        }
1453        fn signature(&self) -> &Signature {
1454            &self.signature
1455        }
1456        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1457            unimplemented!()
1458        }
1459        fn accumulator(
1460            &self,
1461            _acc_args: AccumulatorArgs,
1462        ) -> Result<Box<dyn Accumulator>> {
1463            unimplemented!()
1464        }
1465        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1466            unimplemented!()
1467        }
1468    }
1469
1470    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1471    struct BMeanUdf {
1472        signature: Signature,
1473    }
1474    impl BMeanUdf {
1475        fn new() -> Self {
1476            Self {
1477                signature: Signature::uniform(
1478                    1,
1479                    vec![DataType::Float64],
1480                    Volatility::Immutable,
1481                ),
1482            }
1483        }
1484    }
1485
1486    impl AggregateUDFImpl for BMeanUdf {
1487        fn name(&self) -> &str {
1488            "b"
1489        }
1490        fn signature(&self) -> &Signature {
1491            &self.signature
1492        }
1493        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1494            unimplemented!()
1495        }
1496        fn accumulator(
1497            &self,
1498            _acc_args: AccumulatorArgs,
1499        ) -> Result<Box<dyn Accumulator>> {
1500            unimplemented!()
1501        }
1502        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1503            unimplemented!()
1504        }
1505    }
1506
1507    #[test]
1508    fn test_partial_eq() {
1509        let a1 = AggregateUDF::from(AMeanUdf::new());
1510        let a2 = AggregateUDF::from(AMeanUdf::new());
1511        let eq = a1 == a2;
1512        assert!(eq);
1513        assert_eq!(a1, a2);
1514        assert_eq!(hash(a1), hash(a2));
1515    }
1516
1517    #[test]
1518    fn test_partial_ord() {
1519        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1520        // not intended to exhaustively test all possibilities
1521        let a1 = AggregateUDF::from(AMeanUdf::new());
1522        let a2 = AggregateUDF::from(AMeanUdf::new());
1523        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1524
1525        let b1 = AggregateUDF::from(BMeanUdf::new());
1526        assert!(a1 < b1);
1527        assert!(!(a1 == b1));
1528    }
1529
1530    fn hash<T: Hash>(value: T) -> u64 {
1531        let hasher = &mut DefaultHasher::new();
1532        value.hash(hasher);
1533        hasher.finish()
1534    }
1535}