datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err};
30use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
31use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
32
33use crate::expr::{
34    AggregateFunction, AggregateFunctionParams, ExprListDisplay, WindowFunctionParams,
35    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
36    schema_name_from_sorts,
37};
38use crate::function::{
39    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
40};
41use crate::groups_accumulator::GroupsAccumulator;
42use crate::udf_eq::UdfEq;
43use crate::utils::AggregateOrderSensitivity;
44use crate::utils::format_state_name;
45use crate::{Accumulator, Expr, expr_vec_fmt};
46use crate::{Documentation, Signature};
47
48/// Logical representation of a user-defined [aggregate function] (UDAF).
49///
50/// An aggregate function combines the values from multiple input rows
51/// into a single output "aggregate" (summary) row. It is different
52/// from a scalar function because it is stateful across batches. User
53/// defined aggregate functions can be used as normal SQL aggregate
54/// functions (`GROUP BY` clause) as well as window functions (`OVER`
55/// clause).
56///
57/// `AggregateUDF` provides DataFusion the information needed to plan and call
58/// aggregate functions, including name, type information, and a factory
59/// function to create an [`Accumulator`] instance, to perform the actual
60/// aggregation.
61///
62/// For more information, please see [the examples]:
63///
64/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
65///
66/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
67///    access (examples in [`advanced_udaf.rs`]).
68///
69/// # API Note
70/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
71/// compatibility with the older API.
72///
73/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
74/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
75/// [`Accumulator`]: Accumulator
76/// [`create_udaf`]: crate::expr_fn::create_udaf
77/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udaf.rs
78/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
79#[derive(Debug, Clone, PartialOrd)]
80pub struct AggregateUDF {
81    inner: Arc<dyn AggregateUDFImpl>,
82}
83
84impl PartialEq for AggregateUDF {
85    fn eq(&self, other: &Self) -> bool {
86        self.inner.dyn_eq(other.inner.as_any())
87    }
88}
89
90impl Eq for AggregateUDF {}
91
92impl Hash for AggregateUDF {
93    fn hash<H: Hasher>(&self, state: &mut H) {
94        self.inner.dyn_hash(state)
95    }
96}
97
98impl fmt::Display for AggregateUDF {
99    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
100        write!(f, "{}", self.name())
101    }
102}
103
104/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
105#[derive(Debug)]
106pub struct StatisticsArgs<'a> {
107    /// The statistics of the aggregate input
108    pub statistics: &'a Statistics,
109    /// The resolved return type of the aggregate function
110    pub return_type: &'a DataType,
111    /// Whether the aggregate function is distinct.
112    ///
113    /// ```sql
114    /// SELECT COUNT(DISTINCT column1) FROM t;
115    /// ```
116    pub is_distinct: bool,
117    /// The physical expression of arguments the aggregate function takes.
118    pub exprs: &'a [Arc<dyn PhysicalExpr>],
119}
120
121impl AggregateUDF {
122    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
123    ///
124    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
125    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
126    where
127        F: AggregateUDFImpl + 'static,
128    {
129        Self::new_from_shared_impl(Arc::new(fun))
130    }
131
132    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
133    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
134        Self { inner: fun }
135    }
136
137    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
138    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
139        &self.inner
140    }
141
142    /// Adds additional names that can be used to invoke this function, in
143    /// addition to `name`
144    ///
145    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
146    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
147        Self::new_from_impl(AliasedAggregateUDFImpl::new(
148            Arc::clone(&self.inner),
149            aliases,
150        ))
151    }
152
153    /// Creates an [`Expr`] that calls the aggregate function.
154    ///
155    /// This utility allows using the UDAF without requiring access to
156    /// the registry, such as with the DataFrame API.
157    pub fn call(&self, args: Vec<Expr>) -> Expr {
158        Expr::AggregateFunction(AggregateFunction::new_udf(
159            Arc::new(self.clone()),
160            args,
161            false,
162            None,
163            vec![],
164            None,
165        ))
166    }
167
168    /// Returns this function's name
169    ///
170    /// See [`AggregateUDFImpl::name`] for more details.
171    pub fn name(&self) -> &str {
172        self.inner.name()
173    }
174
175    /// Returns the aliases for this function.
176    pub fn aliases(&self) -> &[String] {
177        self.inner.aliases()
178    }
179
180    /// See [`AggregateUDFImpl::schema_name`] for more details.
181    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
182        self.inner.schema_name(params)
183    }
184
185    /// Returns a human readable expression.
186    ///
187    /// See [`Expr::human_display`] for details.
188    pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
189        self.inner.human_display(params)
190    }
191
192    pub fn window_function_schema_name(
193        &self,
194        params: &WindowFunctionParams,
195    ) -> Result<String> {
196        self.inner.window_function_schema_name(params)
197    }
198
199    /// See [`AggregateUDFImpl::display_name`] for more details.
200    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
201        self.inner.display_name(params)
202    }
203
204    pub fn window_function_display_name(
205        &self,
206        params: &WindowFunctionParams,
207    ) -> Result<String> {
208        self.inner.window_function_display_name(params)
209    }
210
211    pub fn is_nullable(&self) -> bool {
212        self.inner.is_nullable()
213    }
214
215    /// Returns this function's signature (what input types are accepted)
216    ///
217    /// See [`AggregateUDFImpl::signature`] for more details.
218    pub fn signature(&self) -> &Signature {
219        self.inner.signature()
220    }
221
222    /// Return the type of the function given its input types
223    ///
224    /// See [`AggregateUDFImpl::return_type`] for more details.
225    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
226        self.inner.return_type(args)
227    }
228
229    /// Return the field of the function given its input fields
230    ///
231    /// See [`AggregateUDFImpl::return_field`] for more details.
232    pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
233        self.inner.return_field(args)
234    }
235
236    /// Return an accumulator the given aggregate, given its return datatype
237    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
238        self.inner.accumulator(acc_args)
239    }
240
241    /// Return the fields used to store the intermediate state for this aggregator, given
242    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
243    /// for more details.
244    ///
245    /// This is used to support multi-phase aggregations
246    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
247        self.inner.state_fields(args)
248    }
249
250    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
251    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
252        self.inner.groups_accumulator_supported(args)
253    }
254
255    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
256    pub fn create_groups_accumulator(
257        &self,
258        args: AccumulatorArgs,
259    ) -> Result<Box<dyn GroupsAccumulator>> {
260        self.inner.create_groups_accumulator(args)
261    }
262
263    pub fn create_sliding_accumulator(
264        &self,
265        args: AccumulatorArgs,
266    ) -> Result<Box<dyn Accumulator>> {
267        self.inner.create_sliding_accumulator(args)
268    }
269
270    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
271        self.inner.coerce_types(arg_types)
272    }
273
274    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
275    pub fn with_beneficial_ordering(
276        self,
277        beneficial_ordering: bool,
278    ) -> Result<Option<AggregateUDF>> {
279        self.inner
280            .with_beneficial_ordering(beneficial_ordering)
281            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
282    }
283
284    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
285    /// for possible options.
286    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
287        self.inner.order_sensitivity()
288    }
289
290    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
291    /// generate same result with this `AggregateUDF` when iterated in reverse
292    /// order, and `None` if there is no such `AggregateUDF`).
293    pub fn reverse_udf(&self) -> ReversedUDAF {
294        self.inner.reverse_expr()
295    }
296
297    /// Do the function rewrite
298    ///
299    /// See [`AggregateUDFImpl::simplify`] for more details.
300    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
301        self.inner.simplify()
302    }
303
304    /// Returns true if the function is max, false if the function is min
305    /// None in all other cases, used in certain optimizations for
306    /// or aggregate
307    pub fn is_descending(&self) -> Option<bool> {
308        self.inner.is_descending()
309    }
310
311    /// Return the value of this aggregate function if it can be determined
312    /// entirely from statistics and arguments.
313    ///
314    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
315    pub fn value_from_stats(
316        &self,
317        statistics_args: &StatisticsArgs,
318    ) -> Option<ScalarValue> {
319        self.inner.value_from_stats(statistics_args)
320    }
321
322    /// See [`AggregateUDFImpl::default_value`] for more details.
323    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
324        self.inner.default_value(data_type)
325    }
326
327    /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
328    pub fn supports_null_handling_clause(&self) -> bool {
329        self.inner.supports_null_handling_clause()
330    }
331
332    /// See [`AggregateUDFImpl::supports_within_group_clause`] for more details.
333    pub fn supports_within_group_clause(&self) -> bool {
334        self.inner.supports_within_group_clause()
335    }
336
337    /// Returns the documentation for this Aggregate UDF.
338    ///
339    /// Documentation can be accessed programmatically as well as
340    /// generating publicly facing documentation.
341    pub fn documentation(&self) -> Option<&Documentation> {
342        self.inner.documentation()
343    }
344}
345
346impl<F> From<F> for AggregateUDF
347where
348    F: AggregateUDFImpl + Send + Sync + 'static,
349{
350    fn from(fun: F) -> Self {
351        Self::new_from_impl(fun)
352    }
353}
354
355/// Trait for implementing [`AggregateUDF`].
356///
357/// This trait exposes the full API for implementing user defined aggregate functions and
358/// can be used to implement any function.
359///
360/// See [`advanced_udaf.rs`] for a full example with complete implementation and
361/// [`AggregateUDF`] for other available options.
362///
363/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
364///
365/// # Basic Example
366/// ```
367/// # use std::any::Any;
368/// # use std::sync::{Arc, LazyLock};
369/// # use arrow::datatypes::{DataType, FieldRef};
370/// # use datafusion_common::{DataFusionError, plan_err, Result};
371/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
372/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
373/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
374/// # use arrow::datatypes::Schema;
375/// # use arrow::datatypes::Field;
376///
377/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
378/// struct GeoMeanUdf {
379///   signature: Signature,
380/// }
381///
382/// impl GeoMeanUdf {
383///   fn new() -> Self {
384///     Self {
385///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
386///      }
387///   }
388/// }
389///
390/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
391///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
392///             .with_argument("arg1", "The Float64 number for the geometric mean")
393///             .build()
394///     });
395///
396/// fn get_doc() -> &'static Documentation {
397///     &DOCUMENTATION
398/// }
399///
400/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
401/// impl AggregateUDFImpl for GeoMeanUdf {
402///    fn as_any(&self) -> &dyn Any { self }
403///    fn name(&self) -> &str { "geo_mean" }
404///    fn signature(&self) -> &Signature { &self.signature }
405///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
406///      if !matches!(args.get(0), Some(&DataType::Float64)) {
407///        return plan_err!("geo_mean only accepts Float64 arguments");
408///      }
409///      Ok(DataType::Float64)
410///    }
411///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
412///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
413///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
414///        Ok(vec![
415///             Arc::new(args.return_field.as_ref().clone().with_name("value")),
416///             Arc::new(Field::new("ordering", DataType::UInt32, true))
417///        ])
418///    }
419///    fn documentation(&self) -> Option<&Documentation> {
420///        Some(get_doc())
421///    }
422/// }
423///
424/// // Create a new AggregateUDF from the implementation
425/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
426///
427/// // Call the function `geo_mean(col)`
428/// let expr = geometric_mean.call(vec![col("a")]);
429/// ```
430pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
431    /// Returns this object as an [`Any`] trait object
432    fn as_any(&self) -> &dyn Any;
433
434    /// Returns this function's name
435    fn name(&self) -> &str;
436
437    /// Returns any aliases (alternate names) for this function.
438    ///
439    /// Note: `aliases` should only include names other than [`Self::name`].
440    /// Defaults to `[]` (no aliases)
441    fn aliases(&self) -> &[String] {
442        &[]
443    }
444
445    /// Returns the name of the column this expression would create
446    ///
447    /// See [`Expr::schema_name`] for details
448    ///
449    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
450    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
451        udaf_default_schema_name(self, params)
452    }
453
454    /// Returns a human readable expression.
455    ///
456    /// See [`Expr::human_display`] for details.
457    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
458        udaf_default_human_display(self, params)
459    }
460
461    /// Returns the name of the column this expression would create
462    ///
463    /// See [`Expr::schema_name`] for details
464    ///
465    /// Different from `schema_name` in that it is used for window aggregate function
466    ///
467    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
468    fn window_function_schema_name(
469        &self,
470        params: &WindowFunctionParams,
471    ) -> Result<String> {
472        udaf_default_window_function_schema_name(self, params)
473    }
474
475    /// Returns the user-defined display name of function, given the arguments
476    ///
477    /// This can be used to customize the output column name generated by this
478    /// function.
479    ///
480    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
481    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
482        udaf_default_display_name(self, params)
483    }
484
485    /// Returns the user-defined display name of function, given the arguments
486    ///
487    /// This can be used to customize the output column name generated by this
488    /// function.
489    ///
490    /// Different from `display_name` in that it is used for window aggregate function
491    ///
492    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
493    fn window_function_display_name(
494        &self,
495        params: &WindowFunctionParams,
496    ) -> Result<String> {
497        udaf_default_window_function_display_name(self, params)
498    }
499
500    /// Returns the function's [`Signature`] for information about what input
501    /// types are accepted and the function's Volatility.
502    fn signature(&self) -> &Signature;
503
504    /// What [`DataType`] will be returned by this function, given the types of
505    /// the arguments
506    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
507
508    /// What type will be returned by this function, given the arguments?
509    ///
510    /// By default, this function calls [`Self::return_type`] with the
511    /// types of each argument.
512    ///
513    /// # Notes
514    ///
515    /// Most UDFs should implement [`Self::return_type`] and not this
516    /// function as the output type for most functions only depends on the types
517    /// of their inputs (e.g. `sum(f64)` is always `f64`).
518    ///
519    /// This function can be used for more advanced cases such as:
520    ///
521    /// 1. specifying nullability
522    /// 2. return types based on the **values** of the arguments (rather than
523    ///    their **types**.
524    /// 3. return types based on metadata within the fields of the inputs
525    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
526        udaf_default_return_field(self, arg_fields)
527    }
528
529    /// Whether the aggregate function is nullable.
530    ///
531    /// Nullable means that the function could return `null` for any inputs.
532    /// For example, aggregate functions like `COUNT` always return a non null value
533    /// but others like `MIN` will return `NULL` if there is nullable input.
534    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
535    fn is_nullable(&self) -> bool {
536        true
537    }
538
539    /// Return a new [`Accumulator`] that aggregates values for a specific
540    /// group during query execution.
541    ///
542    /// acc_args: [`AccumulatorArgs`] contains information about how the
543    /// aggregate function was called.
544    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
545
546    /// Return the fields used to store the intermediate state of this accumulator.
547    ///
548    /// See [`Accumulator::state`] for background information.
549    ///
550    /// args:  [`StateFieldsArgs`] contains arguments passed to the
551    /// aggregate function's accumulator.
552    ///
553    /// # Notes:
554    ///
555    /// The default implementation returns a single state field named `name`
556    /// with the same type as `value_type`. This is suitable for aggregates such
557    /// as `SUM` or `MIN` where partial state can be combined by applying the
558    /// same aggregate.
559    ///
560    /// For aggregates such as `AVG` where the partial state is more complex
561    /// (e.g. a COUNT and a SUM), this method is used to define the additional
562    /// fields.
563    ///
564    /// The name of the fields must be unique within the query and thus should
565    /// be derived from `name`. See [`format_state_name`] for a utility function
566    /// to generate a unique name.
567    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
568        let fields = vec![
569            args.return_field
570                .as_ref()
571                .clone()
572                .with_name(format_state_name(args.name, "value")),
573        ];
574
575        Ok(fields
576            .into_iter()
577            .map(Arc::new)
578            .chain(args.ordering_fields.to_vec())
579            .collect())
580    }
581
582    /// If the aggregate expression has a specialized
583    /// [`GroupsAccumulator`] implementation. If this returns true,
584    /// `[Self::create_groups_accumulator]` will be called.
585    ///
586    /// # Notes
587    ///
588    /// Even if this function returns true, DataFusion will still use
589    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
590    /// used as a window function or when there no GROUP BY columns in the
591    /// query.
592    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
593        false
594    }
595
596    /// Return a specialized [`GroupsAccumulator`] that manages state
597    /// for all groups.
598    ///
599    /// For maximum performance, a [`GroupsAccumulator`] should be
600    /// implemented in addition to [`Accumulator`].
601    fn create_groups_accumulator(
602        &self,
603        _args: AccumulatorArgs,
604    ) -> Result<Box<dyn GroupsAccumulator>> {
605        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
606    }
607
608    /// Sliding accumulator is an alternative accumulator that can be used for
609    /// window functions. It has retract method to revert the previous update.
610    ///
611    /// See [retract_batch] for more details.
612    ///
613    /// [retract_batch]: Accumulator::retract_batch
614    fn create_sliding_accumulator(
615        &self,
616        args: AccumulatorArgs,
617    ) -> Result<Box<dyn Accumulator>> {
618        self.accumulator(args)
619    }
620
621    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
622    /// satisfied by its input. If this is not the case, UDFs with order
623    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
624    /// the correct result with possibly more work internally.
625    ///
626    /// # Returns
627    ///
628    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
629    /// If the expression can benefit from existing input ordering, but does
630    /// not implement the method, returns an error. Order insensitive and hard
631    /// requirement aggregators return `Ok(None)`.
632    fn with_beneficial_ordering(
633        self: Arc<Self>,
634        _beneficial_ordering: bool,
635    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
636        if self.order_sensitivity().is_beneficial() {
637            return exec_err!(
638                "Should implement with satisfied for aggregator :{:?}",
639                self.name()
640            );
641        }
642        Ok(None)
643    }
644
645    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
646    /// for possible options.
647    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
648        // We have hard ordering requirements by default, meaning that order
649        // sensitive UDFs need their input orderings to satisfy their ordering
650        // requirements to generate correct results.
651        AggregateOrderSensitivity::HardRequirement
652    }
653
654    /// Optionally apply per-UDaF simplification / rewrite rules.
655    ///
656    /// This can be used to apply function specific simplification rules during
657    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
658    /// implementation does nothing.
659    ///
660    /// Note that DataFusion handles simplifying arguments and  "constant
661    /// folding" (replacing a function call with constant arguments such as
662    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
663    /// optimizations manually for specific UDFs.
664    ///
665    /// # Returns
666    ///
667    /// [None] if simplify is not defined or,
668    ///
669    /// Or, a closure with two arguments:
670    /// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked
671    /// * 'info': [crate::simplify::SimplifyInfo]
672    ///
673    /// closure returns simplified [Expr] or an error.
674    ///
675    /// # Notes
676    ///
677    /// The returned expression must have the same schema as the original
678    /// expression, including both the data type and nullability. For example,
679    /// if the original expression is nullable, the returned expression must
680    /// also be nullable, otherwise it may lead to schema verification errors
681    /// later in query planning.
682    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
683        None
684    }
685
686    /// Returns the reverse expression of the aggregate function.
687    fn reverse_expr(&self) -> ReversedUDAF {
688        ReversedUDAF::NotSupported
689    }
690
691    /// Coerce arguments of a function call to types that the function can evaluate.
692    ///
693    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
694    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
695    /// cases
696    ///
697    /// See the [type coercion module](crate::type_coercion)
698    /// documentation for more details on type coercion
699    ///
700    /// For example, if your function requires a floating point arguments, but the user calls
701    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
702    /// to ensure the argument was cast to `1::double`
703    ///
704    /// # Parameters
705    /// * `arg_types`: The argument types of the arguments  this function with
706    ///
707    /// # Return value
708    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
709    /// arguments to these specific types.
710    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
711        not_impl_err!("Function {} does not implement coerce_types", self.name())
712    }
713
714    /// If this function is max, return true
715    /// If the function is min, return false
716    /// Otherwise return None (the default)
717    ///
718    ///
719    /// Note: this is used to use special aggregate implementations in certain conditions
720    fn is_descending(&self) -> Option<bool> {
721        None
722    }
723
724    /// Return the value of this aggregate function if it can be determined
725    /// entirely from statistics and arguments.
726    ///
727    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
728    /// improving query performance.
729    ///
730    /// For example, if the minimum value of column `x` is known to be `42` from
731    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
732    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
733        None
734    }
735
736    /// Returns default value of the function given the input is all `null`.
737    ///
738    /// Most of the aggregate function return Null if input is Null,
739    /// while `count` returns 0 if input is Null
740    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
741        ScalarValue::try_from(data_type)
742    }
743
744    /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` SQL clause,
745    /// return `true`. Otherwise, return `false` which will cause an error to be
746    /// raised during SQL parsing if these clauses are detected for this function.
747    ///
748    /// Functions which implement this as `true` are expected to handle the resulting
749    /// null handling config present in [`AccumulatorArgs`], `ignore_nulls`.
750    fn supports_null_handling_clause(&self) -> bool {
751        false
752    }
753
754    /// If this function supports the `WITHIN GROUP (ORDER BY column [ASC|DESC])`
755    /// SQL syntax, return `true`. Otherwise, return `false` (default) which will
756    /// cause an error when parsing SQL where this syntax is detected for this
757    /// function.
758    ///
759    /// This function should return `true` for ordered-set aggregate functions
760    /// only.
761    ///
762    /// # Ordered-set aggregate functions
763    ///
764    /// Ordered-set aggregate functions allow specifying a sort order that affects
765    /// how the function calculates its result, unlike other aggregate functions
766    /// like `sum` or `count`. For example, `percentile_cont` is an ordered-set
767    /// aggregate function that calculates the exact percentile value from a list
768    /// of values; the output of calculating the `0.75` percentile depends on if
769    /// you're calculating on an ascending or descending list of values.
770    ///
771    /// An example of how an ordered-set aggregate function is called with the
772    /// `WITHIN GROUP` SQL syntax:
773    ///
774    /// ```sql
775    /// -- Ascending
776    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
777    /// -- Default ordering is ascending if not explicitly specified
778    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1) FROM table;
779    /// -- Descending
780    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 DESC) FROM table;
781    /// ```
782    ///
783    /// This calculates the `0.75` percentile of the column `c1` from `table`,
784    /// according to the specific ordering. The column specified in the `WITHIN GROUP`
785    /// ordering clause is taken as the column to calculate values on; specifying
786    /// the `WITHIN GROUP` clause is optional so these queries are equivalent:
787    ///
788    /// ```sql
789    /// -- If no WITHIN GROUP is specified then default ordering is implementation
790    /// -- dependent; in this case ascending for percentile_cont
791    /// SELECT percentile_cont(c1, 0.75) FROM table;
792    /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
793    /// ```
794    ///
795    /// Aggregate UDFs can define their default ordering if the function is called
796    /// without the `WITHIN GROUP` clause, though a default of ascending is the
797    /// standard practice.
798    ///
799    /// Ordered-set aggregate function implementations are responsible for handling
800    /// the input sort order themselves (e.g. `percentile_cont` must buffer and
801    /// sort the values internally). That is, DataFusion does not introduce any
802    /// kind of sort into the plan for these functions with this syntax.
803    fn supports_within_group_clause(&self) -> bool {
804        false
805    }
806
807    /// Returns the documentation for this Aggregate UDF.
808    ///
809    /// Documentation can be accessed programmatically as well as
810    /// generating publicly facing documentation.
811    fn documentation(&self) -> Option<&Documentation> {
812        None
813    }
814
815    /// Indicates whether the aggregation function is monotonic as a set
816    /// function. See [`SetMonotonicity`] for details.
817    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
818        SetMonotonicity::NotMonotonic
819    }
820}
821
822impl PartialEq for dyn AggregateUDFImpl {
823    fn eq(&self, other: &Self) -> bool {
824        self.dyn_eq(other.as_any())
825    }
826}
827
828impl PartialOrd for dyn AggregateUDFImpl {
829    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
830        match self.name().partial_cmp(other.name()) {
831            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
832            cmp => cmp,
833        }
834        // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
835        .filter(|cmp| *cmp != Ordering::Equal || self == other)
836    }
837}
838
839/// Encapsulates default implementation of [`AggregateUDFImpl::schema_name`].
840pub fn udaf_default_schema_name<F: AggregateUDFImpl + ?Sized>(
841    func: &F,
842    params: &AggregateFunctionParams,
843) -> Result<String> {
844    let AggregateFunctionParams {
845        args,
846        distinct,
847        filter,
848        order_by,
849        null_treatment,
850    } = params;
851
852    // exclude the first function argument(= column) in ordered set aggregate function,
853    // because it is duplicated with the WITHIN GROUP clause in schema name.
854    let args = if func.supports_within_group_clause() && !order_by.is_empty() {
855        &args[1..]
856    } else {
857        &args[..]
858    };
859
860    let mut schema_name = String::new();
861
862    schema_name.write_fmt(format_args!(
863        "{}({}{})",
864        func.name(),
865        if *distinct { "DISTINCT " } else { "" },
866        schema_name_from_exprs_comma_separated_without_space(args)?
867    ))?;
868
869    if let Some(null_treatment) = null_treatment {
870        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
871    }
872
873    if let Some(filter) = filter {
874        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
875    };
876
877    if !order_by.is_empty() {
878        let clause = match func.supports_within_group_clause() {
879            true => "WITHIN GROUP",
880            false => "ORDER BY",
881        };
882
883        schema_name.write_fmt(format_args!(
884            " {} [{}]",
885            clause,
886            schema_name_from_sorts(order_by)?
887        ))?;
888    };
889
890    Ok(schema_name)
891}
892
893/// Encapsulates default implementation of [`AggregateUDFImpl::human_display`].
894pub fn udaf_default_human_display<F: AggregateUDFImpl + ?Sized>(
895    func: &F,
896    params: &AggregateFunctionParams,
897) -> Result<String> {
898    let AggregateFunctionParams {
899        args,
900        distinct,
901        filter,
902        order_by,
903        null_treatment,
904    } = params;
905
906    let mut schema_name = String::new();
907
908    schema_name.write_fmt(format_args!(
909        "{}({}{})",
910        func.name(),
911        if *distinct { "DISTINCT " } else { "" },
912        ExprListDisplay::comma_separated(args.as_slice())
913    ))?;
914
915    if let Some(null_treatment) = null_treatment {
916        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
917    }
918
919    if let Some(filter) = filter {
920        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
921    };
922
923    if !order_by.is_empty() {
924        schema_name.write_fmt(format_args!(
925            " ORDER BY [{}]",
926            schema_name_from_sorts(order_by)?
927        ))?;
928    };
929
930    Ok(schema_name)
931}
932
933/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_schema_name`].
934pub fn udaf_default_window_function_schema_name<F: AggregateUDFImpl + ?Sized>(
935    func: &F,
936    params: &WindowFunctionParams,
937) -> Result<String> {
938    let WindowFunctionParams {
939        args,
940        partition_by,
941        order_by,
942        window_frame,
943        filter,
944        null_treatment,
945        distinct,
946    } = params;
947
948    let mut schema_name = String::new();
949
950    // Inject DISTINCT into the schema name when requested
951    if *distinct {
952        schema_name.write_fmt(format_args!(
953            "{}(DISTINCT {})",
954            func.name(),
955            schema_name_from_exprs(args)?
956        ))?;
957    } else {
958        schema_name.write_fmt(format_args!(
959            "{}({})",
960            func.name(),
961            schema_name_from_exprs(args)?
962        ))?;
963    }
964
965    if let Some(null_treatment) = null_treatment {
966        schema_name.write_fmt(format_args!(" {null_treatment}"))?;
967    }
968
969    if let Some(filter) = filter {
970        schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
971    }
972
973    if !partition_by.is_empty() {
974        schema_name.write_fmt(format_args!(
975            " PARTITION BY [{}]",
976            schema_name_from_exprs(partition_by)?
977        ))?;
978    }
979
980    if !order_by.is_empty() {
981        schema_name.write_fmt(format_args!(
982            " ORDER BY [{}]",
983            schema_name_from_sorts(order_by)?
984        ))?;
985    }
986
987    schema_name.write_fmt(format_args!(" {window_frame}"))?;
988
989    Ok(schema_name)
990}
991
992/// Encapsulates default implementation of [`AggregateUDFImpl::display_name`].
993pub fn udaf_default_display_name<F: AggregateUDFImpl + ?Sized>(
994    func: &F,
995    params: &AggregateFunctionParams,
996) -> Result<String> {
997    let AggregateFunctionParams {
998        args,
999        distinct,
1000        filter,
1001        order_by,
1002        null_treatment,
1003    } = params;
1004
1005    let mut display_name = String::new();
1006
1007    display_name.write_fmt(format_args!(
1008        "{}({}{})",
1009        func.name(),
1010        if *distinct { "DISTINCT " } else { "" },
1011        expr_vec_fmt!(args)
1012    ))?;
1013
1014    if let Some(nt) = null_treatment {
1015        display_name.write_fmt(format_args!(" {nt}"))?;
1016    }
1017    if let Some(fe) = filter {
1018        display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1019    }
1020    if !order_by.is_empty() {
1021        display_name.write_fmt(format_args!(
1022            " ORDER BY [{}]",
1023            order_by
1024                .iter()
1025                .map(|o| format!("{o}"))
1026                .collect::<Vec<String>>()
1027                .join(", ")
1028        ))?;
1029    }
1030
1031    Ok(display_name)
1032}
1033
1034/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_display_name`].
1035pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>(
1036    func: &F,
1037    params: &WindowFunctionParams,
1038) -> Result<String> {
1039    let WindowFunctionParams {
1040        args,
1041        partition_by,
1042        order_by,
1043        window_frame,
1044        filter,
1045        null_treatment,
1046        distinct,
1047    } = params;
1048
1049    let mut display_name = String::new();
1050
1051    if *distinct {
1052        display_name.write_fmt(format_args!(
1053            "{}(DISTINCT {})",
1054            func.name(),
1055            expr_vec_fmt!(args)
1056        ))?;
1057    } else {
1058        display_name.write_fmt(format_args!(
1059            "{}({})",
1060            func.name(),
1061            expr_vec_fmt!(args)
1062        ))?;
1063    }
1064
1065    if let Some(null_treatment) = null_treatment {
1066        display_name.write_fmt(format_args!(" {null_treatment}"))?;
1067    }
1068
1069    if let Some(fe) = filter {
1070        display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1071    }
1072
1073    if !partition_by.is_empty() {
1074        display_name.write_fmt(format_args!(
1075            " PARTITION BY [{}]",
1076            expr_vec_fmt!(partition_by)
1077        ))?;
1078    }
1079
1080    if !order_by.is_empty() {
1081        display_name
1082            .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
1083    };
1084
1085    display_name.write_fmt(format_args!(
1086        " {} BETWEEN {} AND {}",
1087        window_frame.units, window_frame.start_bound, window_frame.end_bound
1088    ))?;
1089
1090    Ok(display_name)
1091}
1092
1093/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`].
1094pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>(
1095    func: &F,
1096    arg_fields: &[FieldRef],
1097) -> Result<FieldRef> {
1098    let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
1099    let data_type = func.return_type(&arg_types)?;
1100
1101    Ok(Arc::new(Field::new(
1102        func.name(),
1103        data_type,
1104        func.is_nullable(),
1105    )))
1106}
1107
1108pub enum ReversedUDAF {
1109    /// The expression is the same as the original expression, like SUM, COUNT
1110    Identical,
1111    /// The expression does not support reverse calculation
1112    NotSupported,
1113    /// The expression is different from the original expression
1114    Reversed(Arc<AggregateUDF>),
1115}
1116
1117/// AggregateUDF that adds an alias to the underlying function. It is better to
1118/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1119#[derive(Debug, PartialEq, Eq, Hash)]
1120struct AliasedAggregateUDFImpl {
1121    inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
1122    aliases: Vec<String>,
1123}
1124
1125impl AliasedAggregateUDFImpl {
1126    pub fn new(
1127        inner: Arc<dyn AggregateUDFImpl>,
1128        new_aliases: impl IntoIterator<Item = &'static str>,
1129    ) -> Self {
1130        let mut aliases = inner.aliases().to_vec();
1131        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1132
1133        Self {
1134            inner: inner.into(),
1135            aliases,
1136        }
1137    }
1138}
1139
1140#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1141impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1142    fn as_any(&self) -> &dyn Any {
1143        self
1144    }
1145
1146    fn name(&self) -> &str {
1147        self.inner.name()
1148    }
1149
1150    fn signature(&self) -> &Signature {
1151        self.inner.signature()
1152    }
1153
1154    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1155        self.inner.return_type(arg_types)
1156    }
1157
1158    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1159        self.inner.accumulator(acc_args)
1160    }
1161
1162    fn aliases(&self) -> &[String] {
1163        &self.aliases
1164    }
1165
1166    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1167        self.inner.schema_name(params)
1168    }
1169
1170    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
1171        self.inner.human_display(params)
1172    }
1173
1174    fn window_function_schema_name(
1175        &self,
1176        params: &WindowFunctionParams,
1177    ) -> Result<String> {
1178        self.inner.window_function_schema_name(params)
1179    }
1180
1181    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1182        self.inner.display_name(params)
1183    }
1184
1185    fn window_function_display_name(
1186        &self,
1187        params: &WindowFunctionParams,
1188    ) -> Result<String> {
1189        self.inner.window_function_display_name(params)
1190    }
1191
1192    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1193        self.inner.state_fields(args)
1194    }
1195
1196    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1197        self.inner.groups_accumulator_supported(args)
1198    }
1199
1200    fn create_groups_accumulator(
1201        &self,
1202        args: AccumulatorArgs,
1203    ) -> Result<Box<dyn GroupsAccumulator>> {
1204        self.inner.create_groups_accumulator(args)
1205    }
1206
1207    fn create_sliding_accumulator(
1208        &self,
1209        args: AccumulatorArgs,
1210    ) -> Result<Box<dyn Accumulator>> {
1211        self.inner.accumulator(args)
1212    }
1213
1214    fn with_beneficial_ordering(
1215        self: Arc<Self>,
1216        beneficial_ordering: bool,
1217    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1218        Arc::clone(&self.inner)
1219            .with_beneficial_ordering(beneficial_ordering)
1220            .map(|udf| {
1221                udf.map(|udf| {
1222                    Arc::new(AliasedAggregateUDFImpl {
1223                        inner: udf.into(),
1224                        aliases: self.aliases.clone(),
1225                    }) as Arc<dyn AggregateUDFImpl>
1226                })
1227            })
1228    }
1229
1230    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1231        self.inner.order_sensitivity()
1232    }
1233
1234    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1235        self.inner.simplify()
1236    }
1237
1238    fn reverse_expr(&self) -> ReversedUDAF {
1239        self.inner.reverse_expr()
1240    }
1241
1242    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1243        self.inner.coerce_types(arg_types)
1244    }
1245
1246    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1247        self.inner.return_field(arg_fields)
1248    }
1249
1250    fn is_nullable(&self) -> bool {
1251        self.inner.is_nullable()
1252    }
1253
1254    fn is_descending(&self) -> Option<bool> {
1255        self.inner.is_descending()
1256    }
1257
1258    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1259        self.inner.value_from_stats(statistics_args)
1260    }
1261
1262    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
1263        self.inner.default_value(data_type)
1264    }
1265
1266    fn supports_null_handling_clause(&self) -> bool {
1267        self.inner.supports_null_handling_clause()
1268    }
1269
1270    fn supports_within_group_clause(&self) -> bool {
1271        self.inner.supports_within_group_clause()
1272    }
1273
1274    fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
1275        self.inner.set_monotonicity(data_type)
1276    }
1277
1278    fn documentation(&self) -> Option<&Documentation> {
1279        self.inner.documentation()
1280    }
1281}
1282
1283/// Indicates whether an aggregation function is monotonic as a set
1284/// function. A set function is monotonically increasing if its value
1285/// increases as its argument grows (as a set). Formally, `f` is a
1286/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1287/// is a superset of `T`.
1288///
1289/// For example `COUNT` and `MAX` are monotonically increasing as their
1290/// values always increase (or stay the same) as new values are seen. On
1291/// the other hand, `MIN` is monotonically decreasing as its value always
1292/// decreases or stays the same as new values are seen.
1293#[derive(Debug, Clone, PartialEq)]
1294pub enum SetMonotonicity {
1295    /// Aggregate value increases or stays the same as the input set grows.
1296    Increasing,
1297    /// Aggregate value decreases or stays the same as the input set grows.
1298    Decreasing,
1299    /// Aggregate value may increase, decrease, or stay the same as the input
1300    /// set grows.
1301    NotMonotonic,
1302}
1303
1304#[cfg(test)]
1305mod test {
1306    use crate::{AggregateUDF, AggregateUDFImpl};
1307    use arrow::datatypes::{DataType, FieldRef};
1308    use datafusion_common::Result;
1309    use datafusion_expr_common::accumulator::Accumulator;
1310    use datafusion_expr_common::signature::{Signature, Volatility};
1311    use datafusion_functions_aggregate_common::accumulator::{
1312        AccumulatorArgs, StateFieldsArgs,
1313    };
1314    use std::any::Any;
1315    use std::cmp::Ordering;
1316    use std::hash::{DefaultHasher, Hash, Hasher};
1317
1318    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1319    struct AMeanUdf {
1320        signature: Signature,
1321    }
1322
1323    impl AMeanUdf {
1324        fn new() -> Self {
1325            Self {
1326                signature: Signature::uniform(
1327                    1,
1328                    vec![DataType::Float64],
1329                    Volatility::Immutable,
1330                ),
1331            }
1332        }
1333    }
1334
1335    impl AggregateUDFImpl for AMeanUdf {
1336        fn as_any(&self) -> &dyn Any {
1337            self
1338        }
1339        fn name(&self) -> &str {
1340            "a"
1341        }
1342        fn signature(&self) -> &Signature {
1343            &self.signature
1344        }
1345        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1346            unimplemented!()
1347        }
1348        fn accumulator(
1349            &self,
1350            _acc_args: AccumulatorArgs,
1351        ) -> Result<Box<dyn Accumulator>> {
1352            unimplemented!()
1353        }
1354        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1355            unimplemented!()
1356        }
1357    }
1358
1359    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1360    struct BMeanUdf {
1361        signature: Signature,
1362    }
1363    impl BMeanUdf {
1364        fn new() -> Self {
1365            Self {
1366                signature: Signature::uniform(
1367                    1,
1368                    vec![DataType::Float64],
1369                    Volatility::Immutable,
1370                ),
1371            }
1372        }
1373    }
1374
1375    impl AggregateUDFImpl for BMeanUdf {
1376        fn as_any(&self) -> &dyn Any {
1377            self
1378        }
1379        fn name(&self) -> &str {
1380            "b"
1381        }
1382        fn signature(&self) -> &Signature {
1383            &self.signature
1384        }
1385        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1386            unimplemented!()
1387        }
1388        fn accumulator(
1389            &self,
1390            _acc_args: AccumulatorArgs,
1391        ) -> Result<Box<dyn Accumulator>> {
1392            unimplemented!()
1393        }
1394        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1395            unimplemented!()
1396        }
1397    }
1398
1399    #[test]
1400    fn test_partial_eq() {
1401        let a1 = AggregateUDF::from(AMeanUdf::new());
1402        let a2 = AggregateUDF::from(AMeanUdf::new());
1403        let eq = a1 == a2;
1404        assert!(eq);
1405        assert_eq!(a1, a2);
1406        assert_eq!(hash(a1), hash(a2));
1407    }
1408
1409    #[test]
1410    fn test_partial_ord() {
1411        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1412        // not intended to exhaustively test all possibilities
1413        let a1 = AggregateUDF::from(AMeanUdf::new());
1414        let a2 = AggregateUDF::from(AMeanUdf::new());
1415        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1416
1417        let b1 = AggregateUDF::from(BMeanUdf::new());
1418        assert!(a1 < b1);
1419        assert!(!(a1 == b1));
1420    }
1421
1422    fn hash<T: Hash>(value: T) -> u64 {
1423        let hasher = &mut DefaultHasher::new();
1424        value.hash(hasher);
1425        hasher.finish()
1426    }
1427}