datafusion_physical_plan/windows/
bounded_window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19//! The executor given here uses bounded memory (does not maintain all
20//! the input data seen so far), which makes it appropriate when processing
21//! infinite inputs.
22
23use std::any::Any;
24use std::cmp::{min, Ordering};
25use std::collections::VecDeque;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use super::utils::create_schema;
31use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use crate::windows::{
33    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
34    window_equivalence_properties,
35};
36use crate::{
37    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
38    ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream,
39    SendableRecordBatchStream, Statistics, WindowExpr,
40};
41
42use arrow::compute::take_record_batch;
43use arrow::{
44    array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
45    compute::{concat, concat_batches, sort_to_indices, take_arrays},
46    datatypes::SchemaRef,
47    record_batch::RecordBatch,
48};
49use datafusion_common::hash_utils::create_hashes;
50use datafusion_common::stats::Precision;
51use datafusion_common::utils::{
52    evaluate_partition_ranges, get_at_indices, get_row_at_idx,
53};
54use datafusion_common::{
55    arrow_datafusion_err, exec_err, DataFusionError, HashMap, Result,
56};
57use datafusion_execution::TaskContext;
58use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
59use datafusion_expr::ColumnarValue;
60use datafusion_physical_expr::window::{
61    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
62};
63use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
64use datafusion_physical_expr_common::sort_expr::{
65    OrderingRequirements, PhysicalSortExpr,
66};
67
68use ahash::RandomState;
69use futures::stream::Stream;
70use futures::{ready, StreamExt};
71use hashbrown::hash_table::HashTable;
72use indexmap::IndexMap;
73use log::debug;
74
75/// Window execution plan
76#[derive(Debug, Clone)]
77pub struct BoundedWindowAggExec {
78    /// Input plan
79    input: Arc<dyn ExecutionPlan>,
80    /// Window function expression
81    window_expr: Vec<Arc<dyn WindowExpr>>,
82    /// Schema after the window is run
83    schema: SchemaRef,
84    /// Execution metrics
85    metrics: ExecutionPlanMetricsSet,
86    /// Describes how the input is ordered relative to the partition keys
87    pub input_order_mode: InputOrderMode,
88    /// Partition by indices that define ordering
89    // For example, if input ordering is ORDER BY a, b and window expression
90    // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0.
91    // Similarly, if window expression contains PARTITION BY a, b; then
92    // `ordered_partition_by_indices` would be 0, 1.
93    // See `get_ordered_partition_by_indices` for more details.
94    ordered_partition_by_indices: Vec<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97    /// If `can_rerepartition` is false, partition_keys is always empty.
98    can_repartition: bool,
99}
100
101impl BoundedWindowAggExec {
102    /// Create a new execution plan for window aggregates
103    pub fn try_new(
104        window_expr: Vec<Arc<dyn WindowExpr>>,
105        input: Arc<dyn ExecutionPlan>,
106        input_order_mode: InputOrderMode,
107        can_repartition: bool,
108    ) -> Result<Self> {
109        let schema = create_schema(&input.schema(), &window_expr)?;
110        let schema = Arc::new(schema);
111        let partition_by_exprs = window_expr[0].partition_by();
112        let ordered_partition_by_indices = match &input_order_mode {
113            InputOrderMode::Sorted => {
114                let indices = get_ordered_partition_by_indices(
115                    window_expr[0].partition_by(),
116                    &input,
117                )?;
118                if indices.len() == partition_by_exprs.len() {
119                    indices
120                } else {
121                    (0..partition_by_exprs.len()).collect::<Vec<_>>()
122                }
123            }
124            InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(),
125            InputOrderMode::Linear => {
126                vec![]
127            }
128        };
129        let cache = Self::compute_properties(&input, &schema, &window_expr)?;
130        Ok(Self {
131            input,
132            window_expr,
133            schema,
134            metrics: ExecutionPlanMetricsSet::new(),
135            input_order_mode,
136            ordered_partition_by_indices,
137            cache,
138            can_repartition,
139        })
140    }
141
142    /// Window expressions
143    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
144        &self.window_expr
145    }
146
147    /// Input plan
148    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
149        &self.input
150    }
151
152    /// Return the output sort order of partition keys: For example
153    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
154    // We are sure that partition by columns are always at the beginning of sort_keys
155    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
156    // to calculate partition separation points
157    pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
158        let partition_by = self.window_expr()[0].partition_by();
159        get_partition_by_sort_exprs(
160            &self.input,
161            partition_by,
162            &self.ordered_partition_by_indices,
163        )
164    }
165
166    /// Initializes the appropriate [`PartitionSearcher`] implementation from
167    /// the state.
168    fn get_search_algo(&self) -> Result<Box<dyn PartitionSearcher>> {
169        let partition_by_sort_keys = self.partition_by_sort_keys()?;
170        let ordered_partition_by_indices = self.ordered_partition_by_indices.clone();
171        let input_schema = self.input().schema();
172        Ok(match &self.input_order_mode {
173            InputOrderMode::Sorted => {
174                // In Sorted mode, all partition by columns should be ordered.
175                if self.window_expr()[0].partition_by().len()
176                    != ordered_partition_by_indices.len()
177                {
178                    return exec_err!("All partition by columns should have an ordering in Sorted mode.");
179                }
180                Box::new(SortedSearch {
181                    partition_by_sort_keys,
182                    ordered_partition_by_indices,
183                    input_schema,
184                })
185            }
186            InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => Box::new(
187                LinearSearch::new(ordered_partition_by_indices, input_schema),
188            ),
189        })
190    }
191
192    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
193    fn compute_properties(
194        input: &Arc<dyn ExecutionPlan>,
195        schema: &SchemaRef,
196        window_exprs: &[Arc<dyn WindowExpr>],
197    ) -> Result<PlanProperties> {
198        // Calculate equivalence properties:
199        let eq_properties = window_equivalence_properties(schema, input, window_exprs)?;
200
201        // As we can have repartitioning using the partition keys, this can
202        // be either one or more than one, depending on the presence of
203        // repartitioning.
204        let output_partitioning = input.output_partitioning().clone();
205
206        // Construct properties cache
207        Ok(PlanProperties::new(
208            eq_properties,
209            output_partitioning,
210            // TODO: Emission type and boundedness information can be enhanced here
211            input.pipeline_behavior(),
212            input.boundedness(),
213        ))
214    }
215
216    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
217        if !self.can_repartition {
218            vec![]
219        } else {
220            let all_partition_keys = self
221                .window_expr()
222                .iter()
223                .map(|expr| expr.partition_by().to_vec())
224                .collect::<Vec<_>>();
225
226            all_partition_keys
227                .into_iter()
228                .min_by_key(|s| s.len())
229                .unwrap_or_else(Vec::new)
230        }
231    }
232
233    fn statistics_helper(&self, statistics: Statistics) -> Result<Statistics> {
234        let win_cols = self.window_expr.len();
235        let input_cols = self.input.schema().fields().len();
236        // TODO stats: some windowing function will maintain invariants such as min, max...
237        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
238        // copy stats of the input to the beginning of the schema.
239        column_statistics.extend(statistics.column_statistics);
240        for _ in 0..win_cols {
241            column_statistics.push(ColumnStatistics::new_unknown())
242        }
243        Ok(Statistics {
244            num_rows: statistics.num_rows,
245            column_statistics,
246            total_byte_size: Precision::Absent,
247        })
248    }
249}
250
251impl DisplayAs for BoundedWindowAggExec {
252    fn fmt_as(
253        &self,
254        t: DisplayFormatType,
255        f: &mut std::fmt::Formatter,
256    ) -> std::fmt::Result {
257        match t {
258            DisplayFormatType::Default | DisplayFormatType::Verbose => {
259                write!(f, "BoundedWindowAggExec: ")?;
260                let g: Vec<String> = self
261                    .window_expr
262                    .iter()
263                    .map(|e| {
264                        let field = match e.field() {
265                            Ok(f) => f.to_string(),
266                            Err(e) => format!("{e:?}"),
267                        };
268                        format!(
269                            "{}: {}, frame: {}",
270                            e.name().to_owned(),
271                            field,
272                            e.get_window_frame()
273                        )
274                    })
275                    .collect();
276                let mode = &self.input_order_mode;
277                write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?;
278            }
279            DisplayFormatType::TreeRender => {
280                let g: Vec<String> = self
281                    .window_expr
282                    .iter()
283                    .map(|e| e.name().to_owned().to_string())
284                    .collect();
285                writeln!(f, "select_list={}", g.join(", "))?;
286
287                let mode = &self.input_order_mode;
288                writeln!(f, "mode={mode:?}")?;
289            }
290        }
291        Ok(())
292    }
293}
294
295impl ExecutionPlan for BoundedWindowAggExec {
296    fn name(&self) -> &'static str {
297        "BoundedWindowAggExec"
298    }
299
300    /// Return a reference to Any that can be used for downcasting
301    fn as_any(&self) -> &dyn Any {
302        self
303    }
304
305    fn properties(&self) -> &PlanProperties {
306        &self.cache
307    }
308
309    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
310        vec![&self.input]
311    }
312
313    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
314        let partition_bys = self.window_expr()[0].partition_by();
315        let order_keys = self.window_expr()[0].order_by();
316        let partition_bys = self
317            .ordered_partition_by_indices
318            .iter()
319            .map(|idx| &partition_bys[*idx]);
320        vec![calc_requirements(partition_bys, order_keys)]
321    }
322
323    fn required_input_distribution(&self) -> Vec<Distribution> {
324        if self.partition_keys().is_empty() {
325            debug!("No partition defined for BoundedWindowAggExec!!!");
326            vec![Distribution::SinglePartition]
327        } else {
328            vec![Distribution::HashPartitioned(self.partition_keys().clone())]
329        }
330    }
331
332    fn maintains_input_order(&self) -> Vec<bool> {
333        vec![true]
334    }
335
336    fn with_new_children(
337        self: Arc<Self>,
338        children: Vec<Arc<dyn ExecutionPlan>>,
339    ) -> Result<Arc<dyn ExecutionPlan>> {
340        Ok(Arc::new(BoundedWindowAggExec::try_new(
341            self.window_expr.clone(),
342            Arc::clone(&children[0]),
343            self.input_order_mode.clone(),
344            self.can_repartition,
345        )?))
346    }
347
348    fn execute(
349        &self,
350        partition: usize,
351        context: Arc<TaskContext>,
352    ) -> Result<SendableRecordBatchStream> {
353        let input = self.input.execute(partition, context)?;
354        let search_mode = self.get_search_algo()?;
355        let stream = Box::pin(BoundedWindowAggStream::new(
356            Arc::clone(&self.schema),
357            self.window_expr.clone(),
358            input,
359            BaselineMetrics::new(&self.metrics, partition),
360            search_mode,
361        )?);
362        Ok(stream)
363    }
364
365    fn metrics(&self) -> Option<MetricsSet> {
366        Some(self.metrics.clone_inner())
367    }
368
369    fn statistics(&self) -> Result<Statistics> {
370        self.partition_statistics(None)
371    }
372
373    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
374        let input_stat = self.input.partition_statistics(partition)?;
375        self.statistics_helper(input_stat)
376    }
377}
378
379/// Trait that specifies how we search for (or calculate) partitions. It has two
380/// implementations: [`SortedSearch`] and [`LinearSearch`].
381trait PartitionSearcher: Send {
382    /// This method constructs output columns using the result of each window expression
383    /// (each entry in the output vector comes from a window expression).
384    /// Executor when producing output concatenates `input_buffer` (corresponding section), and
385    /// result of this function to generate output `RecordBatch`. `input_buffer` is used to determine
386    /// which sections of the window expression results should be used to generate output.
387    /// `partition_buffers` contains corresponding section of the `RecordBatch` for each partition.
388    /// `window_agg_states` stores per partition state for each window expression.
389    /// None case means that no result is generated
390    /// `Some(Vec<ArrayRef>)` is the result of each window expression.
391    fn calculate_out_columns(
392        &mut self,
393        input_buffer: &RecordBatch,
394        window_agg_states: &[PartitionWindowAggStates],
395        partition_buffers: &mut PartitionBatches,
396        window_expr: &[Arc<dyn WindowExpr>],
397    ) -> Result<Option<Vec<ArrayRef>>>;
398
399    /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not.
400    fn is_mode_linear(&self) -> bool {
401        false
402    }
403
404    // Constructs corresponding batches for each partition for the record_batch.
405    fn evaluate_partition_batches(
406        &mut self,
407        record_batch: &RecordBatch,
408        window_expr: &[Arc<dyn WindowExpr>],
409    ) -> Result<Vec<(PartitionKey, RecordBatch)>>;
410
411    /// Prunes the state.
412    fn prune(&mut self, _n_out: usize) {}
413
414    /// Marks the partition as done if we are sure that corresponding partition
415    /// cannot receive any more values.
416    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches);
417
418    /// Updates `input_buffer` and `partition_buffers` with the new `record_batch`.
419    fn update_partition_batch(
420        &mut self,
421        input_buffer: &mut RecordBatch,
422        record_batch: RecordBatch,
423        window_expr: &[Arc<dyn WindowExpr>],
424        partition_buffers: &mut PartitionBatches,
425    ) -> Result<()> {
426        if record_batch.num_rows() == 0 {
427            return Ok(());
428        }
429        let partition_batches =
430            self.evaluate_partition_batches(&record_batch, window_expr)?;
431        for (partition_row, partition_batch) in partition_batches {
432            let partition_batch_state = partition_buffers
433                .entry(partition_row)
434                // Use input_schema for the buffer schema, not `record_batch.schema()`
435                // as it may not have the "correct" schema in terms of output
436                // nullability constraints. For details, see the following issue:
437                // https://github.com/apache/datafusion/issues/9320
438                .or_insert_with(|| {
439                    PartitionBatchState::new(Arc::clone(self.input_schema()))
440                });
441            partition_batch_state.extend(&partition_batch)?;
442        }
443
444        if self.is_mode_linear() {
445            // In `Linear` mode, it is guaranteed that the first ORDER BY column
446            // is sorted across partitions. Note that only the first ORDER BY
447            // column is guaranteed to be ordered. As a counter example, consider
448            // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted
449            // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be
450            // `Linear`. However, we cannot guarantee that the last row of the
451            // input data will be the "last" data in terms of the ordering requirement
452            // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`.
453            // Hence, only column `a` should be used as a guarantee of the "last"
454            // data across partitions. For other modes (`Sorted`, `PartiallySorted`),
455            // we do not need to keep track of the most recent row guarantee across
456            // partitions. Since leading ordering separates partitions, guaranteed
457            // by the most recent row, already prune the previous partitions completely.
458            let last_row = get_last_row_batch(&record_batch)?;
459            for (_, partition_batch) in partition_buffers.iter_mut() {
460                partition_batch.set_most_recent_row(last_row.clone());
461            }
462        }
463        self.mark_partition_end(partition_buffers);
464
465        *input_buffer = if input_buffer.num_rows() == 0 {
466            record_batch
467        } else {
468            concat_batches(self.input_schema(), [input_buffer, &record_batch])?
469        };
470
471        Ok(())
472    }
473
474    fn input_schema(&self) -> &SchemaRef;
475}
476
477/// This object encapsulates the algorithm state for a simple linear scan
478/// algorithm for computing partitions.
479pub struct LinearSearch {
480    /// Keeps the hash of input buffer calculated from PARTITION BY columns.
481    /// Its length is equal to the `input_buffer` length.
482    input_buffer_hashes: VecDeque<u64>,
483    /// Used during hash value calculation.
484    random_state: RandomState,
485    /// Input ordering and partition by key ordering need not be the same, so
486    /// this vector stores the mapping between them. For instance, if the input
487    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
488    /// clause, this attribute stores [1, 0].
489    ordered_partition_by_indices: Vec<usize>,
490    /// We use this [`HashTable`] to calculate unique partitions for each new
491    /// RecordBatch. First entry in the tuple is the hash value, the second
492    /// entry is the unique ID for each partition (increments from 0 to n).
493    row_map_batch: HashTable<(u64, usize)>,
494    /// We use this [`HashTable`] to calculate the output columns that we can
495    /// produce at each cycle. First entry in the tuple is the hash value, the
496    /// second entry is the unique ID for each partition (increments from 0 to n).
497    /// The third entry stores how many new outputs are calculated for the
498    /// corresponding partition.
499    row_map_out: HashTable<(u64, usize, usize)>,
500    input_schema: SchemaRef,
501}
502
503impl PartitionSearcher for LinearSearch {
504    /// This method constructs output columns using the result of each window expression.
505    // Assume input buffer is         |      Partition Buffers would be (Where each partition and its data is separated)
506    // a, 2                           |      a, 2
507    // b, 2                           |      a, 2
508    // a, 2                           |      a, 2
509    // b, 2                           |
510    // a, 2                           |      b, 2
511    // b, 2                           |      b, 2
512    // b, 2                           |      b, 2
513    //                                |      b, 2
514    // Also assume we happen to calculate 2 new values for a, and 3 for b (To be calculate missing values we may need to consider future values).
515    // Partition buffers effectively will be
516    // a, 2, 1
517    // a, 2, 2
518    // a, 2, (missing)
519    //
520    // b, 2, 1
521    // b, 2, 2
522    // b, 2, 3
523    // b, 2, (missing)
524    // When partition buffers are mapped back to the original record batch. Result becomes
525    // a, 2, 1
526    // b, 2, 1
527    // a, 2, 2
528    // b, 2, 2
529    // a, 2, (missing)
530    // b, 2, 3
531    // b, 2, (missing)
532    // This function calculates the column result of window expression(s) (First 4 entry of 3rd column in the above section.)
533    // 1
534    // 1
535    // 2
536    // 2
537    // Above section corresponds to calculated result which can be emitted without breaking input buffer ordering.
538    fn calculate_out_columns(
539        &mut self,
540        input_buffer: &RecordBatch,
541        window_agg_states: &[PartitionWindowAggStates],
542        partition_buffers: &mut PartitionBatches,
543        window_expr: &[Arc<dyn WindowExpr>],
544    ) -> Result<Option<Vec<ArrayRef>>> {
545        let partition_output_indices = self.calc_partition_output_indices(
546            input_buffer,
547            window_agg_states,
548            window_expr,
549        )?;
550
551        let n_window_col = window_agg_states.len();
552        let mut new_columns = vec![vec![]; n_window_col];
553        // Size of all_indices can be at most input_buffer.num_rows():
554        let mut all_indices = UInt32Builder::with_capacity(input_buffer.num_rows());
555        for (row, indices) in partition_output_indices {
556            let length = indices.len();
557            for (idx, window_agg_state) in window_agg_states.iter().enumerate() {
558                let partition = &window_agg_state[&row];
559                let values = Arc::clone(&partition.state.out_col.slice(0, length));
560                new_columns[idx].push(values);
561            }
562            let partition_batch_state = &mut partition_buffers[&row];
563            // Store how many rows are generated for each partition
564            partition_batch_state.n_out_row = length;
565            // For each row keep corresponding index in the input record batch
566            all_indices.append_slice(&indices);
567        }
568        let all_indices = all_indices.finish();
569        if all_indices.is_empty() {
570            // We couldn't generate any new value, return early:
571            return Ok(None);
572        }
573
574        // Concatenate results for each column by converting `Vec<Vec<ArrayRef>>`
575        // to Vec<ArrayRef> where inner `Vec<ArrayRef>`s are converted to `ArrayRef`s.
576        let new_columns = new_columns
577            .iter()
578            .map(|items| {
579                concat(&items.iter().map(|e| e.as_ref()).collect::<Vec<_>>())
580                    .map_err(|e| arrow_datafusion_err!(e))
581            })
582            .collect::<Result<Vec<_>>>()?;
583        // We should emit columns according to row index ordering.
584        let sorted_indices = sort_to_indices(&all_indices, None, None)?;
585        // Construct new column according to row ordering. This fixes ordering
586        take_arrays(&new_columns, &sorted_indices, None)
587            .map(Some)
588            .map_err(|e| arrow_datafusion_err!(e))
589    }
590
591    fn evaluate_partition_batches(
592        &mut self,
593        record_batch: &RecordBatch,
594        window_expr: &[Arc<dyn WindowExpr>],
595    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
596        let partition_bys =
597            evaluate_partition_by_column_values(record_batch, window_expr)?;
598        // NOTE: In Linear or PartiallySorted modes, we are sure that
599        //       `partition_bys` are not empty.
600        // Calculate indices for each partition and construct a new record
601        // batch from the rows at these indices for each partition:
602        self.get_per_partition_indices(&partition_bys, record_batch)?
603            .into_iter()
604            .map(|(row, indices)| {
605                let mut new_indices = UInt32Builder::with_capacity(indices.len());
606                new_indices.append_slice(&indices);
607                let indices = new_indices.finish();
608                Ok((row, take_record_batch(record_batch, &indices)?))
609            })
610            .collect()
611    }
612
613    fn prune(&mut self, n_out: usize) {
614        // Delete hashes for the rows that are outputted.
615        self.input_buffer_hashes.drain(0..n_out);
616    }
617
618    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
619        // We should be in the `PartiallySorted` case, otherwise we can not
620        // tell when we are at the end of a given partition.
621        if !self.ordered_partition_by_indices.is_empty() {
622            if let Some((last_row, _)) = partition_buffers.last() {
623                let last_sorted_cols = self
624                    .ordered_partition_by_indices
625                    .iter()
626                    .map(|idx| last_row[*idx].clone())
627                    .collect::<Vec<_>>();
628                for (row, partition_batch_state) in partition_buffers.iter_mut() {
629                    let sorted_cols = self
630                        .ordered_partition_by_indices
631                        .iter()
632                        .map(|idx| &row[*idx]);
633                    // All the partitions other than `last_sorted_cols` are done.
634                    // We are sure that we will no longer receive values for these
635                    // partitions (arrival of a new value would violate ordering).
636                    partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols);
637                }
638            }
639        }
640    }
641
642    fn is_mode_linear(&self) -> bool {
643        self.ordered_partition_by_indices.is_empty()
644    }
645
646    fn input_schema(&self) -> &SchemaRef {
647        &self.input_schema
648    }
649}
650
651impl LinearSearch {
652    /// Initialize a new [`LinearSearch`] partition searcher.
653    fn new(ordered_partition_by_indices: Vec<usize>, input_schema: SchemaRef) -> Self {
654        LinearSearch {
655            input_buffer_hashes: VecDeque::new(),
656            random_state: Default::default(),
657            ordered_partition_by_indices,
658            row_map_batch: HashTable::with_capacity(256),
659            row_map_out: HashTable::with_capacity(256),
660            input_schema,
661        }
662    }
663
664    /// Calculate indices of each partition (according to PARTITION BY expression)
665    /// `columns` contain partition by expression results.
666    fn get_per_partition_indices(
667        &mut self,
668        columns: &[ArrayRef],
669        batch: &RecordBatch,
670    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
671        let mut batch_hashes = vec![0; batch.num_rows()];
672        create_hashes(columns, &self.random_state, &mut batch_hashes)?;
673        self.input_buffer_hashes.extend(&batch_hashes);
674        // reset row_map for new calculation
675        self.row_map_batch.clear();
676        // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition.
677        let mut result: Vec<(PartitionKey, Vec<u32>)> = vec![];
678        for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) {
679            let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| {
680                // We can safely get the first index of the partition indices
681                // since partition indices has one element during initialization.
682                let row = get_row_at_idx(columns, row_idx as usize).unwrap();
683                // Handle hash collusions with an equality check:
684                row.eq(&result[*group_idx].0)
685            });
686            if let Some((_, group_idx)) = entry {
687                result[*group_idx].1.push(row_idx)
688            } else {
689                self.row_map_batch.insert_unique(
690                    hash,
691                    (hash, result.len()),
692                    |(hash, _)| *hash,
693                );
694                let row = get_row_at_idx(columns, row_idx as usize)?;
695                // This is a new partition its only index is row_idx for now.
696                result.push((row, vec![row_idx]));
697            }
698        }
699        Ok(result)
700    }
701
702    /// Calculates partition keys and result indices for each partition.
703    /// The return value is a vector of tuples where the first entry stores
704    /// the partition key (unique for each partition) and the second entry
705    /// stores indices of the rows for which the partition is constructed.
706    fn calc_partition_output_indices(
707        &mut self,
708        input_buffer: &RecordBatch,
709        window_agg_states: &[PartitionWindowAggStates],
710        window_expr: &[Arc<dyn WindowExpr>],
711    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
712        let partition_by_columns =
713            evaluate_partition_by_column_values(input_buffer, window_expr)?;
714        // Reset the row_map state:
715        self.row_map_out.clear();
716        let mut partition_indices: Vec<(PartitionKey, Vec<u32>)> = vec![];
717        for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) {
718            let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| {
719                let row =
720                    get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap();
721                row == partition_indices[*group_idx].0
722            });
723            if let Some((_, group_idx, n_out)) = entry {
724                let (_, indices) = &mut partition_indices[*group_idx];
725                if indices.len() >= *n_out {
726                    break;
727                }
728                indices.push(row_idx);
729            } else {
730                let row = get_row_at_idx(&partition_by_columns, row_idx as usize)?;
731                let min_out = window_agg_states
732                    .iter()
733                    .map(|window_agg_state| {
734                        window_agg_state
735                            .get(&row)
736                            .map(|partition| partition.state.out_col.len())
737                            .unwrap_or(0)
738                    })
739                    .min()
740                    .unwrap_or(0);
741                if min_out == 0 {
742                    break;
743                }
744                self.row_map_out.insert_unique(
745                    *hash,
746                    (*hash, partition_indices.len(), min_out),
747                    |(hash, _, _)| *hash,
748                );
749                partition_indices.push((row, vec![row_idx]));
750            }
751        }
752        Ok(partition_indices)
753    }
754}
755
756/// This object encapsulates the algorithm state for sorted searching
757/// when computing partitions.
758pub struct SortedSearch {
759    /// Stores partition by columns and their ordering information
760    partition_by_sort_keys: Vec<PhysicalSortExpr>,
761    /// Input ordering and partition by key ordering need not be the same, so
762    /// this vector stores the mapping between them. For instance, if the input
763    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
764    /// clause, this attribute stores [1, 0].
765    ordered_partition_by_indices: Vec<usize>,
766    input_schema: SchemaRef,
767}
768
769impl PartitionSearcher for SortedSearch {
770    /// This method constructs new output columns using the result of each window expression.
771    fn calculate_out_columns(
772        &mut self,
773        _input_buffer: &RecordBatch,
774        window_agg_states: &[PartitionWindowAggStates],
775        partition_buffers: &mut PartitionBatches,
776        _window_expr: &[Arc<dyn WindowExpr>],
777    ) -> Result<Option<Vec<ArrayRef>>> {
778        let n_out = self.calculate_n_out_row(window_agg_states, partition_buffers);
779        if n_out == 0 {
780            Ok(None)
781        } else {
782            window_agg_states
783                .iter()
784                .map(|map| get_aggregate_result_out_column(map, n_out).map(Some))
785                .collect()
786        }
787    }
788
789    fn evaluate_partition_batches(
790        &mut self,
791        record_batch: &RecordBatch,
792        _window_expr: &[Arc<dyn WindowExpr>],
793    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
794        let num_rows = record_batch.num_rows();
795        // Calculate result of partition by column expressions
796        let partition_columns = self
797            .partition_by_sort_keys
798            .iter()
799            .map(|elem| elem.evaluate_to_sort_column(record_batch))
800            .collect::<Result<Vec<_>>>()?;
801        // Reorder `partition_columns` such that its ordering matches input ordering.
802        let partition_columns_ordered =
803            get_at_indices(&partition_columns, &self.ordered_partition_by_indices)?;
804        let partition_points =
805            evaluate_partition_ranges(num_rows, &partition_columns_ordered)?;
806        let partition_bys = partition_columns
807            .into_iter()
808            .map(|arr| arr.values)
809            .collect::<Vec<ArrayRef>>();
810
811        partition_points
812            .iter()
813            .map(|range| {
814                let row = get_row_at_idx(&partition_bys, range.start)?;
815                let len = range.end - range.start;
816                let slice = record_batch.slice(range.start, len);
817                Ok((row, slice))
818            })
819            .collect::<Result<Vec<_>>>()
820    }
821
822    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
823        // In Sorted case. We can mark all partitions besides last partition as ended.
824        // We are sure that those partitions will never receive any values.
825        // (Otherwise ordering invariant is violated.)
826        let n_partitions = partition_buffers.len();
827        for (idx, (_, partition_batch_state)) in partition_buffers.iter_mut().enumerate()
828        {
829            partition_batch_state.is_end |= idx < n_partitions - 1;
830        }
831    }
832
833    fn input_schema(&self) -> &SchemaRef {
834        &self.input_schema
835    }
836}
837
838impl SortedSearch {
839    /// Calculates how many rows we can output.
840    fn calculate_n_out_row(
841        &mut self,
842        window_agg_states: &[PartitionWindowAggStates],
843        partition_buffers: &mut PartitionBatches,
844    ) -> usize {
845        // Different window aggregators may produce results at different rates.
846        // We produce the overall batch result only as fast as the slowest one.
847        let mut counts = vec![];
848        let out_col_counts = window_agg_states.iter().map(|window_agg_state| {
849            // Store how many elements are generated for the current
850            // window expression:
851            let mut cur_window_expr_out_result_len = 0;
852            // We iterate over `window_agg_state`, which is an IndexMap.
853            // Iterations follow the insertion order, hence we preserve
854            // sorting when partition columns are sorted.
855            let mut per_partition_out_results = HashMap::new();
856            for (row, WindowState { state, .. }) in window_agg_state.iter() {
857                cur_window_expr_out_result_len += state.out_col.len();
858                let count = per_partition_out_results.entry(row).or_insert(0);
859                if *count < state.out_col.len() {
860                    *count = state.out_col.len();
861                }
862                // If we do not generate all results for the current
863                // partition, we do not generate results for next
864                // partition --  otherwise we will lose input ordering.
865                if state.n_row_result_missing > 0 {
866                    break;
867                }
868            }
869            counts.push(per_partition_out_results);
870            cur_window_expr_out_result_len
871        });
872        argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
873            for (row, count) in counts.swap_remove(min_idx).into_iter() {
874                let partition_batch = &mut partition_buffers[row];
875                partition_batch.n_out_row = count;
876            }
877            minima
878        })
879    }
880}
881
882/// Calculates partition by expression results for each window expression
883/// on `record_batch`.
884fn evaluate_partition_by_column_values(
885    record_batch: &RecordBatch,
886    window_expr: &[Arc<dyn WindowExpr>],
887) -> Result<Vec<ArrayRef>> {
888    window_expr[0]
889        .partition_by()
890        .iter()
891        .map(|item| match item.evaluate(record_batch)? {
892            ColumnarValue::Array(array) => Ok(array),
893            ColumnarValue::Scalar(scalar) => {
894                scalar.to_array_of_size(record_batch.num_rows())
895            }
896        })
897        .collect()
898}
899
900/// Stream for the bounded window aggregation plan.
901pub struct BoundedWindowAggStream {
902    schema: SchemaRef,
903    input: SendableRecordBatchStream,
904    /// The record batch executor receives as input (i.e. the columns needed
905    /// while calculating aggregation results).
906    input_buffer: RecordBatch,
907    /// We separate `input_buffer` based on partitions (as
908    /// determined by PARTITION BY columns) and store them per partition
909    /// in `partition_batches`. We use this variable when calculating results
910    /// for each window expression. This enables us to use the same batch for
911    /// different window expressions without copying.
912    // Note that we could keep record batches for each window expression in
913    // `PartitionWindowAggStates`. However, this would use more memory (as
914    // many times as the number of window expressions).
915    partition_buffers: PartitionBatches,
916    /// An executor can run multiple window expressions if the PARTITION BY
917    /// and ORDER BY sections are same. We keep state of the each window
918    /// expression inside `window_agg_states`.
919    window_agg_states: Vec<PartitionWindowAggStates>,
920    finished: bool,
921    window_expr: Vec<Arc<dyn WindowExpr>>,
922    baseline_metrics: BaselineMetrics,
923    /// Search mode for partition columns. This determines the algorithm with
924    /// which we group each partition.
925    search_mode: Box<dyn PartitionSearcher>,
926}
927
928impl BoundedWindowAggStream {
929    /// Prunes sections of the state that are no longer needed when calculating
930    /// results (as determined by window frame boundaries and number of results generated).
931    // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to
932    // calculate window expression result (outside the window frame boundary) we retract first `n` elements
933    // from `self.partition_batches` in corresponding partition.
934    // For instance, if `n_out` number of rows are calculated, we can remove
935    // first `n_out` rows from `self.input_buffer`.
936    fn prune_state(&mut self, n_out: usize) -> Result<()> {
937        // Prune `self.window_agg_states`:
938        self.prune_out_columns();
939        // Prune `self.partition_batches`:
940        self.prune_partition_batches();
941        // Prune `self.input_buffer`:
942        self.prune_input_batch(n_out)?;
943        // Prune internal state of search algorithm.
944        self.search_mode.prune(n_out);
945        Ok(())
946    }
947}
948
949impl Stream for BoundedWindowAggStream {
950    type Item = Result<RecordBatch>;
951
952    fn poll_next(
953        mut self: Pin<&mut Self>,
954        cx: &mut Context<'_>,
955    ) -> Poll<Option<Self::Item>> {
956        let poll = self.poll_next_inner(cx);
957        self.baseline_metrics.record_poll(poll)
958    }
959}
960
961impl BoundedWindowAggStream {
962    /// Create a new BoundedWindowAggStream
963    fn new(
964        schema: SchemaRef,
965        window_expr: Vec<Arc<dyn WindowExpr>>,
966        input: SendableRecordBatchStream,
967        baseline_metrics: BaselineMetrics,
968        search_mode: Box<dyn PartitionSearcher>,
969    ) -> Result<Self> {
970        let state = window_expr.iter().map(|_| IndexMap::new()).collect();
971        let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
972        Ok(Self {
973            schema,
974            input,
975            input_buffer: empty_batch,
976            partition_buffers: IndexMap::new(),
977            window_agg_states: state,
978            finished: false,
979            window_expr,
980            baseline_metrics,
981            search_mode,
982        })
983    }
984
985    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
986        // calculate window cols
987        for (cur_window_expr, state) in
988            self.window_expr.iter().zip(&mut self.window_agg_states)
989        {
990            cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
991        }
992
993        let schema = Arc::clone(&self.schema);
994        let window_expr_out = self.search_mode.calculate_out_columns(
995            &self.input_buffer,
996            &self.window_agg_states,
997            &mut self.partition_buffers,
998            &self.window_expr,
999        )?;
1000        if let Some(window_expr_out) = window_expr_out {
1001            let n_out = window_expr_out[0].len();
1002            // right append new columns to corresponding section in the original input buffer.
1003            let columns_to_show = self
1004                .input_buffer
1005                .columns()
1006                .iter()
1007                .map(|elem| elem.slice(0, n_out))
1008                .chain(window_expr_out)
1009                .collect::<Vec<_>>();
1010            let n_generated = columns_to_show[0].len();
1011            self.prune_state(n_generated)?;
1012            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
1013        } else {
1014            Ok(None)
1015        }
1016    }
1017
1018    #[inline]
1019    fn poll_next_inner(
1020        &mut self,
1021        cx: &mut Context<'_>,
1022    ) -> Poll<Option<Result<RecordBatch>>> {
1023        if self.finished {
1024            return Poll::Ready(None);
1025        }
1026
1027        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1028        match ready!(self.input.poll_next_unpin(cx)) {
1029            Some(Ok(batch)) => {
1030                // Start the timer for compute time within this operator. It will be
1031                // stopped when dropped.
1032                let _timer = elapsed_compute.timer();
1033
1034                self.search_mode.update_partition_batch(
1035                    &mut self.input_buffer,
1036                    batch,
1037                    &self.window_expr,
1038                    &mut self.partition_buffers,
1039                )?;
1040                if let Some(batch) = self.compute_aggregates()? {
1041                    return Poll::Ready(Some(Ok(batch)));
1042                }
1043                self.poll_next_inner(cx)
1044            }
1045            Some(Err(e)) => Poll::Ready(Some(Err(e))),
1046            None => {
1047                let _timer = elapsed_compute.timer();
1048
1049                self.finished = true;
1050                for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
1051                    partition_batch_state.is_end = true;
1052                }
1053                if let Some(batch) = self.compute_aggregates()? {
1054                    return Poll::Ready(Some(Ok(batch)));
1055                }
1056                Poll::Ready(None)
1057            }
1058        }
1059    }
1060
1061    /// Prunes the sections of the record batch (for each partition)
1062    /// that we no longer need to calculate the window function result.
1063    fn prune_partition_batches(&mut self) {
1064        // Remove partitions which we know already ended (is_end flag is true).
1065        // Since the retain method preserves insertion order, we still have
1066        // ordering in between partitions after removal.
1067        self.partition_buffers
1068            .retain(|_, partition_batch_state| !partition_batch_state.is_end);
1069
1070        // The data in `self.partition_batches` is used by all window expressions.
1071        // Therefore, when removing from `self.partition_batches`, we need to remove
1072        // from the earliest range boundary among all window expressions. Variable
1073        // `n_prune_each_partition` fill the earliest range boundary information for
1074        // each partition. This way, we can delete the no-longer-needed sections from
1075        // `self.partition_batches`.
1076        // For instance, if window frame one uses [10, 20] and window frame two uses
1077        // [5, 15]; we only prune the first 5 elements from the corresponding record
1078        // batch in `self.partition_batches`.
1079
1080        // Calculate how many elements to prune for each partition batch
1081        let mut n_prune_each_partition = HashMap::new();
1082        for window_agg_state in self.window_agg_states.iter_mut() {
1083            window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
1084            for (partition_row, WindowState { state: value, .. }) in window_agg_state {
1085                let n_prune =
1086                    min(value.window_frame_range.start, value.last_calculated_index);
1087                if let Some(current) = n_prune_each_partition.get_mut(partition_row) {
1088                    if n_prune < *current {
1089                        *current = n_prune;
1090                    }
1091                } else {
1092                    n_prune_each_partition.insert(partition_row.clone(), n_prune);
1093                }
1094            }
1095        }
1096
1097        // Retract no longer needed parts during window calculations from partition batch:
1098        for (partition_row, n_prune) in n_prune_each_partition.iter() {
1099            let pb_state = &mut self.partition_buffers[partition_row];
1100
1101            let batch = &pb_state.record_batch;
1102            pb_state.record_batch = batch.slice(*n_prune, batch.num_rows() - n_prune);
1103            pb_state.n_out_row = 0;
1104
1105            // Update state indices since we have pruned some rows from the beginning:
1106            for window_agg_state in self.window_agg_states.iter_mut() {
1107                window_agg_state[partition_row].state.prune_state(*n_prune);
1108            }
1109        }
1110    }
1111
1112    /// Prunes the section of the input batch whose aggregate results
1113    /// are calculated and emitted.
1114    fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
1115        // Prune first n_out rows from the input_buffer
1116        let n_to_keep = self.input_buffer.num_rows() - n_out;
1117        let batch_to_keep = self
1118            .input_buffer
1119            .columns()
1120            .iter()
1121            .map(|elem| elem.slice(n_out, n_to_keep))
1122            .collect::<Vec<_>>();
1123        self.input_buffer = RecordBatch::try_new_with_options(
1124            self.input_buffer.schema(),
1125            batch_to_keep,
1126            &RecordBatchOptions::new().with_row_count(Some(n_to_keep)),
1127        )?;
1128        Ok(())
1129    }
1130
1131    /// Prunes emitted parts from WindowAggState `out_col` field.
1132    fn prune_out_columns(&mut self) {
1133        // We store generated columns for each window expression in the `out_col`
1134        // field of `WindowAggState`. Given how many rows are emitted, we remove
1135        // these sections from state.
1136        for partition_window_agg_states in self.window_agg_states.iter_mut() {
1137            // Remove `n_out` entries from the `out_col` field of `WindowAggState`.
1138            // `n_out` is stored in `self.partition_buffers` for each partition.
1139            // If `is_end` is set, directly remove them; this shrinks the hash map.
1140            partition_window_agg_states
1141                .retain(|_, partition_batch_state| !partition_batch_state.state.is_end);
1142            for (
1143                partition_key,
1144                WindowState {
1145                    state: WindowAggState { out_col, .. },
1146                    ..
1147                },
1148            ) in partition_window_agg_states
1149            {
1150                let partition_batch = &mut self.partition_buffers[partition_key];
1151                let n_to_del = partition_batch.n_out_row;
1152                let n_to_keep = out_col.len() - n_to_del;
1153                *out_col = out_col.slice(n_to_del, n_to_keep);
1154            }
1155        }
1156    }
1157}
1158
1159impl RecordBatchStream for BoundedWindowAggStream {
1160    /// Get the schema
1161    fn schema(&self) -> SchemaRef {
1162        Arc::clone(&self.schema)
1163    }
1164}
1165
1166// Gets the index of minimum entry, returns None if empty.
1167fn argmin<T: PartialOrd>(data: impl Iterator<Item = T>) -> Option<(usize, T)> {
1168    data.enumerate()
1169        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1170}
1171
1172/// Calculates the section we can show results for expression
1173fn get_aggregate_result_out_column(
1174    partition_window_agg_states: &PartitionWindowAggStates,
1175    len_to_show: usize,
1176) -> Result<ArrayRef> {
1177    let mut result = None;
1178    let mut running_length = 0;
1179    // We assume that iteration order is according to insertion order
1180    for (
1181        _,
1182        WindowState {
1183            state: WindowAggState { out_col, .. },
1184            ..
1185        },
1186    ) in partition_window_agg_states
1187    {
1188        if running_length < len_to_show {
1189            let n_to_use = min(len_to_show - running_length, out_col.len());
1190            let slice_to_use = out_col.slice(0, n_to_use);
1191            result = Some(match result {
1192                Some(arr) => concat(&[&arr, &slice_to_use])?,
1193                None => slice_to_use,
1194            });
1195            running_length += n_to_use;
1196        } else {
1197            break;
1198        }
1199    }
1200    if running_length != len_to_show {
1201        return exec_err!(
1202            "Generated row number should be {len_to_show}, it is {running_length}"
1203        );
1204    }
1205    result
1206        .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string()))
1207}
1208
1209/// Constructs a batch from the last row of batch in the argument.
1210pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result<RecordBatch> {
1211    if batch.num_rows() == 0 {
1212        return exec_err!("Latest batch should have at least 1 row");
1213    }
1214    Ok(batch.slice(batch.num_rows() - 1, 1))
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219    use std::pin::Pin;
1220    use std::sync::Arc;
1221    use std::task::{Context, Poll};
1222    use std::time::Duration;
1223
1224    use crate::common::collect;
1225    use crate::expressions::PhysicalSortExpr;
1226    use crate::projection::{ProjectionExec, ProjectionExpr};
1227    use crate::streaming::{PartitionStream, StreamingTableExec};
1228    use crate::test::TestMemoryExec;
1229    use crate::windows::{
1230        create_udwf_window_expr, create_window_expr, BoundedWindowAggExec, InputOrderMode,
1231    };
1232    use crate::{execute_stream, get_plan_string, ExecutionPlan};
1233
1234    use arrow::array::{
1235        builder::{Int64Builder, UInt64Builder},
1236        RecordBatch,
1237    };
1238    use arrow::compute::SortOptions;
1239    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1240    use datafusion_common::test_util::batches_to_string;
1241    use datafusion_common::{exec_datafusion_err, Result, ScalarValue};
1242    use datafusion_execution::config::SessionConfig;
1243    use datafusion_execution::{
1244        RecordBatchStream, SendableRecordBatchStream, TaskContext,
1245    };
1246    use datafusion_expr::{
1247        WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1248    };
1249    use datafusion_functions_aggregate::count::count_udaf;
1250    use datafusion_functions_window::nth_value::last_value_udwf;
1251    use datafusion_functions_window::nth_value::nth_value_udwf;
1252    use datafusion_physical_expr::expressions::{col, Column, Literal};
1253    use datafusion_physical_expr::window::StandardWindowExpr;
1254    use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
1255
1256    use futures::future::Shared;
1257    use futures::{pin_mut, ready, FutureExt, Stream, StreamExt};
1258    use insta::assert_snapshot;
1259    use itertools::Itertools;
1260    use tokio::time::timeout;
1261
1262    #[derive(Debug, Clone)]
1263    struct TestStreamPartition {
1264        schema: SchemaRef,
1265        batches: Vec<RecordBatch>,
1266        idx: usize,
1267        state: PolingState,
1268        sleep_duration: Duration,
1269        send_exit: bool,
1270    }
1271
1272    impl PartitionStream for TestStreamPartition {
1273        fn schema(&self) -> &SchemaRef {
1274            &self.schema
1275        }
1276
1277        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
1278            // We create an iterator from the record batches and map them into Ok values,
1279            // converting the iterator into a futures::stream::Stream
1280            Box::pin(self.clone())
1281        }
1282    }
1283
1284    impl Stream for TestStreamPartition {
1285        type Item = Result<RecordBatch>;
1286
1287        fn poll_next(
1288            mut self: Pin<&mut Self>,
1289            cx: &mut Context<'_>,
1290        ) -> Poll<Option<Self::Item>> {
1291            self.poll_next_inner(cx)
1292        }
1293    }
1294
1295    #[derive(Debug, Clone)]
1296    enum PolingState {
1297        Sleep(Shared<futures::future::BoxFuture<'static, ()>>),
1298        BatchReturn,
1299    }
1300
1301    impl TestStreamPartition {
1302        fn poll_next_inner(
1303            self: &mut Pin<&mut Self>,
1304            cx: &mut Context<'_>,
1305        ) -> Poll<Option<Result<RecordBatch>>> {
1306            loop {
1307                match &mut self.state {
1308                    PolingState::BatchReturn => {
1309                        // Wait for self.sleep_duration before sending any new data
1310                        let f = tokio::time::sleep(self.sleep_duration).boxed().shared();
1311                        self.state = PolingState::Sleep(f);
1312                        let input_batch = if let Some(batch) =
1313                            self.batches.clone().get(self.idx)
1314                        {
1315                            batch.clone()
1316                        } else if self.send_exit {
1317                            // Send None to signal end of data
1318                            return Poll::Ready(None);
1319                        } else {
1320                            // Go to sleep mode
1321                            let f =
1322                                tokio::time::sleep(self.sleep_duration).boxed().shared();
1323                            self.state = PolingState::Sleep(f);
1324                            continue;
1325                        };
1326                        self.idx += 1;
1327                        return Poll::Ready(Some(Ok(input_batch)));
1328                    }
1329                    PolingState::Sleep(future) => {
1330                        pin_mut!(future);
1331                        ready!(future.poll_unpin(cx));
1332                        self.state = PolingState::BatchReturn;
1333                    }
1334                }
1335            }
1336        }
1337    }
1338
1339    impl RecordBatchStream for TestStreamPartition {
1340        fn schema(&self) -> SchemaRef {
1341            Arc::clone(&self.schema)
1342        }
1343    }
1344
1345    fn bounded_window_exec_pb_latent_range(
1346        input: Arc<dyn ExecutionPlan>,
1347        n_future_range: usize,
1348        hash: &str,
1349        order_by: &str,
1350    ) -> Result<Arc<dyn ExecutionPlan>> {
1351        let schema = input.schema();
1352        let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
1353        let col_expr =
1354            Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1355        let args = vec![col_expr];
1356        let partitionby_exprs = vec![col(hash, &schema)?];
1357        let orderby_exprs = vec![PhysicalSortExpr {
1358            expr: col(order_by, &schema)?,
1359            options: SortOptions::default(),
1360        }];
1361        let window_frame = WindowFrame::new_bounds(
1362            WindowFrameUnits::Range,
1363            WindowFrameBound::CurrentRow,
1364            WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))),
1365        );
1366        let fn_name = format!(
1367            "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]"
1368        );
1369        let input_order_mode = InputOrderMode::Linear;
1370        Ok(Arc::new(BoundedWindowAggExec::try_new(
1371            vec![create_window_expr(
1372                &window_fn,
1373                fn_name,
1374                &args,
1375                &partitionby_exprs,
1376                &orderby_exprs,
1377                Arc::new(window_frame),
1378                &input.schema(),
1379                false,
1380                false,
1381                None,
1382            )?],
1383            input,
1384            input_order_mode,
1385            true,
1386        )?))
1387    }
1388
1389    fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
1390        let schema = input.schema();
1391        let exprs = input
1392            .schema()
1393            .fields
1394            .iter()
1395            .enumerate()
1396            .map(|(idx, field)| {
1397                let name = if field.name().len() > 20 {
1398                    format!("col_{idx}")
1399                } else {
1400                    field.name().clone()
1401                };
1402                let expr = col(field.name(), &schema).unwrap();
1403                (expr, name)
1404            })
1405            .collect::<Vec<_>>();
1406        let proj_exprs: Vec<ProjectionExpr> = exprs
1407            .into_iter()
1408            .map(|(expr, alias)| ProjectionExpr { expr, alias })
1409            .collect();
1410        Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?))
1411    }
1412
1413    fn task_context_helper() -> TaskContext {
1414        let task_ctx = TaskContext::default();
1415        // Create session context with config
1416        let session_config = SessionConfig::new()
1417            .with_batch_size(1)
1418            .with_target_partitions(2)
1419            .with_round_robin_repartition(false);
1420        task_ctx.with_session_config(session_config)
1421    }
1422
1423    fn task_context() -> Arc<TaskContext> {
1424        Arc::new(task_context_helper())
1425    }
1426
1427    pub async fn collect_stream(
1428        mut stream: SendableRecordBatchStream,
1429        results: &mut Vec<RecordBatch>,
1430    ) -> Result<()> {
1431        while let Some(item) = stream.next().await {
1432            results.push(item?);
1433        }
1434        Ok(())
1435    }
1436
1437    /// Execute the [ExecutionPlan] and collect the results in memory
1438    pub async fn collect_with_timeout(
1439        plan: Arc<dyn ExecutionPlan>,
1440        context: Arc<TaskContext>,
1441        timeout_duration: Duration,
1442    ) -> Result<Vec<RecordBatch>> {
1443        let stream = execute_stream(plan, context)?;
1444        let mut results = vec![];
1445
1446        // Execute the asynchronous operation with a timeout
1447        if timeout(timeout_duration, collect_stream(stream, &mut results))
1448            .await
1449            .is_ok()
1450        {
1451            return Err(exec_datafusion_err!("shouldn't have completed"));
1452        };
1453
1454        Ok(results)
1455    }
1456
1457    /// Execute the [ExecutionPlan] and collect the results in memory
1458    #[allow(dead_code)]
1459    pub async fn collect_bonafide(
1460        plan: Arc<dyn ExecutionPlan>,
1461        context: Arc<TaskContext>,
1462    ) -> Result<Vec<RecordBatch>> {
1463        let stream = execute_stream(plan, context)?;
1464        let mut results = vec![];
1465
1466        collect_stream(stream, &mut results).await?;
1467
1468        Ok(results)
1469    }
1470
1471    fn test_schema() -> SchemaRef {
1472        Arc::new(Schema::new(vec![
1473            Field::new("sn", DataType::UInt64, true),
1474            Field::new("hash", DataType::Int64, true),
1475        ]))
1476    }
1477
1478    fn schema_orders(schema: &SchemaRef) -> Result<Vec<LexOrdering>> {
1479        let orderings = vec![[PhysicalSortExpr {
1480            expr: col("sn", schema)?,
1481            options: SortOptions {
1482                descending: false,
1483                nulls_first: false,
1484            },
1485        }]
1486        .into()];
1487        Ok(orderings)
1488    }
1489
1490    fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool {
1491        let res = lhs / rhs;
1492        res * rhs == lhs
1493    }
1494    fn generate_batches(
1495        schema: &SchemaRef,
1496        n_row: usize,
1497        n_chunk: usize,
1498    ) -> Result<Vec<RecordBatch>> {
1499        let mut batches = vec![];
1500        assert!(n_row > 0);
1501        assert!(n_chunk > 0);
1502        assert!(is_integer_division_safe(n_row, n_chunk));
1503        let hash_replicate = 4;
1504
1505        let chunks = (0..n_row)
1506            .chunks(n_chunk)
1507            .into_iter()
1508            .map(|elem| elem.into_iter().collect::<Vec<_>>())
1509            .collect::<Vec<_>>();
1510
1511        // Send 2 RecordBatches at the source
1512        for sn_values in chunks {
1513            let mut sn1_array = UInt64Builder::with_capacity(sn_values.len());
1514            let mut hash_array = Int64Builder::with_capacity(sn_values.len());
1515
1516            for sn in sn_values {
1517                sn1_array.append_value(sn as u64);
1518                let hash_value = (2 - (sn / hash_replicate)) as i64;
1519                hash_array.append_value(hash_value);
1520            }
1521
1522            let batch = RecordBatch::try_new(
1523                Arc::clone(schema),
1524                vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())],
1525            )?;
1526            batches.push(batch);
1527        }
1528        Ok(batches)
1529    }
1530
1531    fn generate_never_ending_source(
1532        n_rows: usize,
1533        chunk_length: usize,
1534        n_partition: usize,
1535        is_infinite: bool,
1536        send_exit: bool,
1537        per_batch_wait_duration_in_millis: u64,
1538    ) -> Result<Arc<dyn ExecutionPlan>> {
1539        assert!(n_partition > 0);
1540
1541        // We use same hash value in the table. This makes sure that
1542        // After hashing computation will continue in only in one of the output partitions
1543        // In this case, data flow should still continue
1544        let schema = test_schema();
1545        let orderings = schema_orders(&schema)?;
1546
1547        // Source waits per_batch_wait_duration_in_millis ms before sending other batch
1548        let per_batch_wait_duration =
1549            Duration::from_millis(per_batch_wait_duration_in_millis);
1550
1551        let batches = generate_batches(&schema, n_rows, chunk_length)?;
1552
1553        // Source has 2 partitions
1554        let partitions = vec![
1555            Arc::new(TestStreamPartition {
1556                schema: Arc::clone(&schema),
1557                batches,
1558                idx: 0,
1559                state: PolingState::BatchReturn,
1560                sleep_duration: per_batch_wait_duration,
1561                send_exit,
1562            }) as _;
1563            n_partition
1564        ];
1565        let source = Arc::new(StreamingTableExec::try_new(
1566            Arc::clone(&schema),
1567            partitions,
1568            None,
1569            orderings,
1570            is_infinite,
1571            None,
1572        )?) as _;
1573        Ok(source)
1574    }
1575
1576    // Tests NTH_VALUE(negative index) with memoize feature
1577    // To be able to trigger memoize feature for NTH_VALUE we need to
1578    // - feed BoundedWindowAggExec with batch stream data.
1579    // - Window frame should contain UNBOUNDED PRECEDING.
1580    // It hard to ensure these conditions are met, from the sql query.
1581    #[tokio::test]
1582    async fn test_window_nth_value_bounded_memoize() -> Result<()> {
1583        let config = SessionConfig::new().with_target_partitions(1);
1584        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
1585
1586        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1587        // Create a new batch of data to insert into the table
1588        let batch = RecordBatch::try_new(
1589            Arc::clone(&schema),
1590            vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
1591        )?;
1592
1593        let memory_exec = TestMemoryExec::try_new_exec(
1594            &[vec![batch.clone(), batch.clone(), batch.clone()]],
1595            Arc::clone(&schema),
1596            None,
1597        )?;
1598        let col_a = col("a", &schema)?;
1599        let nth_value_func1 = create_udwf_window_expr(
1600            &nth_value_udwf(),
1601            &[
1602                Arc::clone(&col_a),
1603                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1604            ],
1605            &schema,
1606            "nth_value(-1)".to_string(),
1607            false,
1608        )?
1609        .reverse_expr()
1610        .unwrap();
1611        let nth_value_func2 = create_udwf_window_expr(
1612            &nth_value_udwf(),
1613            &[
1614                Arc::clone(&col_a),
1615                Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
1616            ],
1617            &schema,
1618            "nth_value(-2)".to_string(),
1619            false,
1620        )?
1621        .reverse_expr()
1622        .unwrap();
1623
1624        let last_value_func = create_udwf_window_expr(
1625            &last_value_udwf(),
1626            &[Arc::clone(&col_a)],
1627            &schema,
1628            "last".to_string(),
1629            false,
1630        )?;
1631
1632        let window_exprs = vec![
1633            // LAST_VALUE(a)
1634            Arc::new(StandardWindowExpr::new(
1635                last_value_func,
1636                &[],
1637                &[],
1638                Arc::new(WindowFrame::new_bounds(
1639                    WindowFrameUnits::Rows,
1640                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1641                    WindowFrameBound::CurrentRow,
1642                )),
1643            )) as _,
1644            // NTH_VALUE(a, -1)
1645            Arc::new(StandardWindowExpr::new(
1646                nth_value_func1,
1647                &[],
1648                &[],
1649                Arc::new(WindowFrame::new_bounds(
1650                    WindowFrameUnits::Rows,
1651                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1652                    WindowFrameBound::CurrentRow,
1653                )),
1654            )) as _,
1655            // NTH_VALUE(a, -2)
1656            Arc::new(StandardWindowExpr::new(
1657                nth_value_func2,
1658                &[],
1659                &[],
1660                Arc::new(WindowFrame::new_bounds(
1661                    WindowFrameUnits::Rows,
1662                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1663                    WindowFrameBound::CurrentRow,
1664                )),
1665            )) as _,
1666        ];
1667        let physical_plan = BoundedWindowAggExec::try_new(
1668            window_exprs,
1669            memory_exec,
1670            InputOrderMode::Sorted,
1671            true,
1672        )
1673        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
1674
1675        let batches = collect(physical_plan.execute(0, task_ctx)?).await?;
1676
1677        let expected = vec![
1678            "BoundedWindowAggExec: wdw=[last: Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-1): Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-2): Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]",
1679            "  DataSourceExec: partitions=1, partition_sizes=[3]",
1680        ];
1681        // Get string representation of the plan
1682        let actual = get_plan_string(&physical_plan);
1683        assert_eq!(
1684            expected, actual,
1685            "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
1686        );
1687
1688        assert_snapshot!(batches_to_string(&batches), @r#"
1689            +---+------+---------------+---------------+
1690            | a | last | nth_value(-1) | nth_value(-2) |
1691            +---+------+---------------+---------------+
1692            | 1 | 1    | 1             |               |
1693            | 2 | 2    | 2             | 1             |
1694            | 3 | 3    | 3             | 2             |
1695            | 1 | 1    | 1             | 3             |
1696            | 2 | 2    | 2             | 1             |
1697            | 3 | 3    | 3             | 2             |
1698            | 1 | 1    | 1             | 3             |
1699            | 2 | 2    | 2             | 1             |
1700            | 3 | 3    | 3             | 2             |
1701            +---+------+---------------+---------------+
1702            "#);
1703        Ok(())
1704    }
1705
1706    // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec`
1707    // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode.
1708    // Input data generated at the source is
1709    //       "+----+------+",
1710    //       "| sn | hash |",
1711    //       "+----+------+",
1712    //       "| 0  | 2    |",
1713    //       "| 1  | 2    |",
1714    //       "| 2  | 2    |",
1715    //       "| 3  | 2    |",
1716    //       "| 4  | 1    |",
1717    //       "| 5  | 1    |",
1718    //       "| 6  | 1    |",
1719    //       "| 7  | 1    |",
1720    //       "| 8  | 0    |",
1721    //       "| 9  | 0    |",
1722    //       "+----+------+",
1723    //
1724    // Effectively following query is run on this data
1725    //
1726    //   SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)
1727    //   FROM test;
1728    //
1729    // partition `duplicated_hash=2` receives following data from the input
1730    //
1731    //       "+----+------+",
1732    //       "| sn | hash |",
1733    //       "+----+------+",
1734    //       "| 0  | 2    |",
1735    //       "| 1  | 2    |",
1736    //       "| 2  | 2    |",
1737    //       "| 3  | 2    |",
1738    //       "+----+------+",
1739    // normally `BoundedWindowExec` can only generate following result from the input above
1740    //
1741    //       "+----+------+---------+",
1742    //       "| sn | hash |  count  |",
1743    //       "+----+------+---------+",
1744    //       "| 0  | 2    |  2      |",
1745    //       "| 1  | 2    |  2      |",
1746    //       "| 2  | 2    |<not yet>|",
1747    //       "| 3  | 2    |<not yet>|",
1748    //       "+----+------+---------+",
1749    // where result of last 2 row is missing. Since window frame end is not may change with future data
1750    // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we
1751    // need to received sn=4 to make sure window frame end bound won't change with future data).
1752    //
1753    // With the ability of different partitions to use global ordering at the input (where most up-to date
1754    //   row is
1755    //      "| 9  | 0    |",
1756    //   )
1757    //
1758    // `BoundedWindowExec` should be able to generate following result in the test
1759    //
1760    //       "+----+------+-------+",
1761    //       "| sn | hash | col_2 |",
1762    //       "+----+------+-------+",
1763    //       "| 0  | 2    | 2     |",
1764    //       "| 1  | 2    | 2     |",
1765    //       "| 2  | 2    | 2     |",
1766    //       "| 3  | 2    | 1     |",
1767    //       "| 4  | 1    | 2     |",
1768    //       "| 5  | 1    | 2     |",
1769    //       "| 6  | 1    | 2     |",
1770    //       "| 7  | 1    | 1     |",
1771    //       "+----+------+-------+",
1772    //
1773    // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8
1774    //   we need to receive sn=10 value to calculate it result.).
1775    // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate
1776    // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal
1777    //    is sent to output from source). After, row:
1778    //
1779    //       "| 9  | 0    |",
1780    //
1781    // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the
1782    // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise).
1783    #[tokio::test]
1784    async fn bounded_window_exec_linear_mode_range_information() -> Result<()> {
1785        let n_rows = 10;
1786        let chunk_length = 2;
1787        let n_future_range = 1;
1788
1789        let timeout_duration = Duration::from_millis(2000);
1790
1791        let source =
1792            generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?;
1793
1794        let window =
1795            bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?;
1796
1797        let plan = projection_exec(window)?;
1798
1799        let expected_plan = vec![
1800            "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]",
1801            "  BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Linear]",
1802            "    StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]",
1803        ];
1804
1805        // Get string representation of the plan
1806        let actual = get_plan_string(&plan);
1807        assert_eq!(
1808            expected_plan, actual,
1809            "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_plan:#?}\nactual:\n\n{actual:#?}\n\n"
1810        );
1811
1812        let task_ctx = task_context();
1813        let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?;
1814
1815        assert_snapshot!(batches_to_string(&batches), @r#"
1816            +----+------+-------+
1817            | sn | hash | col_2 |
1818            +----+------+-------+
1819            | 0  | 2    | 2     |
1820            | 1  | 2    | 2     |
1821            | 2  | 2    | 2     |
1822            | 3  | 2    | 1     |
1823            | 4  | 1    | 2     |
1824            | 5  | 1    | 2     |
1825            | 6  | 1    | 2     |
1826            | 7  | 1    | 1     |
1827            +----+------+-------+
1828            "#);
1829
1830        Ok(())
1831    }
1832}