datafusion_physical_expr/window/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::AggregateWindowExpr;
27use crate::window::{
28    PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::Array;
33use arrow::array::ArrayRef;
34use arrow::datatypes::FieldRef;
35use arrow::record_batch::RecordBatch;
36use datafusion_common::{DataFusionError, Result, ScalarValue};
37use datafusion_expr::{Accumulator, WindowFrame};
38use datafusion_physical_expr_common::sort_expr::LexOrdering;
39
40/// A window expr that takes the form of an aggregate function.
41///
42/// See comments on [`WindowExpr`] for more details.
43#[derive(Debug)]
44pub struct PlainAggregateWindowExpr {
45    aggregate: Arc<AggregateFunctionExpr>,
46    partition_by: Vec<Arc<dyn PhysicalExpr>>,
47    order_by: LexOrdering,
48    window_frame: Arc<WindowFrame>,
49}
50
51impl PlainAggregateWindowExpr {
52    /// Create a new aggregate window function expression
53    pub fn new(
54        aggregate: Arc<AggregateFunctionExpr>,
55        partition_by: &[Arc<dyn PhysicalExpr>],
56        order_by: &LexOrdering,
57        window_frame: Arc<WindowFrame>,
58    ) -> Self {
59        Self {
60            aggregate,
61            partition_by: partition_by.to_vec(),
62            order_by: order_by.clone(),
63            window_frame,
64        }
65    }
66
67    /// Get aggregate expr of AggregateWindowExpr
68    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
69        &self.aggregate
70    }
71
72    pub fn add_equal_orderings(
73        &self,
74        eq_properties: &mut EquivalenceProperties,
75        window_expr_index: usize,
76    ) {
77        if let Some(expr) = self
78            .get_aggregate_expr()
79            .get_result_ordering(window_expr_index)
80        {
81            add_new_ordering_expr_with_partition_by(
82                eq_properties,
83                expr,
84                &self.partition_by,
85            );
86        }
87    }
88}
89
90/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
91/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
92/// results for peers) and concatenate the results.
93impl WindowExpr for PlainAggregateWindowExpr {
94    /// Return a reference to Any that can be used for downcasting
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn field(&self) -> Result<FieldRef> {
100        Ok(self.aggregate.field())
101    }
102
103    fn name(&self) -> &str {
104        self.aggregate.name()
105    }
106
107    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
108        self.aggregate.expressions()
109    }
110
111    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
112        self.aggregate_evaluate(batch)
113    }
114
115    fn evaluate_stateful(
116        &self,
117        partition_batches: &PartitionBatches,
118        window_agg_state: &mut PartitionWindowAggStates,
119    ) -> Result<()> {
120        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
121
122        // Update window frame range for each partition. As we know that
123        // non-sliding aggregations will never call `retract_batch`, this value
124        // can safely increase, and we can remove "old" parts of the state.
125        // This enables us to run queries involving UNBOUNDED PRECEDING frames
126        // using bounded memory for suitable aggregations.
127        for partition_row in partition_batches.keys() {
128            let window_state =
129                window_agg_state.get_mut(partition_row).ok_or_else(|| {
130                    DataFusionError::Execution("Cannot find state".to_string())
131                })?;
132            let state = &mut window_state.state;
133            if self.window_frame.start_bound.is_unbounded() {
134                state.window_frame_range.start =
135                    state.window_frame_range.end.saturating_sub(1);
136            }
137        }
138        Ok(())
139    }
140
141    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
142        &self.partition_by
143    }
144
145    fn order_by(&self) -> &LexOrdering {
146        self.order_by.as_ref()
147    }
148
149    fn get_window_frame(&self) -> &Arc<WindowFrame> {
150        &self.window_frame
151    }
152
153    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
154        self.aggregate.reverse_expr().map(|reverse_expr| {
155            let reverse_window_frame = self.window_frame.reverse();
156            if reverse_window_frame.is_ever_expanding() {
157                Arc::new(PlainAggregateWindowExpr::new(
158                    Arc::new(reverse_expr),
159                    &self.partition_by.clone(),
160                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
161                    Arc::new(self.window_frame.reverse()),
162                )) as _
163            } else {
164                Arc::new(SlidingAggregateWindowExpr::new(
165                    Arc::new(reverse_expr),
166                    &self.partition_by.clone(),
167                    reverse_order_bys(self.order_by.as_ref()).as_ref(),
168                    Arc::new(self.window_frame.reverse()),
169                )) as _
170            }
171        })
172    }
173
174    fn uses_bounded_memory(&self) -> bool {
175        !self.window_frame.end_bound.is_unbounded()
176    }
177}
178
179impl AggregateWindowExpr for PlainAggregateWindowExpr {
180    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
181        self.aggregate.create_accumulator()
182    }
183
184    /// For a given range, calculate accumulation result inside the range on
185    /// `value_slice` and update accumulator state.
186    // We assume that `cur_range` contains `last_range` and their start points
187    // are same. In summary if `last_range` is `Range{start: a,end: b}` and
188    // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b.
189    fn get_aggregate_result_inside_range(
190        &self,
191        last_range: &Range<usize>,
192        cur_range: &Range<usize>,
193        value_slice: &[ArrayRef],
194        accumulator: &mut Box<dyn Accumulator>,
195    ) -> Result<ScalarValue> {
196        if cur_range.start == cur_range.end {
197            self.aggregate
198                .default_value(self.aggregate.field().data_type())
199        } else {
200            // Accumulate any new rows that have entered the window:
201            let update_bound = cur_range.end - last_range.end;
202            // A non-sliding aggregation only processes new data, it never
203            // deals with expiring data as its starting point is always the
204            // same point (i.e. the beginning of the table/frame). Hence, we
205            // do not call `retract_batch`.
206            if update_bound > 0 {
207                let update: Vec<ArrayRef> = value_slice
208                    .iter()
209                    .map(|v| v.slice(last_range.end, update_bound))
210                    .collect();
211                accumulator.update_batch(&update)?
212            }
213            accumulator.evaluate()
214        }
215    }
216}