datafusion_physical_expr/window/
window_expr.rs1use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::array::{new_empty_array, Array, ArrayRef};
26use arrow::compute::kernels::sort::SortColumn;
27use arrow::compute::SortOptions;
28use arrow::datatypes::FieldRef;
29use arrow::record_batch::RecordBatch;
30use datafusion_common::utils::compare_rows;
31use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
32use datafusion_expr::window_state::{
33 PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
34};
35use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
36use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
37
38use indexmap::IndexMap;
39
40pub trait WindowExpr: Send + Sync + Debug {
66 fn as_any(&self) -> &dyn Any;
69
70 fn field(&self) -> Result<FieldRef>;
72
73 fn name(&self) -> &str {
76 "WindowExpr: default name"
77 }
78
79 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
83
84 fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
87 self.expressions()
88 .iter()
89 .map(|e| {
90 e.evaluate(batch)
91 .and_then(|v| v.into_array(batch.num_rows()))
92 })
93 .collect()
94 }
95
96 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
98
99 fn evaluate_stateful(
102 &self,
103 _partition_batches: &PartitionBatches,
104 _window_agg_state: &mut PartitionWindowAggStates,
105 ) -> Result<()> {
106 internal_err!("evaluate_stateful is not implemented for {}", self.name())
107 }
108
109 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
111
112 fn order_by(&self) -> &[PhysicalSortExpr];
114
115 fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
117 self.order_by()
118 .iter()
119 .map(|e| e.evaluate_to_sort_column(batch))
120 .collect()
121 }
122
123 fn get_window_frame(&self) -> &Arc<WindowFrame>;
125
126 fn uses_bounded_memory(&self) -> bool;
129
130 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
132
133 fn all_expressions(&self) -> WindowPhysicalExpressions {
136 let args = self.expressions();
137 let partition_by_exprs = self.partition_by().to_vec();
138 let order_by_exprs = self
139 .order_by()
140 .iter()
141 .map(|sort_expr| Arc::clone(&sort_expr.expr))
142 .collect();
143 WindowPhysicalExpressions {
144 args,
145 partition_by_exprs,
146 order_by_exprs,
147 }
148 }
149
150 fn with_new_expressions(
154 &self,
155 _args: Vec<Arc<dyn PhysicalExpr>>,
156 _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
157 _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
158 ) -> Option<Arc<dyn WindowExpr>> {
159 None
160 }
161}
162
163pub struct WindowPhysicalExpressions {
165 pub args: Vec<Arc<dyn PhysicalExpr>>,
167 pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
169 pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
171}
172
173pub trait AggregateWindowExpr: WindowExpr {
175 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
179
180 fn get_aggregate_result_inside_range(
183 &self,
184 last_range: &Range<usize>,
185 cur_range: &Range<usize>,
186 value_slice: &[ArrayRef],
187 accumulator: &mut Box<dyn Accumulator>,
188 ) -> Result<ScalarValue>;
189
190 fn is_constant_in_partition(&self) -> bool;
193
194 fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
196 let mut accumulator = self.get_accumulator()?;
197 let mut last_range = Range { start: 0, end: 0 };
198 let sort_options = self.order_by().iter().map(|o| o.options).collect();
199 let mut window_frame_ctx =
200 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
201 self.get_result_column(
202 &mut accumulator,
203 batch,
204 None,
205 &mut last_range,
206 &mut window_frame_ctx,
207 0,
208 false,
209 )
210 }
211
212 fn aggregate_evaluate_stateful(
215 &self,
216 partition_batches: &PartitionBatches,
217 window_agg_state: &mut PartitionWindowAggStates,
218 ) -> Result<()> {
219 let field = self.field()?;
220 let out_type = field.data_type();
221 for (partition_row, partition_batch_state) in partition_batches.iter() {
222 if !window_agg_state.contains_key(partition_row) {
223 let accumulator = self.get_accumulator()?;
224 window_agg_state.insert(
225 partition_row.clone(),
226 WindowState {
227 state: WindowAggState::new(out_type)?,
228 window_fn: WindowFn::Aggregate(accumulator),
229 },
230 );
231 };
232 let window_state =
233 window_agg_state.get_mut(partition_row).ok_or_else(|| {
234 DataFusionError::Execution("Cannot find state".to_string())
235 })?;
236 let accumulator = match &mut window_state.window_fn {
237 WindowFn::Aggregate(accumulator) => accumulator,
238 _ => unreachable!(),
239 };
240 let state = &mut window_state.state;
241 let record_batch = &partition_batch_state.record_batch;
242 let most_recent_row = partition_batch_state.most_recent_row.as_ref();
243
244 let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
246 let sort_options = self.order_by().iter().map(|o| o.options).collect();
247 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
248 });
249 let out_col = self.get_result_column(
250 accumulator,
251 record_batch,
252 most_recent_row,
253 &mut state.window_frame_range,
255 window_frame_ctx,
256 state.last_calculated_index,
257 !partition_batch_state.is_end,
258 )?;
259 state.update(&out_col, partition_batch_state)?;
260 }
261 Ok(())
262 }
263
264 #[allow(clippy::too_many_arguments)]
276 fn get_result_column(
277 &self,
278 accumulator: &mut Box<dyn Accumulator>,
279 record_batch: &RecordBatch,
280 most_recent_row: Option<&RecordBatch>,
281 last_range: &mut Range<usize>,
282 window_frame_ctx: &mut WindowFrameContext,
283 mut idx: usize,
284 not_end: bool,
285 ) -> Result<ArrayRef> {
286 let values = self.evaluate_args(record_batch)?;
287
288 if self.is_constant_in_partition() {
289 if not_end {
290 let field = self.field()?;
291 let out_type = field.data_type();
292 return Ok(new_empty_array(out_type));
293 }
294 accumulator.update_batch(&values)?;
295 let value = accumulator.evaluate()?;
296 return value.to_array_of_size(record_batch.num_rows());
297 }
298 let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
299 let most_recent_row_order_bys = most_recent_row
300 .map(|batch| self.order_by_columns(batch))
301 .transpose()?
302 .map(get_orderby_values);
303
304 let length = values[0].len();
306 let mut row_wise_results: Vec<ScalarValue> = vec![];
307 let is_causal = self.get_window_frame().is_causal();
308 while idx < length {
309 let cur_range =
311 window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
312 if cur_range.end == length
314 && !is_causal
315 && not_end
316 && !is_end_bound_safe(
317 window_frame_ctx,
318 &order_bys,
319 most_recent_row_order_bys.as_deref(),
320 self.order_by(),
321 idx,
322 )?
323 {
324 break;
325 }
326 let value = self.get_aggregate_result_inside_range(
327 last_range,
328 &cur_range,
329 &values,
330 accumulator,
331 )?;
332 *last_range = cur_range;
334 row_wise_results.push(value);
335 idx += 1;
336 }
337
338 if row_wise_results.is_empty() {
339 let field = self.field()?;
340 let out_type = field.data_type();
341 Ok(new_empty_array(out_type))
342 } else {
343 ScalarValue::iter_to_array(row_wise_results)
344 }
345 }
346}
347
348pub(crate) fn is_end_bound_safe(
366 window_frame_ctx: &WindowFrameContext,
367 order_bys: &[ArrayRef],
368 most_recent_order_bys: Option<&[ArrayRef]>,
369 sort_exprs: &[PhysicalSortExpr],
370 idx: usize,
371) -> Result<bool> {
372 if sort_exprs.is_empty() {
373 return Ok(false);
375 };
376
377 match window_frame_ctx {
378 WindowFrameContext::Rows(window_frame) => {
379 is_end_bound_safe_for_rows(&window_frame.end_bound)
380 }
381 WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
382 &window_frame.end_bound,
383 &order_bys[0],
384 most_recent_order_bys.map(|items| &items[0]),
385 &sort_exprs[0].options,
386 idx,
387 ),
388 WindowFrameContext::Groups {
389 window_frame,
390 state,
391 } => is_end_bound_safe_for_groups(
392 &window_frame.end_bound,
393 state,
394 &order_bys[0],
395 most_recent_order_bys.map(|items| &items[0]),
396 &sort_exprs[0].options,
397 ),
398 }
399}
400
401fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
414 if let WindowFrameBound::Following(value) = end_bound {
415 let zero = ScalarValue::new_zero(&value.data_type());
416 Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
417 } else {
418 Ok(true)
419 }
420}
421
422fn is_end_bound_safe_for_range(
439 end_bound: &WindowFrameBound,
440 orderby_col: &ArrayRef,
441 most_recent_ob_col: Option<&ArrayRef>,
442 sort_options: &SortOptions,
443 idx: usize,
444) -> Result<bool> {
445 match end_bound {
446 WindowFrameBound::Preceding(value) => {
447 let zero = ScalarValue::new_zero(&value.data_type())?;
448 if value.eq(&zero) {
449 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
450 } else {
451 Ok(true)
452 }
453 }
454 WindowFrameBound::CurrentRow => {
455 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
456 }
457 WindowFrameBound::Following(delta) => {
458 let Some(most_recent_ob_col) = most_recent_ob_col else {
459 return Ok(false);
460 };
461 let most_recent_row_value =
462 ScalarValue::try_from_array(most_recent_ob_col, 0)?;
463 let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
464
465 if sort_options.descending {
466 current_row_value
467 .sub(delta)
468 .map(|value| value > most_recent_row_value)
469 } else {
470 current_row_value
471 .add(delta)
472 .map(|value| most_recent_row_value > value)
473 }
474 }
475 }
476}
477
478fn is_end_bound_safe_for_groups(
495 end_bound: &WindowFrameBound,
496 state: &WindowFrameStateGroups,
497 orderby_col: &ArrayRef,
498 most_recent_ob_col: Option<&ArrayRef>,
499 sort_options: &SortOptions,
500) -> Result<bool> {
501 match end_bound {
502 WindowFrameBound::Preceding(value) => {
503 let zero = ScalarValue::new_zero(&value.data_type())?;
504 if value.eq(&zero) {
505 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
506 } else {
507 Ok(true)
508 }
509 }
510 WindowFrameBound::CurrentRow => {
511 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
512 }
513 WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
514 let delta = state.group_end_indices.len() - state.current_group_idx;
515 if delta == (*offset as usize) + 1 {
516 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
517 } else {
518 Ok(false)
519 }
520 }
521 _ => Ok(false),
522 }
523}
524
525fn is_row_ahead(
528 old_col: &ArrayRef,
529 current_col: Option<&ArrayRef>,
530 sort_options: &SortOptions,
531) -> Result<bool> {
532 let Some(current_col) = current_col else {
533 return Ok(false);
534 };
535 if old_col.is_empty() || current_col.is_empty() {
536 return Ok(false);
537 }
538 let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
539 let current_value = ScalarValue::try_from_array(current_col, 0)?;
540 let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
541 Ok(cmp.is_gt())
542}
543
544pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
546 order_by_columns.into_iter().map(|s| s.values).collect()
547}
548
549#[derive(Debug)]
550pub enum WindowFn {
551 Builtin(Box<dyn PartitionEvaluator>),
552 Aggregate(Box<dyn Accumulator>),
553}
554
555pub type PartitionKey = Vec<ScalarValue>;
560
561#[derive(Debug)]
562pub struct WindowState {
563 pub state: WindowAggState,
564 pub window_fn: WindowFn,
565}
566pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
567
568pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
570
571#[cfg(test)]
572mod tests {
573 use std::sync::Arc;
574
575 use crate::window::window_expr::is_row_ahead;
576
577 use arrow::array::{ArrayRef, Float64Array};
578 use arrow::compute::SortOptions;
579 use datafusion_common::Result;
580
581 #[test]
582 fn test_is_row_ahead() -> Result<()> {
583 let old_values: ArrayRef =
584 Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
585
586 let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
587 let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
588
589 assert!(is_row_ahead(
590 &old_values,
591 Some(&new_values1),
592 &SortOptions {
593 descending: false,
594 nulls_first: false
595 }
596 )?);
597 assert!(!is_row_ahead(
598 &old_values,
599 Some(&new_values2),
600 &SortOptions {
601 descending: false,
602 nulls_first: false
603 }
604 )?);
605
606 Ok(())
607 }
608}