datafusion_physical_expr/window/
aggregate.rs1use std::any::Any;
21use std::ops::Range;
22use std::sync::Arc;
23
24use crate::aggregate::AggregateFunctionExpr;
25use crate::window::standard::add_new_ordering_expr_with_partition_by;
26use crate::window::window_expr::AggregateWindowExpr;
27use crate::window::{
28 PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
29};
30use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr};
31
32use arrow::array::Array;
33use arrow::array::ArrayRef;
34use arrow::datatypes::FieldRef;
35use arrow::record_batch::RecordBatch;
36use datafusion_common::{DataFusionError, Result, ScalarValue};
37use datafusion_expr::{Accumulator, WindowFrame};
38use datafusion_physical_expr_common::sort_expr::LexOrdering;
39
40#[derive(Debug)]
44pub struct PlainAggregateWindowExpr {
45 aggregate: Arc<AggregateFunctionExpr>,
46 partition_by: Vec<Arc<dyn PhysicalExpr>>,
47 order_by: LexOrdering,
48 window_frame: Arc<WindowFrame>,
49}
50
51impl PlainAggregateWindowExpr {
52 pub fn new(
54 aggregate: Arc<AggregateFunctionExpr>,
55 partition_by: &[Arc<dyn PhysicalExpr>],
56 order_by: &LexOrdering,
57 window_frame: Arc<WindowFrame>,
58 ) -> Self {
59 Self {
60 aggregate,
61 partition_by: partition_by.to_vec(),
62 order_by: order_by.clone(),
63 window_frame,
64 }
65 }
66
67 pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr {
69 &self.aggregate
70 }
71
72 pub fn add_equal_orderings(
73 &self,
74 eq_properties: &mut EquivalenceProperties,
75 window_expr_index: usize,
76 ) {
77 if let Some(expr) = self
78 .get_aggregate_expr()
79 .get_result_ordering(window_expr_index)
80 {
81 add_new_ordering_expr_with_partition_by(
82 eq_properties,
83 expr,
84 &self.partition_by,
85 );
86 }
87 }
88}
89
90impl WindowExpr for PlainAggregateWindowExpr {
94 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn field(&self) -> Result<FieldRef> {
100 Ok(self.aggregate.field())
101 }
102
103 fn name(&self) -> &str {
104 self.aggregate.name()
105 }
106
107 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
108 self.aggregate.expressions()
109 }
110
111 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
112 self.aggregate_evaluate(batch)
113 }
114
115 fn evaluate_stateful(
116 &self,
117 partition_batches: &PartitionBatches,
118 window_agg_state: &mut PartitionWindowAggStates,
119 ) -> Result<()> {
120 self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
121
122 for partition_row in partition_batches.keys() {
128 let window_state =
129 window_agg_state.get_mut(partition_row).ok_or_else(|| {
130 DataFusionError::Execution("Cannot find state".to_string())
131 })?;
132 let state = &mut window_state.state;
133 if self.window_frame.start_bound.is_unbounded() {
134 state.window_frame_range.start =
135 state.window_frame_range.end.saturating_sub(1);
136 }
137 }
138 Ok(())
139 }
140
141 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
142 &self.partition_by
143 }
144
145 fn order_by(&self) -> &LexOrdering {
146 self.order_by.as_ref()
147 }
148
149 fn get_window_frame(&self) -> &Arc<WindowFrame> {
150 &self.window_frame
151 }
152
153 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
154 self.aggregate.reverse_expr().map(|reverse_expr| {
155 let reverse_window_frame = self.window_frame.reverse();
156 if reverse_window_frame.is_ever_expanding() {
157 Arc::new(PlainAggregateWindowExpr::new(
158 Arc::new(reverse_expr),
159 &self.partition_by.clone(),
160 reverse_order_bys(self.order_by.as_ref()).as_ref(),
161 Arc::new(self.window_frame.reverse()),
162 )) as _
163 } else {
164 Arc::new(SlidingAggregateWindowExpr::new(
165 Arc::new(reverse_expr),
166 &self.partition_by.clone(),
167 reverse_order_bys(self.order_by.as_ref()).as_ref(),
168 Arc::new(self.window_frame.reverse()),
169 )) as _
170 }
171 })
172 }
173
174 fn uses_bounded_memory(&self) -> bool {
175 !self.window_frame.end_bound.is_unbounded()
176 }
177}
178
179impl AggregateWindowExpr for PlainAggregateWindowExpr {
180 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
181 self.aggregate.create_accumulator()
182 }
183
184 fn get_aggregate_result_inside_range(
190 &self,
191 last_range: &Range<usize>,
192 cur_range: &Range<usize>,
193 value_slice: &[ArrayRef],
194 accumulator: &mut Box<dyn Accumulator>,
195 ) -> Result<ScalarValue> {
196 if cur_range.start == cur_range.end {
197 self.aggregate
198 .default_value(self.aggregate.field().data_type())
199 } else {
200 let update_bound = cur_range.end - last_range.end;
202 if update_bound > 0 {
207 let update: Vec<ArrayRef> = value_slice
208 .iter()
209 .map(|v| v.slice(last_range.end, update_bound))
210 .collect();
211 accumulator.update_batch(&update)?
212 }
213 accumulator.evaluate()
214 }
215 }
216}