Skip to main content

datafusion_physical_expr/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18pub(crate) mod groups_accumulator {
19    #[expect(unused_imports)]
20    pub(crate) mod accumulate {
21        pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
22    }
23    pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
24        GroupsAccumulatorAdapter, accumulate::NullState,
25    };
26}
27pub(crate) mod stats {
28    pub use datafusion_functions_aggregate_common::stats::StatsType;
29}
30pub mod utils {
31    pub use datafusion_functions_aggregate_common::utils::{
32        DecimalAverager, Hashable, get_accum_scalar_values_as_arrays, get_sort_options,
33        ordering_fields,
34    };
35}
36
37use std::fmt::Debug;
38use std::sync::Arc;
39
40use crate::expressions::Column;
41use crate::physical_expr::create_physical_sort_exprs;
42use crate::planner::{create_physical_expr, create_physical_exprs};
43
44use arrow::compute::SortOptions;
45use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
46use datafusion_common::metadata::FieldMetadata;
47use datafusion_common::{
48    DFSchema, Result, ScalarValue, assert_or_internal_err, internal_err, not_impl_err,
49};
50use datafusion_expr::execution_props::ExecutionProps;
51use datafusion_expr::expr::{
52    AggregateFunction, AggregateFunctionParams, NullTreatment, physical_name,
53};
54use datafusion_expr::{AggregateUDF, Expr, ReversedUDAF, SetMonotonicity};
55use datafusion_expr_common::accumulator::Accumulator;
56use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
57use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
58use datafusion_functions_aggregate_common::accumulator::{
59    AccumulatorArgs, StateFieldsArgs,
60};
61use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
62use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
63use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
64
65#[derive(Debug, Clone)]
66struct AggregateHumanDisplay {
67    expression: String,
68    alias: Option<String>,
69}
70
71impl AggregateHumanDisplay {
72    fn try_new(
73        expression: Option<String>,
74        alias: Option<String>,
75        name: &str,
76    ) -> Result<Option<Self>> {
77        let alias = alias.filter(|alias| !alias.is_empty());
78        let Some(expression) = expression else {
79            if alias.is_some() {
80                return internal_err!(
81                    "AggregateExprBuilder::human_display must be provided when human_display_alias is set"
82                );
83            }
84            return Ok(None);
85        };
86
87        if expression.is_empty() {
88            if alias.is_some() {
89                return internal_err!(
90                    "AggregateExprBuilder::human_display must be non-empty when human_display_alias is set"
91                );
92            }
93            return Ok(None);
94        }
95
96        if let Some(alias) = alias.as_deref()
97            && alias != name
98        {
99            return internal_err!(
100                "aggregate human_display_alias must match aggregate name `{name}`: {alias}"
101            );
102        }
103
104        Ok(Some(Self { expression, alias }))
105    }
106
107    fn expression(&self) -> &str {
108        &self.expression
109    }
110
111    fn alias(&self) -> Option<&str> {
112        self.alias.as_deref()
113    }
114}
115
116/// Builder for physical [`AggregateFunctionExpr`]
117///
118/// `AggregateFunctionExpr` contains the information necessary to call
119/// an aggregate expression.
120#[derive(Debug, Clone)]
121pub struct AggregateExprBuilder {
122    fun: Arc<AggregateUDF>,
123    /// Physical expressions of the aggregate function
124    args: Vec<Arc<dyn PhysicalExpr>>,
125    alias: Option<String>,
126    output_metadata: Option<FieldMetadata>,
127    /// A human readable name
128    human_display: Option<String>,
129    /// Optional visible output alias for `human_display`.
130    human_display_alias: Option<String>,
131    /// Arrow Schema for the aggregate function
132    schema: SchemaRef,
133    /// The physical order by expressions
134    order_bys: Vec<PhysicalSortExpr>,
135    /// Whether to ignore null values
136    ignore_nulls: bool,
137    /// Whether is distinct aggregate function
138    is_distinct: bool,
139    /// Whether the expression is reversed
140    is_reversed: bool,
141}
142
143impl AggregateExprBuilder {
144    pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
145        Self {
146            fun,
147            args,
148            alias: None,
149            output_metadata: None,
150            human_display: None,
151            human_display_alias: None,
152            schema: Arc::new(Schema::empty()),
153            order_bys: vec![],
154            ignore_nulls: false,
155            is_distinct: false,
156            is_reversed: false,
157        }
158    }
159
160    /// Constructs an `AggregateFunctionExpr` from the builder
161    ///
162    /// Note that an [`Self::alias`] must be provided before calling this method.
163    ///
164    /// # Example: Create an [`AggregateUDF`]
165    ///
166    /// In the following example, [`AggregateFunctionExpr`] will be built using [`AggregateExprBuilder`]
167    /// which provides a build function. Full example could be accessed from the source file.
168    ///
169    /// ```
170    /// # use std::any::Any;
171    /// # use std::sync::Arc;
172    /// # use arrow::datatypes::{DataType, FieldRef};
173    /// # use datafusion_common::{Result, ScalarValue};
174    /// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility, Expr};
175    /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
176    /// # use arrow::datatypes::Field;
177    /// #
178    /// # #[derive(Debug, Clone, PartialEq, Eq, Hash)]
179    /// # struct FirstValueUdf {
180    /// #     signature: Signature,
181    /// # }
182    /// #
183    /// # impl FirstValueUdf {
184    /// #     fn new() -> Self {
185    /// #         Self {
186    /// #             signature: Signature::any(1, Volatility::Immutable),
187    /// #         }
188    /// #     }
189    /// # }
190    /// #
191    /// # impl AggregateUDFImpl for FirstValueUdf {
192    /// #     fn name(&self) -> &str {
193    /// #         unimplemented!()
194    /// #     }
195    /// #
196    /// #     fn signature(&self) -> &Signature {
197    /// #         unimplemented!()
198    /// #     }
199    /// #
200    /// #     fn return_type(&self, args: &[DataType]) -> Result<DataType> {
201    /// #         unimplemented!()
202    /// #     }
203    /// #
204    /// #     fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
205    /// #         unimplemented!()
206    /// #         }
207    /// #
208    /// #     fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
209    /// #         unimplemented!()
210    /// #     }
211    /// #
212    /// #     fn documentation(&self) -> Option<&Documentation> {
213    /// #         unimplemented!()
214    /// #     }
215    /// # }
216    /// #
217    /// # let first_value = AggregateUDF::from(FirstValueUdf::new());
218    /// # let expr = first_value.call(vec![col("a")]);
219    /// #
220    /// # use datafusion_physical_expr::expressions::Column;
221    /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
222    /// # use datafusion_physical_expr::aggregate::AggregateExprBuilder;
223    /// # use datafusion_physical_expr::expressions::PhysicalSortExpr;
224    /// # use datafusion_physical_expr::PhysicalSortRequirement;
225    /// #
226    /// fn build_aggregate_expr() -> Result<()> {
227    ///     let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
228    ///     let order_by = vec![PhysicalSortExpr {
229    ///         expr: Arc::new(Column::new("x", 1)) as Arc<dyn PhysicalExpr>,
230    ///         options: Default::default(),
231    ///     }];
232    ///
233    ///     let first_value = AggregateUDF::from(FirstValueUdf::new());
234    ///
235    ///     let aggregate_expr = AggregateExprBuilder::new(
236    ///         Arc::new(first_value),
237    ///         args
238    ///     )
239    ///     .order_by(order_by)
240    ///     .alias("first_a_by_x")
241    ///     .ignore_nulls()
242    ///     .build()?;
243    ///
244    ///     Ok(())
245    /// }
246    /// ```
247    ///
248    /// This creates a physical expression equivalent to SQL:
249    /// `first_value(a ORDER BY x) IGNORE NULLS AS first_a_by_x`
250    pub fn build(self) -> Result<AggregateFunctionExpr> {
251        let Self {
252            fun,
253            args,
254            alias,
255            output_metadata,
256            human_display,
257            human_display_alias,
258            schema,
259            order_bys,
260            ignore_nulls,
261            is_distinct,
262            is_reversed,
263        } = self;
264        assert_or_internal_err!(!args.is_empty(), "args should not be empty");
265
266        let ordering_types = order_bys
267            .iter()
268            .map(|e| e.expr.data_type(&schema))
269            .collect::<Result<Vec<_>>>()?;
270
271        let ordering_fields = utils::ordering_fields(&order_bys, &ordering_types);
272
273        let input_exprs_fields = args
274            .iter()
275            .map(|arg| arg.return_field(&schema))
276            .collect::<Result<Vec<_>>>()?;
277
278        check_arg_count(
279            fun.name(),
280            &input_exprs_fields,
281            &fun.signature().type_signature,
282        )?;
283
284        let mut return_field = fun.return_field(&input_exprs_fields)?;
285        if let Some(output_metadata) = output_metadata {
286            return_field = output_metadata.add_to_field_ref(return_field);
287        }
288        let is_nullable = fun.is_nullable();
289        let name = match alias {
290            None => {
291                return internal_err!(
292                    "AggregateExprBuilder::alias must be provided prior to calling build"
293                );
294            }
295            Some(alias) => alias,
296        };
297
298        let human_display =
299            AggregateHumanDisplay::try_new(human_display, human_display_alias, &name)?;
300
301        let arg_fields = args
302            .iter()
303            .map(|e| e.return_field(schema.as_ref()))
304            .collect::<Result<Vec<_>>>()?;
305
306        Ok(AggregateFunctionExpr {
307            fun: Arc::unwrap_or_clone(fun),
308            args,
309            arg_fields,
310            return_field,
311            name,
312            human_display,
313            schema: Arc::unwrap_or_clone(schema),
314            order_bys,
315            ignore_nulls,
316            ordering_fields,
317            is_distinct,
318            input_fields: input_exprs_fields,
319            is_reversed,
320            is_nullable,
321        })
322    }
323
324    pub fn alias(mut self, alias: impl Into<String>) -> Self {
325        self.alias = Some(alias.into());
326        self
327    }
328
329    fn output_metadata(mut self, metadata: Option<FieldMetadata>) -> Self {
330        self.output_metadata = metadata;
331        self
332    }
333
334    pub fn human_display(mut self, name: impl Into<String>) -> Self {
335        let name = name.into();
336        self.human_display = (!name.is_empty()).then_some(name);
337        if self.human_display.is_none() {
338            self.human_display_alias = None;
339        }
340        self
341    }
342
343    #[doc(hidden)]
344    pub fn human_display_alias(mut self, alias: impl Into<String>) -> Self {
345        let alias = alias.into();
346        self.human_display_alias = (!alias.is_empty()).then_some(alias);
347        self
348    }
349
350    pub fn schema(mut self, schema: SchemaRef) -> Self {
351        self.schema = schema;
352        self
353    }
354
355    pub fn order_by(mut self, order_bys: Vec<PhysicalSortExpr>) -> Self {
356        self.order_bys = order_bys;
357        self
358    }
359
360    pub fn reversed(mut self) -> Self {
361        self.is_reversed = true;
362        self
363    }
364
365    pub fn with_reversed(mut self, is_reversed: bool) -> Self {
366        self.is_reversed = is_reversed;
367        self
368    }
369
370    pub fn distinct(mut self) -> Self {
371        self.is_distinct = true;
372        self
373    }
374
375    pub fn with_distinct(mut self, is_distinct: bool) -> Self {
376        self.is_distinct = is_distinct;
377        self
378    }
379
380    pub fn ignore_nulls(mut self) -> Self {
381        self.ignore_nulls = true;
382        self
383    }
384
385    pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
386        self.ignore_nulls = ignore_nulls;
387        self
388    }
389}
390
391#[derive(Debug, Clone)]
392struct LoweredAggregateHumanDisplay {
393    expression: String,
394    alias: Option<String>,
395}
396
397/// Result of lowering a logical aggregate expression into physical aggregate
398/// planning pieces.
399#[derive(Debug, Clone)]
400pub struct LoweredAggregate {
401    /// Physical aggregate expression that can be used by an aggregate execution
402    /// plan.
403    pub aggregate: Arc<AggregateFunctionExpr>,
404    /// Optional physical filter expression for `FILTER (WHERE ...)`.
405    pub filter: Option<Arc<dyn PhysicalExpr>>,
406    /// Physical ordering expressions from aggregate `ORDER BY`.
407    pub order_bys: Vec<PhysicalSortExpr>,
408}
409
410/// Builder for converting a logical aggregate [`Expr`] into physical aggregate
411/// planning pieces.
412///
413/// This builder handles the logical-to-physical work needed for aggregate
414/// planning: unwrapping aggregate aliases, choosing the output name, preserving
415/// user-facing display text, lowering aggregate arguments, lowering the optional
416/// filter, and lowering aggregate `ORDER BY` expressions.
417pub struct LoweredAggregateBuilder<'a> {
418    expr: &'a Expr,
419    name: Option<String>,
420    human_display: Option<LoweredAggregateHumanDisplay>,
421    output_metadata: Option<FieldMetadata>,
422    preserve_alias_metadata: bool,
423    logical_input_schema: &'a DFSchema,
424    physical_input_schema: &'a Schema,
425    execution_props: &'a ExecutionProps,
426}
427
428impl<'a> LoweredAggregateBuilder<'a> {
429    /// Create a builder for lowering `expr`.
430    ///
431    /// `logical_input_schema` is used to resolve logical expressions such as
432    /// columns, while `physical_input_schema` is the input schema used by the
433    /// physical aggregate expression.
434    pub fn new(
435        expr: &'a Expr,
436        logical_input_schema: &'a DFSchema,
437        physical_input_schema: &'a Schema,
438        execution_props: &'a ExecutionProps,
439    ) -> Self {
440        Self {
441            expr,
442            name: None,
443            human_display: None,
444            output_metadata: None,
445            preserve_alias_metadata: true,
446            logical_input_schema,
447            physical_input_schema,
448            execution_props,
449        }
450    }
451
452    /// Override the output column name for the aggregate.
453    ///
454    /// If this is not set, the builder uses the alias from `expr` when present,
455    /// or derives the physical name from the aggregate expression.
456    pub fn with_name(mut self, name: impl Into<String>) -> Self {
457        self.name = Some(name.into());
458        self
459    }
460
461    /// Override the human-readable display text for the aggregate.
462    ///
463    /// This is useful when a caller has already computed the exact display text
464    /// it wants to preserve. When this override is used, aliases with metadata
465    /// are still unwrapped for planning, but alias metadata is not copied to the
466    /// aggregate output field.
467    pub fn with_human_display(mut self, human_display: impl Into<String>) -> Self {
468        self.human_display = Some(LoweredAggregateHumanDisplay {
469            expression: human_display.into(),
470            alias: None,
471        });
472        self.preserve_alias_metadata = false;
473        self
474    }
475
476    /// Lower the logical aggregate expression into physical aggregate pieces.
477    pub fn build(self) -> Result<LoweredAggregate> {
478        let Self {
479            expr,
480            name,
481            human_display,
482            output_metadata,
483            preserve_alias_metadata,
484            logical_input_schema,
485            physical_input_schema,
486            execution_props,
487        } = self;
488
489        let (name, human_display, output_metadata, expr) = lower_aggregate_display(
490            expr,
491            name,
492            human_display,
493            output_metadata,
494            preserve_alias_metadata,
495        );
496
497        let Expr::AggregateFunction(AggregateFunction {
498            func,
499            params:
500                AggregateFunctionParams {
501                    args,
502                    distinct,
503                    filter,
504                    order_by,
505                    null_treatment,
506                },
507        }) = &expr
508        else {
509            return internal_err!("Invalid aggregate expression '{expr:?}'");
510        };
511
512        let name = if let Some(name) = name {
513            name
514        } else {
515            physical_name(&expr)?
516        };
517
518        let physical_args =
519            create_physical_exprs(args, logical_input_schema, execution_props)?;
520        let filter = filter
521            .as_ref()
522            .map(|filter| {
523                create_physical_expr(filter, logical_input_schema, execution_props)
524            })
525            .transpose()?;
526        let order_bys =
527            create_physical_sort_exprs(order_by, logical_input_schema, execution_props)?;
528        let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
529            == NullTreatment::IgnoreNulls;
530
531        let mut builder = AggregateExprBuilder::new(func.to_owned(), physical_args)
532            .order_by(order_bys.clone())
533            .schema(Arc::new(physical_input_schema.to_owned()))
534            .alias(name)
535            .output_metadata(output_metadata)
536            .with_ignore_nulls(ignore_nulls)
537            .with_distinct(*distinct);
538
539        if let Some(human_display) = human_display {
540            builder = builder.human_display(human_display.expression);
541            if let Some(alias) = human_display.alias {
542                builder = builder.human_display_alias(alias);
543            }
544        }
545
546        Ok(LoweredAggregate {
547            aggregate: Arc::new(builder.build()?),
548            filter,
549            order_bys,
550        })
551    }
552}
553
554fn lower_aggregate_display(
555    expr: &Expr,
556    name: Option<String>,
557    human_display: Option<LoweredAggregateHumanDisplay>,
558    output_metadata: Option<FieldMetadata>,
559    preserve_alias_metadata: bool,
560) -> (
561    Option<String>,
562    Option<LoweredAggregateHumanDisplay>,
563    Option<FieldMetadata>,
564    Expr,
565) {
566    let mut expr = expr.clone();
567    let mut alias_name = None;
568    let mut alias_metadata = None;
569    while let Expr::Alias(alias) = expr {
570        if alias_name.is_none() {
571            alias_name = Some(alias.name);
572            alias_metadata = alias.metadata;
573        }
574        expr = *alias.expr;
575    }
576
577    let output_metadata = if preserve_alias_metadata {
578        output_metadata.or(alias_metadata)
579    } else {
580        output_metadata
581    };
582
583    if human_display.is_some() {
584        return (name.or(alias_name), human_display, output_metadata, expr);
585    }
586
587    match &expr {
588        Expr::AggregateFunction(_) => {
589            if let Some(alias_name) = alias_name {
590                let name = name.unwrap_or(alias_name);
591                let expression = expr.human_display().to_string();
592                let human_display = if expression.is_empty() || expression == name {
593                    LoweredAggregateHumanDisplay {
594                        expression: name.clone(),
595                        alias: None,
596                    }
597                } else {
598                    LoweredAggregateHumanDisplay {
599                        expression,
600                        alias: Some(name.clone()),
601                    }
602                };
603
604                return (Some(name), Some(human_display), output_metadata, expr);
605            }
606
607            let name = name.unwrap_or_else(|| expr.schema_name().to_string());
608            let human_display = LoweredAggregateHumanDisplay {
609                expression: expr.human_display().to_string(),
610                alias: None,
611            };
612
613            (Some(name), Some(human_display), output_metadata, expr)
614        }
615        _ => (name.or(alias_name), None, output_metadata, expr),
616    }
617}
618
619/// Physical aggregate expression of a UDAF.
620///
621/// Instances are constructed via [`AggregateExprBuilder`].
622#[derive(Debug, Clone)]
623pub struct AggregateFunctionExpr {
624    fun: AggregateUDF,
625    args: Vec<Arc<dyn PhysicalExpr>>,
626    /// Fields corresponding to args (same order & length)
627    arg_fields: Vec<FieldRef>,
628    /// Output / return field of this aggregate
629    return_field: FieldRef,
630    /// Output column name that this expression creates
631    name: String,
632    /// Simplified name for `tree` explain.
633    human_display: Option<AggregateHumanDisplay>,
634    schema: Schema,
635    // The physical order by expressions
636    order_bys: Vec<PhysicalSortExpr>,
637    // Whether to ignore null values
638    ignore_nulls: bool,
639    // fields used for order sensitive aggregation functions
640    ordering_fields: Vec<FieldRef>,
641    is_distinct: bool,
642    is_reversed: bool,
643    input_fields: Vec<FieldRef>,
644    is_nullable: bool,
645}
646
647impl AggregateFunctionExpr {
648    /// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
649    pub fn fun(&self) -> &AggregateUDF {
650        &self.fun
651    }
652
653    /// expressions that are passed to the Accumulator.
654    /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
655    pub fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
656        self.args.clone()
657    }
658
659    /// Human readable name such as `"MIN(c2)"`.
660    pub fn name(&self) -> &str {
661        &self.name
662    }
663
664    /// Simplified name for `tree` explain.
665    pub fn human_display(&self) -> Option<&str> {
666        self.human_display
667            .as_ref()
668            .map(AggregateHumanDisplay::expression)
669    }
670
671    #[doc(hidden)]
672    pub fn human_display_alias(&self) -> Option<&str> {
673        self.human_display
674            .as_ref()
675            .and_then(AggregateHumanDisplay::alias)
676    }
677
678    fn return_field_metadata(&self) -> Option<FieldMetadata> {
679        let metadata = FieldMetadata::from(self.return_field.as_ref());
680        (!metadata.is_empty()).then_some(metadata)
681    }
682
683    /// Return if the aggregation is distinct
684    pub fn is_distinct(&self) -> bool {
685        self.is_distinct
686    }
687
688    /// Return if the aggregation ignores nulls
689    pub fn ignore_nulls(&self) -> bool {
690        self.ignore_nulls
691    }
692
693    /// Return if the aggregation is reversed
694    pub fn is_reversed(&self) -> bool {
695        self.is_reversed
696    }
697
698    /// Return if the aggregation is nullable
699    pub fn is_nullable(&self) -> bool {
700        self.is_nullable
701    }
702
703    /// the field of the final result of this aggregation.
704    pub fn field(&self) -> FieldRef {
705        self.return_field
706            .as_ref()
707            .clone()
708            .with_name(&self.name)
709            .into()
710    }
711
712    /// the accumulator used to accumulate values from the expressions.
713    /// the accumulator expects the same number of arguments as `expressions` and must
714    /// return states with the same description as `state_fields`
715    pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
716        let acc_args = AccumulatorArgs {
717            return_field: Arc::clone(&self.return_field),
718            schema: &self.schema,
719            expr_fields: &self.arg_fields,
720            ignore_nulls: self.ignore_nulls,
721            order_bys: self.order_bys.as_ref(),
722            is_distinct: self.is_distinct,
723            name: &self.name,
724            is_reversed: self.is_reversed,
725            exprs: &self.args,
726        };
727
728        self.fun.accumulator(acc_args)
729    }
730
731    /// the field of the final result of this aggregation.
732    pub fn state_fields(&self) -> Result<Vec<FieldRef>> {
733        let args = StateFieldsArgs {
734            name: &self.name,
735            input_fields: &self.input_fields,
736            return_field: Arc::clone(&self.return_field),
737            ordering_fields: &self.ordering_fields,
738            is_distinct: self.is_distinct,
739        };
740
741        self.fun.state_fields(args)
742    }
743
744    /// Returns the ORDER BY expressions for the aggregate function.
745    pub fn order_bys(&self) -> &[PhysicalSortExpr] {
746        if self.order_sensitivity().is_insensitive() {
747            &[]
748        } else {
749            &self.order_bys
750        }
751    }
752
753    /// Indicates whether aggregator can produce the correct result with any
754    /// arbitrary input ordering. By default, we assume that aggregate expressions
755    /// are order insensitive.
756    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
757        if self.order_bys.is_empty() {
758            AggregateOrderSensitivity::Insensitive
759        } else {
760            // If there is an ORDER BY clause, use the sensitivity of the implementation:
761            self.fun.order_sensitivity()
762        }
763    }
764
765    /// Sets the indicator whether ordering requirements of the aggregator is
766    /// satisfied by its input. If this is not the case, aggregators with order
767    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
768    /// the correct result with possibly more work internally.
769    ///
770    /// # Returns
771    ///
772    /// Returns `Ok(Some(updated_expr))` if the process completes successfully.
773    /// If the expression can benefit from existing input ordering, but does
774    /// not implement the method, returns an error. Order insensitive and hard
775    /// requirement aggregators return `Ok(None)`.
776    pub fn with_beneficial_ordering(
777        self: Arc<Self>,
778        beneficial_ordering: bool,
779    ) -> Result<Option<AggregateFunctionExpr>> {
780        let Some(updated_fn) = self
781            .fun
782            .clone()
783            .with_beneficial_ordering(beneficial_ordering)?
784        else {
785            return Ok(None);
786        };
787
788        let mut builder =
789            AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec())
790                .order_by(self.order_bys.clone())
791                .schema(Arc::new(self.schema.clone()))
792                .alias(self.name().to_string())
793                .output_metadata(self.return_field_metadata())
794                .with_ignore_nulls(self.ignore_nulls)
795                .with_distinct(self.is_distinct)
796                .with_reversed(self.is_reversed);
797        if let Some(human_display) = self.human_display() {
798            builder = builder.human_display(human_display);
799        }
800        if let Some(alias) = self.human_display_alias() {
801            builder = builder.human_display_alias(alias);
802        }
803        builder.build().map(Some)
804    }
805
806    /// Creates accumulator implementation that supports retract
807    pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
808        let args = AccumulatorArgs {
809            return_field: Arc::clone(&self.return_field),
810            schema: &self.schema,
811            expr_fields: &self.arg_fields,
812            ignore_nulls: self.ignore_nulls,
813            order_bys: self.order_bys.as_ref(),
814            is_distinct: self.is_distinct,
815            name: &self.name,
816            is_reversed: self.is_reversed,
817            exprs: &self.args,
818        };
819
820        let accumulator = self.fun.create_sliding_accumulator(args)?;
821
822        // Accumulators that have window frame startings different
823        // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to
824        // implement retract_batch method in order to run correctly
825        // currently in DataFusion.
826        //
827        // If this `retract_batches` is not present, there is no way
828        // to calculate result correctly. For example, the query
829        //
830        // ```sql
831        // SELECT
832        //  SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
833        // FROM
834        //  t
835        // ```
836        //
837        // 1. First sum value will be the sum of rows between `[0, 1)`,
838        //
839        // 2. Second sum value will be the sum of rows between `[0, 2)`
840        //
841        // 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
842        //
843        // Since the accumulator keeps the running sum:
844        //
845        // 1. First sum we add to the state sum value between `[0, 1)`
846        //
847        // 2. Second sum we add to the state sum value between `[1, 2)`
848        // (`[0, 1)` is already in the state sum, hence running sum will
849        // cover `[0, 2)` range)
850        //
851        // 3. Third sum we add to the state sum value between `[2, 3)`
852        // (`[0, 2)` is already in the state sum).  Also we need to
853        // retract values between `[0, 1)` by this way we can obtain sum
854        // between [1, 3) which is indeed the appropriate range.
855        //
856        // When we use `UNBOUNDED PRECEDING` in the query starting
857        // index will always be 0 for the desired range, and hence the
858        // `retract_batch` method will not be called. In this case
859        // having retract_batch is not a requirement.
860        //
861        // This approach is a a bit different than window function
862        // approach. In window function (when they use a window frame)
863        // they get all the desired range during evaluation.
864        if !accumulator.supports_retract_batch() {
865            return not_impl_err!(
866                "Aggregate can not be used as a sliding accumulator because \
867                     `retract_batch` is not implemented: {}",
868                self.name
869            );
870        }
871        Ok(accumulator)
872    }
873
874    /// If the aggregate expression has a specialized
875    /// [`GroupsAccumulator`] implementation. If this returns true,
876    /// `[Self::create_groups_accumulator`] will be called.
877    pub fn groups_accumulator_supported(&self) -> bool {
878        let args = AccumulatorArgs {
879            return_field: Arc::clone(&self.return_field),
880            schema: &self.schema,
881            expr_fields: &self.arg_fields,
882            ignore_nulls: self.ignore_nulls,
883            order_bys: self.order_bys.as_ref(),
884            is_distinct: self.is_distinct,
885            name: &self.name,
886            is_reversed: self.is_reversed,
887            exprs: &self.args,
888        };
889        self.fun.groups_accumulator_supported(args)
890    }
891
892    /// Return a specialized [`GroupsAccumulator`] that manages state
893    /// for all groups.
894    ///
895    /// For maximum performance, a [`GroupsAccumulator`] should be
896    /// implemented in addition to [`Accumulator`].
897    pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
898        let args = AccumulatorArgs {
899            return_field: Arc::clone(&self.return_field),
900            schema: &self.schema,
901            expr_fields: &self.arg_fields,
902            ignore_nulls: self.ignore_nulls,
903            order_bys: self.order_bys.as_ref(),
904            is_distinct: self.is_distinct,
905            name: &self.name,
906            is_reversed: self.is_reversed,
907            exprs: &self.args,
908        };
909        self.fun.create_groups_accumulator(args)
910    }
911
912    /// Construct an expression that calculates the aggregate in reverse.
913    /// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
914    /// For aggregates that do not support calculation in reverse,
915    /// returns None (which is the default value).
916    pub fn reverse_expr(&self) -> Option<AggregateFunctionExpr> {
917        match self.fun.reverse_udf() {
918            ReversedUDAF::NotSupported => None,
919            ReversedUDAF::Identical => Some(self.clone()),
920            ReversedUDAF::Reversed(reverse_udf) => {
921                let was_aliased = self.human_display_alias().is_some();
922                let mut name = self.name().to_string();
923                let mut human_display = self.human_display.clone();
924                // Reversing display follows two paths:
925                // - aliased display keeps the output `name` unchanged and rewrites only
926                //   the lowered expression in `human_display`.
927                // - non-aliased display rewrites the canonical `name`, and rewrites
928                //   `human_display` only when present.
929                // If the function is changed, we need to reverse order_by clause as well
930                // i.e. First(a order by b asc null first) -> Last(a order by b desc null last)
931                if !was_aliased && self.fun().name() != reverse_udf.name() {
932                    replace_order_by_clause(&mut name);
933                }
934                if !was_aliased {
935                    replace_fn_name_clause(
936                        &mut name,
937                        self.fun.name(),
938                        reverse_udf.name(),
939                    );
940                }
941
942                if let Some(human_display) = human_display.as_mut() {
943                    if self.fun().name() != reverse_udf.name() {
944                        replace_order_by_clause(&mut human_display.expression);
945                    }
946                    replace_fn_name_clause(
947                        &mut human_display.expression,
948                        self.fun.name(),
949                        reverse_udf.name(),
950                    );
951                }
952
953                let mut builder =
954                    AggregateExprBuilder::new(reverse_udf, self.args.to_vec())
955                        .order_by(self.order_bys.iter().map(|e| e.reverse()).collect())
956                        .schema(Arc::new(self.schema.clone()))
957                        .alias(name)
958                        .output_metadata(self.return_field_metadata())
959                        .with_ignore_nulls(self.ignore_nulls)
960                        .with_distinct(self.is_distinct)
961                        .with_reversed(!self.is_reversed);
962                if let Some(human_display) = human_display {
963                    builder = builder.human_display(human_display.expression);
964                    if let Some(alias) = human_display.alias {
965                        builder = builder.human_display_alias(alias);
966                    }
967                }
968                builder.build().ok()
969            }
970        }
971    }
972
973    /// Returns all expressions used in the [`AggregateFunctionExpr`].
974    /// These expressions are  (1)function arguments, (2) order by expressions.
975    pub fn all_expressions(&self) -> AggregatePhysicalExpressions {
976        let args = self.expressions();
977        let order_by_exprs = self
978            .order_bys()
979            .iter()
980            .map(|sort_expr| Arc::clone(&sort_expr.expr))
981            .collect();
982        AggregatePhysicalExpressions {
983            args,
984            order_by_exprs,
985        }
986    }
987
988    /// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent
989    /// with the return value of the [`AggregateFunctionExpr::all_expressions`] method.
990    /// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
991    pub fn with_new_expressions(
992        &self,
993        args: Vec<Arc<dyn PhysicalExpr>>,
994        order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
995    ) -> Option<AggregateFunctionExpr> {
996        if args.len() != self.args.len()
997            || (self.order_sensitivity() != AggregateOrderSensitivity::Insensitive
998                && order_by_exprs.len() != self.order_bys.len())
999        {
1000            return None;
1001        }
1002
1003        let new_order_bys = self
1004            .order_bys
1005            .iter()
1006            .zip(order_by_exprs)
1007            .map(|(req, new_expr)| PhysicalSortExpr {
1008                expr: new_expr,
1009                options: req.options,
1010            })
1011            .collect();
1012
1013        Some(AggregateFunctionExpr {
1014            fun: self.fun.clone(),
1015            args,
1016            // TODO: need to align arg_fields here with new args
1017            //       https://github.com/apache/datafusion/issues/18149
1018            arg_fields: self.arg_fields.clone(),
1019            return_field: Arc::clone(&self.return_field),
1020            name: self.name.clone(),
1021            // TODO: Human name should be updated after re-write to not mislead
1022            human_display: self.human_display.clone(),
1023            schema: self.schema.clone(),
1024            order_bys: new_order_bys,
1025            ignore_nulls: self.ignore_nulls,
1026            ordering_fields: self.ordering_fields.clone(),
1027            is_distinct: self.is_distinct,
1028            is_reversed: false,
1029            input_fields: self.input_fields.clone(),
1030            is_nullable: self.is_nullable,
1031        })
1032    }
1033
1034    /// If this function is max, return (output_field, true)
1035    /// if the function is min, return (output_field, false)
1036    /// otherwise return None (the default)
1037    ///
1038    /// output_field is the name of the column produced by this aggregate
1039    ///
1040    /// Note: this is used to use special aggregate implementations in certain conditions
1041    pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> {
1042        self.fun.is_descending().map(|flag| (self.field(), flag))
1043    }
1044
1045    /// Returns default value of the function given the input is Null
1046    /// Most of the aggregate function return Null if input is Null,
1047    /// while `count` returns 0 if input is Null
1048    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
1049        self.fun.default_value(data_type)
1050    }
1051
1052    /// Indicates whether the aggregation function is monotonic as a set
1053    /// function. See [`SetMonotonicity`] for details.
1054    pub fn set_monotonicity(&self) -> SetMonotonicity {
1055        let field = self.field();
1056        let data_type = field.data_type();
1057        self.fun.inner().set_monotonicity(data_type)
1058    }
1059
1060    /// Returns `PhysicalSortExpr` based on the set monotonicity of the function.
1061    pub fn get_result_ordering(&self, aggr_func_idx: usize) -> Option<PhysicalSortExpr> {
1062        // If the aggregate expressions are set-monotonic, the output data is
1063        // naturally ordered with it per group or partition.
1064        let monotonicity = self.set_monotonicity();
1065        if monotonicity == SetMonotonicity::NotMonotonic {
1066            return None;
1067        }
1068        let expr = Arc::new(Column::new(self.name(), aggr_func_idx));
1069        let options =
1070            SortOptions::new(monotonicity == SetMonotonicity::Decreasing, false);
1071        Some(PhysicalSortExpr { expr, options })
1072    }
1073}
1074
1075/// Stores the physical expressions used inside the `AggregateExpr`.
1076pub struct AggregatePhysicalExpressions {
1077    /// Aggregate function arguments
1078    pub args: Vec<Arc<dyn PhysicalExpr>>,
1079    /// Order by expressions
1080    pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
1081}
1082
1083impl PartialEq for AggregateFunctionExpr {
1084    fn eq(&self, other: &Self) -> bool {
1085        self.name == other.name
1086            && self.return_field == other.return_field
1087            && self.fun == other.fun
1088            && self.args.len() == other.args.len()
1089            && self
1090                .args
1091                .iter()
1092                .zip(other.args.iter())
1093                .all(|(this_arg, other_arg)| this_arg.eq(other_arg))
1094    }
1095}
1096
1097fn replace_order_by_clause(order_by: &mut String) {
1098    let suffixes = [
1099        (" DESC NULLS FIRST]", " ASC NULLS LAST]"),
1100        (" ASC NULLS FIRST]", " DESC NULLS LAST]"),
1101        (" DESC NULLS LAST]", " ASC NULLS FIRST]"),
1102        (" ASC NULLS LAST]", " DESC NULLS FIRST]"),
1103    ];
1104
1105    if let Some(start) = order_by.find("ORDER BY [")
1106        && let Some(end) = order_by[start..].find(']')
1107    {
1108        let order_by_start = start + 9;
1109        let order_by_end = start + end;
1110
1111        let column_order = &order_by[order_by_start..=order_by_end];
1112        for (suffix, replacement) in suffixes {
1113            if column_order.ends_with(suffix) {
1114                let new_order = column_order.replace(suffix, replacement);
1115                order_by.replace_range(order_by_start..=order_by_end, &new_order);
1116                break;
1117            }
1118        }
1119    }
1120}
1121
1122fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) {
1123    if let Some(rest) = aggr_name.strip_prefix(fn_name_old) {
1124        *aggr_name = format!("{fn_name_new}{rest}");
1125    }
1126}
1127
1128#[cfg(test)]
1129mod tests {
1130    use super::*;
1131
1132    use std::collections::HashMap;
1133
1134    use arrow::datatypes::Field;
1135    use datafusion_common::metadata::FieldMetadata;
1136    use datafusion_expr::{col, test::function_stub::sum};
1137
1138    fn aggregate_test_schema() -> Result<(Schema, DFSchema)> {
1139        let schema = Schema::new(vec![Field::new("column1", DataType::Int64, true)]);
1140        let logical_schema = DFSchema::try_from(schema.clone())?;
1141        Ok((schema, logical_schema))
1142    }
1143
1144    fn test_metadata() -> FieldMetadata {
1145        FieldMetadata::from(HashMap::from([(
1146            "some_key".to_string(),
1147            "some_value".to_string(),
1148        )]))
1149    }
1150
1151    fn aggregate_alias_with_metadata() -> Expr {
1152        sum(col("column1")).alias_with_metadata("agg", Some(test_metadata()))
1153    }
1154
1155    #[test]
1156    fn lowered_aggregate_builder_unwraps_alias_with_metadata() -> Result<()> {
1157        let (schema, logical_schema) = aggregate_test_schema()?;
1158        let expr = aggregate_alias_with_metadata();
1159
1160        let lowered = LoweredAggregateBuilder::new(
1161            &expr,
1162            &logical_schema,
1163            &schema,
1164            &ExecutionProps::new(),
1165        )
1166        .build()?;
1167
1168        assert_eq!(lowered.aggregate.name(), "agg");
1169        assert_eq!(lowered.aggregate.human_display_alias(), Some("agg"));
1170        assert_eq!(
1171            lowered.aggregate.field().metadata().get("some_key"),
1172            Some(&"some_value".to_string())
1173        );
1174
1175        Ok(())
1176    }
1177
1178    #[test]
1179    fn lowered_aggregate_builder_display_override_skips_alias_metadata() -> Result<()> {
1180        let (schema, logical_schema) = aggregate_test_schema()?;
1181        let expr = aggregate_alias_with_metadata();
1182
1183        let lowered = LoweredAggregateBuilder::new(
1184            &expr,
1185            &logical_schema,
1186            &schema,
1187            &ExecutionProps::new(),
1188        )
1189        .with_human_display(expr.human_display().to_string())
1190        .build()?;
1191
1192        assert_eq!(lowered.aggregate.name(), "agg");
1193        assert_eq!(lowered.aggregate.human_display_alias(), None);
1194        assert!(
1195            lowered
1196                .aggregate
1197                .field()
1198                .metadata()
1199                .get("some_key")
1200                .is_none()
1201        );
1202
1203        Ok(())
1204    }
1205}