datafusion_expr/udaf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err};
30use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
31use datafusion_expr_common::operator::Operator;
32use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
33
34use crate::expr::{
35 AggregateFunction, AggregateFunctionParams, ExprListDisplay, WindowFunctionParams,
36 schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
37 schema_name_from_sorts,
38};
39use crate::function::{
40 AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
41};
42use crate::groups_accumulator::GroupsAccumulator;
43use crate::udf_eq::UdfEq;
44use crate::utils::AggregateOrderSensitivity;
45use crate::utils::format_state_name;
46use crate::{Accumulator, Expr, expr_vec_fmt};
47use crate::{Documentation, Signature};
48
49/// Logical representation of a user-defined [aggregate function] (UDAF).
50///
51/// An aggregate function combines the values from multiple input rows
52/// into a single output "aggregate" (summary) row. It is different
53/// from a scalar function because it is stateful across batches. User
54/// defined aggregate functions can be used as normal SQL aggregate
55/// functions (`GROUP BY` clause) as well as window functions (`OVER`
56/// clause).
57///
58/// `AggregateUDF` provides DataFusion the information needed to plan and call
59/// aggregate functions, including name, type information, and a factory
60/// function to create an [`Accumulator`] instance, to perform the actual
61/// aggregation.
62///
63/// For more information, please see [the examples]:
64///
65/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
66///
67/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
68/// access (examples in [`advanced_udaf.rs`]).
69///
70/// # API Note
71/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
72/// compatibility with the older API.
73///
74/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
75/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
76/// [`Accumulator`]: Accumulator
77/// [`create_udaf`]: crate::expr_fn::create_udaf
78/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udaf.rs
79/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
80#[derive(Debug, Clone, PartialOrd)]
81pub struct AggregateUDF {
82 inner: Arc<dyn AggregateUDFImpl>,
83}
84
85impl PartialEq for AggregateUDF {
86 fn eq(&self, other: &Self) -> bool {
87 self.inner.dyn_eq(other.inner.as_ref() as &dyn Any)
88 }
89}
90
91impl Eq for AggregateUDF {}
92
93impl Hash for AggregateUDF {
94 fn hash<H: Hasher>(&self, state: &mut H) {
95 self.inner.dyn_hash(state)
96 }
97}
98
99impl fmt::Display for AggregateUDF {
100 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
101 write!(f, "{}", self.name())
102 }
103}
104
105/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
106#[derive(Debug)]
107pub struct StatisticsArgs<'a> {
108 /// The statistics of the aggregate input
109 pub statistics: &'a Statistics,
110 /// The resolved return type of the aggregate function
111 pub return_type: &'a DataType,
112 /// Whether the aggregate function is distinct.
113 ///
114 /// ```sql
115 /// SELECT COUNT(DISTINCT column1) FROM t;
116 /// ```
117 pub is_distinct: bool,
118 /// The physical expression of arguments the aggregate function takes.
119 pub exprs: &'a [Arc<dyn PhysicalExpr>],
120}
121
122impl AggregateUDF {
123 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
124 ///
125 /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
126 pub fn new_from_impl<F>(fun: F) -> AggregateUDF
127 where
128 F: AggregateUDFImpl + 'static,
129 {
130 Self::new_from_shared_impl(Arc::new(fun))
131 }
132
133 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
134 pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
135 Self { inner: fun }
136 }
137
138 /// Return the underlying [`AggregateUDFImpl`] trait object for this function
139 pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
140 &self.inner
141 }
142
143 /// Adds additional names that can be used to invoke this function, in
144 /// addition to `name`
145 ///
146 /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
147 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
148 Self::new_from_impl(AliasedAggregateUDFImpl::new(
149 Arc::clone(&self.inner),
150 aliases,
151 ))
152 }
153
154 /// Creates an [`Expr`] that calls the aggregate function.
155 ///
156 /// This utility allows using the UDAF without requiring access to
157 /// the registry, such as with the DataFrame API.
158 pub fn call(&self, args: Vec<Expr>) -> Expr {
159 Expr::AggregateFunction(AggregateFunction::new_udf(
160 Arc::new(self.clone()),
161 args,
162 false,
163 None,
164 vec![],
165 None,
166 ))
167 }
168
169 /// Returns this function's name
170 ///
171 /// See [`AggregateUDFImpl::name`] for more details.
172 pub fn name(&self) -> &str {
173 self.inner.name()
174 }
175
176 /// Returns the aliases for this function.
177 pub fn aliases(&self) -> &[String] {
178 self.inner.aliases()
179 }
180
181 /// See [`AggregateUDFImpl::schema_name`] for more details.
182 pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
183 self.inner.schema_name(params)
184 }
185
186 /// Returns a human readable expression.
187 ///
188 /// See [`Expr::human_display`] for details.
189 pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
190 self.inner.human_display(params)
191 }
192
193 pub fn window_function_schema_name(
194 &self,
195 params: &WindowFunctionParams,
196 ) -> Result<String> {
197 self.inner.window_function_schema_name(params)
198 }
199
200 /// See [`AggregateUDFImpl::display_name`] for more details.
201 pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
202 self.inner.display_name(params)
203 }
204
205 pub fn window_function_display_name(
206 &self,
207 params: &WindowFunctionParams,
208 ) -> Result<String> {
209 self.inner.window_function_display_name(params)
210 }
211
212 pub fn is_nullable(&self) -> bool {
213 self.inner.is_nullable()
214 }
215
216 /// Returns this function's signature (what input types are accepted)
217 ///
218 /// See [`AggregateUDFImpl::signature`] for more details.
219 pub fn signature(&self) -> &Signature {
220 self.inner.signature()
221 }
222
223 /// Return the type of the function given its input types
224 ///
225 /// See [`AggregateUDFImpl::return_type`] for more details.
226 pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
227 self.inner.return_type(args)
228 }
229
230 /// Return the field of the function given its input fields
231 ///
232 /// See [`AggregateUDFImpl::return_field`] for more details.
233 pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
234 self.inner.return_field(args)
235 }
236
237 /// Return an accumulator the given aggregate, given its return datatype
238 pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
239 self.inner.accumulator(acc_args)
240 }
241
242 /// Return the fields used to store the intermediate state for this aggregator, given
243 /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
244 /// for more details.
245 ///
246 /// This is used to support multi-phase aggregations
247 pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
248 self.inner.state_fields(args)
249 }
250
251 /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
252 pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
253 self.inner.groups_accumulator_supported(args)
254 }
255
256 /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
257 pub fn create_groups_accumulator(
258 &self,
259 args: AccumulatorArgs,
260 ) -> Result<Box<dyn GroupsAccumulator>> {
261 self.inner.create_groups_accumulator(args)
262 }
263
264 pub fn create_sliding_accumulator(
265 &self,
266 args: AccumulatorArgs,
267 ) -> Result<Box<dyn Accumulator>> {
268 self.inner.create_sliding_accumulator(args)
269 }
270
271 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
272 self.inner.coerce_types(arg_types)
273 }
274
275 /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
276 pub fn with_beneficial_ordering(
277 self,
278 beneficial_ordering: bool,
279 ) -> Result<Option<AggregateUDF>> {
280 self.inner
281 .with_beneficial_ordering(beneficial_ordering)
282 .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
283 }
284
285 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
286 /// for possible options.
287 pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
288 self.inner.order_sensitivity()
289 }
290
291 /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
292 /// generate same result with this `AggregateUDF` when iterated in reverse
293 /// order, and `None` if there is no such `AggregateUDF`).
294 pub fn reverse_udf(&self) -> ReversedUDAF {
295 self.inner.reverse_expr()
296 }
297
298 /// Returns this aggregate function's simplification hook, if any.
299 ///
300 /// See [`AggregateUDFImpl::simplify`] for more details.
301 pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
302 self.inner.simplify()
303 }
304
305 /// Rewrite aggregate to have simpler arguments
306 ///
307 /// See [`AggregateUDFImpl::simplify_expr_op_literal`] for more details
308 pub fn simplify_expr_op_literal(
309 &self,
310 agg_function: &AggregateFunction,
311 arg: &Expr,
312 op: Operator,
313 lit: &Expr,
314 arg_is_left: bool,
315 ) -> Result<Option<Expr>> {
316 self.inner
317 .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
318 }
319
320 /// Returns true if the function is max, false if the function is min
321 /// None in all other cases, used in certain optimizations for
322 /// or aggregate
323 pub fn is_descending(&self) -> Option<bool> {
324 self.inner.is_descending()
325 }
326
327 /// Return the value of this aggregate function if it can be determined
328 /// entirely from statistics and arguments.
329 ///
330 /// See [`AggregateUDFImpl::value_from_stats`] for more details.
331 pub fn value_from_stats(
332 &self,
333 statistics_args: &StatisticsArgs,
334 ) -> Option<ScalarValue> {
335 self.inner.value_from_stats(statistics_args)
336 }
337
338 /// See [`AggregateUDFImpl::default_value`] for more details.
339 pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
340 self.inner.default_value(data_type)
341 }
342
343 /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
344 pub fn supports_null_handling_clause(&self) -> bool {
345 self.inner.supports_null_handling_clause()
346 }
347
348 /// See [`AggregateUDFImpl::supports_within_group_clause`] for more details.
349 pub fn supports_within_group_clause(&self) -> bool {
350 self.inner.supports_within_group_clause()
351 }
352
353 /// Returns the documentation for this Aggregate UDF.
354 ///
355 /// Documentation can be accessed programmatically as well as
356 /// generating publicly facing documentation.
357 pub fn documentation(&self) -> Option<&Documentation> {
358 self.inner.documentation()
359 }
360}
361
362impl<F> From<F> for AggregateUDF
363where
364 F: AggregateUDFImpl + Send + Sync + 'static,
365{
366 fn from(fun: F) -> Self {
367 Self::new_from_impl(fun)
368 }
369}
370
371/// Trait for implementing [`AggregateUDF`].
372///
373/// This trait exposes the full API for implementing user defined aggregate functions and
374/// can be used to implement any function.
375///
376/// See [`advanced_udaf.rs`] for a full example with complete implementation and
377/// [`AggregateUDF`] for other available options.
378///
379/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs
380///
381/// # Basic Example
382/// ```
383/// # use std::any::Any;
384/// # use std::sync::{Arc, LazyLock};
385/// # use arrow::datatypes::{DataType, FieldRef};
386/// # use datafusion_common::{DataFusionError, plan_err, Result};
387/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
388/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
389/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
390/// # use arrow::datatypes::Schema;
391/// # use arrow::datatypes::Field;
392///
393/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
394/// struct GeoMeanUdf {
395/// signature: Signature,
396/// }
397///
398/// impl GeoMeanUdf {
399/// fn new() -> Self {
400/// Self {
401/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
402/// }
403/// }
404/// }
405///
406/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
407/// Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
408/// .with_argument("arg1", "The Float64 number for the geometric mean")
409/// .build()
410/// });
411///
412/// fn get_doc() -> &'static Documentation {
413/// &DOCUMENTATION
414/// }
415///
416/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
417/// impl AggregateUDFImpl for GeoMeanUdf {
418/// fn name(&self) -> &str { "geo_mean" }
419/// fn signature(&self) -> &Signature { &self.signature }
420/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
421/// if !matches!(args.get(0), Some(&DataType::Float64)) {
422/// return plan_err!("geo_mean only accepts Float64 arguments");
423/// }
424/// Ok(DataType::Float64)
425/// }
426/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
427/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
428/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
429/// Ok(vec![
430/// Arc::new(args.return_field.as_ref().clone().with_name("value")),
431/// Arc::new(Field::new("ordering", DataType::UInt32, true))
432/// ])
433/// }
434/// fn documentation(&self) -> Option<&Documentation> {
435/// Some(get_doc())
436/// }
437/// }
438///
439/// // Create a new AggregateUDF from the implementation
440/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
441///
442/// // Call the function `geo_mean(col)`
443/// let expr = geometric_mean.call(vec![col("a")]);
444/// ```
445pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
446 /// Returns this function's name
447 fn name(&self) -> &str;
448
449 /// Returns any aliases (alternate names) for this function.
450 ///
451 /// Note: `aliases` should only include names other than [`Self::name`].
452 /// Defaults to `[]` (no aliases)
453 fn aliases(&self) -> &[String] {
454 &[]
455 }
456
457 /// Returns the name of the column this expression would create
458 ///
459 /// See [`Expr::schema_name`] for details
460 ///
461 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
462 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
463 udaf_default_schema_name(self, params)
464 }
465
466 /// Returns a human readable expression.
467 ///
468 /// See [`Expr::human_display`] for details.
469 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
470 udaf_default_human_display(self, params)
471 }
472
473 /// Returns the name of the column this expression would create
474 ///
475 /// See [`Expr::schema_name`] for details
476 ///
477 /// Different from `schema_name` in that it is used for window aggregate function
478 ///
479 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
480 fn window_function_schema_name(
481 &self,
482 params: &WindowFunctionParams,
483 ) -> Result<String> {
484 udaf_default_window_function_schema_name(self, params)
485 }
486
487 /// Returns the user-defined display name of function, given the arguments
488 ///
489 /// This can be used to customize the output column name generated by this
490 /// function.
491 ///
492 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
493 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
494 udaf_default_display_name(self, params)
495 }
496
497 /// Returns the user-defined display name of function, given the arguments
498 ///
499 /// This can be used to customize the output column name generated by this
500 /// function.
501 ///
502 /// Different from `display_name` in that it is used for window aggregate function
503 ///
504 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
505 fn window_function_display_name(
506 &self,
507 params: &WindowFunctionParams,
508 ) -> Result<String> {
509 udaf_default_window_function_display_name(self, params)
510 }
511
512 /// Returns the function's [`Signature`] for information about what input
513 /// types are accepted and the function's Volatility.
514 fn signature(&self) -> &Signature;
515
516 /// What [`DataType`] will be returned by this function, given the types of
517 /// the arguments
518 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
519
520 /// What type will be returned by this function, given the arguments?
521 ///
522 /// By default, this function calls [`Self::return_type`] with the
523 /// types of each argument.
524 ///
525 /// # Notes
526 ///
527 /// Most UDFs should implement [`Self::return_type`] and not this
528 /// function as the output type for most functions only depends on the types
529 /// of their inputs (e.g. `sum(f64)` is always `f64`).
530 ///
531 /// This function can be used for more advanced cases such as:
532 ///
533 /// 1. specifying nullability
534 /// 2. return types based on the **values** of the arguments (rather than
535 /// their **types**.
536 /// 3. return types based on metadata within the fields of the inputs
537 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
538 udaf_default_return_field(self, arg_fields)
539 }
540
541 /// Whether the aggregate function is nullable.
542 ///
543 /// Nullable means that the function could return `null` for any inputs.
544 /// For example, aggregate functions like `COUNT` always return a non null value
545 /// but others like `MIN` will return `NULL` if there is nullable input.
546 /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
547 fn is_nullable(&self) -> bool {
548 true
549 }
550
551 /// Return a new [`Accumulator`] that aggregates values for a specific
552 /// group during query execution.
553 ///
554 /// acc_args: [`AccumulatorArgs`] contains information about how the
555 /// aggregate function was called.
556 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
557
558 /// Return the fields used to store the intermediate state of this accumulator.
559 ///
560 /// See [`Accumulator::state`] for background information.
561 ///
562 /// args: [`StateFieldsArgs`] contains arguments passed to the
563 /// aggregate function's accumulator.
564 ///
565 /// # Notes:
566 ///
567 /// The default implementation returns a single state field named `name`
568 /// with the same type as `value_type`. This is suitable for aggregates such
569 /// as `SUM` or `MIN` where partial state can be combined by applying the
570 /// same aggregate.
571 ///
572 /// For aggregates such as `AVG` where the partial state is more complex
573 /// (e.g. a COUNT and a SUM), this method is used to define the additional
574 /// fields.
575 ///
576 /// The name of the fields must be unique within the query and thus should
577 /// be derived from `name`. See [`format_state_name`] for a utility function
578 /// to generate a unique name.
579 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
580 let fields = vec![
581 args.return_field
582 .as_ref()
583 .clone()
584 .with_name(format_state_name(args.name, "value")),
585 ];
586
587 Ok(fields
588 .into_iter()
589 .map(Arc::new)
590 .chain(args.ordering_fields.to_vec())
591 .collect())
592 }
593
594 /// If the aggregate expression has a specialized
595 /// [`GroupsAccumulator`] implementation. If this returns true,
596 /// `[Self::create_groups_accumulator]` will be called.
597 ///
598 /// # Notes
599 ///
600 /// Even if this function returns true, DataFusion will still use
601 /// [`Self::accumulator`] for certain queries, such as when this aggregate is
602 /// used as a window function or when there no GROUP BY columns in the
603 /// query.
604 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
605 false
606 }
607
608 /// Return a specialized [`GroupsAccumulator`] that manages state
609 /// for all groups.
610 ///
611 /// For maximum performance, a [`GroupsAccumulator`] should be
612 /// implemented in addition to [`Accumulator`].
613 fn create_groups_accumulator(
614 &self,
615 _args: AccumulatorArgs,
616 ) -> Result<Box<dyn GroupsAccumulator>> {
617 not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
618 }
619
620 /// Sliding accumulator is an alternative accumulator that can be used for
621 /// window functions. It has retract method to revert the previous update.
622 ///
623 /// See [retract_batch] for more details.
624 ///
625 /// [retract_batch]: Accumulator::retract_batch
626 fn create_sliding_accumulator(
627 &self,
628 args: AccumulatorArgs,
629 ) -> Result<Box<dyn Accumulator>> {
630 self.accumulator(args)
631 }
632
633 /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
634 /// satisfied by its input. If this is not the case, UDFs with order
635 /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
636 /// the correct result with possibly more work internally.
637 ///
638 /// # Returns
639 ///
640 /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
641 /// If the expression can benefit from existing input ordering, but does
642 /// not implement the method, returns an error. Order insensitive and hard
643 /// requirement aggregators return `Ok(None)`.
644 fn with_beneficial_ordering(
645 self: Arc<Self>,
646 _beneficial_ordering: bool,
647 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
648 if self.order_sensitivity().is_beneficial() {
649 return exec_err!(
650 "Should implement with satisfied for aggregator :{:?}",
651 self.name()
652 );
653 }
654 Ok(None)
655 }
656
657 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
658 /// for possible options.
659 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
660 // We have hard ordering requirements by default, meaning that order
661 // sensitive UDFs need their input orderings to satisfy their ordering
662 // requirements to generate correct results.
663 AggregateOrderSensitivity::HardRequirement
664 }
665
666 /// Returns an optional hook for simplifying this user-defined aggregate.
667 ///
668 /// Use this hook to apply function-specific rewrites during optimization.
669 /// The default implementation returns `None`.
670 ///
671 /// For example, `percentile_cont(x, 0.0)` and `percentile_cont(x, 1.0)` can
672 /// be rewritten to `MIN(x)` or `MAX(x)` depending on the `ORDER BY`
673 /// direction.
674 ///
675 /// DataFusion already simplifies arguments and performs constant folding
676 /// (for example, `my_add(1, 2) -> 3`). For nested expressions, the optimizer
677 /// runs simplification in multiple passes, so arguments are typically
678 /// simplified before this hook is invoked. As a result, UDF implementations
679 /// usually do not need to handle argument simplification themselves.
680 ///
681 /// See configuration `datafusion.optimizer.max_passes` for details on how many
682 /// optimization passes may be applied.
683 ///
684 /// # Returns
685 ///
686 /// `None` if simplify is not defined.
687 ///
688 /// Or, a closure ([`AggregateFunctionSimplification`]) invoked with:
689 /// * `aggregate_function`: [AggregateFunction] with already simplified
690 /// arguments
691 /// * `info`: [crate::simplify::SimplifyContext]
692 ///
693 /// The closure returns a simplified [Expr] or an error.
694 ///
695 /// # Notes
696 ///
697 /// The returned expression must have the same schema as the original
698 /// expression, including both the data type and nullability. For example,
699 /// if the original expression is nullable, the returned expression must
700 /// also be nullable, otherwise it may lead to schema verification errors
701 /// later in query planning.
702 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
703 None
704 }
705
706 /// Rewrite the aggregate to have simpler arguments
707 ///
708 /// This query pattern is not common in most real workloads, and most
709 /// aggregate implementations can safely ignore it. This API is included in
710 /// DataFusion because it is important for ClickBench Q29. See backstory
711 /// on <https://github.com/apache/datafusion/issues/15524>
712 ///
713 /// # Rewrite Overview
714 ///
715 /// The idea is to rewrite multiple aggregates with "complex arguments" into
716 /// ones with simpler arguments that can be optimized by common subexpression
717 /// elimination (CSE). At a high level the rewrite looks like
718 ///
719 /// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)`
720 ///
721 /// Into
722 ///
723 /// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)`
724 ///
725 /// While this rewrite may seem worse (slower) than the original as it
726 /// computes *more* aggregate expressions, the common subexpression
727 /// elimination (CSE) can then reduce the number of distinct aggregates the
728 /// query actually needs to compute with a rewrite like
729 ///
730 /// * `Projection(_A + 1*_B, _A + 2*_B)`
731 /// * ` Aggregate(_A = SUM(x), _B = COUNT(x))`
732 ///
733 /// This optimization is extremely important for ClickBench Q29, which has 90
734 /// such expressions for some reason, and so this optimization results in
735 /// only two aggregates being needed. The DataFusion optimizer will invoke
736 /// this method when it detects multiple aggregates in a query that share
737 /// arguments of the form `<arg> <op> <literal>`.
738 ///
739 /// # API
740 ///
741 /// If `agg_function` supports the rewrite, it should return a semantically
742 /// equivalent expression (likely with more aggregate expressions, but
743 /// simpler arguments)
744 ///
745 /// This is only called when:
746 /// 1. There are no "special" aggregate params (filters, null handling, etc)
747 /// 2. Aggregate functions with exactly one [`Expr`] argument
748 /// 3. There are no volatile expressions
749 ///
750 /// Arguments
751 /// * `agg_function`: the original aggregate function detected with complex
752 /// arguments.
753 /// * `arg`: The common argument shared across multiple aggregates (e.g. `x`
754 /// in the example above)
755 /// * `op`: the operator between the common argument and the literal (e.g.
756 /// `+` in `x + 1` or `1 + x`)
757 /// * `lit`: the literal argument (e.g. `1` or `2` in the example above)
758 /// * `arg_is_left`: whether the common argument is on the left or right of
759 /// the operator (e.g. `true` for `x + 1` and false for `1 + x`)
760 ///
761 /// The default implementation returns `None`, which is what most aggregates
762 /// should do.
763 fn simplify_expr_op_literal(
764 &self,
765 _agg_function: &AggregateFunction,
766 _arg: &Expr,
767 _op: Operator,
768 _lit: &Expr,
769 _arg_is_left: bool,
770 ) -> Result<Option<Expr>> {
771 Ok(None)
772 }
773
774 /// Returns the reverse expression of the aggregate function.
775 fn reverse_expr(&self) -> ReversedUDAF {
776 ReversedUDAF::NotSupported
777 }
778
779 /// Coerce arguments of a function call to types that the function can evaluate.
780 ///
781 /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
782 /// UDAFs should return one of the other variants of `TypeSignature` which handle common
783 /// cases
784 ///
785 /// See the [type coercion module](crate::type_coercion)
786 /// documentation for more details on type coercion
787 ///
788 /// For example, if your function requires a floating point arguments, but the user calls
789 /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
790 /// to ensure the argument was cast to `1::double`
791 ///
792 /// # Parameters
793 /// * `arg_types`: The argument types of the arguments this function with
794 ///
795 /// # Return value
796 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
797 /// arguments to these specific types.
798 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
799 not_impl_err!("Function {} does not implement coerce_types", self.name())
800 }
801
802 /// If this function is max, return true
803 /// If the function is min, return false
804 /// Otherwise return None (the default)
805 ///
806 ///
807 /// Note: this is used to use special aggregate implementations in certain conditions
808 fn is_descending(&self) -> Option<bool> {
809 None
810 }
811
812 /// Return the value of this aggregate function if it can be determined
813 /// entirely from statistics and arguments.
814 ///
815 /// Using a [`ScalarValue`] rather than a runtime computation can significantly
816 /// improving query performance.
817 ///
818 /// For example, if the minimum value of column `x` is known to be `42` from
819 /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
820 fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
821 None
822 }
823
824 /// Returns default value of the function given the input is all `null`.
825 ///
826 /// Most of the aggregate function return Null if input is Null,
827 /// while `count` returns 0 if input is Null
828 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
829 ScalarValue::try_from(data_type)
830 }
831
832 /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` SQL clause,
833 /// return `true`. Otherwise, return `false` which will cause an error to be
834 /// raised during SQL parsing if these clauses are detected for this function.
835 ///
836 /// Functions which implement this as `true` are expected to handle the resulting
837 /// null handling config present in [`AccumulatorArgs`], `ignore_nulls`.
838 fn supports_null_handling_clause(&self) -> bool {
839 false
840 }
841
842 /// If this function supports the `WITHIN GROUP (ORDER BY column [ASC|DESC])`
843 /// SQL syntax, return `true`. Otherwise, return `false` (default) which will
844 /// cause an error when parsing SQL where this syntax is detected for this
845 /// function.
846 ///
847 /// This function should return `true` for ordered-set aggregate functions
848 /// only.
849 ///
850 /// # Ordered-set aggregate functions
851 ///
852 /// Ordered-set aggregate functions allow specifying a sort order that affects
853 /// how the function calculates its result, unlike other aggregate functions
854 /// like `sum` or `count`. For example, `percentile_cont` is an ordered-set
855 /// aggregate function that calculates the exact percentile value from a list
856 /// of values; the output of calculating the `0.75` percentile depends on if
857 /// you're calculating on an ascending or descending list of values.
858 ///
859 /// An example of how an ordered-set aggregate function is called with the
860 /// `WITHIN GROUP` SQL syntax:
861 ///
862 /// ```sql
863 /// -- Ascending
864 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
865 /// -- Default ordering is ascending if not explicitly specified
866 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1) FROM table;
867 /// -- Descending
868 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 DESC) FROM table;
869 /// ```
870 ///
871 /// This calculates the `0.75` percentile of the column `c1` from `table`,
872 /// according to the specific ordering. The column specified in the `WITHIN GROUP`
873 /// ordering clause is taken as the column to calculate values on; specifying
874 /// the `WITHIN GROUP` clause is optional so these queries are equivalent:
875 ///
876 /// ```sql
877 /// -- If no WITHIN GROUP is specified then default ordering is implementation
878 /// -- dependent; in this case ascending for percentile_cont
879 /// SELECT percentile_cont(c1, 0.75) FROM table;
880 /// SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY c1 ASC) FROM table;
881 /// ```
882 ///
883 /// Aggregate UDFs can define their default ordering if the function is called
884 /// without the `WITHIN GROUP` clause, though a default of ascending is the
885 /// standard practice.
886 ///
887 /// Ordered-set aggregate function implementations are responsible for handling
888 /// the input sort order themselves (e.g. `percentile_cont` must buffer and
889 /// sort the values internally). That is, DataFusion does not introduce any
890 /// kind of sort into the plan for these functions with this syntax.
891 fn supports_within_group_clause(&self) -> bool {
892 false
893 }
894
895 /// Returns the documentation for this Aggregate UDF.
896 ///
897 /// Documentation can be accessed programmatically as well as
898 /// generating publicly facing documentation.
899 fn documentation(&self) -> Option<&Documentation> {
900 None
901 }
902
903 /// Indicates whether the aggregation function is monotonic as a set
904 /// function. See [`SetMonotonicity`] for details.
905 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
906 SetMonotonicity::NotMonotonic
907 }
908}
909
910impl dyn AggregateUDFImpl {
911 /// Returns `true` if the implementation is of type `T`.
912 ///
913 /// Works correctly when called on `Arc<dyn AggregateUDFImpl>` via auto-deref.
914 pub fn is<T: AggregateUDFImpl>(&self) -> bool {
915 (self as &dyn Any).is::<T>()
916 }
917
918 /// Attempts to downcast to a concrete type `T`, returning `None` if the
919 /// implementation is not of that type.
920 ///
921 /// Works correctly when called on `Arc<dyn AggregateUDFImpl>` via auto-deref,
922 /// unlike `(&arc as &dyn Any).downcast_ref::<T>()` which would attempt to
923 /// downcast the `Arc` itself.
924 pub fn downcast_ref<T: AggregateUDFImpl>(&self) -> Option<&T> {
925 (self as &dyn Any).downcast_ref()
926 }
927}
928
929impl PartialEq for dyn AggregateUDFImpl {
930 fn eq(&self, other: &Self) -> bool {
931 self.dyn_eq(other)
932 }
933}
934
935impl PartialOrd for dyn AggregateUDFImpl {
936 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
937 match self.name().partial_cmp(other.name()) {
938 Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
939 cmp => cmp,
940 }
941 // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
942 .filter(|cmp| *cmp != Ordering::Equal || self == other)
943 }
944}
945
946/// Encapsulates default implementation of [`AggregateUDFImpl::schema_name`].
947pub fn udaf_default_schema_name<F: AggregateUDFImpl + ?Sized>(
948 func: &F,
949 params: &AggregateFunctionParams,
950) -> Result<String> {
951 let AggregateFunctionParams {
952 args,
953 distinct,
954 filter,
955 order_by,
956 null_treatment,
957 } = params;
958
959 // exclude the first function argument(= column) in ordered set aggregate function,
960 // because it is duplicated with the WITHIN GROUP clause in schema name.
961 let args = if func.supports_within_group_clause() && !order_by.is_empty() {
962 &args[1..]
963 } else {
964 &args[..]
965 };
966
967 let mut schema_name = String::new();
968
969 schema_name.write_fmt(format_args!(
970 "{}({}{})",
971 func.name(),
972 if *distinct { "DISTINCT " } else { "" },
973 schema_name_from_exprs_comma_separated_without_space(args)?
974 ))?;
975
976 if let Some(null_treatment) = null_treatment {
977 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
978 }
979
980 if let Some(filter) = filter {
981 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
982 };
983
984 if !order_by.is_empty() {
985 let clause = match func.supports_within_group_clause() {
986 true => "WITHIN GROUP",
987 false => "ORDER BY",
988 };
989
990 schema_name.write_fmt(format_args!(
991 " {} [{}]",
992 clause,
993 schema_name_from_sorts(order_by)?
994 ))?;
995 };
996
997 Ok(schema_name)
998}
999
1000/// Encapsulates default implementation of [`AggregateUDFImpl::human_display`].
1001pub fn udaf_default_human_display<F: AggregateUDFImpl + ?Sized>(
1002 func: &F,
1003 params: &AggregateFunctionParams,
1004) -> Result<String> {
1005 let AggregateFunctionParams {
1006 args,
1007 distinct,
1008 filter,
1009 order_by,
1010 null_treatment,
1011 } = params;
1012
1013 let mut schema_name = String::new();
1014
1015 schema_name.write_fmt(format_args!(
1016 "{}({}{})",
1017 func.name(),
1018 if *distinct { "DISTINCT " } else { "" },
1019 ExprListDisplay::comma_separated(args.as_slice())
1020 ))?;
1021
1022 if let Some(null_treatment) = null_treatment {
1023 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
1024 }
1025
1026 if let Some(filter) = filter {
1027 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
1028 };
1029
1030 if !order_by.is_empty() {
1031 schema_name.write_fmt(format_args!(
1032 " ORDER BY [{}]",
1033 schema_name_from_sorts(order_by)?
1034 ))?;
1035 };
1036
1037 Ok(schema_name)
1038}
1039
1040/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_schema_name`].
1041pub fn udaf_default_window_function_schema_name<F: AggregateUDFImpl + ?Sized>(
1042 func: &F,
1043 params: &WindowFunctionParams,
1044) -> Result<String> {
1045 let WindowFunctionParams {
1046 args,
1047 partition_by,
1048 order_by,
1049 window_frame,
1050 filter,
1051 null_treatment,
1052 distinct,
1053 } = params;
1054
1055 let mut schema_name = String::new();
1056
1057 // Inject DISTINCT into the schema name when requested
1058 if *distinct {
1059 schema_name.write_fmt(format_args!(
1060 "{}(DISTINCT {})",
1061 func.name(),
1062 schema_name_from_exprs(args)?
1063 ))?;
1064 } else {
1065 schema_name.write_fmt(format_args!(
1066 "{}({})",
1067 func.name(),
1068 schema_name_from_exprs(args)?
1069 ))?;
1070 }
1071
1072 if let Some(null_treatment) = null_treatment {
1073 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
1074 }
1075
1076 if let Some(filter) = filter {
1077 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
1078 }
1079
1080 if !partition_by.is_empty() {
1081 schema_name.write_fmt(format_args!(
1082 " PARTITION BY [{}]",
1083 schema_name_from_exprs(partition_by)?
1084 ))?;
1085 }
1086
1087 if !order_by.is_empty() {
1088 schema_name.write_fmt(format_args!(
1089 " ORDER BY [{}]",
1090 schema_name_from_sorts(order_by)?
1091 ))?;
1092 }
1093
1094 schema_name.write_fmt(format_args!(" {window_frame}"))?;
1095
1096 Ok(schema_name)
1097}
1098
1099/// Encapsulates default implementation of [`AggregateUDFImpl::display_name`].
1100pub fn udaf_default_display_name<F: AggregateUDFImpl + ?Sized>(
1101 func: &F,
1102 params: &AggregateFunctionParams,
1103) -> Result<String> {
1104 let AggregateFunctionParams {
1105 args,
1106 distinct,
1107 filter,
1108 order_by,
1109 null_treatment,
1110 } = params;
1111
1112 let mut display_name = String::new();
1113
1114 display_name.write_fmt(format_args!(
1115 "{}({}{})",
1116 func.name(),
1117 if *distinct { "DISTINCT " } else { "" },
1118 expr_vec_fmt!(args)
1119 ))?;
1120
1121 if let Some(nt) = null_treatment {
1122 display_name.write_fmt(format_args!(" {nt}"))?;
1123 }
1124 if let Some(fe) = filter {
1125 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1126 }
1127 if !order_by.is_empty() {
1128 display_name.write_fmt(format_args!(
1129 " ORDER BY [{}]",
1130 order_by
1131 .iter()
1132 .map(|o| format!("{o}"))
1133 .collect::<Vec<String>>()
1134 .join(", ")
1135 ))?;
1136 }
1137
1138 Ok(display_name)
1139}
1140
1141/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_display_name`].
1142pub fn udaf_default_window_function_display_name<F: AggregateUDFImpl + ?Sized>(
1143 func: &F,
1144 params: &WindowFunctionParams,
1145) -> Result<String> {
1146 let WindowFunctionParams {
1147 args,
1148 partition_by,
1149 order_by,
1150 window_frame,
1151 filter,
1152 null_treatment,
1153 distinct,
1154 } = params;
1155
1156 let mut display_name = String::new();
1157
1158 if *distinct {
1159 display_name.write_fmt(format_args!(
1160 "{}(DISTINCT {})",
1161 func.name(),
1162 expr_vec_fmt!(args)
1163 ))?;
1164 } else {
1165 display_name.write_fmt(format_args!(
1166 "{}({})",
1167 func.name(),
1168 expr_vec_fmt!(args)
1169 ))?;
1170 }
1171
1172 if let Some(null_treatment) = null_treatment {
1173 display_name.write_fmt(format_args!(" {null_treatment}"))?;
1174 }
1175
1176 if let Some(fe) = filter {
1177 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
1178 }
1179
1180 if !partition_by.is_empty() {
1181 display_name.write_fmt(format_args!(
1182 " PARTITION BY [{}]",
1183 expr_vec_fmt!(partition_by)
1184 ))?;
1185 }
1186
1187 if !order_by.is_empty() {
1188 display_name
1189 .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
1190 };
1191
1192 display_name.write_fmt(format_args!(
1193 " {} BETWEEN {} AND {}",
1194 window_frame.units, window_frame.start_bound, window_frame.end_bound
1195 ))?;
1196
1197 Ok(display_name)
1198}
1199
1200/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`].
1201pub fn udaf_default_return_field<F: AggregateUDFImpl + ?Sized>(
1202 func: &F,
1203 arg_fields: &[FieldRef],
1204) -> Result<FieldRef> {
1205 let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect();
1206 let data_type = func.return_type(&arg_types)?;
1207
1208 Ok(Arc::new(Field::new(
1209 func.name(),
1210 data_type,
1211 func.is_nullable(),
1212 )))
1213}
1214
1215pub enum ReversedUDAF {
1216 /// The expression is the same as the original expression, like SUM, COUNT
1217 Identical,
1218 /// The expression does not support reverse calculation
1219 NotSupported,
1220 /// The expression is different from the original expression
1221 Reversed(Arc<AggregateUDF>),
1222}
1223
1224/// AggregateUDF that adds an alias to the underlying function. It is better to
1225/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1226#[derive(Debug, PartialEq, Eq, Hash)]
1227struct AliasedAggregateUDFImpl {
1228 inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
1229 aliases: Vec<String>,
1230}
1231
1232impl AliasedAggregateUDFImpl {
1233 pub fn new(
1234 inner: Arc<dyn AggregateUDFImpl>,
1235 new_aliases: impl IntoIterator<Item = &'static str>,
1236 ) -> Self {
1237 let mut aliases = inner.aliases().to_vec();
1238 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1239
1240 Self {
1241 inner: inner.into(),
1242 aliases,
1243 }
1244 }
1245}
1246
1247#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1248impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1249 fn name(&self) -> &str {
1250 self.inner.name()
1251 }
1252
1253 fn signature(&self) -> &Signature {
1254 self.inner.signature()
1255 }
1256
1257 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1258 self.inner.return_type(arg_types)
1259 }
1260
1261 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1262 self.inner.accumulator(acc_args)
1263 }
1264
1265 fn aliases(&self) -> &[String] {
1266 &self.aliases
1267 }
1268
1269 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1270 self.inner.schema_name(params)
1271 }
1272
1273 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
1274 self.inner.human_display(params)
1275 }
1276
1277 fn window_function_schema_name(
1278 &self,
1279 params: &WindowFunctionParams,
1280 ) -> Result<String> {
1281 self.inner.window_function_schema_name(params)
1282 }
1283
1284 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
1285 self.inner.display_name(params)
1286 }
1287
1288 fn window_function_display_name(
1289 &self,
1290 params: &WindowFunctionParams,
1291 ) -> Result<String> {
1292 self.inner.window_function_display_name(params)
1293 }
1294
1295 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1296 self.inner.state_fields(args)
1297 }
1298
1299 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1300 self.inner.groups_accumulator_supported(args)
1301 }
1302
1303 fn create_groups_accumulator(
1304 &self,
1305 args: AccumulatorArgs,
1306 ) -> Result<Box<dyn GroupsAccumulator>> {
1307 self.inner.create_groups_accumulator(args)
1308 }
1309
1310 fn create_sliding_accumulator(
1311 &self,
1312 args: AccumulatorArgs,
1313 ) -> Result<Box<dyn Accumulator>> {
1314 self.inner.accumulator(args)
1315 }
1316
1317 fn with_beneficial_ordering(
1318 self: Arc<Self>,
1319 beneficial_ordering: bool,
1320 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1321 Arc::clone(&self.inner)
1322 .with_beneficial_ordering(beneficial_ordering)
1323 .map(|udf| {
1324 udf.map(|udf| {
1325 Arc::new(AliasedAggregateUDFImpl {
1326 inner: udf.into(),
1327 aliases: self.aliases.clone(),
1328 }) as Arc<dyn AggregateUDFImpl>
1329 })
1330 })
1331 }
1332
1333 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1334 self.inner.order_sensitivity()
1335 }
1336
1337 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1338 self.inner.simplify()
1339 }
1340
1341 fn simplify_expr_op_literal(
1342 &self,
1343 agg_function: &AggregateFunction,
1344 arg: &Expr,
1345 op: Operator,
1346 lit: &Expr,
1347 arg_is_left: bool,
1348 ) -> Result<Option<Expr>> {
1349 self.inner
1350 .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
1351 }
1352
1353 fn reverse_expr(&self) -> ReversedUDAF {
1354 self.inner.reverse_expr()
1355 }
1356
1357 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1358 self.inner.coerce_types(arg_types)
1359 }
1360
1361 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1362 self.inner.return_field(arg_fields)
1363 }
1364
1365 fn is_nullable(&self) -> bool {
1366 self.inner.is_nullable()
1367 }
1368
1369 fn is_descending(&self) -> Option<bool> {
1370 self.inner.is_descending()
1371 }
1372
1373 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1374 self.inner.value_from_stats(statistics_args)
1375 }
1376
1377 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
1378 self.inner.default_value(data_type)
1379 }
1380
1381 fn supports_null_handling_clause(&self) -> bool {
1382 self.inner.supports_null_handling_clause()
1383 }
1384
1385 fn supports_within_group_clause(&self) -> bool {
1386 self.inner.supports_within_group_clause()
1387 }
1388
1389 fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
1390 self.inner.set_monotonicity(data_type)
1391 }
1392
1393 fn documentation(&self) -> Option<&Documentation> {
1394 self.inner.documentation()
1395 }
1396}
1397
1398/// Indicates whether an aggregation function is monotonic as a set
1399/// function. A set function is monotonically increasing if its value
1400/// increases as its argument grows (as a set). Formally, `f` is a
1401/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1402/// is a superset of `T`.
1403///
1404/// For example `COUNT` and `MAX` are monotonically increasing as their
1405/// values always increase (or stay the same) as new values are seen. On
1406/// the other hand, `MIN` is monotonically decreasing as its value always
1407/// decreases or stays the same as new values are seen.
1408#[derive(Debug, Clone, PartialEq)]
1409pub enum SetMonotonicity {
1410 /// Aggregate value increases or stays the same as the input set grows.
1411 Increasing,
1412 /// Aggregate value decreases or stays the same as the input set grows.
1413 Decreasing,
1414 /// Aggregate value may increase, decrease, or stay the same as the input
1415 /// set grows.
1416 NotMonotonic,
1417}
1418
1419#[cfg(test)]
1420mod test {
1421 use crate::{AggregateUDF, AggregateUDFImpl};
1422 use arrow::datatypes::{DataType, FieldRef};
1423 use datafusion_common::Result;
1424 use datafusion_expr_common::accumulator::Accumulator;
1425 use datafusion_expr_common::signature::{Signature, Volatility};
1426 use datafusion_functions_aggregate_common::accumulator::{
1427 AccumulatorArgs, StateFieldsArgs,
1428 };
1429 use std::cmp::Ordering;
1430 use std::hash::{DefaultHasher, Hash, Hasher};
1431
1432 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1433 struct AMeanUdf {
1434 signature: Signature,
1435 }
1436
1437 impl AMeanUdf {
1438 fn new() -> Self {
1439 Self {
1440 signature: Signature::uniform(
1441 1,
1442 vec![DataType::Float64],
1443 Volatility::Immutable,
1444 ),
1445 }
1446 }
1447 }
1448
1449 impl AggregateUDFImpl for AMeanUdf {
1450 fn name(&self) -> &str {
1451 "a"
1452 }
1453 fn signature(&self) -> &Signature {
1454 &self.signature
1455 }
1456 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1457 unimplemented!()
1458 }
1459 fn accumulator(
1460 &self,
1461 _acc_args: AccumulatorArgs,
1462 ) -> Result<Box<dyn Accumulator>> {
1463 unimplemented!()
1464 }
1465 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1466 unimplemented!()
1467 }
1468 }
1469
1470 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
1471 struct BMeanUdf {
1472 signature: Signature,
1473 }
1474 impl BMeanUdf {
1475 fn new() -> Self {
1476 Self {
1477 signature: Signature::uniform(
1478 1,
1479 vec![DataType::Float64],
1480 Volatility::Immutable,
1481 ),
1482 }
1483 }
1484 }
1485
1486 impl AggregateUDFImpl for BMeanUdf {
1487 fn name(&self) -> &str {
1488 "b"
1489 }
1490 fn signature(&self) -> &Signature {
1491 &self.signature
1492 }
1493 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1494 unimplemented!()
1495 }
1496 fn accumulator(
1497 &self,
1498 _acc_args: AccumulatorArgs,
1499 ) -> Result<Box<dyn Accumulator>> {
1500 unimplemented!()
1501 }
1502 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1503 unimplemented!()
1504 }
1505 }
1506
1507 #[test]
1508 fn test_partial_eq() {
1509 let a1 = AggregateUDF::from(AMeanUdf::new());
1510 let a2 = AggregateUDF::from(AMeanUdf::new());
1511 let eq = a1 == a2;
1512 assert!(eq);
1513 assert_eq!(a1, a2);
1514 assert_eq!(hash(a1), hash(a2));
1515 }
1516
1517 #[test]
1518 fn test_partial_ord() {
1519 // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1520 // not intended to exhaustively test all possibilities
1521 let a1 = AggregateUDF::from(AMeanUdf::new());
1522 let a2 = AggregateUDF::from(AMeanUdf::new());
1523 assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1524
1525 let b1 = AggregateUDF::from(BMeanUdf::new());
1526 assert!(a1 < b1);
1527 assert!(!(a1 == b1));
1528 }
1529
1530 fn hash<T: Hash>(value: T) -> u64 {
1531 let hasher = &mut DefaultHasher::new();
1532 value.hash(hasher);
1533 hasher.finish()
1534 }
1535}