datafusion_physical_expr/window/
window_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::Debug;
20use std::ops::Range;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24
25use arrow::array::{new_empty_array, Array, ArrayRef};
26use arrow::compute::kernels::sort::SortColumn;
27use arrow::compute::SortOptions;
28use arrow::datatypes::FieldRef;
29use arrow::record_batch::RecordBatch;
30use datafusion_common::utils::compare_rows;
31use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
32use datafusion_expr::window_state::{
33    PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups,
34};
35use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound};
36use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
37
38use indexmap::IndexMap;
39
40/// Common trait for [window function] implementations
41///
42/// # Aggregate Window Expressions
43///
44/// These expressions take the form
45///
46/// ```text
47/// OVER({ROWS | RANGE| GROUPS} BETWEEN UNBOUNDED PRECEDING AND ...)
48/// ```
49///
50/// For example, cumulative window frames uses `PlainAggregateWindowExpr`.
51///
52/// # Non Aggregate Window Expressions
53///
54/// The expressions have the form
55///
56/// ```text
57/// OVER({ROWS | RANGE| GROUPS} BETWEEN M {PRECEDING| FOLLOWING} AND ...)
58/// ```
59///
60/// For example, sliding window frames use [`SlidingAggregateWindowExpr`].
61///
62/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
63/// [`PlainAggregateWindowExpr`]: crate::window::PlainAggregateWindowExpr
64/// [`SlidingAggregateWindowExpr`]: crate::window::SlidingAggregateWindowExpr
65pub trait WindowExpr: Send + Sync + Debug {
66    /// Returns the window expression as [`Any`] so that it can be
67    /// downcast to a specific implementation.
68    fn as_any(&self) -> &dyn Any;
69
70    /// The field of the final result of this window function.
71    fn field(&self) -> Result<FieldRef>;
72
73    /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default
74    /// implementation returns placeholder text.
75    fn name(&self) -> &str {
76        "WindowExpr: default name"
77    }
78
79    /// Expressions that are passed to the WindowAccumulator.
80    /// Functions which take a single input argument, such as `sum`, return a single [`datafusion_expr::expr::Expr`],
81    /// others (e.g. `cov`) return many.
82    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
83
84    /// Evaluate the window function arguments against the batch and return
85    /// array ref, normally the resulting `Vec` is a single element one.
86    fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
87        self.expressions()
88            .iter()
89            .map(|e| {
90                e.evaluate(batch)
91                    .and_then(|v| v.into_array(batch.num_rows()))
92            })
93            .collect()
94    }
95
96    /// Evaluate the window function values against the batch
97    fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
98
99    /// Evaluate the window function against the batch. This function facilitates
100    /// stateful, bounded-memory implementations.
101    fn evaluate_stateful(
102        &self,
103        _partition_batches: &PartitionBatches,
104        _window_agg_state: &mut PartitionWindowAggStates,
105    ) -> Result<()> {
106        internal_err!("evaluate_stateful is not implemented for {}", self.name())
107    }
108
109    /// Expressions that's from the window function's partition by clause, empty if absent
110    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
111
112    /// Expressions that's from the window function's order by clause, empty if absent
113    fn order_by(&self) -> &[PhysicalSortExpr];
114
115    /// Get order by columns, empty if absent
116    fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
117        self.order_by()
118            .iter()
119            .map(|e| e.evaluate_to_sort_column(batch))
120            .collect()
121    }
122
123    /// Get the window frame of this [WindowExpr].
124    fn get_window_frame(&self) -> &Arc<WindowFrame>;
125
126    /// Return a flag indicating whether this [WindowExpr] can run with
127    /// bounded memory.
128    fn uses_bounded_memory(&self) -> bool;
129
130    /// Get the reverse expression of this [WindowExpr].
131    fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
132
133    /// Returns all expressions used in the [`WindowExpr`].
134    /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions.
135    fn all_expressions(&self) -> WindowPhysicalExpressions {
136        let args = self.expressions();
137        let partition_by_exprs = self.partition_by().to_vec();
138        let order_by_exprs = self
139            .order_by()
140            .iter()
141            .map(|sort_expr| Arc::clone(&sort_expr.expr))
142            .collect();
143        WindowPhysicalExpressions {
144            args,
145            partition_by_exprs,
146            order_by_exprs,
147        }
148    }
149
150    /// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent
151    /// with the return value of the [`WindowExpr::all_expressions`] method.
152    /// Returns `Some(Arc<dyn WindowExpr>)` if re-write is supported, otherwise returns `None`.
153    fn with_new_expressions(
154        &self,
155        _args: Vec<Arc<dyn PhysicalExpr>>,
156        _partition_bys: Vec<Arc<dyn PhysicalExpr>>,
157        _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
158    ) -> Option<Arc<dyn WindowExpr>> {
159        None
160    }
161}
162
163/// Stores the physical expressions used inside the `WindowExpr`.
164pub struct WindowPhysicalExpressions {
165    /// Window function arguments
166    pub args: Vec<Arc<dyn PhysicalExpr>>,
167    /// PARTITION BY expressions
168    pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
169    /// ORDER BY expressions
170    pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
171}
172
173/// Extension trait that adds common functionality to [`AggregateWindowExpr`]s
174pub trait AggregateWindowExpr: WindowExpr {
175    /// Get the accumulator for the window expression. Note that distinct
176    /// window expressions may return distinct accumulators; e.g. sliding
177    /// (non-sliding) expressions will return sliding (normal) accumulators.
178    fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
179
180    /// Given current range and the last range, calculates the accumulator
181    /// result for the range of interest.
182    fn get_aggregate_result_inside_range(
183        &self,
184        last_range: &Range<usize>,
185        cur_range: &Range<usize>,
186        value_slice: &[ArrayRef],
187        accumulator: &mut Box<dyn Accumulator>,
188    ) -> Result<ScalarValue>;
189
190    /// Indicates whether this window function always produces the same result
191    /// for all rows in the partition.
192    fn is_constant_in_partition(&self) -> bool;
193
194    /// Evaluates the window function against the batch.
195    fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
196        let mut accumulator = self.get_accumulator()?;
197        let mut last_range = Range { start: 0, end: 0 };
198        let sort_options = self.order_by().iter().map(|o| o.options).collect();
199        let mut window_frame_ctx =
200            WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options);
201        self.get_result_column(
202            &mut accumulator,
203            batch,
204            None,
205            &mut last_range,
206            &mut window_frame_ctx,
207            0,
208            false,
209        )
210    }
211
212    /// Statefully evaluates the window function against the batch. Maintains
213    /// state so that it can work incrementally over multiple chunks.
214    fn aggregate_evaluate_stateful(
215        &self,
216        partition_batches: &PartitionBatches,
217        window_agg_state: &mut PartitionWindowAggStates,
218    ) -> Result<()> {
219        let field = self.field()?;
220        let out_type = field.data_type();
221        for (partition_row, partition_batch_state) in partition_batches.iter() {
222            if !window_agg_state.contains_key(partition_row) {
223                let accumulator = self.get_accumulator()?;
224                window_agg_state.insert(
225                    partition_row.clone(),
226                    WindowState {
227                        state: WindowAggState::new(out_type)?,
228                        window_fn: WindowFn::Aggregate(accumulator),
229                    },
230                );
231            };
232            let window_state =
233                window_agg_state.get_mut(partition_row).ok_or_else(|| {
234                    DataFusionError::Execution("Cannot find state".to_string())
235                })?;
236            let accumulator = match &mut window_state.window_fn {
237                WindowFn::Aggregate(accumulator) => accumulator,
238                _ => unreachable!(),
239            };
240            let state = &mut window_state.state;
241            let record_batch = &partition_batch_state.record_batch;
242            let most_recent_row = partition_batch_state.most_recent_row.as_ref();
243
244            // If there is no window state context, initialize it.
245            let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| {
246                let sort_options = self.order_by().iter().map(|o| o.options).collect();
247                WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options)
248            });
249            let out_col = self.get_result_column(
250                accumulator,
251                record_batch,
252                most_recent_row,
253                // Start search from the last range
254                &mut state.window_frame_range,
255                window_frame_ctx,
256                state.last_calculated_index,
257                !partition_batch_state.is_end,
258            )?;
259            state.update(&out_col, partition_batch_state)?;
260        }
261        Ok(())
262    }
263
264    /// Calculates the window expression result for the given record batch.
265    /// Assumes that `record_batch` belongs to a single partition.
266    ///
267    /// # Arguments
268    /// * `accumulator`: The accumulator to use for the calculation.
269    /// * `record_batch`: batch belonging to the current partition (see [`PartitionBatchState`]).
270    /// * `most_recent_row`: the batch that contains the most recent row, if available (see [`PartitionBatchState`]).
271    /// * `last_range`: The last range of rows that were processed (see [`WindowAggState`]).
272    /// * `window_frame_ctx`: Details about the window frame (see [`WindowFrameContext`]).
273    /// * `idx`: The index of the current row in the record batch.
274    /// * `not_end`: is the current row not the end of the partition (see [`PartitionBatchState`]).
275    #[allow(clippy::too_many_arguments)]
276    fn get_result_column(
277        &self,
278        accumulator: &mut Box<dyn Accumulator>,
279        record_batch: &RecordBatch,
280        most_recent_row: Option<&RecordBatch>,
281        last_range: &mut Range<usize>,
282        window_frame_ctx: &mut WindowFrameContext,
283        mut idx: usize,
284        not_end: bool,
285    ) -> Result<ArrayRef> {
286        let values = self.evaluate_args(record_batch)?;
287
288        if self.is_constant_in_partition() {
289            if not_end {
290                let field = self.field()?;
291                let out_type = field.data_type();
292                return Ok(new_empty_array(out_type));
293            }
294            accumulator.update_batch(&values)?;
295            let value = accumulator.evaluate()?;
296            return value.to_array_of_size(record_batch.num_rows());
297        }
298        let order_bys = get_orderby_values(self.order_by_columns(record_batch)?);
299        let most_recent_row_order_bys = most_recent_row
300            .map(|batch| self.order_by_columns(batch))
301            .transpose()?
302            .map(get_orderby_values);
303
304        // We iterate on each row to perform a running calculation.
305        let length = values[0].len();
306        let mut row_wise_results: Vec<ScalarValue> = vec![];
307        let is_causal = self.get_window_frame().is_causal();
308        while idx < length {
309            // Start search from the last_range. This squeezes searched range.
310            let cur_range =
311                window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?;
312            // Exit if the range is non-causal and extends all the way:
313            if cur_range.end == length
314                && !is_causal
315                && not_end
316                && !is_end_bound_safe(
317                    window_frame_ctx,
318                    &order_bys,
319                    most_recent_row_order_bys.as_deref(),
320                    self.order_by(),
321                    idx,
322                )?
323            {
324                break;
325            }
326            let value = self.get_aggregate_result_inside_range(
327                last_range,
328                &cur_range,
329                &values,
330                accumulator,
331            )?;
332            // Update last range
333            *last_range = cur_range;
334            row_wise_results.push(value);
335            idx += 1;
336        }
337
338        if row_wise_results.is_empty() {
339            let field = self.field()?;
340            let out_type = field.data_type();
341            Ok(new_empty_array(out_type))
342        } else {
343            ScalarValue::iter_to_array(row_wise_results)
344        }
345    }
346}
347
348/// Determines whether the end bound calculation for a window frame context is
349/// safe, meaning that the end bound stays the same, regardless of future data,
350/// based on the current sort expressions and ORDER BY columns. This function
351/// delegates work to specific functions for each frame type.
352///
353/// # Parameters
354///
355/// * `window_frame_ctx`: The context of the window frame being evaluated.
356/// * `order_bys`: A slice of `ArrayRef` representing the ORDER BY columns.
357/// * `most_recent_order_bys`: An optional reference to the most recent ORDER BY
358///   columns.
359/// * `sort_exprs`: Defines the lexicographical ordering in question.
360/// * `idx`: The current index in the window frame.
361///
362/// # Returns
363///
364/// A `Result` which is `Ok(true)` if the end bound is safe, `Ok(false)` otherwise.
365pub(crate) fn is_end_bound_safe(
366    window_frame_ctx: &WindowFrameContext,
367    order_bys: &[ArrayRef],
368    most_recent_order_bys: Option<&[ArrayRef]>,
369    sort_exprs: &[PhysicalSortExpr],
370    idx: usize,
371) -> Result<bool> {
372    if sort_exprs.is_empty() {
373        // Early return if no sort expressions are present:
374        return Ok(false);
375    };
376
377    match window_frame_ctx {
378        WindowFrameContext::Rows(window_frame) => {
379            is_end_bound_safe_for_rows(&window_frame.end_bound)
380        }
381        WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range(
382            &window_frame.end_bound,
383            &order_bys[0],
384            most_recent_order_bys.map(|items| &items[0]),
385            &sort_exprs[0].options,
386            idx,
387        ),
388        WindowFrameContext::Groups {
389            window_frame,
390            state,
391        } => is_end_bound_safe_for_groups(
392            &window_frame.end_bound,
393            state,
394            &order_bys[0],
395            most_recent_order_bys.map(|items| &items[0]),
396            &sort_exprs[0].options,
397        ),
398    }
399}
400
401/// For row-based window frames, determines whether the end bound calculation
402/// is safe, which is trivially the case for `Preceding` and `CurrentRow` bounds.
403/// For 'Following' bounds, it compares the bound value to zero to ensure that
404/// it doesn't extend beyond the current row.
405///
406/// # Parameters
407///
408/// * `end_bound`: Reference to the window frame bound in question.
409///
410/// # Returns
411///
412/// A `Result` indicating whether the end bound is safe for row-based window frames.
413fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> {
414    if let WindowFrameBound::Following(value) = end_bound {
415        let zero = ScalarValue::new_zero(&value.data_type());
416        Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false))
417    } else {
418        Ok(true)
419    }
420}
421
422/// For row-based window frames, determines whether the end bound calculation
423/// is safe by comparing it against specific values (zero, current row). It uses
424/// the `is_row_ahead` helper function to determine if the current row is ahead
425/// of the most recent row based on the ORDER BY column and sorting options.
426///
427/// # Parameters
428///
429/// * `end_bound`: Reference to the window frame bound in question.
430/// * `orderby_col`: Reference to the column used for ordering.
431/// * `most_recent_ob_col`: Optional reference to the most recent order-by column.
432/// * `sort_options`: The sorting options used in the window frame.
433/// * `idx`: The current index in the window frame.
434///
435/// # Returns
436///
437/// A `Result` indicating whether the end bound is safe for range-based window frames.
438fn is_end_bound_safe_for_range(
439    end_bound: &WindowFrameBound,
440    orderby_col: &ArrayRef,
441    most_recent_ob_col: Option<&ArrayRef>,
442    sort_options: &SortOptions,
443    idx: usize,
444) -> Result<bool> {
445    match end_bound {
446        WindowFrameBound::Preceding(value) => {
447            let zero = ScalarValue::new_zero(&value.data_type())?;
448            if value.eq(&zero) {
449                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
450            } else {
451                Ok(true)
452            }
453        }
454        WindowFrameBound::CurrentRow => {
455            is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
456        }
457        WindowFrameBound::Following(delta) => {
458            let Some(most_recent_ob_col) = most_recent_ob_col else {
459                return Ok(false);
460            };
461            let most_recent_row_value =
462                ScalarValue::try_from_array(most_recent_ob_col, 0)?;
463            let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?;
464
465            if sort_options.descending {
466                current_row_value
467                    .sub(delta)
468                    .map(|value| value > most_recent_row_value)
469            } else {
470                current_row_value
471                    .add(delta)
472                    .map(|value| most_recent_row_value > value)
473            }
474        }
475    }
476}
477
478/// For group-based window frames, determines whether the end bound calculation
479/// is safe by considering the group offset and whether the current row is ahead
480/// of the most recent row in terms of sorting. It checks if the end bound is
481/// within the bounds of the current group based on group end indices.
482///
483/// # Parameters
484///
485/// * `end_bound`: Reference to the window frame bound in question.
486/// * `state`: The state of the window frame for group calculations.
487/// * `orderby_col`: Reference to the column used for ordering.
488/// * `most_recent_ob_col`: Optional reference to the most recent order-by column.
489/// * `sort_options`: The sorting options used in the window frame.
490///
491/// # Returns
492///
493/// A `Result` indicating whether the end bound is safe for group-based window frames.
494fn is_end_bound_safe_for_groups(
495    end_bound: &WindowFrameBound,
496    state: &WindowFrameStateGroups,
497    orderby_col: &ArrayRef,
498    most_recent_ob_col: Option<&ArrayRef>,
499    sort_options: &SortOptions,
500) -> Result<bool> {
501    match end_bound {
502        WindowFrameBound::Preceding(value) => {
503            let zero = ScalarValue::new_zero(&value.data_type())?;
504            if value.eq(&zero) {
505                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
506            } else {
507                Ok(true)
508            }
509        }
510        WindowFrameBound::CurrentRow => {
511            is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
512        }
513        WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => {
514            let delta = state.group_end_indices.len() - state.current_group_idx;
515            if delta == (*offset as usize) + 1 {
516                is_row_ahead(orderby_col, most_recent_ob_col, sort_options)
517            } else {
518                Ok(false)
519            }
520        }
521        _ => Ok(false),
522    }
523}
524
525/// This utility function checks whether `current_cols` is ahead of the `old_cols`
526/// in terms of `sort_options`.
527fn is_row_ahead(
528    old_col: &ArrayRef,
529    current_col: Option<&ArrayRef>,
530    sort_options: &SortOptions,
531) -> Result<bool> {
532    let Some(current_col) = current_col else {
533        return Ok(false);
534    };
535    if old_col.is_empty() || current_col.is_empty() {
536        return Ok(false);
537    }
538    let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?;
539    let current_value = ScalarValue::try_from_array(current_col, 0)?;
540    let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?;
541    Ok(cmp.is_gt())
542}
543
544/// Get order by expression results inside `order_by_columns`.
545pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> {
546    order_by_columns.into_iter().map(|s| s.values).collect()
547}
548
549#[derive(Debug)]
550pub enum WindowFn {
551    Builtin(Box<dyn PartitionEvaluator>),
552    Aggregate(Box<dyn Accumulator>),
553}
554
555/// Key for IndexMap for each unique partition
556///
557/// For instance, if window frame is `OVER(PARTITION BY a,b)`,
558/// PartitionKey would consist of unique `[a,b]` pairs
559pub type PartitionKey = Vec<ScalarValue>;
560
561#[derive(Debug)]
562pub struct WindowState {
563    pub state: WindowAggState,
564    pub window_fn: WindowFn,
565}
566pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
567
568/// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition.
569pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
570
571#[cfg(test)]
572mod tests {
573    use std::sync::Arc;
574
575    use crate::window::window_expr::is_row_ahead;
576
577    use arrow::array::{ArrayRef, Float64Array};
578    use arrow::compute::SortOptions;
579    use datafusion_common::Result;
580
581    #[test]
582    fn test_is_row_ahead() -> Result<()> {
583        let old_values: ArrayRef =
584            Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]));
585
586        let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0]));
587        let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0]));
588
589        assert!(is_row_ahead(
590            &old_values,
591            Some(&new_values1),
592            &SortOptions {
593                descending: false,
594                nulls_first: false
595            }
596        )?);
597        assert!(!is_row_ahead(
598            &old_values,
599            Some(&new_values2),
600            &SortOptions {
601                descending: false,
602                nulls_first: false
603            }
604        )?);
605
606        Ok(())
607    }
608}