datafusion_physical_expr/window/
aggregate.rs1use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn};
27use crate::window::{
28 PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::ArrayRef;
33use arrow::array::BooleanArray;
34use arrow::datatypes::FieldRef;
35use arrow::record_batch::RecordBatch;
36use datafusion_common::{DataFusionError, Result, ScalarValue};
37use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits};
38use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
39
40#[derive(Debug)]
44pub struct PlainAggregateWindowExpr {
45 aggregate: Arc<AggregateFunctionExpr>,
46 partition_by: Vec<Arc<dyn PhysicalExpr>>,
47 order_by: Vec<PhysicalSortExpr>,
48 window_frame: Arc<WindowFrame>,
49 is_constant_in_partition: bool,
50 filter: Option<Arc<dyn PhysicalExpr>>,
51}
52
53impl PlainAggregateWindowExpr {
54 pub fn new(
56 aggregate: Arc<AggregateFunctionExpr>,
57 partition_by: &[Arc<dyn PhysicalExpr>],
58 order_by: &[PhysicalSortExpr],
59 window_frame: Arc<WindowFrame>,
60 filter: Option<Arc<dyn PhysicalExpr>>,
61 ) -> Self {
62 let is_constant_in_partition =
63 Self::is_window_constant_in_partition(order_by, &window_frame);
64 Self {
65 aggregate,
66 partition_by: partition_by.to_vec(),
67 order_by: order_by.to_vec(),
68 window_frame,
69 is_constant_in_partition,
70 filter,
71 }
72 }
73
74 pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
76 &self.aggregate
77 }
78
79 pub fn add_equal_orderings(
80 &self,
81 eq_properties: &mut EquivalenceProperties,
82 window_expr_index: usize,
83 ) -> Result<()> {
84 if let Some(expr) = self
85 .get_aggregate_expr()
86 .get_result_ordering(window_expr_index)
87 {
88 add_new_ordering_expr_with_partition_by(
89 eq_properties,
90 expr,
91 &self.partition_by,
92 )?;
93 }
94 Ok(())
95 }
96
97 fn is_window_constant_in_partition(
107 order_by: &[PhysicalSortExpr],
108 window_frame: &WindowFrame,
109 ) -> bool {
110 let is_constant_bound = |bound: &WindowFrameBound| match bound {
111 WindowFrameBound::CurrentRow => {
112 window_frame.units == WindowFrameUnits::Range && order_by.is_empty()
113 }
114 _ => bound.is_unbounded(),
115 };
116
117 is_constant_bound(&window_frame.start_bound)
118 && is_constant_bound(&window_frame.end_bound)
119 }
120}
121
122impl WindowExpr for PlainAggregateWindowExpr {
126 fn as_any(&self) -> &dyn Any {
128 self
129 }
130
131 fn field(&self) -> Result<FieldRef> {
132 Ok(self.aggregate.field())
133 }
134
135 fn name(&self) -> &str {
136 self.aggregate.name()
137 }
138
139 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
140 self.aggregate.expressions()
141 }
142
143 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
144 self.aggregate_evaluate(batch)
145 }
146
147 fn evaluate_stateful(
148 &self,
149 partition_batches: &PartitionBatches,
150 window_agg_state: &mut PartitionWindowAggStates,
151 ) -> Result<()> {
152 self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
153
154 for partition_row in partition_batches.keys() {
160 let window_state =
161 window_agg_state.get_mut(partition_row).ok_or_else(|| {
162 DataFusionError::Execution("Cannot find state".to_string())
163 })?;
164 let state = &mut window_state.state;
165 if self.window_frame.start_bound.is_unbounded() {
166 state.window_frame_range.start =
167 state.window_frame_range.end.saturating_sub(1);
168 }
169 }
170 Ok(())
171 }
172
173 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
174 &self.partition_by
175 }
176
177 fn order_by(&self) -> &[PhysicalSortExpr] {
178 &self.order_by
179 }
180
181 fn get_window_frame(&self) -> &Arc<WindowFrame> {
182 &self.window_frame
183 }
184
185 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
186 self.aggregate.reverse_expr().map(|reverse_expr| {
187 let reverse_window_frame = self.window_frame.reverse();
188 if reverse_window_frame.is_ever_expanding() {
189 Arc::new(PlainAggregateWindowExpr::new(
190 Arc::new(reverse_expr),
191 &self.partition_by.clone(),
192 &self
193 .order_by
194 .iter()
195 .map(|e| e.reverse())
196 .collect::<Vec<_>>(),
197 Arc::new(self.window_frame.reverse()),
198 self.filter.clone(),
199 )) as _
200 } else {
201 Arc::new(SlidingAggregateWindowExpr::new(
202 Arc::new(reverse_expr),
203 &self.partition_by.clone(),
204 &self
205 .order_by
206 .iter()
207 .map(|e| e.reverse())
208 .collect::<Vec<_>>(),
209 Arc::new(self.window_frame.reverse()),
210 self.filter.clone(),
211 )) as _
212 }
213 })
214 }
215
216 fn uses_bounded_memory(&self) -> bool {
217 !self.window_frame.end_bound.is_unbounded()
218 }
219
220 fn create_window_fn(&self) -> Result<WindowFn> {
221 Ok(WindowFn::Aggregate(self.get_accumulator()?))
222 }
223}
224
225impl AggregateWindowExpr for PlainAggregateWindowExpr {
226 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
227 self.aggregate.create_accumulator()
228 }
229
230 fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
231 self.filter.as_ref()
232 }
233
234 fn get_aggregate_result_inside_range(
240 &self,
241 last_range: &Range<usize>,
242 cur_range: &Range<usize>,
243 value_slice: &[ArrayRef],
244 accumulator: &mut Box<dyn Accumulator>,
245 filter_mask: Option<&BooleanArray>,
246 ) -> Result<ScalarValue> {
247 if cur_range.start == cur_range.end {
248 self.aggregate
249 .default_value(self.aggregate.field().data_type())
250 } else {
251 let update_bound = cur_range.end - last_range.end;
253 if update_bound > 0 {
258 let slice_mask =
259 filter_mask.map(|m| m.slice(last_range.end, update_bound));
260 let update: Vec<ArrayRef> = value_slice
261 .iter()
262 .map(|v| v.slice(last_range.end, update_bound))
263 .map(|arr| match &slice_mask {
264 Some(m) => filter_array(&arr, m),
265 None => Ok(arr),
266 })
267 .collect::<Result<Vec<_>>>()?;
268 accumulator.update_batch(&update)?
269 }
270 accumulator.evaluate()
271 }
272 }
273
274 fn is_constant_in_partition(&self) -> bool {
275 self.is_constant_in_partition
276 }
277}