datafusion_physical_expr/window/
window_expr.rs1use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::array::BooleanArray;
26use arrow::array::{new_empty_array, Array, ArrayRef};
27use arrow::compute::filter as arrow_filter;
28use arrow::compute::kernels::sort::SortColumn;
29use arrow::compute::SortOptions;
30use arrow::datatypes::FieldRef;
31use arrow::record_batch::RecordBatch;
32use datafusion_common::cast::as_boolean_array;
33use datafusion_common::utils::compare_rows;
34use datafusion_common::{
35 arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
36};
37use datafusion_expr::window_state::{
38 PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
39};
40use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
41use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
42
43use indexmap::IndexMap;
44
45pub trait WindowExpr: Send + Sync + Debug {
71 fn as_any(&self) -> &dyn Any;
74
75 fn field(&self) -> Result<FieldRef>;
77
78 fn name(&self) -> &str {
81 "WindowExpr: default name"
82 }
83
84 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
88
89 fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
92 self.expressions()
93 .iter()
94 .map(|e| {
95 e.evaluate(batch)
96 .and_then(|v| v.into_array(batch.num_rows()))
97 })
98 .collect()
99 }
100
101 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
103
104 fn evaluate_stateful(
107 &self,
108 _partition_batches: &PartitionBatches,
109 _window_agg_state: &mut PartitionWindowAggStates,
110 ) -> Result<()> {
111 internal_err!("evaluate_stateful is not implemented for {}", self.name())
112 }
113
114 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
116
117 fn order_by(&self) -> &[PhysicalSortExpr];
119
120 fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
122 self.order_by()
123 .iter()
124 .map(|e| e.evaluate_to_sort_column(batch))
125 .collect()
126 }
127
128 fn get_window_frame(&self) -> &Arc<WindowFrame>;
130
131 fn uses_bounded_memory(&self) -> bool;
134
135 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
137
138 fn create_window_fn(&self) -> Result<WindowFn>;
143
144 fn all_expressions(&self) -> WindowPhysicalExpressions {
147 let args = self.expressions();
148 let partition_by_exprs = self.partition_by().to_vec();
149 let order_by_exprs = self
150 .order_by()
151 .iter()
152 .map(|sort_expr| Arc::clone(&sort_expr.expr))
153 .collect();
154 WindowPhysicalExpressions {
155 args,
156 partition_by_exprs,
157 order_by_exprs,
158 }
159 }
160
161 fn with_new_expressions(
165 &self,
166 _args: Vec<Arc<dyn PhysicalExpr>>,
167 _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
168 _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
169 ) -> Option<Arc<dyn WindowExpr>> {
170 None
171 }
172}
173
174pub struct WindowPhysicalExpressions {
176 pub args: Vec<Arc<dyn PhysicalExpr>>,
178 pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
180 pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
182}
183
184pub trait AggregateWindowExpr: WindowExpr {
186 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
190
191 fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>>;
193
194 fn get_aggregate_result_inside_range(
197 &self,
198 last_range: &Range<usize>,
199 cur_range: &Range<usize>,
200 value_slice: &[ArrayRef],
201 accumulator: &mut Box<dyn Accumulator>,
202 filter_mask: Option<&BooleanArray>,
203 ) -> Result<ScalarValue>;
204
205 fn is_constant_in_partition(&self) -> bool;
208
209 fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
211 let mut accumulator = self.get_accumulator()?;
212 let mut last_range = Range { start: 0, end: 0 };
213 let sort_options = self.order_by().iter().map(|o| o.options).collect();
214 let mut window_frame_ctx =
215 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
216 self.get_result_column(
217 &mut accumulator,
218 batch,
219 None,
220 &mut last_range,
221 &mut window_frame_ctx,
222 0,
223 false,
224 )
225 }
226
227 fn aggregate_evaluate_stateful(
230 &self,
231 partition_batches: &PartitionBatches,
232 window_agg_state: &mut PartitionWindowAggStates,
233 ) -> Result<()> {
234 let field = self.field()?;
235 let out_type = field.data_type();
236 for (partition_row, partition_batch_state) in partition_batches.iter() {
237 if !window_agg_state.contains_key(partition_row) {
238 let accumulator = self.get_accumulator()?;
239 window_agg_state.insert(
240 partition_row.clone(),
241 WindowState {
242 state: WindowAggState::new(out_type)?,
243 window_fn: WindowFn::Aggregate(accumulator),
244 },
245 );
246 };
247 let window_state =
248 window_agg_state.get_mut(partition_row).ok_or_else(|| {
249 DataFusionError::Execution("Cannot find state".to_string())
250 })?;
251 let accumulator = match &mut window_state.window_fn {
252 WindowFn::Aggregate(accumulator) => accumulator,
253 _ => unreachable!(),
254 };
255 let state = &mut window_state.state;
256 let record_batch = &partition_batch_state.record_batch;
257 let most_recent_row = partition_batch_state.most_recent_row.as_ref();
258
259 let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
261 let sort_options = self.order_by().iter().map(|o| o.options).collect();
262 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
263 });
264 let out_col = self.get_result_column(
265 accumulator,
266 record_batch,
267 most_recent_row,
268 &mut state.window_frame_range,
270 window_frame_ctx,
271 state.last_calculated_index,
272 !partition_batch_state.is_end,
273 )?;
274 state.update(&out_col, partition_batch_state)?;
275 }
276 Ok(())
277 }
278
279 #[allow(clippy::too_many_arguments)]
291 fn get_result_column(
292 &self,
293 accumulator: &mut Box<dyn Accumulator>,
294 record_batch: &RecordBatch,
295 most_recent_row: Option<&RecordBatch>,
296 last_range: &mut Range<usize>,
297 window_frame_ctx: &mut WindowFrameContext,
298 mut idx: usize,
299 not_end: bool,
300 ) -> Result<ArrayRef> {
301 let values = self.evaluate_args(record_batch)?;
302
303 let filter_mask_arr: Option<ArrayRef> = match self.filter_expr() {
305 Some(expr) => {
306 let value = expr.evaluate(record_batch)?;
307 Some(value.into_array(record_batch.num_rows())?)
308 }
309 None => None,
310 };
311
312 let filter_mask: Option<&BooleanArray> = match filter_mask_arr.as_deref() {
314 Some(arr) => Some(as_boolean_array(arr)?),
315 None => None,
316 };
317
318 if self.is_constant_in_partition() {
319 if not_end {
320 let field = self.field()?;
321 let out_type = field.data_type();
322 return Ok(new_empty_array(out_type));
323 }
324 let values = if let Some(mask) = filter_mask {
325 filter_arrays(&values, mask)?
327 } else {
328 values
329 };
330 accumulator.update_batch(&values)?;
331 let value = accumulator.evaluate()?;
332 return value.to_array_of_size(record_batch.num_rows());
333 }
334 let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
335 let most_recent_row_order_bys = most_recent_row
336 .map(|batch| self.order_by_columns(batch))
337 .transpose()?
338 .map(get_orderby_values);
339
340 let length = values[0].len();
342 let mut row_wise_results: Vec<ScalarValue> = vec![];
343 let is_causal = self.get_window_frame().is_causal();
344 while idx < length {
345 let cur_range =
347 window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
348 if cur_range.end == length
350 && !is_causal
351 && not_end
352 && !is_end_bound_safe(
353 window_frame_ctx,
354 &order_bys,
355 most_recent_row_order_bys.as_deref(),
356 self.order_by(),
357 idx,
358 )?
359 {
360 break;
361 }
362 let value = self.get_aggregate_result_inside_range(
363 last_range,
364 &cur_range,
365 &values,
366 accumulator,
367 filter_mask,
368 )?;
369 *last_range = cur_range;
371 row_wise_results.push(value);
372 idx += 1;
373 }
374
375 if row_wise_results.is_empty() {
376 let field = self.field()?;
377 let out_type = field.data_type();
378 Ok(new_empty_array(out_type))
379 } else {
380 ScalarValue::iter_to_array(row_wise_results)
381 }
382 }
383}
384
385pub(crate) fn filter_array(array: &ArrayRef, mask: &BooleanArray) -> Result<ArrayRef> {
387 arrow_filter(array.as_ref(), mask)
388 .map(|a| a as ArrayRef)
389 .map_err(|e| arrow_datafusion_err!(e))
390}
391
392pub(crate) fn filter_arrays(
394 arrays: &[ArrayRef],
395 mask: &BooleanArray,
396) -> Result<Vec<ArrayRef>> {
397 arrays.iter().map(|arr| filter_array(arr, mask)).collect()
398}
399
400pub(crate) fn is_end_bound_safe(
418 window_frame_ctx: &WindowFrameContext,
419 order_bys: &[ArrayRef],
420 most_recent_order_bys: Option<&[ArrayRef]>,
421 sort_exprs: &[PhysicalSortExpr],
422 idx: usize,
423) -> Result<bool> {
424 if sort_exprs.is_empty() {
425 return Ok(false);
427 };
428
429 match window_frame_ctx {
430 WindowFrameContext::Rows(window_frame) => {
431 is_end_bound_safe_for_rows(&window_frame.end_bound)
432 }
433 WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
434 &window_frame.end_bound,
435 &order_bys[0],
436 most_recent_order_bys.map(|items| &items[0]),
437 &sort_exprs[0].options,
438 idx,
439 ),
440 WindowFrameContext::Groups {
441 window_frame,
442 state,
443 } => is_end_bound_safe_for_groups(
444 &window_frame.end_bound,
445 state,
446 &order_bys[0],
447 most_recent_order_bys.map(|items| &items[0]),
448 &sort_exprs[0].options,
449 ),
450 }
451}
452
453fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
466 if let WindowFrameBound::Following(value) = end_bound {
467 let zero = ScalarValue::new_zero(&value.data_type());
468 Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
469 } else {
470 Ok(true)
471 }
472}
473
474fn is_end_bound_safe_for_range(
491 end_bound: &WindowFrameBound,
492 orderby_col: &ArrayRef,
493 most_recent_ob_col: Option<&ArrayRef>,
494 sort_options: &SortOptions,
495 idx: usize,
496) -> Result<bool> {
497 match end_bound {
498 WindowFrameBound::Preceding(value) => {
499 let zero = ScalarValue::new_zero(&value.data_type())?;
500 if value.eq(&zero) {
501 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
502 } else {
503 Ok(true)
504 }
505 }
506 WindowFrameBound::CurrentRow => {
507 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
508 }
509 WindowFrameBound::Following(delta) => {
510 let Some(most_recent_ob_col) = most_recent_ob_col else {
511 return Ok(false);
512 };
513 let most_recent_row_value =
514 ScalarValue::try_from_array(most_recent_ob_col, 0)?;
515 let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
516
517 if sort_options.descending {
518 current_row_value
519 .sub(delta)
520 .map(|value| value > most_recent_row_value)
521 } else {
522 current_row_value
523 .add(delta)
524 .map(|value| most_recent_row_value > value)
525 }
526 }
527 }
528}
529
530fn is_end_bound_safe_for_groups(
547 end_bound: &WindowFrameBound,
548 state: &WindowFrameStateGroups,
549 orderby_col: &ArrayRef,
550 most_recent_ob_col: Option<&ArrayRef>,
551 sort_options: &SortOptions,
552) -> Result<bool> {
553 match end_bound {
554 WindowFrameBound::Preceding(value) => {
555 let zero = ScalarValue::new_zero(&value.data_type())?;
556 if value.eq(&zero) {
557 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
558 } else {
559 Ok(true)
560 }
561 }
562 WindowFrameBound::CurrentRow => {
563 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
564 }
565 WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
566 let delta = state.group_end_indices.len() - state.current_group_idx;
567 if delta == (*offset as usize) + 1 {
568 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
569 } else {
570 Ok(false)
571 }
572 }
573 _ => Ok(false),
574 }
575}
576
577fn is_row_ahead(
580 old_col: &ArrayRef,
581 current_col: Option<&ArrayRef>,
582 sort_options: &SortOptions,
583) -> Result<bool> {
584 let Some(current_col) = current_col else {
585 return Ok(false);
586 };
587 if old_col.is_empty() || current_col.is_empty() {
588 return Ok(false);
589 }
590 let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
591 let current_value = ScalarValue::try_from_array(current_col, 0)?;
592 let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
593 Ok(cmp.is_gt())
594}
595
596pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
598 order_by_columns.into_iter().map(|s| s.values).collect()
599}
600
601#[derive(Debug)]
602pub enum WindowFn {
603 Builtin(Box<dyn PartitionEvaluator>),
604 Aggregate(Box<dyn Accumulator>),
605}
606
607pub type PartitionKey = Vec<ScalarValue>;
612
613#[derive(Debug)]
614pub struct WindowState {
615 pub state: WindowAggState,
616 pub window_fn: WindowFn,
617}
618pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
619
620pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
622
623#[cfg(test)]
624mod tests {
625 use std::sync::Arc;
626
627 use crate::window::window_expr::is_row_ahead;
628
629 use arrow::array::{ArrayRef, Float64Array};
630 use arrow::compute::SortOptions;
631 use datafusion_common::Result;
632
633 #[test]
634 fn test_is_row_ahead() -> Result<()> {
635 let old_values: ArrayRef =
636 Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
637
638 let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
639 let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
640
641 assert!(is_row_ahead(
642 &old_values,
643 Some(&new_values1),
644 &SortOptions {
645 descending: false,
646 nulls_first: false
647 }
648 )?);
649 assert!(!is_row_ahead(
650 &old_values,
651 Some(&new_values2),
652 &SortOptions {
653 descending: false,
654 nulls_first: false
655 }
656 )?);
657
658 Ok(())
659 }
660}