Skip to main content

datafusion_physical_plan/windows/
bounded_window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19//! The executor given here uses bounded memory (does not maintain all
20//! the input data seen so far), which makes it appropriate when processing
21//! infinite inputs.
22
23use std::any::Any;
24use std::cmp::{Ordering, min};
25use std::collections::VecDeque;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use super::utils::create_schema;
31use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use crate::windows::{
33    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
34    window_equivalence_properties,
35};
36use crate::{
37    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
38    ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream,
39    SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties,
40};
41
42use arrow::compute::take_record_batch;
43use arrow::{
44    array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
45    compute::{concat, concat_batches, sort_to_indices, take_arrays},
46    datatypes::SchemaRef,
47    record_batch::RecordBatch,
48};
49use datafusion_common::hash_utils::create_hashes;
50use datafusion_common::stats::Precision;
51use datafusion_common::utils::{
52    evaluate_partition_ranges, get_at_indices, get_row_at_idx,
53};
54use datafusion_common::{
55    HashMap, Result, arrow_datafusion_err, exec_datafusion_err, exec_err,
56};
57use datafusion_execution::TaskContext;
58use datafusion_expr::ColumnarValue;
59use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
60use datafusion_physical_expr::window::{
61    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
62};
63use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
64use datafusion_physical_expr_common::sort_expr::{
65    OrderingRequirements, PhysicalSortExpr,
66};
67
68use ahash::RandomState;
69use futures::stream::Stream;
70use futures::{StreamExt, ready};
71use hashbrown::hash_table::HashTable;
72use indexmap::IndexMap;
73use log::debug;
74
75/// Window execution plan
76#[derive(Debug, Clone)]
77pub struct BoundedWindowAggExec {
78    /// Input plan
79    input: Arc<dyn ExecutionPlan>,
80    /// Window function expression
81    window_expr: Vec<Arc<dyn WindowExpr>>,
82    /// Schema after the window is run
83    schema: SchemaRef,
84    /// Execution metrics
85    metrics: ExecutionPlanMetricsSet,
86    /// Describes how the input is ordered relative to the partition keys
87    pub input_order_mode: InputOrderMode,
88    /// Partition by indices that define ordering
89    // For example, if input ordering is ORDER BY a, b and window expression
90    // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0.
91    // Similarly, if window expression contains PARTITION BY a, b; then
92    // `ordered_partition_by_indices` would be 0, 1.
93    // See `get_ordered_partition_by_indices` for more details.
94    ordered_partition_by_indices: Vec<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: Arc<PlanProperties>,
97    /// If `can_rerepartition` is false, partition_keys is always empty.
98    can_repartition: bool,
99}
100
101impl BoundedWindowAggExec {
102    /// Create a new execution plan for window aggregates
103    pub fn try_new(
104        window_expr: Vec<Arc<dyn WindowExpr>>,
105        input: Arc<dyn ExecutionPlan>,
106        input_order_mode: InputOrderMode,
107        can_repartition: bool,
108    ) -> Result<Self> {
109        let schema = create_schema(&input.schema(), &window_expr)?;
110        let schema = Arc::new(schema);
111        let partition_by_exprs = window_expr[0].partition_by();
112        let ordered_partition_by_indices = match &input_order_mode {
113            InputOrderMode::Sorted => {
114                let indices = get_ordered_partition_by_indices(
115                    window_expr[0].partition_by(),
116                    &input,
117                )?;
118                if indices.len() == partition_by_exprs.len() {
119                    indices
120                } else {
121                    (0..partition_by_exprs.len()).collect::<Vec<_>>()
122                }
123            }
124            InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(),
125            InputOrderMode::Linear => {
126                vec![]
127            }
128        };
129        let cache = Self::compute_properties(&input, &schema, &window_expr)?;
130        Ok(Self {
131            input,
132            window_expr,
133            schema,
134            metrics: ExecutionPlanMetricsSet::new(),
135            input_order_mode,
136            ordered_partition_by_indices,
137            cache: Arc::new(cache),
138            can_repartition,
139        })
140    }
141
142    /// Window expressions
143    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
144        &self.window_expr
145    }
146
147    /// Input plan
148    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
149        &self.input
150    }
151
152    /// Return the output sort order of partition keys: For example
153    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
154    // We are sure that partition by columns are always at the beginning of sort_keys
155    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
156    // to calculate partition separation points
157    pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
158        let partition_by = self.window_expr()[0].partition_by();
159        get_partition_by_sort_exprs(
160            &self.input,
161            partition_by,
162            &self.ordered_partition_by_indices,
163        )
164    }
165
166    /// Initializes the appropriate [`PartitionSearcher`] implementation from
167    /// the state.
168    fn get_search_algo(&self) -> Result<Box<dyn PartitionSearcher>> {
169        let partition_by_sort_keys = self.partition_by_sort_keys()?;
170        let ordered_partition_by_indices = self.ordered_partition_by_indices.clone();
171        let input_schema = self.input().schema();
172        Ok(match &self.input_order_mode {
173            InputOrderMode::Sorted => {
174                // In Sorted mode, all partition by columns should be ordered.
175                if self.window_expr()[0].partition_by().len()
176                    != ordered_partition_by_indices.len()
177                {
178                    return exec_err!(
179                        "All partition by columns should have an ordering in Sorted mode."
180                    );
181                }
182                Box::new(SortedSearch {
183                    partition_by_sort_keys,
184                    ordered_partition_by_indices,
185                    input_schema,
186                })
187            }
188            InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => Box::new(
189                LinearSearch::new(ordered_partition_by_indices, input_schema),
190            ),
191        })
192    }
193
194    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
195    fn compute_properties(
196        input: &Arc<dyn ExecutionPlan>,
197        schema: &SchemaRef,
198        window_exprs: &[Arc<dyn WindowExpr>],
199    ) -> Result<PlanProperties> {
200        // Calculate equivalence properties:
201        let eq_properties = window_equivalence_properties(schema, input, window_exprs)?;
202
203        // As we can have repartitioning using the partition keys, this can
204        // be either one or more than one, depending on the presence of
205        // repartitioning.
206        let output_partitioning = input.output_partitioning().clone();
207
208        // Construct properties cache
209        Ok(PlanProperties::new(
210            eq_properties,
211            output_partitioning,
212            // TODO: Emission type and boundedness information can be enhanced here
213            input.pipeline_behavior(),
214            input.boundedness(),
215        ))
216    }
217
218    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
219        if !self.can_repartition {
220            vec![]
221        } else {
222            let all_partition_keys = self
223                .window_expr()
224                .iter()
225                .map(|expr| expr.partition_by().to_vec())
226                .collect::<Vec<_>>();
227
228            all_partition_keys
229                .into_iter()
230                .min_by_key(|s| s.len())
231                .unwrap_or_else(Vec::new)
232        }
233    }
234
235    fn statistics_helper(&self, statistics: Statistics) -> Result<Statistics> {
236        let win_cols = self.window_expr.len();
237        let input_cols = self.input.schema().fields().len();
238        // TODO stats: some windowing function will maintain invariants such as min, max...
239        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
240        // copy stats of the input to the beginning of the schema.
241        column_statistics.extend(statistics.column_statistics);
242        for _ in 0..win_cols {
243            column_statistics.push(ColumnStatistics::new_unknown())
244        }
245        Ok(Statistics {
246            num_rows: statistics.num_rows,
247            column_statistics,
248            total_byte_size: Precision::Absent,
249        })
250    }
251
252    fn with_new_children_and_same_properties(
253        &self,
254        mut children: Vec<Arc<dyn ExecutionPlan>>,
255    ) -> Self {
256        Self {
257            input: children.swap_remove(0),
258            metrics: ExecutionPlanMetricsSet::new(),
259            ..Self::clone(self)
260        }
261    }
262}
263
264impl DisplayAs for BoundedWindowAggExec {
265    fn fmt_as(
266        &self,
267        t: DisplayFormatType,
268        f: &mut std::fmt::Formatter,
269    ) -> std::fmt::Result {
270        match t {
271            DisplayFormatType::Default | DisplayFormatType::Verbose => {
272                write!(f, "BoundedWindowAggExec: ")?;
273                let g: Vec<String> = self
274                    .window_expr
275                    .iter()
276                    .map(|e| {
277                        let field = match e.field() {
278                            Ok(f) => f.to_string(),
279                            Err(e) => format!("{e:?}"),
280                        };
281                        format!(
282                            "{}: {}, frame: {}",
283                            e.name().to_owned(),
284                            field,
285                            e.get_window_frame()
286                        )
287                    })
288                    .collect();
289                let mode = &self.input_order_mode;
290                write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?;
291            }
292            DisplayFormatType::TreeRender => {
293                let g: Vec<String> = self
294                    .window_expr
295                    .iter()
296                    .map(|e| e.name().to_owned().to_string())
297                    .collect();
298                writeln!(f, "select_list={}", g.join(", "))?;
299
300                let mode = &self.input_order_mode;
301                writeln!(f, "mode={mode:?}")?;
302            }
303        }
304        Ok(())
305    }
306}
307
308impl ExecutionPlan for BoundedWindowAggExec {
309    fn name(&self) -> &'static str {
310        "BoundedWindowAggExec"
311    }
312
313    /// Return a reference to Any that can be used for downcasting
314    fn as_any(&self) -> &dyn Any {
315        self
316    }
317
318    fn properties(&self) -> &Arc<PlanProperties> {
319        &self.cache
320    }
321
322    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
323        vec![&self.input]
324    }
325
326    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
327        let partition_bys = self.window_expr()[0].partition_by();
328        let order_keys = self.window_expr()[0].order_by();
329        let partition_bys = self
330            .ordered_partition_by_indices
331            .iter()
332            .map(|idx| &partition_bys[*idx]);
333        vec![calc_requirements(partition_bys, order_keys)]
334    }
335
336    fn required_input_distribution(&self) -> Vec<Distribution> {
337        if self.partition_keys().is_empty() {
338            debug!("No partition defined for BoundedWindowAggExec!!!");
339            vec![Distribution::SinglePartition]
340        } else {
341            vec![Distribution::HashPartitioned(self.partition_keys().clone())]
342        }
343    }
344
345    fn maintains_input_order(&self) -> Vec<bool> {
346        vec![true]
347    }
348
349    fn with_new_children(
350        self: Arc<Self>,
351        children: Vec<Arc<dyn ExecutionPlan>>,
352    ) -> Result<Arc<dyn ExecutionPlan>> {
353        check_if_same_properties!(self, children);
354        Ok(Arc::new(BoundedWindowAggExec::try_new(
355            self.window_expr.clone(),
356            Arc::clone(&children[0]),
357            self.input_order_mode.clone(),
358            self.can_repartition,
359        )?))
360    }
361
362    fn execute(
363        &self,
364        partition: usize,
365        context: Arc<TaskContext>,
366    ) -> Result<SendableRecordBatchStream> {
367        let input = self.input.execute(partition, context)?;
368        let search_mode = self.get_search_algo()?;
369        let stream = Box::pin(BoundedWindowAggStream::new(
370            Arc::clone(&self.schema),
371            self.window_expr.clone(),
372            input,
373            BaselineMetrics::new(&self.metrics, partition),
374            search_mode,
375        )?);
376        Ok(stream)
377    }
378
379    fn metrics(&self) -> Option<MetricsSet> {
380        Some(self.metrics.clone_inner())
381    }
382
383    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
384        let input_stat = self.input.partition_statistics(partition)?;
385        self.statistics_helper(input_stat)
386    }
387}
388
389/// Trait that specifies how we search for (or calculate) partitions. It has two
390/// implementations: [`SortedSearch`] and [`LinearSearch`].
391trait PartitionSearcher: Send {
392    /// This method constructs output columns using the result of each window expression
393    /// (each entry in the output vector comes from a window expression).
394    /// Executor when producing output concatenates `input_buffer` (corresponding section), and
395    /// result of this function to generate output `RecordBatch`. `input_buffer` is used to determine
396    /// which sections of the window expression results should be used to generate output.
397    /// `partition_buffers` contains corresponding section of the `RecordBatch` for each partition.
398    /// `window_agg_states` stores per partition state for each window expression.
399    /// None case means that no result is generated
400    /// `Some(Vec<ArrayRef>)` is the result of each window expression.
401    fn calculate_out_columns(
402        &mut self,
403        input_buffer: &RecordBatch,
404        window_agg_states: &[PartitionWindowAggStates],
405        partition_buffers: &mut PartitionBatches,
406        window_expr: &[Arc<dyn WindowExpr>],
407    ) -> Result<Option<Vec<ArrayRef>>>;
408
409    /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not.
410    fn is_mode_linear(&self) -> bool {
411        false
412    }
413
414    // Constructs corresponding batches for each partition for the record_batch.
415    fn evaluate_partition_batches(
416        &mut self,
417        record_batch: &RecordBatch,
418        window_expr: &[Arc<dyn WindowExpr>],
419    ) -> Result<Vec<(PartitionKey, RecordBatch)>>;
420
421    /// Prunes the state.
422    fn prune(&mut self, _n_out: usize) {}
423
424    /// Marks the partition as done if we are sure that corresponding partition
425    /// cannot receive any more values.
426    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches);
427
428    /// Updates `input_buffer` and `partition_buffers` with the new `record_batch`.
429    fn update_partition_batch(
430        &mut self,
431        input_buffer: &mut RecordBatch,
432        record_batch: RecordBatch,
433        window_expr: &[Arc<dyn WindowExpr>],
434        partition_buffers: &mut PartitionBatches,
435    ) -> Result<()> {
436        if record_batch.num_rows() == 0 {
437            return Ok(());
438        }
439        let partition_batches =
440            self.evaluate_partition_batches(&record_batch, window_expr)?;
441        for (partition_row, partition_batch) in partition_batches {
442            if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row)
443            {
444                partition_batch_state.extend(&partition_batch)?
445            } else {
446                let options = RecordBatchOptions::new()
447                    .with_row_count(Some(partition_batch.num_rows()));
448                // Use input_schema for the buffer schema, not `record_batch.schema()`
449                // as it may not have the "correct" schema in terms of output
450                // nullability constraints. For details, see the following issue:
451                // https://github.com/apache/datafusion/issues/9320
452                let partition_batch = RecordBatch::try_new_with_options(
453                    Arc::clone(self.input_schema()),
454                    partition_batch.columns().to_vec(),
455                    &options,
456                )?;
457                let partition_batch_state =
458                    PartitionBatchState::new_with_batch(partition_batch);
459                partition_buffers.insert(partition_row, partition_batch_state);
460            }
461        }
462
463        if self.is_mode_linear() {
464            // In `Linear` mode, it is guaranteed that the first ORDER BY column
465            // is sorted across partitions. Note that only the first ORDER BY
466            // column is guaranteed to be ordered. As a counter example, consider
467            // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted
468            // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be
469            // `Linear`. However, we cannot guarantee that the last row of the
470            // input data will be the "last" data in terms of the ordering requirement
471            // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`.
472            // Hence, only column `a` should be used as a guarantee of the "last"
473            // data across partitions. For other modes (`Sorted`, `PartiallySorted`),
474            // we do not need to keep track of the most recent row guarantee across
475            // partitions. Since leading ordering separates partitions, guaranteed
476            // by the most recent row, already prune the previous partitions completely.
477            let last_row = get_last_row_batch(&record_batch)?;
478            for (_, partition_batch) in partition_buffers.iter_mut() {
479                partition_batch.set_most_recent_row(last_row.clone());
480            }
481        }
482        self.mark_partition_end(partition_buffers);
483
484        *input_buffer = if input_buffer.num_rows() == 0 {
485            record_batch
486        } else {
487            concat_batches(self.input_schema(), [input_buffer, &record_batch])?
488        };
489
490        Ok(())
491    }
492
493    fn input_schema(&self) -> &SchemaRef;
494}
495
496/// This object encapsulates the algorithm state for a simple linear scan
497/// algorithm for computing partitions.
498pub struct LinearSearch {
499    /// Keeps the hash of input buffer calculated from PARTITION BY columns.
500    /// Its length is equal to the `input_buffer` length.
501    input_buffer_hashes: VecDeque<u64>,
502    /// Used during hash value calculation.
503    random_state: RandomState,
504    /// Input ordering and partition by key ordering need not be the same, so
505    /// this vector stores the mapping between them. For instance, if the input
506    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
507    /// clause, this attribute stores [1, 0].
508    ordered_partition_by_indices: Vec<usize>,
509    /// We use this [`HashTable`] to calculate unique partitions for each new
510    /// RecordBatch. First entry in the tuple is the hash value, the second
511    /// entry is the unique ID for each partition (increments from 0 to n).
512    row_map_batch: HashTable<(u64, usize)>,
513    /// We use this [`HashTable`] to calculate the output columns that we can
514    /// produce at each cycle. First entry in the tuple is the hash value, the
515    /// second entry is the unique ID for each partition (increments from 0 to n).
516    /// The third entry stores how many new outputs are calculated for the
517    /// corresponding partition.
518    row_map_out: HashTable<(u64, usize, usize)>,
519    input_schema: SchemaRef,
520}
521
522impl PartitionSearcher for LinearSearch {
523    /// This method constructs output columns using the result of each window expression.
524    // Assume input buffer is         |      Partition Buffers would be (Where each partition and its data is separated)
525    // a, 2                           |      a, 2
526    // b, 2                           |      a, 2
527    // a, 2                           |      a, 2
528    // b, 2                           |
529    // a, 2                           |      b, 2
530    // b, 2                           |      b, 2
531    // b, 2                           |      b, 2
532    //                                |      b, 2
533    // Also assume we happen to calculate 2 new values for a, and 3 for b (To be calculate missing values we may need to consider future values).
534    // Partition buffers effectively will be
535    // a, 2, 1
536    // a, 2, 2
537    // a, 2, (missing)
538    //
539    // b, 2, 1
540    // b, 2, 2
541    // b, 2, 3
542    // b, 2, (missing)
543    // When partition buffers are mapped back to the original record batch. Result becomes
544    // a, 2, 1
545    // b, 2, 1
546    // a, 2, 2
547    // b, 2, 2
548    // a, 2, (missing)
549    // b, 2, 3
550    // b, 2, (missing)
551    // This function calculates the column result of window expression(s) (First 4 entry of 3rd column in the above section.)
552    // 1
553    // 1
554    // 2
555    // 2
556    // Above section corresponds to calculated result which can be emitted without breaking input buffer ordering.
557    fn calculate_out_columns(
558        &mut self,
559        input_buffer: &RecordBatch,
560        window_agg_states: &[PartitionWindowAggStates],
561        partition_buffers: &mut PartitionBatches,
562        window_expr: &[Arc<dyn WindowExpr>],
563    ) -> Result<Option<Vec<ArrayRef>>> {
564        let partition_output_indices = self.calc_partition_output_indices(
565            input_buffer,
566            window_agg_states,
567            window_expr,
568        )?;
569
570        let n_window_col = window_agg_states.len();
571        let mut new_columns = vec![vec![]; n_window_col];
572        // Size of all_indices can be at most input_buffer.num_rows():
573        let mut all_indices = UInt32Builder::with_capacity(input_buffer.num_rows());
574        for (row, indices) in partition_output_indices {
575            let length = indices.len();
576            for (idx, window_agg_state) in window_agg_states.iter().enumerate() {
577                let partition = &window_agg_state[&row];
578                let values = Arc::clone(&partition.state.out_col.slice(0, length));
579                new_columns[idx].push(values);
580            }
581            let partition_batch_state = &mut partition_buffers[&row];
582            // Store how many rows are generated for each partition
583            partition_batch_state.n_out_row = length;
584            // For each row keep corresponding index in the input record batch
585            all_indices.append_slice(&indices);
586        }
587        let all_indices = all_indices.finish();
588        if all_indices.is_empty() {
589            // We couldn't generate any new value, return early:
590            return Ok(None);
591        }
592
593        // Concatenate results for each column by converting `Vec<Vec<ArrayRef>>`
594        // to Vec<ArrayRef> where inner `Vec<ArrayRef>`s are converted to `ArrayRef`s.
595        let new_columns = new_columns
596            .iter()
597            .map(|items| {
598                concat(&items.iter().map(|e| e.as_ref()).collect::<Vec<_>>())
599                    .map_err(|e| arrow_datafusion_err!(e))
600            })
601            .collect::<Result<Vec<_>>>()?;
602        // We should emit columns according to row index ordering.
603        let sorted_indices = sort_to_indices(&all_indices, None, None)?;
604        // Construct new column according to row ordering. This fixes ordering
605        take_arrays(&new_columns, &sorted_indices, None)
606            .map(Some)
607            .map_err(|e| arrow_datafusion_err!(e))
608    }
609
610    fn evaluate_partition_batches(
611        &mut self,
612        record_batch: &RecordBatch,
613        window_expr: &[Arc<dyn WindowExpr>],
614    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
615        let partition_bys =
616            evaluate_partition_by_column_values(record_batch, window_expr)?;
617        // NOTE: In Linear or PartiallySorted modes, we are sure that
618        //       `partition_bys` are not empty.
619        // Calculate indices for each partition and construct a new record
620        // batch from the rows at these indices for each partition:
621        self.get_per_partition_indices(&partition_bys, record_batch)?
622            .into_iter()
623            .map(|(row, indices)| {
624                let mut new_indices = UInt32Builder::with_capacity(indices.len());
625                new_indices.append_slice(&indices);
626                let indices = new_indices.finish();
627                Ok((row, take_record_batch(record_batch, &indices)?))
628            })
629            .collect()
630    }
631
632    fn prune(&mut self, n_out: usize) {
633        // Delete hashes for the rows that are outputted.
634        self.input_buffer_hashes.drain(0..n_out);
635    }
636
637    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
638        // We should be in the `PartiallySorted` case, otherwise we can not
639        // tell when we are at the end of a given partition.
640        if !self.ordered_partition_by_indices.is_empty()
641            && let Some((last_row, _)) = partition_buffers.last()
642        {
643            let last_sorted_cols = self
644                .ordered_partition_by_indices
645                .iter()
646                .map(|idx| last_row[*idx].clone())
647                .collect::<Vec<_>>();
648            for (row, partition_batch_state) in partition_buffers.iter_mut() {
649                let sorted_cols = self
650                    .ordered_partition_by_indices
651                    .iter()
652                    .map(|idx| &row[*idx]);
653                // All the partitions other than `last_sorted_cols` are done.
654                // We are sure that we will no longer receive values for these
655                // partitions (arrival of a new value would violate ordering).
656                partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols);
657            }
658        }
659    }
660
661    fn is_mode_linear(&self) -> bool {
662        self.ordered_partition_by_indices.is_empty()
663    }
664
665    fn input_schema(&self) -> &SchemaRef {
666        &self.input_schema
667    }
668}
669
670impl LinearSearch {
671    /// Initialize a new [`LinearSearch`] partition searcher.
672    fn new(ordered_partition_by_indices: Vec<usize>, input_schema: SchemaRef) -> Self {
673        LinearSearch {
674            input_buffer_hashes: VecDeque::new(),
675            random_state: Default::default(),
676            ordered_partition_by_indices,
677            row_map_batch: HashTable::with_capacity(256),
678            row_map_out: HashTable::with_capacity(256),
679            input_schema,
680        }
681    }
682
683    /// Calculate indices of each partition (according to PARTITION BY expression)
684    /// `columns` contain partition by expression results.
685    fn get_per_partition_indices(
686        &mut self,
687        columns: &[ArrayRef],
688        batch: &RecordBatch,
689    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
690        let mut batch_hashes = vec![0; batch.num_rows()];
691        create_hashes(columns, &self.random_state, &mut batch_hashes)?;
692        self.input_buffer_hashes.extend(&batch_hashes);
693        // reset row_map for new calculation
694        self.row_map_batch.clear();
695        // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition.
696        let mut result: Vec<(PartitionKey, Vec<u32>)> = vec![];
697        for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) {
698            let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| {
699                // We can safely get the first index of the partition indices
700                // since partition indices has one element during initialization.
701                let row = get_row_at_idx(columns, row_idx as usize).unwrap();
702                // Handle hash collusions with an equality check:
703                row.eq(&result[*group_idx].0)
704            });
705            if let Some((_, group_idx)) = entry {
706                result[*group_idx].1.push(row_idx)
707            } else {
708                self.row_map_batch.insert_unique(
709                    hash,
710                    (hash, result.len()),
711                    |(hash, _)| *hash,
712                );
713                let row = get_row_at_idx(columns, row_idx as usize)?;
714                // This is a new partition its only index is row_idx for now.
715                result.push((row, vec![row_idx]));
716            }
717        }
718        Ok(result)
719    }
720
721    /// Calculates partition keys and result indices for each partition.
722    /// The return value is a vector of tuples where the first entry stores
723    /// the partition key (unique for each partition) and the second entry
724    /// stores indices of the rows for which the partition is constructed.
725    fn calc_partition_output_indices(
726        &mut self,
727        input_buffer: &RecordBatch,
728        window_agg_states: &[PartitionWindowAggStates],
729        window_expr: &[Arc<dyn WindowExpr>],
730    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
731        let partition_by_columns =
732            evaluate_partition_by_column_values(input_buffer, window_expr)?;
733        // Reset the row_map state:
734        self.row_map_out.clear();
735        let mut partition_indices: Vec<(PartitionKey, Vec<u32>)> = vec![];
736        for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) {
737            let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| {
738                let row =
739                    get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap();
740                row == partition_indices[*group_idx].0
741            });
742            if let Some((_, group_idx, n_out)) = entry {
743                let (_, indices) = &mut partition_indices[*group_idx];
744                if indices.len() >= *n_out {
745                    break;
746                }
747                indices.push(row_idx);
748            } else {
749                let row = get_row_at_idx(&partition_by_columns, row_idx as usize)?;
750                let min_out = window_agg_states
751                    .iter()
752                    .map(|window_agg_state| {
753                        window_agg_state
754                            .get(&row)
755                            .map(|partition| partition.state.out_col.len())
756                            .unwrap_or(0)
757                    })
758                    .min()
759                    .unwrap_or(0);
760                if min_out == 0 {
761                    break;
762                }
763                self.row_map_out.insert_unique(
764                    *hash,
765                    (*hash, partition_indices.len(), min_out),
766                    |(hash, _, _)| *hash,
767                );
768                partition_indices.push((row, vec![row_idx]));
769            }
770        }
771        Ok(partition_indices)
772    }
773}
774
775/// This object encapsulates the algorithm state for sorted searching
776/// when computing partitions.
777pub struct SortedSearch {
778    /// Stores partition by columns and their ordering information
779    partition_by_sort_keys: Vec<PhysicalSortExpr>,
780    /// Input ordering and partition by key ordering need not be the same, so
781    /// this vector stores the mapping between them. For instance, if the input
782    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
783    /// clause, this attribute stores [1, 0].
784    ordered_partition_by_indices: Vec<usize>,
785    input_schema: SchemaRef,
786}
787
788impl PartitionSearcher for SortedSearch {
789    /// This method constructs new output columns using the result of each window expression.
790    fn calculate_out_columns(
791        &mut self,
792        _input_buffer: &RecordBatch,
793        window_agg_states: &[PartitionWindowAggStates],
794        partition_buffers: &mut PartitionBatches,
795        _window_expr: &[Arc<dyn WindowExpr>],
796    ) -> Result<Option<Vec<ArrayRef>>> {
797        let n_out = self.calculate_n_out_row(window_agg_states, partition_buffers);
798        if n_out == 0 {
799            Ok(None)
800        } else {
801            window_agg_states
802                .iter()
803                .map(|map| get_aggregate_result_out_column(map, n_out).map(Some))
804                .collect()
805        }
806    }
807
808    fn evaluate_partition_batches(
809        &mut self,
810        record_batch: &RecordBatch,
811        _window_expr: &[Arc<dyn WindowExpr>],
812    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
813        let num_rows = record_batch.num_rows();
814        // Calculate result of partition by column expressions
815        let partition_columns = self
816            .partition_by_sort_keys
817            .iter()
818            .map(|elem| elem.evaluate_to_sort_column(record_batch))
819            .collect::<Result<Vec<_>>>()?;
820        // Reorder `partition_columns` such that its ordering matches input ordering.
821        let partition_columns_ordered =
822            get_at_indices(&partition_columns, &self.ordered_partition_by_indices)?;
823        let partition_points =
824            evaluate_partition_ranges(num_rows, &partition_columns_ordered)?;
825        let partition_bys = partition_columns
826            .into_iter()
827            .map(|arr| arr.values)
828            .collect::<Vec<ArrayRef>>();
829
830        partition_points
831            .iter()
832            .map(|range| {
833                let row = get_row_at_idx(&partition_bys, range.start)?;
834                let len = range.end - range.start;
835                let slice = record_batch.slice(range.start, len);
836                Ok((row, slice))
837            })
838            .collect::<Result<Vec<_>>>()
839    }
840
841    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
842        // In Sorted case. We can mark all partitions besides last partition as ended.
843        // We are sure that those partitions will never receive any values.
844        // (Otherwise ordering invariant is violated.)
845        let n_partitions = partition_buffers.len();
846        for (idx, (_, partition_batch_state)) in partition_buffers.iter_mut().enumerate()
847        {
848            partition_batch_state.is_end |= idx < n_partitions - 1;
849        }
850    }
851
852    fn input_schema(&self) -> &SchemaRef {
853        &self.input_schema
854    }
855}
856
857impl SortedSearch {
858    /// Calculates how many rows we can output.
859    fn calculate_n_out_row(
860        &mut self,
861        window_agg_states: &[PartitionWindowAggStates],
862        partition_buffers: &mut PartitionBatches,
863    ) -> usize {
864        // Different window aggregators may produce results at different rates.
865        // We produce the overall batch result only as fast as the slowest one.
866        let mut counts = vec![];
867        let out_col_counts = window_agg_states.iter().map(|window_agg_state| {
868            // Store how many elements are generated for the current
869            // window expression:
870            let mut cur_window_expr_out_result_len = 0;
871            // We iterate over `window_agg_state`, which is an IndexMap.
872            // Iterations follow the insertion order, hence we preserve
873            // sorting when partition columns are sorted.
874            let mut per_partition_out_results = HashMap::new();
875            for (row, WindowState { state, .. }) in window_agg_state.iter() {
876                cur_window_expr_out_result_len += state.out_col.len();
877                let count = per_partition_out_results.entry(row).or_insert(0);
878                if *count < state.out_col.len() {
879                    *count = state.out_col.len();
880                }
881                // If we do not generate all results for the current
882                // partition, we do not generate results for next
883                // partition --  otherwise we will lose input ordering.
884                if state.n_row_result_missing > 0 {
885                    break;
886                }
887            }
888            counts.push(per_partition_out_results);
889            cur_window_expr_out_result_len
890        });
891        argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
892            let mut slowest_partition = counts.swap_remove(min_idx);
893            for (partition_key, partition_batch) in partition_buffers.iter_mut() {
894                if let Some(count) = slowest_partition.remove(partition_key) {
895                    partition_batch.n_out_row = count;
896                }
897            }
898            minima
899        })
900    }
901}
902
903/// Calculates partition by expression results for each window expression
904/// on `record_batch`.
905fn evaluate_partition_by_column_values(
906    record_batch: &RecordBatch,
907    window_expr: &[Arc<dyn WindowExpr>],
908) -> Result<Vec<ArrayRef>> {
909    window_expr[0]
910        .partition_by()
911        .iter()
912        .map(|item| match item.evaluate(record_batch)? {
913            ColumnarValue::Array(array) => Ok(array),
914            ColumnarValue::Scalar(scalar) => {
915                scalar.to_array_of_size(record_batch.num_rows())
916            }
917        })
918        .collect()
919}
920
921/// Stream for the bounded window aggregation plan.
922pub struct BoundedWindowAggStream {
923    schema: SchemaRef,
924    input: SendableRecordBatchStream,
925    /// The record batch executor receives as input (i.e. the columns needed
926    /// while calculating aggregation results).
927    input_buffer: RecordBatch,
928    /// We separate `input_buffer` based on partitions (as
929    /// determined by PARTITION BY columns) and store them per partition
930    /// in `partition_batches`. We use this variable when calculating results
931    /// for each window expression. This enables us to use the same batch for
932    /// different window expressions without copying.
933    // Note that we could keep record batches for each window expression in
934    // `PartitionWindowAggStates`. However, this would use more memory (as
935    // many times as the number of window expressions).
936    partition_buffers: PartitionBatches,
937    /// An executor can run multiple window expressions if the PARTITION BY
938    /// and ORDER BY sections are same. We keep state of the each window
939    /// expression inside `window_agg_states`.
940    window_agg_states: Vec<PartitionWindowAggStates>,
941    finished: bool,
942    window_expr: Vec<Arc<dyn WindowExpr>>,
943    baseline_metrics: BaselineMetrics,
944    /// Search mode for partition columns. This determines the algorithm with
945    /// which we group each partition.
946    search_mode: Box<dyn PartitionSearcher>,
947}
948
949impl BoundedWindowAggStream {
950    /// Prunes sections of the state that are no longer needed when calculating
951    /// results (as determined by window frame boundaries and number of results generated).
952    // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to
953    // calculate window expression result (outside the window frame boundary) we retract first `n` elements
954    // from `self.partition_batches` in corresponding partition.
955    // For instance, if `n_out` number of rows are calculated, we can remove
956    // first `n_out` rows from `self.input_buffer`.
957    fn prune_state(&mut self, n_out: usize) -> Result<()> {
958        // Prune `self.window_agg_states`:
959        self.prune_out_columns();
960        // Prune `self.partition_batches`:
961        self.prune_partition_batches();
962        // Prune `self.input_buffer`:
963        self.prune_input_batch(n_out)?;
964        // Prune internal state of search algorithm.
965        self.search_mode.prune(n_out);
966        Ok(())
967    }
968}
969
970impl Stream for BoundedWindowAggStream {
971    type Item = Result<RecordBatch>;
972
973    fn poll_next(
974        mut self: Pin<&mut Self>,
975        cx: &mut Context<'_>,
976    ) -> Poll<Option<Self::Item>> {
977        let poll = self.poll_next_inner(cx);
978        self.baseline_metrics.record_poll(poll)
979    }
980}
981
982impl BoundedWindowAggStream {
983    /// Create a new BoundedWindowAggStream
984    fn new(
985        schema: SchemaRef,
986        window_expr: Vec<Arc<dyn WindowExpr>>,
987        input: SendableRecordBatchStream,
988        baseline_metrics: BaselineMetrics,
989        search_mode: Box<dyn PartitionSearcher>,
990    ) -> Result<Self> {
991        let state = window_expr.iter().map(|_| IndexMap::new()).collect();
992        let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
993        Ok(Self {
994            schema,
995            input,
996            input_buffer: empty_batch,
997            partition_buffers: IndexMap::new(),
998            window_agg_states: state,
999            finished: false,
1000            window_expr,
1001            baseline_metrics,
1002            search_mode,
1003        })
1004    }
1005
1006    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
1007        // calculate window cols
1008        for (cur_window_expr, state) in
1009            self.window_expr.iter().zip(&mut self.window_agg_states)
1010        {
1011            cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
1012        }
1013
1014        let schema = Arc::clone(&self.schema);
1015        let window_expr_out = self.search_mode.calculate_out_columns(
1016            &self.input_buffer,
1017            &self.window_agg_states,
1018            &mut self.partition_buffers,
1019            &self.window_expr,
1020        )?;
1021        if let Some(window_expr_out) = window_expr_out {
1022            let n_out = window_expr_out[0].len();
1023            // right append new columns to corresponding section in the original input buffer.
1024            let columns_to_show = self
1025                .input_buffer
1026                .columns()
1027                .iter()
1028                .map(|elem| elem.slice(0, n_out))
1029                .chain(window_expr_out)
1030                .collect::<Vec<_>>();
1031            let n_generated = columns_to_show[0].len();
1032            self.prune_state(n_generated)?;
1033            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
1034        } else {
1035            Ok(None)
1036        }
1037    }
1038
1039    #[inline]
1040    fn poll_next_inner(
1041        &mut self,
1042        cx: &mut Context<'_>,
1043    ) -> Poll<Option<Result<RecordBatch>>> {
1044        if self.finished {
1045            return Poll::Ready(None);
1046        }
1047
1048        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1049        match ready!(self.input.poll_next_unpin(cx)) {
1050            Some(Ok(batch)) => {
1051                // Start the timer for compute time within this operator. It will be
1052                // stopped when dropped.
1053                let _timer = elapsed_compute.timer();
1054
1055                self.search_mode.update_partition_batch(
1056                    &mut self.input_buffer,
1057                    batch,
1058                    &self.window_expr,
1059                    &mut self.partition_buffers,
1060                )?;
1061                if let Some(batch) = self.compute_aggregates()? {
1062                    return Poll::Ready(Some(Ok(batch)));
1063                }
1064                self.poll_next_inner(cx)
1065            }
1066            Some(Err(e)) => Poll::Ready(Some(Err(e))),
1067            None => {
1068                let _timer = elapsed_compute.timer();
1069
1070                self.finished = true;
1071                for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
1072                    partition_batch_state.is_end = true;
1073                }
1074                if let Some(batch) = self.compute_aggregates()? {
1075                    return Poll::Ready(Some(Ok(batch)));
1076                }
1077                Poll::Ready(None)
1078            }
1079        }
1080    }
1081
1082    /// Prunes the sections of the record batch (for each partition)
1083    /// that we no longer need to calculate the window function result.
1084    fn prune_partition_batches(&mut self) {
1085        // Remove partitions which we know already ended (is_end flag is true).
1086        // Since the retain method preserves insertion order, we still have
1087        // ordering in between partitions after removal.
1088        self.partition_buffers
1089            .retain(|_, partition_batch_state| !partition_batch_state.is_end);
1090
1091        // The data in `self.partition_batches` is used by all window expressions.
1092        // Therefore, when removing from `self.partition_batches`, we need to remove
1093        // from the earliest range boundary among all window expressions. Variable
1094        // `n_prune_each_partition` fill the earliest range boundary information for
1095        // each partition. This way, we can delete the no-longer-needed sections from
1096        // `self.partition_batches`.
1097        // For instance, if window frame one uses [10, 20] and window frame two uses
1098        // [5, 15]; we only prune the first 5 elements from the corresponding record
1099        // batch in `self.partition_batches`.
1100
1101        // Calculate how many elements to prune for each partition batch
1102        let mut n_prune_each_partition = HashMap::new();
1103        for window_agg_state in self.window_agg_states.iter_mut() {
1104            window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
1105            for (partition_row, WindowState { state: value, .. }) in window_agg_state {
1106                let n_prune =
1107                    min(value.window_frame_range.start, value.last_calculated_index);
1108                if let Some(current) = n_prune_each_partition.get_mut(partition_row) {
1109                    if n_prune < *current {
1110                        *current = n_prune;
1111                    }
1112                } else {
1113                    n_prune_each_partition.insert(partition_row.clone(), n_prune);
1114                }
1115            }
1116        }
1117
1118        // Retract no longer needed parts during window calculations from partition batch:
1119        for (partition_row, n_prune) in n_prune_each_partition.iter() {
1120            let pb_state = &mut self.partition_buffers[partition_row];
1121
1122            let batch = &pb_state.record_batch;
1123            pb_state.record_batch = batch.slice(*n_prune, batch.num_rows() - n_prune);
1124            pb_state.n_out_row = 0;
1125
1126            // Update state indices since we have pruned some rows from the beginning:
1127            for window_agg_state in self.window_agg_states.iter_mut() {
1128                window_agg_state[partition_row].state.prune_state(*n_prune);
1129            }
1130        }
1131    }
1132
1133    /// Prunes the section of the input batch whose aggregate results
1134    /// are calculated and emitted.
1135    fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
1136        // Prune first n_out rows from the input_buffer
1137        let n_to_keep = self.input_buffer.num_rows() - n_out;
1138        let batch_to_keep = self
1139            .input_buffer
1140            .columns()
1141            .iter()
1142            .map(|elem| elem.slice(n_out, n_to_keep))
1143            .collect::<Vec<_>>();
1144        self.input_buffer = RecordBatch::try_new_with_options(
1145            self.input_buffer.schema(),
1146            batch_to_keep,
1147            &RecordBatchOptions::new().with_row_count(Some(n_to_keep)),
1148        )?;
1149        Ok(())
1150    }
1151
1152    /// Prunes emitted parts from WindowAggState `out_col` field.
1153    fn prune_out_columns(&mut self) {
1154        // We store generated columns for each window expression in the `out_col`
1155        // field of `WindowAggState`. Given how many rows are emitted, we remove
1156        // these sections from state.
1157        for partition_window_agg_states in self.window_agg_states.iter_mut() {
1158            // Remove `n_out` entries from the `out_col` field of `WindowAggState`.
1159            // `n_out` is stored in `self.partition_buffers` for each partition.
1160            // If `is_end` is set, directly remove them; this shrinks the hash map.
1161            partition_window_agg_states
1162                .retain(|_, partition_batch_state| !partition_batch_state.state.is_end);
1163            for (
1164                partition_key,
1165                WindowState {
1166                    state: WindowAggState { out_col, .. },
1167                    ..
1168                },
1169            ) in partition_window_agg_states
1170            {
1171                let partition_batch = &mut self.partition_buffers[partition_key];
1172                let n_to_del = partition_batch.n_out_row;
1173                let n_to_keep = out_col.len() - n_to_del;
1174                *out_col = out_col.slice(n_to_del, n_to_keep);
1175            }
1176        }
1177    }
1178}
1179
1180impl RecordBatchStream for BoundedWindowAggStream {
1181    /// Get the schema
1182    fn schema(&self) -> SchemaRef {
1183        Arc::clone(&self.schema)
1184    }
1185}
1186
1187// Gets the index of minimum entry, returns None if empty.
1188fn argmin<T: PartialOrd>(data: impl Iterator<Item = T>) -> Option<(usize, T)> {
1189    data.enumerate()
1190        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1191}
1192
1193/// Calculates the section we can show results for expression
1194fn get_aggregate_result_out_column(
1195    partition_window_agg_states: &PartitionWindowAggStates,
1196    len_to_show: usize,
1197) -> Result<ArrayRef> {
1198    let mut result = None;
1199    let mut running_length = 0;
1200    let mut batches_to_concat = vec![];
1201    // We assume that iteration order is according to insertion order
1202    for (
1203        _,
1204        WindowState {
1205            state: WindowAggState { out_col, .. },
1206            ..
1207        },
1208    ) in partition_window_agg_states
1209    {
1210        if running_length < len_to_show {
1211            let n_to_use = min(len_to_show - running_length, out_col.len());
1212            let slice_to_use = if n_to_use == out_col.len() {
1213                // avoid slice when the entire column is used
1214                Arc::clone(out_col)
1215            } else {
1216                out_col.slice(0, n_to_use)
1217            };
1218            batches_to_concat.push(slice_to_use);
1219            running_length += n_to_use;
1220        } else {
1221            break;
1222        }
1223    }
1224
1225    if !batches_to_concat.is_empty() {
1226        let array_refs: Vec<&dyn Array> =
1227            batches_to_concat.iter().map(|a| a.as_ref()).collect();
1228        result = Some(concat(&array_refs)?);
1229    }
1230
1231    if running_length != len_to_show {
1232        return exec_err!(
1233            "Generated row number should be {len_to_show}, it is {running_length}"
1234        );
1235    }
1236    result.ok_or_else(|| exec_datafusion_err!("Should contain something"))
1237}
1238
1239/// Constructs a batch from the last row of batch in the argument.
1240pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result<RecordBatch> {
1241    if batch.num_rows() == 0 {
1242        return exec_err!("Latest batch should have at least 1 row");
1243    }
1244    Ok(batch.slice(batch.num_rows() - 1, 1))
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249    use std::pin::Pin;
1250    use std::sync::Arc;
1251    use std::task::{Context, Poll};
1252    use std::time::Duration;
1253
1254    use crate::common::collect;
1255    use crate::expressions::PhysicalSortExpr;
1256    use crate::projection::{ProjectionExec, ProjectionExpr};
1257    use crate::streaming::{PartitionStream, StreamingTableExec};
1258    use crate::test::TestMemoryExec;
1259    use crate::windows::{
1260        BoundedWindowAggExec, InputOrderMode, create_udwf_window_expr, create_window_expr,
1261    };
1262    use crate::{ExecutionPlan, displayable, execute_stream};
1263
1264    use arrow::array::{
1265        RecordBatch,
1266        builder::{Int64Builder, UInt64Builder},
1267    };
1268    use arrow::compute::SortOptions;
1269    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1270    use datafusion_common::test_util::batches_to_string;
1271    use datafusion_common::{Result, ScalarValue, exec_datafusion_err};
1272    use datafusion_execution::config::SessionConfig;
1273    use datafusion_execution::{
1274        RecordBatchStream, SendableRecordBatchStream, TaskContext,
1275    };
1276    use datafusion_expr::{
1277        WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1278    };
1279    use datafusion_functions_aggregate::count::count_udaf;
1280    use datafusion_functions_window::nth_value::last_value_udwf;
1281    use datafusion_functions_window::nth_value::nth_value_udwf;
1282    use datafusion_physical_expr::expressions::{Column, Literal, col};
1283    use datafusion_physical_expr::window::StandardWindowExpr;
1284    use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
1285
1286    use futures::future::Shared;
1287    use futures::{FutureExt, Stream, StreamExt, pin_mut, ready};
1288    use insta::assert_snapshot;
1289    use itertools::Itertools;
1290    use tokio::time::timeout;
1291
1292    #[derive(Debug, Clone)]
1293    struct TestStreamPartition {
1294        schema: SchemaRef,
1295        batches: Vec<RecordBatch>,
1296        idx: usize,
1297        state: PolingState,
1298        sleep_duration: Duration,
1299        send_exit: bool,
1300    }
1301
1302    impl PartitionStream for TestStreamPartition {
1303        fn schema(&self) -> &SchemaRef {
1304            &self.schema
1305        }
1306
1307        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
1308            // We create an iterator from the record batches and map them into Ok values,
1309            // converting the iterator into a futures::stream::Stream
1310            Box::pin(self.clone())
1311        }
1312    }
1313
1314    impl Stream for TestStreamPartition {
1315        type Item = Result<RecordBatch>;
1316
1317        fn poll_next(
1318            mut self: Pin<&mut Self>,
1319            cx: &mut Context<'_>,
1320        ) -> Poll<Option<Self::Item>> {
1321            self.poll_next_inner(cx)
1322        }
1323    }
1324
1325    #[derive(Debug, Clone)]
1326    enum PolingState {
1327        Sleep(Shared<futures::future::BoxFuture<'static, ()>>),
1328        BatchReturn,
1329    }
1330
1331    impl TestStreamPartition {
1332        fn poll_next_inner(
1333            self: &mut Pin<&mut Self>,
1334            cx: &mut Context<'_>,
1335        ) -> Poll<Option<Result<RecordBatch>>> {
1336            loop {
1337                match &mut self.state {
1338                    PolingState::BatchReturn => {
1339                        // Wait for self.sleep_duration before sending any new data
1340                        let f = tokio::time::sleep(self.sleep_duration).boxed().shared();
1341                        self.state = PolingState::Sleep(f);
1342                        let input_batch = if let Some(batch) =
1343                            self.batches.clone().get(self.idx)
1344                        {
1345                            batch.clone()
1346                        } else if self.send_exit {
1347                            // Send None to signal end of data
1348                            return Poll::Ready(None);
1349                        } else {
1350                            // Go to sleep mode
1351                            let f =
1352                                tokio::time::sleep(self.sleep_duration).boxed().shared();
1353                            self.state = PolingState::Sleep(f);
1354                            continue;
1355                        };
1356                        self.idx += 1;
1357                        return Poll::Ready(Some(Ok(input_batch)));
1358                    }
1359                    PolingState::Sleep(future) => {
1360                        pin_mut!(future);
1361                        ready!(future.poll_unpin(cx));
1362                        self.state = PolingState::BatchReturn;
1363                    }
1364                }
1365            }
1366        }
1367    }
1368
1369    impl RecordBatchStream for TestStreamPartition {
1370        fn schema(&self) -> SchemaRef {
1371            Arc::clone(&self.schema)
1372        }
1373    }
1374
1375    fn bounded_window_exec_pb_latent_range(
1376        input: Arc<dyn ExecutionPlan>,
1377        n_future_range: usize,
1378        hash: &str,
1379        order_by: &str,
1380    ) -> Result<Arc<dyn ExecutionPlan>> {
1381        let schema = input.schema();
1382        let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
1383        let col_expr =
1384            Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1385        let args = vec![col_expr];
1386        let partitionby_exprs = vec![col(hash, &schema)?];
1387        let orderby_exprs = vec![PhysicalSortExpr {
1388            expr: col(order_by, &schema)?,
1389            options: SortOptions::default(),
1390        }];
1391        let window_frame = WindowFrame::new_bounds(
1392            WindowFrameUnits::Range,
1393            WindowFrameBound::CurrentRow,
1394            WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))),
1395        );
1396        let fn_name = format!(
1397            "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]"
1398        );
1399        let input_order_mode = InputOrderMode::Linear;
1400        Ok(Arc::new(BoundedWindowAggExec::try_new(
1401            vec![create_window_expr(
1402                &window_fn,
1403                fn_name,
1404                &args,
1405                &partitionby_exprs,
1406                &orderby_exprs,
1407                Arc::new(window_frame),
1408                input.schema(),
1409                false,
1410                false,
1411                None,
1412            )?],
1413            input,
1414            input_order_mode,
1415            true,
1416        )?))
1417    }
1418
1419    fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
1420        let schema = input.schema();
1421        let exprs = input
1422            .schema()
1423            .fields
1424            .iter()
1425            .enumerate()
1426            .map(|(idx, field)| {
1427                let name = if field.name().len() > 20 {
1428                    format!("col_{idx}")
1429                } else {
1430                    field.name().clone()
1431                };
1432                let expr = col(field.name(), &schema).unwrap();
1433                (expr, name)
1434            })
1435            .collect::<Vec<_>>();
1436        let proj_exprs: Vec<ProjectionExpr> = exprs
1437            .into_iter()
1438            .map(|(expr, alias)| ProjectionExpr { expr, alias })
1439            .collect();
1440        Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?))
1441    }
1442
1443    fn task_context_helper() -> TaskContext {
1444        let task_ctx = TaskContext::default();
1445        // Create session context with config
1446        let session_config = SessionConfig::new()
1447            .with_batch_size(1)
1448            .with_target_partitions(2)
1449            .with_round_robin_repartition(false);
1450        task_ctx.with_session_config(session_config)
1451    }
1452
1453    fn task_context() -> Arc<TaskContext> {
1454        Arc::new(task_context_helper())
1455    }
1456
1457    pub async fn collect_stream(
1458        mut stream: SendableRecordBatchStream,
1459        results: &mut Vec<RecordBatch>,
1460    ) -> Result<()> {
1461        while let Some(item) = stream.next().await {
1462            results.push(item?);
1463        }
1464        Ok(())
1465    }
1466
1467    /// Execute the [ExecutionPlan] and collect the results in memory
1468    pub async fn collect_with_timeout(
1469        plan: Arc<dyn ExecutionPlan>,
1470        context: Arc<TaskContext>,
1471        timeout_duration: Duration,
1472    ) -> Result<Vec<RecordBatch>> {
1473        let stream = execute_stream(plan, context)?;
1474        let mut results = vec![];
1475
1476        // Execute the asynchronous operation with a timeout
1477        if timeout(timeout_duration, collect_stream(stream, &mut results))
1478            .await
1479            .is_ok()
1480        {
1481            return Err(exec_datafusion_err!("shouldn't have completed"));
1482        };
1483
1484        Ok(results)
1485    }
1486
1487    fn test_schema() -> SchemaRef {
1488        Arc::new(Schema::new(vec![
1489            Field::new("sn", DataType::UInt64, true),
1490            Field::new("hash", DataType::Int64, true),
1491        ]))
1492    }
1493
1494    fn schema_orders(schema: &SchemaRef) -> Result<Vec<LexOrdering>> {
1495        let orderings = vec![
1496            [PhysicalSortExpr {
1497                expr: col("sn", schema)?,
1498                options: SortOptions {
1499                    descending: false,
1500                    nulls_first: false,
1501                },
1502            }]
1503            .into(),
1504        ];
1505        Ok(orderings)
1506    }
1507
1508    fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool {
1509        let res = lhs / rhs;
1510        res * rhs == lhs
1511    }
1512    fn generate_batches(
1513        schema: &SchemaRef,
1514        n_row: usize,
1515        n_chunk: usize,
1516    ) -> Result<Vec<RecordBatch>> {
1517        let mut batches = vec![];
1518        assert!(n_row > 0);
1519        assert!(n_chunk > 0);
1520        assert!(is_integer_division_safe(n_row, n_chunk));
1521        let hash_replicate = 4;
1522
1523        let chunks = (0..n_row)
1524            .chunks(n_chunk)
1525            .into_iter()
1526            .map(|elem| elem.into_iter().collect::<Vec<_>>())
1527            .collect::<Vec<_>>();
1528
1529        // Send 2 RecordBatches at the source
1530        for sn_values in chunks {
1531            let mut sn1_array = UInt64Builder::with_capacity(sn_values.len());
1532            let mut hash_array = Int64Builder::with_capacity(sn_values.len());
1533
1534            for sn in sn_values {
1535                sn1_array.append_value(sn as u64);
1536                let hash_value = (2 - (sn / hash_replicate)) as i64;
1537                hash_array.append_value(hash_value);
1538            }
1539
1540            let batch = RecordBatch::try_new(
1541                Arc::clone(schema),
1542                vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())],
1543            )?;
1544            batches.push(batch);
1545        }
1546        Ok(batches)
1547    }
1548
1549    fn generate_never_ending_source(
1550        n_rows: usize,
1551        chunk_length: usize,
1552        n_partition: usize,
1553        is_infinite: bool,
1554        send_exit: bool,
1555        per_batch_wait_duration_in_millis: u64,
1556    ) -> Result<Arc<dyn ExecutionPlan>> {
1557        assert!(n_partition > 0);
1558
1559        // We use same hash value in the table. This makes sure that
1560        // After hashing computation will continue in only in one of the output partitions
1561        // In this case, data flow should still continue
1562        let schema = test_schema();
1563        let orderings = schema_orders(&schema)?;
1564
1565        // Source waits per_batch_wait_duration_in_millis ms before sending other batch
1566        let per_batch_wait_duration =
1567            Duration::from_millis(per_batch_wait_duration_in_millis);
1568
1569        let batches = generate_batches(&schema, n_rows, chunk_length)?;
1570
1571        // Source has 2 partitions
1572        let partitions = vec![
1573            Arc::new(TestStreamPartition {
1574                schema: Arc::clone(&schema),
1575                batches,
1576                idx: 0,
1577                state: PolingState::BatchReturn,
1578                sleep_duration: per_batch_wait_duration,
1579                send_exit,
1580            }) as _;
1581            n_partition
1582        ];
1583        let source = Arc::new(StreamingTableExec::try_new(
1584            Arc::clone(&schema),
1585            partitions,
1586            None,
1587            orderings,
1588            is_infinite,
1589            None,
1590        )?) as _;
1591        Ok(source)
1592    }
1593
1594    // Tests NTH_VALUE(negative index) with memoize feature
1595    // To be able to trigger memoize feature for NTH_VALUE we need to
1596    // - feed BoundedWindowAggExec with batch stream data.
1597    // - Window frame should contain UNBOUNDED PRECEDING.
1598    // It hard to ensure these conditions are met, from the sql query.
1599    #[tokio::test]
1600    async fn test_window_nth_value_bounded_memoize() -> Result<()> {
1601        let config = SessionConfig::new().with_target_partitions(1);
1602        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
1603
1604        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1605        // Create a new batch of data to insert into the table
1606        let batch = RecordBatch::try_new(
1607            Arc::clone(&schema),
1608            vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
1609        )?;
1610
1611        let memory_exec = TestMemoryExec::try_new_exec(
1612            &[vec![batch.clone(), batch.clone(), batch.clone()]],
1613            Arc::clone(&schema),
1614            None,
1615        )?;
1616        let col_a = col("a", &schema)?;
1617        let nth_value_func1 = create_udwf_window_expr(
1618            &nth_value_udwf(),
1619            &[
1620                Arc::clone(&col_a),
1621                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1622            ],
1623            &schema,
1624            "nth_value(-1)".to_string(),
1625            false,
1626        )?
1627        .reverse_expr()
1628        .unwrap();
1629        let nth_value_func2 = create_udwf_window_expr(
1630            &nth_value_udwf(),
1631            &[
1632                Arc::clone(&col_a),
1633                Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
1634            ],
1635            &schema,
1636            "nth_value(-2)".to_string(),
1637            false,
1638        )?
1639        .reverse_expr()
1640        .unwrap();
1641
1642        let last_value_func = create_udwf_window_expr(
1643            &last_value_udwf(),
1644            &[Arc::clone(&col_a)],
1645            &schema,
1646            "last".to_string(),
1647            false,
1648        )?;
1649
1650        let window_exprs = vec![
1651            // LAST_VALUE(a)
1652            Arc::new(StandardWindowExpr::new(
1653                last_value_func,
1654                &[],
1655                &[],
1656                Arc::new(WindowFrame::new_bounds(
1657                    WindowFrameUnits::Rows,
1658                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1659                    WindowFrameBound::CurrentRow,
1660                )),
1661            )) as _,
1662            // NTH_VALUE(a, -1)
1663            Arc::new(StandardWindowExpr::new(
1664                nth_value_func1,
1665                &[],
1666                &[],
1667                Arc::new(WindowFrame::new_bounds(
1668                    WindowFrameUnits::Rows,
1669                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1670                    WindowFrameBound::CurrentRow,
1671                )),
1672            )) as _,
1673            // NTH_VALUE(a, -2)
1674            Arc::new(StandardWindowExpr::new(
1675                nth_value_func2,
1676                &[],
1677                &[],
1678                Arc::new(WindowFrame::new_bounds(
1679                    WindowFrameUnits::Rows,
1680                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1681                    WindowFrameBound::CurrentRow,
1682                )),
1683            )) as _,
1684        ];
1685        let physical_plan = BoundedWindowAggExec::try_new(
1686            window_exprs,
1687            memory_exec,
1688            InputOrderMode::Sorted,
1689            true,
1690        )
1691        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
1692
1693        let batches = collect(physical_plan.execute(0, task_ctx)?).await?;
1694
1695        // Get string representation of the plan
1696        assert_snapshot!(displayable(physical_plan.as_ref()).indent(true), @r#"
1697        BoundedWindowAggExec: wdw=[last: Field { "last": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-1): Field { "nth_value(-1)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-2): Field { "nth_value(-2)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
1698          DataSourceExec: partitions=1, partition_sizes=[3]
1699        "#);
1700
1701        assert_snapshot!(batches_to_string(&batches), @r"
1702        +---+------+---------------+---------------+
1703        | a | last | nth_value(-1) | nth_value(-2) |
1704        +---+------+---------------+---------------+
1705        | 1 | 1    | 1             |               |
1706        | 2 | 2    | 2             | 1             |
1707        | 3 | 3    | 3             | 2             |
1708        | 1 | 1    | 1             | 3             |
1709        | 2 | 2    | 2             | 1             |
1710        | 3 | 3    | 3             | 2             |
1711        | 1 | 1    | 1             | 3             |
1712        | 2 | 2    | 2             | 1             |
1713        | 3 | 3    | 3             | 2             |
1714        +---+------+---------------+---------------+
1715        ");
1716        Ok(())
1717    }
1718
1719    // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec`
1720    // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode.
1721    // Input data generated at the source is
1722    //       "+----+------+",
1723    //       "| sn | hash |",
1724    //       "+----+------+",
1725    //       "| 0  | 2    |",
1726    //       "| 1  | 2    |",
1727    //       "| 2  | 2    |",
1728    //       "| 3  | 2    |",
1729    //       "| 4  | 1    |",
1730    //       "| 5  | 1    |",
1731    //       "| 6  | 1    |",
1732    //       "| 7  | 1    |",
1733    //       "| 8  | 0    |",
1734    //       "| 9  | 0    |",
1735    //       "+----+------+",
1736    //
1737    // Effectively following query is run on this data
1738    //
1739    //   SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)
1740    //   FROM test;
1741    //
1742    // partition `duplicated_hash=2` receives following data from the input
1743    //
1744    //       "+----+------+",
1745    //       "| sn | hash |",
1746    //       "+----+------+",
1747    //       "| 0  | 2    |",
1748    //       "| 1  | 2    |",
1749    //       "| 2  | 2    |",
1750    //       "| 3  | 2    |",
1751    //       "+----+------+",
1752    // normally `BoundedWindowExec` can only generate following result from the input above
1753    //
1754    //       "+----+------+---------+",
1755    //       "| sn | hash |  count  |",
1756    //       "+----+------+---------+",
1757    //       "| 0  | 2    |  2      |",
1758    //       "| 1  | 2    |  2      |",
1759    //       "| 2  | 2    |<not yet>|",
1760    //       "| 3  | 2    |<not yet>|",
1761    //       "+----+------+---------+",
1762    // where result of last 2 row is missing. Since window frame end is not may change with future data
1763    // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we
1764    // need to received sn=4 to make sure window frame end bound won't change with future data).
1765    //
1766    // With the ability of different partitions to use global ordering at the input (where most up-to date
1767    //   row is
1768    //      "| 9  | 0    |",
1769    //   )
1770    //
1771    // `BoundedWindowExec` should be able to generate following result in the test
1772    //
1773    //       "+----+------+-------+",
1774    //       "| sn | hash | col_2 |",
1775    //       "+----+------+-------+",
1776    //       "| 0  | 2    | 2     |",
1777    //       "| 1  | 2    | 2     |",
1778    //       "| 2  | 2    | 2     |",
1779    //       "| 3  | 2    | 1     |",
1780    //       "| 4  | 1    | 2     |",
1781    //       "| 5  | 1    | 2     |",
1782    //       "| 6  | 1    | 2     |",
1783    //       "| 7  | 1    | 1     |",
1784    //       "+----+------+-------+",
1785    //
1786    // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8
1787    //   we need to receive sn=10 value to calculate it result.).
1788    // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate
1789    // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal
1790    //    is sent to output from source). After, row:
1791    //
1792    //       "| 9  | 0    |",
1793    //
1794    // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the
1795    // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise).
1796    #[tokio::test]
1797    async fn bounded_window_exec_linear_mode_range_information() -> Result<()> {
1798        let n_rows = 10;
1799        let chunk_length = 2;
1800        let n_future_range = 1;
1801
1802        let timeout_duration = Duration::from_millis(2000);
1803
1804        let source =
1805            generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?;
1806
1807        let window =
1808            bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?;
1809
1810        let plan = projection_exec(window)?;
1811
1812        // Get string representation of the plan
1813        assert_snapshot!(displayable(plan.as_ref()).indent(true), @r#"
1814        ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]
1815          BoundedWindowAggExec: wdw=[count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Field { "count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]": Int64 }, frame: RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Linear]
1816            StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]
1817        "#);
1818
1819        let task_ctx = task_context();
1820        let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?;
1821
1822        assert_snapshot!(batches_to_string(&batches), @r"
1823        +----+------+-------+
1824        | sn | hash | col_2 |
1825        +----+------+-------+
1826        | 0  | 2    | 2     |
1827        | 1  | 2    | 2     |
1828        | 2  | 2    | 2     |
1829        | 3  | 2    | 1     |
1830        | 4  | 1    | 2     |
1831        | 5  | 1    | 2     |
1832        | 6  | 1    | 2     |
1833        | 7  | 1    | 1     |
1834        +----+------+-------+
1835        ");
1836
1837        Ok(())
1838    }
1839}