datafusion_physical_expr/window/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn};
27use crate::window::{
28    PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::ArrayRef;
33use arrow::array::BooleanArray;
34use arrow::datatypes::FieldRef;
35use arrow::record_batch::RecordBatch;
36use datafusion_common::{DataFusionError, Result, ScalarValue};
37use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits};
38use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
39
40/// A window expr that takes the form of an aggregate function.
41///
42/// See comments on [`WindowExpr`] for more details.
43#[derive(Debug)]
44pub struct PlainAggregateWindowExpr {
45    aggregate: Arc<AggregateFunctionExpr>,
46    partition_by: Vec<Arc<dyn PhysicalExpr>>,
47    order_by: Vec<PhysicalSortExpr>,
48    window_frame: Arc<WindowFrame>,
49    is_constant_in_partition: bool,
50    filter: Option<Arc<dyn PhysicalExpr>>,
51}
52
53impl PlainAggregateWindowExpr {
54    /// Create a new aggregate window function expression
55    pub fn new(
56        aggregate: Arc<AggregateFunctionExpr>,
57        partition_by: &[Arc<dyn PhysicalExpr>],
58        order_by: &[PhysicalSortExpr],
59        window_frame: Arc<WindowFrame>,
60        filter: Option<Arc<dyn PhysicalExpr>>,
61    ) -> Self {
62        let is_constant_in_partition =
63            Self::is_window_constant_in_partition(order_by, &window_frame);
64        Self {
65            aggregate,
66            partition_by: partition_by.to_vec(),
67            order_by: order_by.to_vec(),
68            window_frame,
69            is_constant_in_partition,
70            filter,
71        }
72    }
73
74    /// Get aggregate expr of AggregateWindowExpr
75    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
76        &self.aggregate
77    }
78
79    pub fn add_equal_orderings(
80        &self,
81        eq_properties: &mut EquivalenceProperties,
82        window_expr_index: usize,
83    ) -> Result<()> {
84        if let Some(expr) = self
85            .get_aggregate_expr()
86            .get_result_ordering(window_expr_index)
87        {
88            add_new_ordering_expr_with_partition_by(
89                eq_properties,
90                expr,
91                &self.partition_by,
92            )?;
93        }
94        Ok(())
95    }
96
97    // Returns true if every row in the partition has the same window frame. This allows
98    // for preventing bound + function calculation for every row due to the values being the
99    // same.
100    //
101    // This occurs when both bounds fall under either condition below:
102    //  1. Bound is unbounded (`Preceding` or `Following`)
103    //  2. Bound is `CurrentRow` while using `Range` units with no order by clause
104    //  This results in an invalid range specification. Following PostgreSQL’s convention,
105    //  we interpret this as the entire partition being used for the current window frame.
106    fn is_window_constant_in_partition(
107        order_by: &[PhysicalSortExpr],
108        window_frame: &WindowFrame,
109    ) -> bool {
110        let is_constant_bound = |bound: &WindowFrameBound| match bound {
111            WindowFrameBound::CurrentRow => {
112                window_frame.units == WindowFrameUnits::Range && order_by.is_empty()
113            }
114            _ => bound.is_unbounded(),
115        };
116
117        is_constant_bound(&window_frame.start_bound)
118            && is_constant_bound(&window_frame.end_bound)
119    }
120}
121
122/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
123/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
124/// results for peers) and concatenate the results.
125impl WindowExpr for PlainAggregateWindowExpr {
126    /// Return a reference to Any that can be used for downcasting
127    fn as_any(&self) -> &dyn Any {
128        self
129    }
130
131    fn field(&self) -> Result<FieldRef> {
132        Ok(self.aggregate.field())
133    }
134
135    fn name(&self) -> &str {
136        self.aggregate.name()
137    }
138
139    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
140        self.aggregate.expressions()
141    }
142
143    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
144        self.aggregate_evaluate(batch)
145    }
146
147    fn evaluate_stateful(
148        &self,
149        partition_batches: &PartitionBatches,
150        window_agg_state: &mut PartitionWindowAggStates,
151    ) -> Result<()> {
152        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
153
154        // Update window frame range for each partition. As we know that
155        // non-sliding aggregations will never call `retract_batch`, this value
156        // can safely increase, and we can remove "old" parts of the state.
157        // This enables us to run queries involving UNBOUNDED PRECEDING frames
158        // using bounded memory for suitable aggregations.
159        for partition_row in partition_batches.keys() {
160            let window_state =
161                window_agg_state.get_mut(partition_row).ok_or_else(|| {
162                    DataFusionError::Execution("Cannot find state".to_string())
163                })?;
164            let state = &mut window_state.state;
165            if self.window_frame.start_bound.is_unbounded() {
166                state.window_frame_range.start =
167                    state.window_frame_range.end.saturating_sub(1);
168            }
169        }
170        Ok(())
171    }
172
173    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
174        &self.partition_by
175    }
176
177    fn order_by(&self) -> &[PhysicalSortExpr] {
178        &self.order_by
179    }
180
181    fn get_window_frame(&self) -> &Arc<WindowFrame> {
182        &self.window_frame
183    }
184
185    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
186        self.aggregate.reverse_expr().map(|reverse_expr| {
187            let reverse_window_frame = self.window_frame.reverse();
188            if reverse_window_frame.is_ever_expanding() {
189                Arc::new(PlainAggregateWindowExpr::new(
190                    Arc::new(reverse_expr),
191                    &self.partition_by.clone(),
192                    &self
193                        .order_by
194                        .iter()
195                        .map(|e| e.reverse())
196                        .collect::<Vec<_>>(),
197                    Arc::new(self.window_frame.reverse()),
198                    self.filter.clone(),
199                )) as _
200            } else {
201                Arc::new(SlidingAggregateWindowExpr::new(
202                    Arc::new(reverse_expr),
203                    &self.partition_by.clone(),
204                    &self
205                        .order_by
206                        .iter()
207                        .map(|e| e.reverse())
208                        .collect::<Vec<_>>(),
209                    Arc::new(self.window_frame.reverse()),
210                    self.filter.clone(),
211                )) as _
212            }
213        })
214    }
215
216    fn uses_bounded_memory(&self) -> bool {
217        !self.window_frame.end_bound.is_unbounded()
218    }
219
220    fn create_window_fn(&self) -> Result<WindowFn> {
221        Ok(WindowFn::Aggregate(self.get_accumulator()?))
222    }
223}
224
225impl AggregateWindowExpr for PlainAggregateWindowExpr {
226    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
227        self.aggregate.create_accumulator()
228    }
229
230    fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
231        self.filter.as_ref()
232    }
233
234    /// For a given range, calculate accumulation result inside the range on
235    /// `value_slice` and update accumulator state.
236    // We assume that `cur_range` contains `last_range` and their start points
237    // are same. In summary if `last_range` is `Range{start: a,end: b}` and
238    // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b.
239    fn get_aggregate_result_inside_range(
240        &self,
241        last_range: &Range<usize>,
242        cur_range: &Range<usize>,
243        value_slice: &[ArrayRef],
244        accumulator: &mut Box<dyn Accumulator>,
245        filter_mask: Option<&BooleanArray>,
246    ) -> Result<ScalarValue> {
247        if cur_range.start == cur_range.end {
248            self.aggregate
249                .default_value(self.aggregate.field().data_type())
250        } else {
251            // Accumulate any new rows that have entered the window:
252            let update_bound = cur_range.end - last_range.end;
253            // A non-sliding aggregation only processes new data, it never
254            // deals with expiring data as its starting point is always the
255            // same point (i.e. the beginning of the table/frame). Hence, we
256            // do not call `retract_batch`.
257            if update_bound > 0 {
258                let slice_mask =
259                    filter_mask.map(|m| m.slice(last_range.end, update_bound));
260                let update: Vec<ArrayRef> = value_slice
261                    .iter()
262                    .map(|v| v.slice(last_range.end, update_bound))
263                    .map(|arr| match &slice_mask {
264                        Some(m) => filter_array(&arr, m),
265                        None => Ok(arr),
266                    })
267                    .collect::<Result<Vec<_>>>()?;
268                accumulator.update_batch(&update)?
269            }
270            accumulator.evaluate()
271        }
272    }
273
274    fn is_constant_in_partition(&self) -> bool {
275        self.is_constant_in_partition
276    }
277}