datafusion_physical_expr/window/
sliding_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn};
26use crate::window::{
27    PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
28};
29use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
30
31use arrow::array::{ArrayRef, BooleanArray};
32use arrow::datatypes::FieldRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, ScalarValue};
35use datafusion_expr::{Accumulator, WindowFrame};
36
37/// A window expr that takes the form of an aggregate function that
38/// can be incrementally computed over sliding windows.
39///
40/// See comments on [`WindowExpr`] for more details.
41#[derive(Debug)]
42pub struct SlidingAggregateWindowExpr {
43    aggregate: Arc<AggregateFunctionExpr>,
44    partition_by: Vec<Arc<dyn PhysicalExpr>>,
45    order_by: Vec<PhysicalSortExpr>,
46    window_frame: Arc<WindowFrame>,
47    filter: Option<Arc<dyn PhysicalExpr>>,
48}
49
50impl SlidingAggregateWindowExpr {
51    /// Create a new (sliding) aggregate window function expression.
52    pub fn new(
53        aggregate: Arc<AggregateFunctionExpr>,
54        partition_by: &[Arc<dyn PhysicalExpr>],
55        order_by: &[PhysicalSortExpr],
56        window_frame: Arc<WindowFrame>,
57        filter: Option<Arc<dyn PhysicalExpr>>,
58    ) -> Self {
59        Self {
60            aggregate,
61            partition_by: partition_by.to_vec(),
62            order_by: order_by.to_vec(),
63            window_frame,
64            filter,
65        }
66    }
67
68    /// Get the [AggregateFunctionExpr] of this object.
69    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
70        &self.aggregate
71    }
72}
73
74/// Incrementally update window function using the fact that batch is
75/// pre-sorted given the sort columns and then per partition point.
76///
77/// Evaluates the peer group (e.g. `SUM` or `MAX` gives the same results
78/// for peers) and concatenate the results.
79impl WindowExpr for SlidingAggregateWindowExpr {
80    /// Return a reference to Any that can be used for downcasting
81    fn as_any(&self) -> &dyn Any {
82        self
83    }
84
85    fn field(&self) -> Result<FieldRef> {
86        Ok(self.aggregate.field())
87    }
88
89    fn name(&self) -> &str {
90        self.aggregate.name()
91    }
92
93    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
94        self.aggregate.expressions()
95    }
96
97    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
98        self.aggregate_evaluate(batch)
99    }
100
101    fn evaluate_stateful(
102        &self,
103        partition_batches: &PartitionBatches,
104        window_agg_state: &mut PartitionWindowAggStates,
105    ) -> Result<()> {
106        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)
107    }
108
109    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
110        &self.partition_by
111    }
112
113    fn order_by(&self) -> &[PhysicalSortExpr] {
114        &self.order_by
115    }
116
117    fn get_window_frame(&self) -> &Arc<WindowFrame> {
118        &self.window_frame
119    }
120
121    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
122        self.aggregate.reverse_expr().map(|reverse_expr| {
123            let reverse_window_frame = self.window_frame.reverse();
124            if reverse_window_frame.is_ever_expanding() {
125                Arc::new(PlainAggregateWindowExpr::new(
126                    Arc::new(reverse_expr),
127                    &self.partition_by.clone(),
128                    &self
129                        .order_by
130                        .iter()
131                        .map(|e| e.reverse())
132                        .collect::<Vec<_>>(),
133                    Arc::new(self.window_frame.reverse()),
134                    self.filter.clone(),
135                )) as _
136            } else {
137                Arc::new(SlidingAggregateWindowExpr::new(
138                    Arc::new(reverse_expr),
139                    &self.partition_by.clone(),
140                    &self
141                        .order_by
142                        .iter()
143                        .map(|e| e.reverse())
144                        .collect::<Vec<_>>(),
145                    Arc::new(self.window_frame.reverse()),
146                    self.filter.clone(),
147                )) as _
148            }
149        })
150    }
151
152    fn uses_bounded_memory(&self) -> bool {
153        !self.window_frame.end_bound.is_unbounded()
154    }
155
156    fn with_new_expressions(
157        &self,
158        args: Vec<Arc<dyn PhysicalExpr>>,
159        partition_bys: Vec<Arc<dyn PhysicalExpr>>,
160        order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
161    ) -> Option<Arc<dyn WindowExpr>> {
162        debug_assert_eq!(self.order_by.len(), order_by_exprs.len());
163
164        let new_order_by = self
165            .order_by
166            .iter()
167            .zip(order_by_exprs)
168            .map(|(req, new_expr)| PhysicalSortExpr {
169                expr: new_expr,
170                options: req.options,
171            })
172            .collect();
173        Some(Arc::new(SlidingAggregateWindowExpr {
174            aggregate: self
175                .aggregate
176                .with_new_expressions(args, vec![])
177                .map(Arc::new)?,
178            partition_by: partition_bys,
179            order_by: new_order_by,
180            window_frame: Arc::clone(&self.window_frame),
181            filter: self.filter.clone(),
182        }))
183    }
184
185    fn create_window_fn(&self) -> Result<WindowFn> {
186        Ok(WindowFn::Aggregate(self.get_accumulator()?))
187    }
188}
189
190impl AggregateWindowExpr for SlidingAggregateWindowExpr {
191    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
192        self.aggregate.create_sliding_accumulator()
193    }
194
195    fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
196        self.filter.as_ref()
197    }
198
199    /// Given current range and the last range, calculates the accumulator
200    /// result for the range of interest.
201    fn get_aggregate_result_inside_range(
202        &self,
203        last_range: &Range<usize>,
204        cur_range: &Range<usize>,
205        value_slice: &[ArrayRef],
206        accumulator: &mut Box<dyn Accumulator>,
207        filter_mask: Option<&BooleanArray>,
208    ) -> Result<ScalarValue> {
209        if cur_range.start == cur_range.end {
210            self.aggregate
211                .default_value(self.aggregate.field().data_type())
212        } else {
213            // Accumulate any new rows that have entered the window:
214            let update_bound = cur_range.end - last_range.end;
215            if update_bound > 0 {
216                let slice_mask =
217                    filter_mask.map(|m| m.slice(last_range.end, update_bound));
218                let update: Vec<ArrayRef> = value_slice
219                    .iter()
220                    .map(|v| v.slice(last_range.end, update_bound))
221                    .map(|arr| match &slice_mask {
222                        Some(m) => filter_array(&arr, m),
223                        None => Ok(arr),
224                    })
225                    .collect::<Result<Vec<_>>>()?;
226                accumulator.update_batch(&update)?
227            }
228
229            // Remove rows that have now left the window:
230            let retract_bound = cur_range.start - last_range.start;
231            if retract_bound > 0 {
232                let slice_mask =
233                    filter_mask.map(|m| m.slice(last_range.start, retract_bound));
234                let retract: Vec<ArrayRef> = value_slice
235                    .iter()
236                    .map(|v| v.slice(last_range.start, retract_bound))
237                    .map(|arr| match &slice_mask {
238                        Some(m) => filter_array(&arr, m),
239                        None => Ok(arr),
240                    })
241                    .collect::<Result<Vec<_>>>()?;
242                accumulator.retract_batch(&retract)?
243            }
244            accumulator.evaluate()
245        }
246    }
247
248    fn is_constant_in_partition(&self) -> bool {
249        false
250    }
251}