datafusion_physical_expr/window/
sliding_aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical exec for aggregate window function expressions.
19
20use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::window_expr::AggregateWindowExpr;
26use crate::window::{
27    PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
28};
29use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
30
31use arrow::array::{Array, ArrayRef};
32use arrow::datatypes::FieldRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::{Result, ScalarValue};
35use datafusion_expr::{Accumulator, WindowFrame};
36
37/// A window expr that takes the form of an aggregate function that
38/// can be incrementally computed over sliding windows.
39///
40/// See comments on [`WindowExpr`] for more details.
41#[derive(Debug)]
42pub struct SlidingAggregateWindowExpr {
43    aggregate: Arc<AggregateFunctionExpr>,
44    partition_by: Vec<Arc<dyn PhysicalExpr>>,
45    order_by: Vec<PhysicalSortExpr>,
46    window_frame: Arc<WindowFrame>,
47}
48
49impl SlidingAggregateWindowExpr {
50    /// Create a new (sliding) aggregate window function expression.
51    pub fn new(
52        aggregate: Arc<AggregateFunctionExpr>,
53        partition_by: &[Arc<dyn PhysicalExpr>],
54        order_by: &[PhysicalSortExpr],
55        window_frame: Arc<WindowFrame>,
56    ) -> Self {
57        Self {
58            aggregate,
59            partition_by: partition_by.to_vec(),
60            order_by: order_by.to_vec(),
61            window_frame,
62        }
63    }
64
65    /// Get the [AggregateFunctionExpr] of this object.
66    pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
67        &self.aggregate
68    }
69}
70
71/// Incrementally update window function using the fact that batch is
72/// pre-sorted given the sort columns and then per partition point.
73///
74/// Evaluates the peer group (e.g. `SUM` or `MAX` gives the same results
75/// for peers) and concatenate the results.
76impl WindowExpr for SlidingAggregateWindowExpr {
77    /// Return a reference to Any that can be used for downcasting
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn field(&self) -> Result<FieldRef> {
83        Ok(self.aggregate.field())
84    }
85
86    fn name(&self) -> &str {
87        self.aggregate.name()
88    }
89
90    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
91        self.aggregate.expressions()
92    }
93
94    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
95        self.aggregate_evaluate(batch)
96    }
97
98    fn evaluate_stateful(
99        &self,
100        partition_batches: &PartitionBatches,
101        window_agg_state: &mut PartitionWindowAggStates,
102    ) -> Result<()> {
103        self.aggregate_evaluate_stateful(partition_batches, window_agg_state)
104    }
105
106    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
107        &self.partition_by
108    }
109
110    fn order_by(&self) -> &[PhysicalSortExpr] {
111        &self.order_by
112    }
113
114    fn get_window_frame(&self) -> &Arc<WindowFrame> {
115        &self.window_frame
116    }
117
118    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
119        self.aggregate.reverse_expr().map(|reverse_expr| {
120            let reverse_window_frame = self.window_frame.reverse();
121            if reverse_window_frame.is_ever_expanding() {
122                Arc::new(PlainAggregateWindowExpr::new(
123                    Arc::new(reverse_expr),
124                    &self.partition_by.clone(),
125                    &self
126                        .order_by
127                        .iter()
128                        .map(|e| e.reverse())
129                        .collect::<Vec<_>>(),
130                    Arc::new(self.window_frame.reverse()),
131                )) as _
132            } else {
133                Arc::new(SlidingAggregateWindowExpr::new(
134                    Arc::new(reverse_expr),
135                    &self.partition_by.clone(),
136                    &self
137                        .order_by
138                        .iter()
139                        .map(|e| e.reverse())
140                        .collect::<Vec<_>>(),
141                    Arc::new(self.window_frame.reverse()),
142                )) as _
143            }
144        })
145    }
146
147    fn uses_bounded_memory(&self) -> bool {
148        !self.window_frame.end_bound.is_unbounded()
149    }
150
151    fn with_new_expressions(
152        &self,
153        args: Vec<Arc<dyn PhysicalExpr>>,
154        partition_bys: Vec<Arc<dyn PhysicalExpr>>,
155        order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
156    ) -> Option<Arc<dyn WindowExpr>> {
157        debug_assert_eq!(self.order_by.len(), order_by_exprs.len());
158
159        let new_order_by = self
160            .order_by
161            .iter()
162            .zip(order_by_exprs)
163            .map(|(req, new_expr)| PhysicalSortExpr {
164                expr: new_expr,
165                options: req.options,
166            })
167            .collect();
168        Some(Arc::new(SlidingAggregateWindowExpr {
169            aggregate: self
170                .aggregate
171                .with_new_expressions(args, vec![])
172                .map(Arc::new)?,
173            partition_by: partition_bys,
174            order_by: new_order_by,
175            window_frame: Arc::clone(&self.window_frame),
176        }))
177    }
178}
179
180impl AggregateWindowExpr for SlidingAggregateWindowExpr {
181    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
182        self.aggregate.create_sliding_accumulator()
183    }
184
185    /// Given current range and the last range, calculates the accumulator
186    /// result for the range of interest.
187    fn get_aggregate_result_inside_range(
188        &self,
189        last_range: &Range<usize>,
190        cur_range: &Range<usize>,
191        value_slice: &[ArrayRef],
192        accumulator: &mut Box<dyn Accumulator>,
193    ) -> Result<ScalarValue> {
194        if cur_range.start == cur_range.end {
195            self.aggregate
196                .default_value(self.aggregate.field().data_type())
197        } else {
198            // Accumulate any new rows that have entered the window:
199            let update_bound = cur_range.end - last_range.end;
200            if update_bound > 0 {
201                let update: Vec<ArrayRef> = value_slice
202                    .iter()
203                    .map(|v| v.slice(last_range.end, update_bound))
204                    .collect();
205                accumulator.update_batch(&update)?
206            }
207
208            // Remove rows that have now left the window:
209            let retract_bound = cur_range.start - last_range.start;
210            if retract_bound > 0 {
211                let retract: Vec<ArrayRef> = value_slice
212                    .iter()
213                    .map(|v| v.slice(last_range.start, retract_bound))
214                    .collect();
215                accumulator.retract_batch(&retract)?
216            }
217            accumulator.evaluate()
218        }
219    }
220
221    fn is_constant_in_partition(&self) -> bool {
222        false
223    }
224}