datafusion_expr/udaf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err};
30use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
31use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
32
33use crate::expr::{
34 AggregateFunction, AggregateFunctionParams, ExprListDisplay, WindowFunctionParams,
35 schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
36 schema_name_from_sorts,
37};
38use crate::function::{
39 AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
40};
41use crate::groups_accumulator::GroupsAccumulator;
42use crate::udf_eq::UdfEq;
43use crate::utils::AggregateOrderSensitivity;
44use crate::utils::format_state_name;
45use crate::{Accumulator, Expr, expr_vec_fmt};
46use crate::{Documentation, Signature};
47
48/// Logical representation of a user-defined [aggregate function] (UDAF).
49///
50/// An aggregate function combines the values from multiple input rows
51/// into a single output "aggregate" (summary) row. It is different
52/// from a scalar function because it is stateful across batches. User
53/// defined aggregate functions can be used as normal SQL aggregate
54/// functions (`GROUP BY` clause) as well as window functions (`OVER`
55/// clause).
56///
57/// `AggregateUDF` provides DataFusion the information needed to plan and call
58/// aggregate functions, including name, type information, and a factory
59/// function to create an [`Accumulator`] instance, to perform the actual
60/// aggregation.
61///
62/// For more information, please see [the examples]:
63///
64/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
65///
66/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
67/// access (examples in [`advanced_udaf.rs`]).
68///
69/// # API Note
70/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
71/// compatibility with the older API.
72///
73/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
74/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
75/// [`Accumulator`]: Accumulator
76/// [`create_udaf`]: crate::expr_fn::create_udaf
77/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udaf.rs
78/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
79#[derive(Debug, Clone, PartialOrd)]
80pub struct AggregateUDF {
81 inner: Arc<dyn AggregateUDFImpl>,
82}
83
84impl PartialEq for AggregateUDF {
85 fn eq(&self, other: &Self) -> bool {
86 self.inner.dyn_eq(other.inner.as_any())
87 }
88}
89
90impl Eq for AggregateUDF {}
91
92impl Hash for AggregateUDF {
93 fn hash<H: Hasher>(&self, state: &mut H) {
94 self.inner.dyn_hash(state)
95 }
96}
97
98impl fmt::Display for AggregateUDF {
99 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
100 write!(f, "{}", self.name())
101 }
102}
103
104/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
105#[derive(Debug)]
106pub struct StatisticsArgs<'a> {
107 /// The statistics of the aggregate input
108 pub statistics: &'a Statistics,
109 /// The resolved return type of the aggregate function
110 pub return_type: &'a DataType,
111 /// Whether the aggregate function is distinct.
112 ///
113 /// ```sql
114 /// SELECT COUNT(DISTINCT column1) FROM t;
115 /// ```
116 pub is_distinct: bool,
117 /// The physical expression of arguments the aggregate function takes.
118 pub exprs: &'a [Arc<dyn PhysicalExpr>],
119}
120
121impl AggregateUDF {
122 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
123 ///
124 /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
125 pub fn new_from_impl<F>(fun: F) -> AggregateUDF
126 where
127 F: AggregateUDFImpl + 'static,
128 {
129 Self::new_from_shared_impl(Arc::new(fun))
130 }
131
132 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
133 pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
134 Self { inner: fun }
135 }
136
137 /// Return the underlying [`AggregateUDFImpl`] trait object for this function
138 pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
139 &self.inner
140 }
141
142 /// Adds additional names that can be used to invoke this function, in
143 /// addition to `name`
144 ///
145 /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
146 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
147 Self::new_from_impl(AliasedAggregateUDFImpl::new(
148 Arc::clone(&self.inner),
149 aliases,
150 ))
151 }
152
153 /// Creates an [`Expr`] that calls the aggregate function.
154 ///
155 /// This utility allows using the UDAF without requiring access to
156 /// the registry, such as with the DataFrame API.
157 pub fn call(&self, args: Vec<Expr>) -> Expr {
158 Expr::AggregateFunction(AggregateFunction::new_udf(
159 Arc::new(self.clone()),
160 args,
161 false,
162 None,
163 vec![],
164 None,
165 ))
166 }
167
168 /// Returns this function's name
169 ///
170 /// See [`AggregateUDFImpl::name`] for more details.
171 pub fn name(&self) -> &str {
172 self.inner.name()
173 }
174
175 /// Returns the aliases for this function.
176 pub fn aliases(&self) -> &[String] {
177 self.inner.aliases()
178 }
179
180 /// See [`AggregateUDFImpl::schema_name`] for more details.
181 pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
182 self.inner.schema_name(params)
183 }
184
185 /// Returns a human readable expression.
186 ///
187 /// See [`Expr::human_display`] for details.
188 pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
189 self.inner.human_display(params)
190 }
191
192 pub fn window_function_schema_name(
193 &self,
194 params: &WindowFunctionParams,
195 ) -> Result<String> {
196 self.inner.window_function_schema_name(params)
197 }
198
199 /// See [`AggregateUDFImpl::display_name`] for more details.
200 pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
201 self.inner.display_name(params)
202 }
203
204 pub fn window_function_display_name(
205 &self,
206 params: &WindowFunctionParams,
207 ) -> Result<String> {
208 self.inner.window_function_display_name(params)
209 }
210
211 pub fn is_nullable(&self) -> bool {
212 self.inner.is_nullable()
213 }
214
215 /// Returns this function's signature (what input types are accepted)
216 ///
217 /// See [`AggregateUDFImpl::signature`] for more details.
218 pub fn signature(&self) -> &Signature {
219 self.inner.signature()
220 }
221
222 /// Return the type of the function given its input types
223 ///
224 /// See [`AggregateUDFImpl::return_type`] for more details.
225 pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
226 self.inner.return_type(args)
227 }
228
229 /// Return the field of the function given its input fields
230 ///
231 /// See [`AggregateUDFImpl::return_field`] for more details.
232 pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
233 self.inner.return_field(args)
234 }
235
236 /// Return an accumulator the given aggregate, given its return datatype
237 pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
238 self.inner.accumulator(acc_args)
239 }
240
241 /// Return the fields used to store the intermediate state for this aggregator, given
242 /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
243 /// for more details.
244 ///
245 /// This is used to support multi-phase aggregations
246 pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
247 self.inner.state_fields(args)
248 }
249
250 /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
251 pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
252 self.inner.groups_accumulator_supported(args)
253 }
254
255 /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
256 pub fn create_groups_accumulator(
257 &self,
258 args: AccumulatorArgs,
259 ) -> Result<Box<dyn GroupsAccumulator>> {
260 self.inner.create_groups_accumulator(args)
261 }
262
263 pub fn create_sliding_accumulator(
264 &self,
265 args: AccumulatorArgs,
266 ) -> Result<Box<dyn Accumulator>> {
267 self.inner.create_sliding_accumulator(args)
268 }
269
270 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
271 self.inner.coerce_types(arg_types)
272 }
273
274 /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
275 pub fn with_beneficial_ordering(
276 self,
277 beneficial_ordering: bool,
278 ) -> Result<Option<AggregateUDF>> {
279 self.inner
280 .with_beneficial_ordering(beneficial_ordering)
281 .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
282 }
283
284 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
285 /// for possible options.
286 pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
287 self.inner.order_sensitivity()
288 }
289
290 /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
291 /// generate same result with this `AggregateUDF` when iterated in reverse
292 /// order, and `None` if there is no such `AggregateUDF`).
293 pub fn reverse_udf(&self) -> ReversedUDAF {
294 self.inner.reverse_expr()
295 }
296
297 /// Do the function rewrite
298 ///
299 /// See [`AggregateUDFImpl::simplify`] for more details.
300 pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
301 self.inner.simplify()
302 }
303
304 /// Returns true if the function is max, false if the function is min
305 /// None in all other cases, used in certain optimizations for
306 /// or aggregate
307 pub fn is_descending(&self) -> Option<bool> {
308 self.inner.is_descending()
309 }
310
311 /// Return the value of this aggregate function if it can be determined
312 /// entirely from statistics and arguments.
313 ///
314 /// See [`AggregateUDFImpl::value_from_stats`] for more details.
315 pub fn value_from_stats(
316 &self,
317 statistics_args: &StatisticsArgs,
318 ) -> Option<ScalarValue> {
319 self.inner.value_from_stats(statistics_args)
320 }
321
322 /// See [`AggregateUDFImpl::default_value`] for more details.
323 pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
324 self.inner.default_value(data_type)
325 }
326
327 /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
328 pub fn supports_null_handling_clause(&self) -> bool {
329 self.inner.supports_null_handling_clause()
330 }
331
332 /// See [`AggregateUDFImpl::supports_within_group_clause`] for more details.
333 pub fn supports_within_group_clause(&self) -> bool {
334 self.inner.supports_within_group_clause()
335 }
336
337 /// Returns the documentation for this Aggregate UDF.
338 ///
339 /// Documentation can be accessed programmatically as well as
340 /// generating publicly facing documentation.
341 pub fn documentation(&self) -> Option<&Documentation> {
342 self.inner.documentation()
343 }
344}
345
346impl<F> From<F> for AggregateUDF
347where
348 F: AggregateUDFImpl + Send + Sync + 'static,
349{
350 fn from(fun: F) -> Self {
351 Self::new_from_impl(fun)
352 }
353}
354
355/// Trait for implementing [`AggregateUDF`].
356///
357/// This trait exposes the full API for implementing user defined aggregate functions and
358/// can be used to implement any function.
359///
360/// See [`advanced_udaf.rs`] for a full example with complete implementation and
361/// [`AggregateUDF`] for other available options.
362///
363/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
364///
365/// # Basic Example
366/// ```
367/// # use std::any::Any;
368/// # use std::sync::{Arc, LazyLock};
369/// # use arrow::datatypes::{DataType, FieldRef};
370/// # use datafusion_common::{DataFusionError, plan_err, Result};
371/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
372/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
373/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
374/// # use arrow::datatypes::Schema;
375/// # use arrow::datatypes::Field;
376///
377/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
378/// struct GeoMeanUdf {
379/// signature: Signature,
380/// }
381///
382/// impl GeoMeanUdf {
383/// fn new() -> Self {
384/// Self {
385/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
386/// }
387/// }
388/// }
389///
390/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
391/// Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
392/// .with_argument("arg1", "The Float64 number for the geometric mean")
393/// .build()
394/// });
395///
396/// fn get_doc() -> &'static Documentation {
397/// &DOCUMENTATION
398/// }
399///
400/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
401/// impl AggregateUDFImpl for GeoMeanUdf {
402/// fn as_any(&self) -> &dyn Any { self }
403/// fn name(&self) -> &str { "geo_mean" }
404/// fn signature(&self) -> &Signature { &self.signature }
405/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
406/// if !matches!(args.get(0), Some(&DataType::Float64)) {
407/// return plan_err!("geo_mean only accepts Float64 arguments");
408/// }
409/// Ok(DataType::Float64)
410/// }
411/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
412/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
413/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
414/// Ok(vec![
415/// Arc::new(args.return_field.as_ref().clone().with_name("value")),
416/// Arc::new(Field::new("ordering", DataType::UInt32, true))
417/// ])
418/// }
419/// fn documentation(&self) -> Option<&Documentation> {
420/// Some(get_doc())
421/// }
422/// }
423///
424/// // Create a new AggregateUDF from the implementation
425/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
426///
427/// // Call the function `geo_mean(col)`
428/// let expr = geometric_mean.call(vec![col("a")]);
429/// ```
430pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
431 /// Returns this object as an [`Any`] trait object
432 fn as_any(&self) -> &dyn Any;
433
434 /// Returns this function's name
435 fn name(&self) -> &str;
436
437 /// Returns any aliases (alternate names) for this function.
438 ///
439 /// Note: `aliases` should only include names other than [`Self::name`].
440 /// Defaults to `[]` (no aliases)
441 fn aliases(&self) -> &[String] {
442 &[]
443 }
444
445 /// Returns the name of the column this expression would create
446 ///
447 /// See [`Expr::schema_name`] for details
448 ///
449 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
450 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
451 udaf_default_schema_name(self, params)
452 }
453
454 /// Returns a human readable expression.
455 ///
456 /// See [`Expr::human_display`] for details.
457 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
458 udaf_default_human_display(self, params)
459 }
460
461 /// Returns the name of the column this expression would create
462 ///
463 /// See [`Expr::schema_name`] for details
464 ///
465 /// Different from `schema_name` in that it is used for window aggregate function
466 ///
467 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
468 fn window_function_schema_name(
469 &self,
470 params: &WindowFunctionParams,
471 ) -> Result<String> {
472 udaf_default_window_function_schema_name(self, params)
473 }
474
475 /// Returns the user-defined display name of function, given the arguments
476 ///
477 /// This can be used to customize the output column name generated by this
478 /// function.
479 ///
480 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
481 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
482 udaf_default_display_name(self, params)
483 }
484
485 /// Returns the user-defined display name of function, given the arguments
486 ///
487 /// This can be used to customize the output column name generated by this
488 /// function.
489 ///
490 /// Different from `display_name` in that it is used for window aggregate function
491 ///
492 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
493 fn window_function_display_name(
494 &self,
495 params: &WindowFunctionParams,
496 ) -> Result<String> {
497 udaf_default_window_function_display_name(self, params)
498 }
499
500 /// Returns the function's [`Signature`] for information about what input
501 /// types are accepted and the function's Volatility.
502 fn signature(&self) -> &Signature;
503
504 /// What [`DataType`] will be returned by this function, given the types of
505 /// the arguments
506 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
507
508 /// What type will be returned by this function, given the arguments?
509 ///
510 /// By default, this function calls [`Self::return_type`] with the
511 /// types of each argument.
512 ///
513 /// # Notes
514 ///
515 /// Most UDFs should implement [`Self::return_type`] and not this
516 /// function as the output type for most functions only depends on the types
517 /// of their inputs (e.g. `sum(f64)` is always `f64`).
518 ///
519 /// This function can be used for more advanced cases such as:
520 ///
521 /// 1. specifying nullability
522 /// 2. return types based on the **values** of the arguments (rather than
523 /// their **types**.
524 /// 3. return types based on metadata within the fields of the inputs
525 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
526 udaf_default_return_field(self, arg_fields)
527 }
528
529 /// Whether the aggregate function is nullable.
530 ///
531 /// Nullable means that the function could return `null` for any inputs.
532 /// For example, aggregate functions like `COUNT` always return a non null value
533 /// but others like `MIN` will return `NULL` if there is nullable input.
534 /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
535 fn is_nullable(&self) -> bool {
536 true
537 }
538
539 /// Return a new [`Accumulator`] that aggregates values for a specific
540 /// group during query execution.
541 ///
542 /// acc_args: [`AccumulatorArgs`] contains information about how the
543 /// aggregate function was called.
544 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
545
546 /// Return the fields used to store the intermediate state of this accumulator.
547 ///
548 /// See [`Accumulator::state`] for background information.
549 ///
550 /// args: [`StateFieldsArgs`] contains arguments passed to the
551 /// aggregate function's accumulator.
552 ///
553 /// # Notes:
554 ///
555 /// The default implementation returns a single state field named `name`
556 /// with the same type as `value_type`. This is suitable for aggregates such
557 /// as `SUM` or `MIN` where partial state can be combined by applying the
558 /// same aggregate.
559 ///
560 /// For aggregates such as `AVG` where the partial state is more complex
561 /// (e.g. a COUNT and a SUM), this method is used to define the additional
562 /// fields.
563 ///
564 /// The name of the fields must be unique within the query and thus should
565 /// be derived from `name`. See [`format_state_name`] for a utility function
566 /// to generate a unique name.
567 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
568 let fields = vec![
569 args.return_field
570 .as_ref()
571 .clone()
572 .with_name(format_state_name(args.name, "value")),
573 ];
574
575 Ok(fields
576 .into_iter()
577 .map(Arc::new)
578 .chain(args.ordering_fields.to_vec())
579 .collect())
580 }
581
582 /// If the aggregate expression has a specialized
583 /// [`GroupsAccumulator`] implementation. If this returns true,
584 /// `[Self::create_groups_accumulator]` will be called.
585 ///
586 /// # Notes
587 ///
588 /// Even if this function returns true, DataFusion will still use
589 /// [`Self::accumulator`] for certain queries, such as when this aggregate is
590 /// used as a window function or when there no GROUP BY columns in the
591 /// query.
592 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
593 false
594 }
595
596 /// Return a specialized [`GroupsAccumulator`] that manages state
597 /// for all groups.
598 ///
599 /// For maximum performance, a [`GroupsAccumulator`] should be
600 /// implemented in addition to [`Accumulator`].
601 fn create_groups_accumulator(
602 &self,
603 _args: AccumulatorArgs,
604 ) -> Result<Box<dyn GroupsAccumulator>> {
605 not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
606 }
607
608 /// Sliding accumulator is an alternative accumulator that can be used for
609 /// window functions. It has retract method to revert the previous update.
610 ///
611 /// See [retract_batch] for more details.
612 ///
613 /// [retract_batch]: Accumulator::retract_batch
614 fn create_sliding_accumulator(
615 &self,
616 args: AccumulatorArgs,
617 ) -> Result<Box<dyn Accumulator>> {
618 self.accumulator(args)
619 }
620
621 /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
622 /// satisfied by its input. If this is not the case, UDFs with order
623 /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
624 /// the correct result with possibly more work internally.
625 ///
626 /// # Returns
627 ///
628 /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
629 /// If the expression can benefit from existing input ordering, but does
630 /// not implement the method, returns an error. Order insensitive and hard
631 /// requirement aggregators return `Ok(None)`.
632 fn with_beneficial_ordering(
633 self: Arc<Self>,
634 _beneficial_ordering: bool,
635 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
636 if self.order_sensitivity().is_beneficial() {
637 return exec_err!(
638 "Should implement with satisfied for aggregator :{:?}",
639 self.name()
640 );
641 }
642 Ok(None)
643 }
644
645 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
646 /// for possible options.
647 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
648 // We have hard ordering requirements by default, meaning that order
649 // sensitive UDFs need their input orderings to satisfy their ordering
650 // requirements to generate correct results.
651 AggregateOrderSensitivity::HardRequirement
652 }
653
654 /// Optionally apply per-UDaF simplification / rewrite rules.
655 ///
656 /// This can be used to apply function specific simplification rules during
657 /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
658 /// implementation does nothing.
659 ///
660 /// Note that DataFusion handles simplifying arguments and "constant
661 /// folding" (replacing a function call with constant arguments such as
662 /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
663 /// optimizations manually for specific UDFs.
664 ///
665 /// # Returns
666 ///
667 /// [None] if simplify is not defined or,
668 ///
669 /// Or, a closure with two arguments:
670 /// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked
671 /// * 'info': [crate::simplify::SimplifyInfo]
672 ///
673 /// closure returns simplified [Expr] or an error.
674 ///
675 /// # Notes
676 ///
677 /// The returned expression must have the same schema as the original
678 /// expression, including both the data type and nullability. For example,
679 /// if the original expression is nullable, the returned expression must
680 /// also be nullable, otherwise it may lead to schema verification errors
681 /// later in query planning.
682 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
683 None
684 }
685
686 /// Returns the reverse expression of the aggregate function.
687 fn reverse_expr(&self) -> ReversedUDAF {
688 ReversedUDAF::NotSupported
689 }
690
691 /// Coerce arguments of a function call to types that the function can evaluate.
692 ///
693 /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
694 /// UDAFs should return one of the other variants of `TypeSignature` which handle common
695 /// cases
696 ///
697 /// See the [type coercion module](crate::type_coercion)
698 /// documentation for more details on type coercion
699 ///
700 /// For example, if your function requires a floating point arguments, but the user calls
701 /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
702 /// to ensure the argument was cast to `1::double`
703 ///
704 /// # Parameters
705 /// * `arg_types`: The argument types of the arguments this function with
706 ///
707 /// # Return value
708 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
709 /// arguments to these specific types.
710 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
711 not_impl_err!("Function {} does not implement coerce_types", self.name())
712 }
713
714 /// If this function is max, return true
715 /// If the function is min, return false
716 /// Otherwise return None (the default)
717 ///
718 ///
719 /// Note: this is used to use special aggregate implementations in certain conditions
720 fn is_descending(&self) -> Option<bool> {
721 None
722 }
723
724 /// Return the value of this aggregate function if it can be determined
725 /// entirely from statistics and arguments.
726 ///
727 /// Using a [`ScalarValue`] rather than a runtime computation can significantly
728 /// improving query performance.
729 ///
730 /// For example, if the minimum value of column `x` is known to be `42` from
731 /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
732 fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
733 None
734 }
735
736 /// Returns default value of the function given the input is all `null`.
737 ///
738 /// Most of the aggregate function return Null if input is Null,
739 /// while `count` returns 0 if input is Null
740 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
741 ScalarValue::try_from(data_type)
742 }
743
744 /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` SQL clause,
745 /// return `true`. Otherwise, return `false` which will cause an error to be
746 /// raised during SQL parsing if these clauses are detected for this function.
747 ///
748 /// Functions which implement this as `true` are expected to handle the resulting
749 /// null handling config present in [`AccumulatorArgs`], `ignore_nulls`.
750 fn supports_null_handling_clause(&self) -> bool {
751 false
752 }
753
754 /// If this function supports the `WITHIN GROUP (ORDER BY column [ASC|DESC])`
755 /// SQL syntax, return `true`. Otherwise, return `false` (default) which will
756 /// cause an error when parsing SQL where this syntax is detected for this
757 /// function.
758 ///
759 /// This function should return `true` for ordered-set aggregate functions
760 /// only.
761 ///
762 /// # Ordered-set aggregate functions
763 ///
764 /// Ordered-set aggregate functions allow specifying a sort order that affects
765 /// how the function calculates its result, unlike other aggregate functions
766 /// like `sum` or `count`. For example, `percentile_cont` is an ordered-set
767 /// aggregate function that calculates the exact percentile value from a list
768 /// of values; the output of calculating the `0.75` percentile depends on if
769 /// you're calculating on an ascending or descending list of values.
770 ///
771 /// An example of how an ordered-set aggregate function is called with the
772 /// `WITHIN GROUP` SQL syntax:
773 ///
774 /// ```sql
775 /// -- Ascending
776 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
777 /// -- Default ordering is ascending if not explicitly specified
778 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1) FROM table;
779 /// -- Descending
780 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 DESC) FROM table;
781 /// ```
782 ///
783 /// This calculates the `0.75` percentile of the column `c1` from `table`,
784 /// according to the specific ordering. The column specified in the `WITHIN GROUP`
785 /// ordering clause is taken as the column to calculate values on; specifying
786 /// the `WITHIN GROUP` clause is optional so these queries are equivalent:
787 ///
788 /// ```sql
789 /// -- If no WITHIN GROUP is specified then default ordering is implementation
790 /// -- dependent; in this case ascending for percentile_cont
791 /// SELECT percentile_cont(c1, 0.75) FROM table;
792 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
793 /// ```
794 ///
795 /// Aggregate UDFs can define their default ordering if the function is called
796 /// without the `WITHIN GROUP` clause, though a default of ascending is the
797 /// standard practice.
798 ///
799 /// Ordered-set aggregate function implementations are responsible for handling
800 /// the input sort order themselves (e.g. `percentile_cont` must buffer and
801 /// sort the values internally). That is, DataFusion does not introduce any
802 /// kind of sort into the plan for these functions with this syntax.
803 fn supports_within_group_clause(&self) -> bool {
804 false
805 }
806
807 /// Returns the documentation for this Aggregate UDF.
808 ///
809 /// Documentation can be accessed programmatically as well as
810 /// generating publicly facing documentation.
811 fn documentation(&self) -> Option<&Documentation> {
812 None
813 }
814
815 /// Indicates whether the aggregation function is monotonic as a set
816 /// function. See [`SetMonotonicity`] for details.
817 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
818 SetMonotonicity::NotMonotonic
819 }
820}
821
822impl PartialEq for dyn AggregateUDFImpl {
823 fn eq(&self, other: &Self) -> bool {
824 self.dyn_eq(other.as_any())
825 }
826}
827
828impl PartialOrd for dyn AggregateUDFImpl {
829 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
830 match self.name().partial_cmp(other.name()) {
831 Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
832 cmp => cmp,
833 }
834 // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
835 .filter(|cmp| *cmp != Ordering::Equal || self == other)
836 }
837}
838
839/// Encapsulates default implementation of [`AggregateUDFImpl::schema_name`].
840pub fn udaf_default_schema_name<F: AggregateUDFImpl + ?Sized>(
841 func: &F,
842 params: &AggregateFunctionParams,
843) -> Result<String> {
844 let AggregateFunctionParams {
845 args,
846 distinct,
847 filter,
848 order_by,
849 null_treatment,
850 } = params;
851
852 // exclude the first function argument(= column) in ordered set aggregate function,
853 // because it is duplicated with the WITHIN GROUP clause in schema name.
854 let args = if func.supports_within_group_clause() && !order_by.is_empty() {
855 &args[1..]
856 } else {
857 &args[..]
858 };
859
860 let mut schema_name = String::new();
861
862 schema_name.write_fmt(format_args!(
863 "{}({}{})",
864 func.name(),
865 if *distinct { "DISTINCT " } else { "" },
866 schema_name_from_exprs_comma_separated_without_space(args)?
867 ))?;
868
869 if let Some(null_treatment) = null_treatment {
870 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
871 }
872
873 if let Some(filter) = filter {
874 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
875 };
876
877 if !order_by.is_empty() {
878 let clause = match func.supports_within_group_clause() {
879 true => "WITHIN GROUP",
880 false => "ORDER BY",
881 };
882
883 schema_name.write_fmt(format_args!(
884 " {} [{}]",
885 clause,
886 schema_name_from_sorts(order_by)?
887 ))?;
888 };
889
890 Ok(schema_name)
891}
892
893/// Encapsulates default implementation of [`AggregateUDFImpl::human_display`].
894pub fn udaf_default_human_display<F: AggregateUDFImpl + ?Sized>(
895 func: &F,
896 params: &AggregateFunctionParams,
897) -> Result<String> {
898 let AggregateFunctionParams {
899 args,
900 distinct,
901 filter,
902 order_by,
903 null_treatment,
904 } = params;
905
906 let mut schema_name = String::new();
907
908 schema_name.write_fmt(format_args!(
909 "{}({}{})",
910 func.name(),
911 if *distinct { "DISTINCT " } else { "" },
912 ExprListDisplay::comma_separated(args.as_slice())
913 ))?;
914
915 if let Some(null_treatment) = null_treatment {
916 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
917 }
918
919 if let Some(filter) = filter {
920 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
921 };
922
923 if !order_by.is_empty() {
924 schema_name.write_fmt(format_args!(
925 " ORDER BY [{}]",
926 schema_name_from_sorts(order_by)?
927 ))?;
928 };
929
930 Ok(schema_name)
931}
932
933/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_schema_name`].
934pub fn udaf_default_window_function_schema_name<F: AggregateUDFImpl + ?Sized>(
935 func: &F,
936 params: &WindowFunctionParams,
937) -> Result<String> {
938 let WindowFunctionParams {
939 args,
940 partition_by,
941 order_by,
942 window_frame,
943 filter,
944 null_treatment,
945 distinct,
946 } = params;
947
948 let mut schema_name = String::new();
949
950 // Inject DISTINCT into the schema name when requested
951 if *distinct {
952 schema_name.write_fmt(format_args!(
953 "{}(DISTINCT {})",
954 func.name(),
955 schema_name_from_exprs(args)?
956 ))?;
957 } else {
958 schema_name.write_fmt(format_args!(
959 "{}({})",
960 func.name(),
961 schema_name_from_exprs(args)?
962 ))?;
963 }
964
965 if let Some(null_treatment) = null_treatment {
966 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
967 }
968
969 if let Some(filter) = filter {
970 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
971 }
972
973 if !partition_by.is_empty() {
974 schema_name.write_fmt(format_args!(
975 " PARTITION BY [{}]",
976 schema_name_from_exprs(partition_by)?
977 ))?;
978 }
979
980 if !order_by.is_empty() {
981 schema_name.write_fmt(format_args!(
982 " ORDER BY [{}]",
983 schema_name_from_sorts(order_by)?
984 ))?;
985 }
986
987 schema_name.write_fmt(format_args!(" {window_frame}"))?;
988
989 Ok(schema_name)
990}
991
992/// Encapsulates default implementation of [`AggregateUDFImpl::display_name`].
993pub fn udaf_default_display_name<F: AggregateUDFImpl + ?Sized>(
994 func: &F,
995 params: &AggregateFunctionParams,
996) -> Result<String> {
997 let AggregateFunctionParams {
998 args,
999 distinct,
1000 filter,
1001 order_by,
1002 null_treatment,
1003 } = params;
1004
1005 let mut display_name = String::new();
1006
1007 display_name.write_fmt(format_args!(
1008 "{}({}{})",
1009 func.name(),
1010 if *distinct { "DISTINCT " } else { "" },
1011 expr_vec_fmt!(args)
1012 ))?;
1013
1014 if let Some(nt) = null_treatment {
1015 display_name.write_fmt(format_args!(" {nt}"))?;
1016 }
1017 if let Some(fe) = filter {
1018 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1019 }
1020 if !order_by.is_empty() {
1021 display_name.write_fmt(format_args!(
1022 " ORDER BY [{}]",
1023 order_by
1024 .iter()
1025 .map(|o| format!("{o}"))
1026 .collect::<Vec<String>>()
1027 .join(", ")
1028 ))?;
1029 }
1030
1031 Ok(display_name)
1032}
1033
1034/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_display_name`].
1035pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>(
1036 func: &F,
1037 params: &WindowFunctionParams,
1038) -> Result<String> {
1039 let WindowFunctionParams {
1040 args,
1041 partition_by,
1042 order_by,
1043 window_frame,
1044 filter,
1045 null_treatment,
1046 distinct,
1047 } = params;
1048
1049 let mut display_name = String::new();
1050
1051 if *distinct {
1052 display_name.write_fmt(format_args!(
1053 "{}(DISTINCT {})",
1054 func.name(),
1055 expr_vec_fmt!(args)
1056 ))?;
1057 } else {
1058 display_name.write_fmt(format_args!(
1059 "{}({})",
1060 func.name(),
1061 expr_vec_fmt!(args)
1062 ))?;
1063 }
1064
1065 if let Some(null_treatment) = null_treatment {
1066 display_name.write_fmt(format_args!(" {null_treatment}"))?;
1067 }
1068
1069 if let Some(fe) = filter {
1070 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1071 }
1072
1073 if !partition_by.is_empty() {
1074 display_name.write_fmt(format_args!(
1075 " PARTITION BY [{}]",
1076 expr_vec_fmt!(partition_by)
1077 ))?;
1078 }
1079
1080 if !order_by.is_empty() {
1081 display_name
1082 .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
1083 };
1084
1085 display_name.write_fmt(format_args!(
1086 " {} BETWEEN {} AND {}",
1087 window_frame.units, window_frame.start_bound, window_frame.end_bound
1088 ))?;
1089
1090 Ok(display_name)
1091}
1092
1093/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`].
1094pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>(
1095 func: &F,
1096 arg_fields: &[FieldRef],
1097) -> Result<FieldRef> {
1098 let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
1099 let data_type = func.return_type(&arg_types)?;
1100
1101 Ok(Arc::new(Field::new(
1102 func.name(),
1103 data_type,
1104 func.is_nullable(),
1105 )))
1106}
1107
1108pub enum ReversedUDAF {
1109 /// The expression is the same as the original expression, like SUM, COUNT
1110 Identical,
1111 /// The expression does not support reverse calculation
1112 NotSupported,
1113 /// The expression is different from the original expression
1114 Reversed(Arc<AggregateUDF>),
1115}
1116
1117/// AggregateUDF that adds an alias to the underlying function. It is better to
1118/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1119#[derive(Debug, PartialEq, Eq, Hash)]
1120struct AliasedAggregateUDFImpl {
1121 inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
1122 aliases: Vec<String>,
1123}
1124
1125impl AliasedAggregateUDFImpl {
1126 pub fn new(
1127 inner: Arc<dyn AggregateUDFImpl>,
1128 new_aliases: impl IntoIterator<Item = &'static str>,
1129 ) -> Self {
1130 let mut aliases = inner.aliases().to_vec();
1131 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1132
1133 Self {
1134 inner: inner.into(),
1135 aliases,
1136 }
1137 }
1138}
1139
1140#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1141impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1142 fn as_any(&self) -> &dyn Any {
1143 self
1144 }
1145
1146 fn name(&self) -> &str {
1147 self.inner.name()
1148 }
1149
1150 fn signature(&self) -> &Signature {
1151 self.inner.signature()
1152 }
1153
1154 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1155 self.inner.return_type(arg_types)
1156 }
1157
1158 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1159 self.inner.accumulator(acc_args)
1160 }
1161
1162 fn aliases(&self) -> &[String] {
1163 &self.aliases
1164 }
1165
1166 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1167 self.inner.schema_name(params)
1168 }
1169
1170 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
1171 self.inner.human_display(params)
1172 }
1173
1174 fn window_function_schema_name(
1175 &self,
1176 params: &WindowFunctionParams,
1177 ) -> Result<String> {
1178 self.inner.window_function_schema_name(params)
1179 }
1180
1181 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1182 self.inner.display_name(params)
1183 }
1184
1185 fn window_function_display_name(
1186 &self,
1187 params: &WindowFunctionParams,
1188 ) -> Result<String> {
1189 self.inner.window_function_display_name(params)
1190 }
1191
1192 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1193 self.inner.state_fields(args)
1194 }
1195
1196 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1197 self.inner.groups_accumulator_supported(args)
1198 }
1199
1200 fn create_groups_accumulator(
1201 &self,
1202 args: AccumulatorArgs,
1203 ) -> Result<Box<dyn GroupsAccumulator>> {
1204 self.inner.create_groups_accumulator(args)
1205 }
1206
1207 fn create_sliding_accumulator(
1208 &self,
1209 args: AccumulatorArgs,
1210 ) -> Result<Box<dyn Accumulator>> {
1211 self.inner.accumulator(args)
1212 }
1213
1214 fn with_beneficial_ordering(
1215 self: Arc<Self>,
1216 beneficial_ordering: bool,
1217 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1218 Arc::clone(&self.inner)
1219 .with_beneficial_ordering(beneficial_ordering)
1220 .map(|udf| {
1221 udf.map(|udf| {
1222 Arc::new(AliasedAggregateUDFImpl {
1223 inner: udf.into(),
1224 aliases: self.aliases.clone(),
1225 }) as Arc<dyn AggregateUDFImpl>
1226 })
1227 })
1228 }
1229
1230 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1231 self.inner.order_sensitivity()
1232 }
1233
1234 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1235 self.inner.simplify()
1236 }
1237
1238 fn reverse_expr(&self) -> ReversedUDAF {
1239 self.inner.reverse_expr()
1240 }
1241
1242 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1243 self.inner.coerce_types(arg_types)
1244 }
1245
1246 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1247 self.inner.return_field(arg_fields)
1248 }
1249
1250 fn is_nullable(&self) -> bool {
1251 self.inner.is_nullable()
1252 }
1253
1254 fn is_descending(&self) -> Option<bool> {
1255 self.inner.is_descending()
1256 }
1257
1258 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1259 self.inner.value_from_stats(statistics_args)
1260 }
1261
1262 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
1263 self.inner.default_value(data_type)
1264 }
1265
1266 fn supports_null_handling_clause(&self) -> bool {
1267 self.inner.supports_null_handling_clause()
1268 }
1269
1270 fn supports_within_group_clause(&self) -> bool {
1271 self.inner.supports_within_group_clause()
1272 }
1273
1274 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
1275 self.inner.set_monotonicity(data_type)
1276 }
1277
1278 fn documentation(&self) -> Option<&Documentation> {
1279 self.inner.documentation()
1280 }
1281}
1282
1283/// Indicates whether an aggregation function is monotonic as a set
1284/// function. A set function is monotonically increasing if its value
1285/// increases as its argument grows (as a set). Formally, `f` is a
1286/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1287/// is a superset of `T`.
1288///
1289/// For example `COUNT` and `MAX` are monotonically increasing as their
1290/// values always increase (or stay the same) as new values are seen. On
1291/// the other hand, `MIN` is monotonically decreasing as its value always
1292/// decreases or stays the same as new values are seen.
1293#[derive(Debug, Clone, PartialEq)]
1294pub enum SetMonotonicity {
1295 /// Aggregate value increases or stays the same as the input set grows.
1296 Increasing,
1297 /// Aggregate value decreases or stays the same as the input set grows.
1298 Decreasing,
1299 /// Aggregate value may increase, decrease, or stay the same as the input
1300 /// set grows.
1301 NotMonotonic,
1302}
1303
1304#[cfg(test)]
1305mod test {
1306 use crate::{AggregateUDF, AggregateUDFImpl};
1307 use arrow::datatypes::{DataType, FieldRef};
1308 use datafusion_common::Result;
1309 use datafusion_expr_common::accumulator::Accumulator;
1310 use datafusion_expr_common::signature::{Signature, Volatility};
1311 use datafusion_functions_aggregate_common::accumulator::{
1312 AccumulatorArgs, StateFieldsArgs,
1313 };
1314 use std::any::Any;
1315 use std::cmp::Ordering;
1316 use std::hash::{DefaultHasher, Hash, Hasher};
1317
1318 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1319 struct AMeanUdf {
1320 signature: Signature,
1321 }
1322
1323 impl AMeanUdf {
1324 fn new() -> Self {
1325 Self {
1326 signature: Signature::uniform(
1327 1,
1328 vec![DataType::Float64],
1329 Volatility::Immutable,
1330 ),
1331 }
1332 }
1333 }
1334
1335 impl AggregateUDFImpl for AMeanUdf {
1336 fn as_any(&self) -> &dyn Any {
1337 self
1338 }
1339 fn name(&self) -> &str {
1340 "a"
1341 }
1342 fn signature(&self) -> &Signature {
1343 &self.signature
1344 }
1345 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1346 unimplemented!()
1347 }
1348 fn accumulator(
1349 &self,
1350 _acc_args: AccumulatorArgs,
1351 ) -> Result<Box<dyn Accumulator>> {
1352 unimplemented!()
1353 }
1354 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1355 unimplemented!()
1356 }
1357 }
1358
1359 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1360 struct BMeanUdf {
1361 signature: Signature,
1362 }
1363 impl BMeanUdf {
1364 fn new() -> Self {
1365 Self {
1366 signature: Signature::uniform(
1367 1,
1368 vec![DataType::Float64],
1369 Volatility::Immutable,
1370 ),
1371 }
1372 }
1373 }
1374
1375 impl AggregateUDFImpl for BMeanUdf {
1376 fn as_any(&self) -> &dyn Any {
1377 self
1378 }
1379 fn name(&self) -> &str {
1380 "b"
1381 }
1382 fn signature(&self) -> &Signature {
1383 &self.signature
1384 }
1385 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1386 unimplemented!()
1387 }
1388 fn accumulator(
1389 &self,
1390 _acc_args: AccumulatorArgs,
1391 ) -> Result<Box<dyn Accumulator>> {
1392 unimplemented!()
1393 }
1394 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1395 unimplemented!()
1396 }
1397 }
1398
1399 #[test]
1400 fn test_partial_eq() {
1401 let a1 = AggregateUDF::from(AMeanUdf::new());
1402 let a2 = AggregateUDF::from(AMeanUdf::new());
1403 let eq = a1 == a2;
1404 assert!(eq);
1405 assert_eq!(a1, a2);
1406 assert_eq!(hash(a1), hash(a2));
1407 }
1408
1409 #[test]
1410 fn test_partial_ord() {
1411 // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1412 // not intended to exhaustively test all possibilities
1413 let a1 = AggregateUDF::from(AMeanUdf::new());
1414 let a2 = AggregateUDF::from(AMeanUdf::new());
1415 assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1416
1417 let b1 = AggregateUDF::from(BMeanUdf::new());
1418 assert!(a1 < b1);
1419 assert!(!(a1 == b1));
1420 }
1421
1422 fn hash<T: Hash>(value: T) -> u64 {
1423 let hasher = &mut DefaultHasher::new();
1424 value.hash(hasher);
1425 hasher.finish()
1426 }
1427}