1use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::array::BooleanArray;
26use arrow::array::{Array, ArrayRef, new_empty_array};
27use arrow::compute::SortOptions;
28use arrow::compute::filter as arrow_filter;
29use arrow::compute::kernels::sort::SortColumn;
30use arrow::datatypes::FieldRef;
31use arrow::record_batch::RecordBatch;
32use datafusion_common::cast::as_boolean_array;
33use datafusion_common::utils::compare_rows;
34use datafusion_common::{
35 Result, ScalarValue, arrow_datafusion_err, exec_datafusion_err, internal_err,
36};
37use datafusion_expr::window_state::{
38 PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
39};
40use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
41use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
42
43use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
44use indexmap::IndexMap;
45
46pub trait WindowExpr: Send + Sync + Debug {
72 fn as_any(&self) -> &dyn Any;
75
76 fn field(&self) -> Result<FieldRef>;
78
79 fn name(&self) -> &str {
82 "WindowExpr: default name"
83 }
84
85 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
89
90 fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
93 evaluate_expressions_to_arrays(&self.expressions(), batch)
94 }
95
96 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
98
99 fn evaluate_stateful(
102 &self,
103 _partition_batches: &PartitionBatches,
104 _window_agg_state: &mut PartitionWindowAggStates,
105 ) -> Result<()> {
106 internal_err!("evaluate_stateful is not implemented for {}", self.name())
107 }
108
109 fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
111
112 fn order_by(&self) -> &[PhysicalSortExpr];
114
115 fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
117 self.order_by()
118 .iter()
119 .map(|e| e.evaluate_to_sort_column(batch))
120 .collect()
121 }
122
123 fn get_window_frame(&self) -> &Arc<WindowFrame>;
125
126 fn uses_bounded_memory(&self) -> bool;
129
130 fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
132
133 fn create_window_fn(&self) -> Result<WindowFn>;
138
139 fn all_expressions(&self) -> WindowPhysicalExpressions {
142 let args = self.expressions();
143 let partition_by_exprs = self.partition_by().to_vec();
144 let order_by_exprs = self
145 .order_by()
146 .iter()
147 .map(|sort_expr| Arc::clone(&sort_expr.expr))
148 .collect();
149 WindowPhysicalExpressions {
150 args,
151 partition_by_exprs,
152 order_by_exprs,
153 }
154 }
155
156 fn with_new_expressions(
160 &self,
161 _args: Vec<Arc<dyn PhysicalExpr>>,
162 _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
163 _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
164 ) -> Option<Arc<dyn WindowExpr>> {
165 None
166 }
167}
168
169pub struct WindowPhysicalExpressions {
171 pub args: Vec<Arc<dyn PhysicalExpr>>,
173 pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
175 pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
177}
178
179pub trait AggregateWindowExpr: WindowExpr {
181 fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
185
186 fn filter_expr(&self) -> Option<&Arc<dyn PhysicalExpr>>;
188
189 fn get_aggregate_result_inside_range(
192 &self,
193 last_range: &Range<usize>,
194 cur_range: &Range<usize>,
195 value_slice: &[ArrayRef],
196 accumulator: &mut Box<dyn Accumulator>,
197 filter_mask: Option<&BooleanArray>,
198 ) -> Result<ScalarValue>;
199
200 fn is_constant_in_partition(&self) -> bool;
203
204 fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
206 let mut accumulator = self.get_accumulator()?;
207 let mut last_range = Range { start: 0, end: 0 };
208 let sort_options = self.order_by().iter().map(|o| o.options).collect();
209 let mut window_frame_ctx =
210 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
211 self.get_result_column(
212 &mut accumulator,
213 batch,
214 None,
215 &mut last_range,
216 &mut window_frame_ctx,
217 0,
218 false,
219 )
220 }
221
222 fn aggregate_evaluate_stateful(
225 &self,
226 partition_batches: &PartitionBatches,
227 window_agg_state: &mut PartitionWindowAggStates,
228 ) -> Result<()> {
229 let field = self.field()?;
230 let out_type = field.data_type();
231 for (partition_row, partition_batch_state) in partition_batches.iter() {
232 if !window_agg_state.contains_key(partition_row) {
233 let accumulator = self.get_accumulator()?;
234 window_agg_state.insert(
235 partition_row.clone(),
236 WindowState {
237 state: WindowAggState::new(out_type)?,
238 window_fn: WindowFn::Aggregate(accumulator),
239 },
240 );
241 };
242 let window_state = window_agg_state
243 .get_mut(partition_row)
244 .ok_or_else(|| exec_datafusion_err!("Cannot find state"))?;
245 let accumulator = match &mut window_state.window_fn {
246 WindowFn::Aggregate(accumulator) => accumulator,
247 _ => unreachable!(),
248 };
249 let state = &mut window_state.state;
250 let record_batch = &partition_batch_state.record_batch;
251 let most_recent_row = partition_batch_state.most_recent_row.as_ref();
252
253 let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
255 let sort_options = self.order_by().iter().map(|o| o.options).collect();
256 WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
257 });
258 let out_col = self.get_result_column(
259 accumulator,
260 record_batch,
261 most_recent_row,
262 &mut state.window_frame_range,
264 window_frame_ctx,
265 state.last_calculated_index,
266 !partition_batch_state.is_end,
267 )?;
268 state.update(&out_col, partition_batch_state)?;
269 }
270 Ok(())
271 }
272
273 #[expect(clippy::too_many_arguments)]
285 fn get_result_column(
286 &self,
287 accumulator: &mut Box<dyn Accumulator>,
288 record_batch: &RecordBatch,
289 most_recent_row: Option<&RecordBatch>,
290 last_range: &mut Range<usize>,
291 window_frame_ctx: &mut WindowFrameContext,
292 mut idx: usize,
293 not_end: bool,
294 ) -> Result<ArrayRef> {
295 let values = self.evaluate_args(record_batch)?;
296
297 let filter_mask_arr: Option<ArrayRef> = match self.filter_expr() {
299 Some(expr) => {
300 let value = expr.evaluate(record_batch)?;
301 Some(value.into_array(record_batch.num_rows())?)
302 }
303 None => None,
304 };
305
306 let filter_mask: Option<&BooleanArray> = match filter_mask_arr.as_deref() {
308 Some(arr) => Some(as_boolean_array(arr)?),
309 None => None,
310 };
311
312 if self.is_constant_in_partition() {
313 if not_end {
314 let field = self.field()?;
315 let out_type = field.data_type();
316 return Ok(new_empty_array(out_type));
317 }
318 let values = if let Some(mask) = filter_mask {
319 filter_arrays(&values, mask)?
321 } else {
322 values
323 };
324 accumulator.update_batch(&values)?;
325 let value = accumulator.evaluate()?;
326 return value.to_array_of_size(record_batch.num_rows());
327 }
328 let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
329 let most_recent_row_order_bys = most_recent_row
330 .map(|batch| self.order_by_columns(batch))
331 .transpose()?
332 .map(get_orderby_values);
333
334 let length = values[0].len();
336 let mut row_wise_results: Vec<ScalarValue> = vec![];
337 let is_causal = self.get_window_frame().is_causal();
338 while idx < length {
339 let cur_range =
341 window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
342 if cur_range.end == length
344 && !is_causal
345 && not_end
346 && !is_end_bound_safe(
347 window_frame_ctx,
348 &order_bys,
349 most_recent_row_order_bys.as_deref(),
350 self.order_by(),
351 idx,
352 )?
353 {
354 break;
355 }
356 let value = self.get_aggregate_result_inside_range(
357 last_range,
358 &cur_range,
359 &values,
360 accumulator,
361 filter_mask,
362 )?;
363 *last_range = cur_range;
365 row_wise_results.push(value);
366 idx += 1;
367 }
368
369 if row_wise_results.is_empty() {
370 let field = self.field()?;
371 let out_type = field.data_type();
372 Ok(new_empty_array(out_type))
373 } else {
374 ScalarValue::iter_to_array(row_wise_results)
375 }
376 }
377}
378
379pub(crate) fn filter_array(array: &ArrayRef, mask: &BooleanArray) -> Result<ArrayRef> {
381 arrow_filter(array.as_ref(), mask)
382 .map(|a| a as ArrayRef)
383 .map_err(|e| arrow_datafusion_err!(e))
384}
385
386pub(crate) fn filter_arrays(
388 arrays: &[ArrayRef],
389 mask: &BooleanArray,
390) -> Result<Vec<ArrayRef>> {
391 arrays.iter().map(|arr| filter_array(arr, mask)).collect()
392}
393
394pub(crate) fn is_end_bound_safe(
412 window_frame_ctx: &WindowFrameContext,
413 order_bys: &[ArrayRef],
414 most_recent_order_bys: Option<&[ArrayRef]>,
415 sort_exprs: &[PhysicalSortExpr],
416 idx: usize,
417) -> Result<bool> {
418 if sort_exprs.is_empty() {
419 return Ok(false);
421 };
422
423 match window_frame_ctx {
424 WindowFrameContext::Rows(window_frame) => {
425 is_end_bound_safe_for_rows(&window_frame.end_bound)
426 }
427 WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
428 &window_frame.end_bound,
429 &order_bys[0],
430 most_recent_order_bys.map(|items| &items[0]),
431 &sort_exprs[0].options,
432 idx,
433 ),
434 WindowFrameContext::Groups {
435 window_frame,
436 state,
437 } => is_end_bound_safe_for_groups(
438 &window_frame.end_bound,
439 state,
440 &order_bys[0],
441 most_recent_order_bys.map(|items| &items[0]),
442 &sort_exprs[0].options,
443 ),
444 }
445}
446
447fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
460 if let WindowFrameBound::Following(value) = end_bound {
461 let zero = ScalarValue::new_zero(&value.data_type());
462 Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
463 } else {
464 Ok(true)
465 }
466}
467
468fn is_end_bound_safe_for_range(
485 end_bound: &WindowFrameBound,
486 orderby_col: &ArrayRef,
487 most_recent_ob_col: Option<&ArrayRef>,
488 sort_options: &SortOptions,
489 idx: usize,
490) -> Result<bool> {
491 match end_bound {
492 WindowFrameBound::Preceding(value) => {
493 let zero = ScalarValue::new_zero(&value.data_type())?;
494 if value.eq(&zero) {
495 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
496 } else {
497 Ok(true)
498 }
499 }
500 WindowFrameBound::CurrentRow => {
501 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
502 }
503 WindowFrameBound::Following(delta) => {
504 let Some(most_recent_ob_col) = most_recent_ob_col else {
505 return Ok(false);
506 };
507 let most_recent_row_value =
508 ScalarValue::try_from_array(most_recent_ob_col, 0)?;
509 let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
510
511 if sort_options.descending {
512 current_row_value
513 .sub(delta)
514 .map(|value| value > most_recent_row_value)
515 } else {
516 current_row_value
517 .add(delta)
518 .map(|value| most_recent_row_value > value)
519 }
520 }
521 }
522}
523
524fn is_end_bound_safe_for_groups(
541 end_bound: &WindowFrameBound,
542 state: &WindowFrameStateGroups,
543 orderby_col: &ArrayRef,
544 most_recent_ob_col: Option<&ArrayRef>,
545 sort_options: &SortOptions,
546) -> Result<bool> {
547 match end_bound {
548 WindowFrameBound::Preceding(value) => {
549 let zero = ScalarValue::new_zero(&value.data_type())?;
550 if value.eq(&zero) {
551 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
552 } else {
553 Ok(true)
554 }
555 }
556 WindowFrameBound::CurrentRow => {
557 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
558 }
559 WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
560 let delta = state.group_end_indices.len() - state.current_group_idx;
561 if delta == (*offset as usize) + 1 {
562 is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
563 } else {
564 Ok(false)
565 }
566 }
567 _ => Ok(false),
568 }
569}
570
571fn is_row_ahead(
574 old_col: &ArrayRef,
575 current_col: Option<&ArrayRef>,
576 sort_options: &SortOptions,
577) -> Result<bool> {
578 let Some(current_col) = current_col else {
579 return Ok(false);
580 };
581 if old_col.is_empty() || current_col.is_empty() {
582 return Ok(false);
583 }
584 let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
585 let current_value = ScalarValue::try_from_array(current_col, 0)?;
586 let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
587 Ok(cmp.is_gt())
588}
589
590pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
592 order_by_columns.into_iter().map(|s| s.values).collect()
593}
594
595#[derive(Debug)]
596pub enum WindowFn {
597 Builtin(Box<dyn PartitionEvaluator>),
598 Aggregate(Box<dyn Accumulator>),
599}
600
601pub type PartitionKey = Vec<ScalarValue>;
606
607#[derive(Debug)]
608pub struct WindowState {
609 pub state: WindowAggState,
610 pub window_fn: WindowFn,
611}
612pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
613
614pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
616
617#[cfg(test)]
618mod tests {
619 use std::sync::Arc;
620
621 use crate::window::window_expr::is_row_ahead;
622
623 use arrow::array::{ArrayRef, Float64Array};
624 use arrow::compute::SortOptions;
625 use datafusion_common::Result;
626
627 #[test]
628 fn test_is_row_ahead() -> Result<()> {
629 let old_values: ArrayRef =
630 Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
631
632 let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
633 let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
634
635 assert!(is_row_ahead(
636 &old_values,
637 Some(&new_values1),
638 &SortOptions {
639 descending: false,
640 nulls_first: false
641 }
642 )?);
643 assert!(!is_row_ahead(
644 &old_values,
645 Some(&new_values2),
646 &SortOptions {
647 descending: false,
648 nulls_first: false
649 }
650 )?);
651
652 Ok(())
653 }
654}