datafusion_physical_expr/window/
aggregate.rs1use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::AggregateWindowExpr;
27use crate::window::{
28 PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::Array;
33use arrow::array::ArrayRef;
34use arrow::datatypes::FieldRef;
35use arrow::record_batch::RecordBatch;
36use datafusion_common::{DataFusionError, Result, ScalarValue};
37use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits};
38use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
39
40#[derive(Debug)]
44pub struct PlainAggregateWindowExpr {
45 aggregate: Arc<AggregateFunctionExpr>,
46 partition_by: Vec<Arc<dyn PhysicalExpr>>,
47 order_by: Vec<PhysicalSortExpr>,
48 window_frame: Arc<WindowFrame>,
49 is_constant_in_partition: bool,
50}
51
52impl PlainAggregateWindowExpr {
53 pub fn new(
55 aggregate: Arc<AggregateFunctionExpr>,
56 partition_by: &[Arc<dyn PhysicalExpr>],
57 order_by: &[PhysicalSortExpr],
58 window_frame: Arc<WindowFrame>,
59 ) -> Self {
60 let is_constant_in_partition =
61 Self::is_window_constant_in_partition(order_by, &window_frame);
62 Self {
63 aggregate,
64 partition_by: partition_by.to_vec(),
65 order_by: order_by.to_vec(),
66 window_frame,
67 is_constant_in_partition,
68 }
69 }
70
71 pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
73 &self.aggregate
74 }
75
76 pub fn add_equal_orderings(
77 &self,
78 eq_properties: &mut EquivalenceProperties,
79 window_expr_index: usize,
80 ) -> Result<()> {
81 if let Some(expr) = self
82 .get_aggregate_expr()
83 .get_result_ordering(window_expr_index)
84 {
85 add_new_ordering_expr_with_partition_by(
86 eq_properties,
87 expr,
88 &self.partition_by,
89 )?;
90 }
91 Ok(())
92 }
93
94 fn is_window_constant_in_partition(
104 order_by: &[PhysicalSortExpr],
105 window_frame: &WindowFrame,
106 ) -> bool {
107 let is_constant_bound = |bound: &WindowFrameBound| match bound {
108 WindowFrameBound::CurrentRow => {
109 window_frame.units == WindowFrameUnits::Range && order_by.is_empty()
110 }
111 _ => bound.is_unbounded(),
112 };
113
114 is_constant_bound(&window_frame.start_bound)
115 && is_constant_bound(&window_frame.end_bound)
116 }
117}
118
119impl WindowExpr for PlainAggregateWindowExpr {
123 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn field(&self) -> Result<FieldRef> {
129 Ok(self.aggregate.field())
130 }
131
132 fn name(&self) -> &str {
133 self.aggregate.name()
134 }
135
136 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
137 self.aggregate.expressions()
138 }
139
140 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
141 self.aggregate_evaluate(batch)
142 }
143
144 fn evaluate_stateful(
145 &self,
146 partition_batches: &PartitionBatches,
147 window_agg_state: &mut PartitionWindowAggStates,
148 ) -> Result<()> {
149 self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
150
151 for partition_row in partition_batches.keys() {
157 let window_state =
158 window_agg_state.get_mut(partition_row).ok_or_else(|| {
159 DataFusionError::Execution("Cannot find state".to_string())
160 })?;
161 let state = &mut window_state.state;
162 if self.window_frame.start_bound.is_unbounded() {
163 state.window_frame_range.start =
164 state.window_frame_range.end.saturating_sub(1);
165 }
166 }
167 Ok(())
168 }
169
170 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
171 &self.partition_by
172 }
173
174 fn order_by(&self) -> &[PhysicalSortExpr] {
175 &self.order_by
176 }
177
178 fn get_window_frame(&self) -> &Arc<WindowFrame> {
179 &self.window_frame
180 }
181
182 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
183 self.aggregate.reverse_expr().map(|reverse_expr| {
184 let reverse_window_frame = self.window_frame.reverse();
185 if reverse_window_frame.is_ever_expanding() {
186 Arc::new(PlainAggregateWindowExpr::new(
187 Arc::new(reverse_expr),
188 &self.partition_by.clone(),
189 &self
190 .order_by
191 .iter()
192 .map(|e| e.reverse())
193 .collect::<Vec<_>>(),
194 Arc::new(self.window_frame.reverse()),
195 )) as _
196 } else {
197 Arc::new(SlidingAggregateWindowExpr::new(
198 Arc::new(reverse_expr),
199 &self.partition_by.clone(),
200 &self
201 .order_by
202 .iter()
203 .map(|e| e.reverse())
204 .collect::<Vec<_>>(),
205 Arc::new(self.window_frame.reverse()),
206 )) as _
207 }
208 })
209 }
210
211 fn uses_bounded_memory(&self) -> bool {
212 !self.window_frame.end_bound.is_unbounded()
213 }
214}
215
216impl AggregateWindowExpr for PlainAggregateWindowExpr {
217 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
218 self.aggregate.create_accumulator()
219 }
220
221 fn get_aggregate_result_inside_range(
227 &self,
228 last_range: &Range<usize>,
229 cur_range: &Range<usize>,
230 value_slice: &[ArrayRef],
231 accumulator: &mut Box<dyn Accumulator>,
232 ) -> Result<ScalarValue> {
233 if cur_range.start == cur_range.end {
234 self.aggregate
235 .default_value(self.aggregate.field().data_type())
236 } else {
237 let update_bound = cur_range.end - last_range.end;
239 if update_bound > 0 {
244 let update: Vec<ArrayRef> = value_slice
245 .iter()
246 .map(|v| v.slice(last_range.end, update_bound))
247 .collect();
248 accumulator.update_batch(&update)?
249 }
250 accumulator.evaluate()
251 }
252 }
253
254 fn is_constant_in_partition(&self) -> bool {
255 self.is_constant_in_partition
256 }
257}