Skip to main content

datafusion_physical_plan/windows/
bounded_window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19//! The executor given here uses bounded memory (does not maintain all
20//! the input data seen so far), which makes it appropriate when processing
21//! infinite inputs.
22
23use std::any::Any;
24use std::cmp::{Ordering, min};
25use std::collections::VecDeque;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use super::utils::create_schema;
31use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use crate::windows::{
33    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
34    window_equivalence_properties,
35};
36use crate::{
37    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
38    ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream,
39    SendableRecordBatchStream, Statistics, WindowExpr,
40};
41
42use arrow::compute::take_record_batch;
43use arrow::{
44    array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
45    compute::{concat, concat_batches, sort_to_indices, take_arrays},
46    datatypes::SchemaRef,
47    record_batch::RecordBatch,
48};
49use datafusion_common::hash_utils::create_hashes;
50use datafusion_common::stats::Precision;
51use datafusion_common::utils::{
52    evaluate_partition_ranges, get_at_indices, get_row_at_idx,
53};
54use datafusion_common::{
55    HashMap, Result, arrow_datafusion_err, exec_datafusion_err, exec_err,
56};
57use datafusion_execution::TaskContext;
58use datafusion_expr::ColumnarValue;
59use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
60use datafusion_physical_expr::window::{
61    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
62};
63use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
64use datafusion_physical_expr_common::sort_expr::{
65    OrderingRequirements, PhysicalSortExpr,
66};
67
68use ahash::RandomState;
69use futures::stream::Stream;
70use futures::{StreamExt, ready};
71use hashbrown::hash_table::HashTable;
72use indexmap::IndexMap;
73use log::debug;
74
75/// Window execution plan
76#[derive(Debug, Clone)]
77pub struct BoundedWindowAggExec {
78    /// Input plan
79    input: Arc<dyn ExecutionPlan>,
80    /// Window function expression
81    window_expr: Vec<Arc<dyn WindowExpr>>,
82    /// Schema after the window is run
83    schema: SchemaRef,
84    /// Execution metrics
85    metrics: ExecutionPlanMetricsSet,
86    /// Describes how the input is ordered relative to the partition keys
87    pub input_order_mode: InputOrderMode,
88    /// Partition by indices that define ordering
89    // For example, if input ordering is ORDER BY a, b and window expression
90    // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0.
91    // Similarly, if window expression contains PARTITION BY a, b; then
92    // `ordered_partition_by_indices` would be 0, 1.
93    // See `get_ordered_partition_by_indices` for more details.
94    ordered_partition_by_indices: Vec<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97    /// If `can_rerepartition` is false, partition_keys is always empty.
98    can_repartition: bool,
99}
100
101impl BoundedWindowAggExec {
102    /// Create a new execution plan for window aggregates
103    pub fn try_new(
104        window_expr: Vec<Arc<dyn WindowExpr>>,
105        input: Arc<dyn ExecutionPlan>,
106        input_order_mode: InputOrderMode,
107        can_repartition: bool,
108    ) -> Result<Self> {
109        let schema = create_schema(&input.schema(), &window_expr)?;
110        let schema = Arc::new(schema);
111        let partition_by_exprs = window_expr[0].partition_by();
112        let ordered_partition_by_indices = match &input_order_mode {
113            InputOrderMode::Sorted => {
114                let indices = get_ordered_partition_by_indices(
115                    window_expr[0].partition_by(),
116                    &input,
117                )?;
118                if indices.len() == partition_by_exprs.len() {
119                    indices
120                } else {
121                    (0..partition_by_exprs.len()).collect::<Vec<_>>()
122                }
123            }
124            InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(),
125            InputOrderMode::Linear => {
126                vec![]
127            }
128        };
129        let cache = Self::compute_properties(&input, &schema, &window_expr)?;
130        Ok(Self {
131            input,
132            window_expr,
133            schema,
134            metrics: ExecutionPlanMetricsSet::new(),
135            input_order_mode,
136            ordered_partition_by_indices,
137            cache,
138            can_repartition,
139        })
140    }
141
142    /// Window expressions
143    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
144        &self.window_expr
145    }
146
147    /// Input plan
148    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
149        &self.input
150    }
151
152    /// Return the output sort order of partition keys: For example
153    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
154    // We are sure that partition by columns are always at the beginning of sort_keys
155    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
156    // to calculate partition separation points
157    pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
158        let partition_by = self.window_expr()[0].partition_by();
159        get_partition_by_sort_exprs(
160            &self.input,
161            partition_by,
162            &self.ordered_partition_by_indices,
163        )
164    }
165
166    /// Initializes the appropriate [`PartitionSearcher`] implementation from
167    /// the state.
168    fn get_search_algo(&self) -> Result<Box<dyn PartitionSearcher>> {
169        let partition_by_sort_keys = self.partition_by_sort_keys()?;
170        let ordered_partition_by_indices = self.ordered_partition_by_indices.clone();
171        let input_schema = self.input().schema();
172        Ok(match &self.input_order_mode {
173            InputOrderMode::Sorted => {
174                // In Sorted mode, all partition by columns should be ordered.
175                if self.window_expr()[0].partition_by().len()
176                    != ordered_partition_by_indices.len()
177                {
178                    return exec_err!(
179                        "All partition by columns should have an ordering in Sorted mode."
180                    );
181                }
182                Box::new(SortedSearch {
183                    partition_by_sort_keys,
184                    ordered_partition_by_indices,
185                    input_schema,
186                })
187            }
188            InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => Box::new(
189                LinearSearch::new(ordered_partition_by_indices, input_schema),
190            ),
191        })
192    }
193
194    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
195    fn compute_properties(
196        input: &Arc<dyn ExecutionPlan>,
197        schema: &SchemaRef,
198        window_exprs: &[Arc<dyn WindowExpr>],
199    ) -> Result<PlanProperties> {
200        // Calculate equivalence properties:
201        let eq_properties = window_equivalence_properties(schema, input, window_exprs)?;
202
203        // As we can have repartitioning using the partition keys, this can
204        // be either one or more than one, depending on the presence of
205        // repartitioning.
206        let output_partitioning = input.output_partitioning().clone();
207
208        // Construct properties cache
209        Ok(PlanProperties::new(
210            eq_properties,
211            output_partitioning,
212            // TODO: Emission type and boundedness information can be enhanced here
213            input.pipeline_behavior(),
214            input.boundedness(),
215        ))
216    }
217
218    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
219        if !self.can_repartition {
220            vec![]
221        } else {
222            let all_partition_keys = self
223                .window_expr()
224                .iter()
225                .map(|expr| expr.partition_by().to_vec())
226                .collect::<Vec<_>>();
227
228            all_partition_keys
229                .into_iter()
230                .min_by_key(|s| s.len())
231                .unwrap_or_else(Vec::new)
232        }
233    }
234
235    fn statistics_helper(&self, statistics: Statistics) -> Result<Statistics> {
236        let win_cols = self.window_expr.len();
237        let input_cols = self.input.schema().fields().len();
238        // TODO stats: some windowing function will maintain invariants such as min, max...
239        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
240        // copy stats of the input to the beginning of the schema.
241        column_statistics.extend(statistics.column_statistics);
242        for _ in 0..win_cols {
243            column_statistics.push(ColumnStatistics::new_unknown())
244        }
245        Ok(Statistics {
246            num_rows: statistics.num_rows,
247            column_statistics,
248            total_byte_size: Precision::Absent,
249        })
250    }
251}
252
253impl DisplayAs for BoundedWindowAggExec {
254    fn fmt_as(
255        &self,
256        t: DisplayFormatType,
257        f: &mut std::fmt::Formatter,
258    ) -> std::fmt::Result {
259        match t {
260            DisplayFormatType::Default | DisplayFormatType::Verbose => {
261                write!(f, "BoundedWindowAggExec: ")?;
262                let g: Vec<String> = self
263                    .window_expr
264                    .iter()
265                    .map(|e| {
266                        let field = match e.field() {
267                            Ok(f) => f.to_string(),
268                            Err(e) => format!("{e:?}"),
269                        };
270                        format!(
271                            "{}: {}, frame: {}",
272                            e.name().to_owned(),
273                            field,
274                            e.get_window_frame()
275                        )
276                    })
277                    .collect();
278                let mode = &self.input_order_mode;
279                write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?;
280            }
281            DisplayFormatType::TreeRender => {
282                let g: Vec<String> = self
283                    .window_expr
284                    .iter()
285                    .map(|e| e.name().to_owned().to_string())
286                    .collect();
287                writeln!(f, "select_list={}", g.join(", "))?;
288
289                let mode = &self.input_order_mode;
290                writeln!(f, "mode={mode:?}")?;
291            }
292        }
293        Ok(())
294    }
295}
296
297impl ExecutionPlan for BoundedWindowAggExec {
298    fn name(&self) -> &'static str {
299        "BoundedWindowAggExec"
300    }
301
302    /// Return a reference to Any that can be used for downcasting
303    fn as_any(&self) -> &dyn Any {
304        self
305    }
306
307    fn properties(&self) -> &PlanProperties {
308        &self.cache
309    }
310
311    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
312        vec![&self.input]
313    }
314
315    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
316        let partition_bys = self.window_expr()[0].partition_by();
317        let order_keys = self.window_expr()[0].order_by();
318        let partition_bys = self
319            .ordered_partition_by_indices
320            .iter()
321            .map(|idx| &partition_bys[*idx]);
322        vec![calc_requirements(partition_bys, order_keys)]
323    }
324
325    fn required_input_distribution(&self) -> Vec<Distribution> {
326        if self.partition_keys().is_empty() {
327            debug!("No partition defined for BoundedWindowAggExec!!!");
328            vec![Distribution::SinglePartition]
329        } else {
330            vec![Distribution::HashPartitioned(self.partition_keys().clone())]
331        }
332    }
333
334    fn maintains_input_order(&self) -> Vec<bool> {
335        vec![true]
336    }
337
338    fn with_new_children(
339        self: Arc<Self>,
340        children: Vec<Arc<dyn ExecutionPlan>>,
341    ) -> Result<Arc<dyn ExecutionPlan>> {
342        Ok(Arc::new(BoundedWindowAggExec::try_new(
343            self.window_expr.clone(),
344            Arc::clone(&children[0]),
345            self.input_order_mode.clone(),
346            self.can_repartition,
347        )?))
348    }
349
350    fn execute(
351        &self,
352        partition: usize,
353        context: Arc<TaskContext>,
354    ) -> Result<SendableRecordBatchStream> {
355        let input = self.input.execute(partition, context)?;
356        let search_mode = self.get_search_algo()?;
357        let stream = Box::pin(BoundedWindowAggStream::new(
358            Arc::clone(&self.schema),
359            self.window_expr.clone(),
360            input,
361            BaselineMetrics::new(&self.metrics, partition),
362            search_mode,
363        )?);
364        Ok(stream)
365    }
366
367    fn metrics(&self) -> Option<MetricsSet> {
368        Some(self.metrics.clone_inner())
369    }
370
371    fn statistics(&self) -> Result<Statistics> {
372        self.partition_statistics(None)
373    }
374
375    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
376        let input_stat = self.input.partition_statistics(partition)?;
377        self.statistics_helper(input_stat)
378    }
379}
380
381/// Trait that specifies how we search for (or calculate) partitions. It has two
382/// implementations: [`SortedSearch`] and [`LinearSearch`].
383trait PartitionSearcher: Send {
384    /// This method constructs output columns using the result of each window expression
385    /// (each entry in the output vector comes from a window expression).
386    /// Executor when producing output concatenates `input_buffer` (corresponding section), and
387    /// result of this function to generate output `RecordBatch`. `input_buffer` is used to determine
388    /// which sections of the window expression results should be used to generate output.
389    /// `partition_buffers` contains corresponding section of the `RecordBatch` for each partition.
390    /// `window_agg_states` stores per partition state for each window expression.
391    /// None case means that no result is generated
392    /// `Some(Vec<ArrayRef>)` is the result of each window expression.
393    fn calculate_out_columns(
394        &mut self,
395        input_buffer: &RecordBatch,
396        window_agg_states: &[PartitionWindowAggStates],
397        partition_buffers: &mut PartitionBatches,
398        window_expr: &[Arc<dyn WindowExpr>],
399    ) -> Result<Option<Vec<ArrayRef>>>;
400
401    /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not.
402    fn is_mode_linear(&self) -> bool {
403        false
404    }
405
406    // Constructs corresponding batches for each partition for the record_batch.
407    fn evaluate_partition_batches(
408        &mut self,
409        record_batch: &RecordBatch,
410        window_expr: &[Arc<dyn WindowExpr>],
411    ) -> Result<Vec<(PartitionKey, RecordBatch)>>;
412
413    /// Prunes the state.
414    fn prune(&mut self, _n_out: usize) {}
415
416    /// Marks the partition as done if we are sure that corresponding partition
417    /// cannot receive any more values.
418    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches);
419
420    /// Updates `input_buffer` and `partition_buffers` with the new `record_batch`.
421    fn update_partition_batch(
422        &mut self,
423        input_buffer: &mut RecordBatch,
424        record_batch: RecordBatch,
425        window_expr: &[Arc<dyn WindowExpr>],
426        partition_buffers: &mut PartitionBatches,
427    ) -> Result<()> {
428        if record_batch.num_rows() == 0 {
429            return Ok(());
430        }
431        let partition_batches =
432            self.evaluate_partition_batches(&record_batch, window_expr)?;
433        for (partition_row, partition_batch) in partition_batches {
434            if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row)
435            {
436                partition_batch_state.extend(&partition_batch)?
437            } else {
438                let options = RecordBatchOptions::new()
439                    .with_row_count(Some(partition_batch.num_rows()));
440                // Use input_schema for the buffer schema, not `record_batch.schema()`
441                // as it may not have the "correct" schema in terms of output
442                // nullability constraints. For details, see the following issue:
443                // https://github.com/apache/datafusion/issues/9320
444                let partition_batch = RecordBatch::try_new_with_options(
445                    Arc::clone(self.input_schema()),
446                    partition_batch.columns().to_vec(),
447                    &options,
448                )?;
449                let partition_batch_state =
450                    PartitionBatchState::new_with_batch(partition_batch);
451                partition_buffers.insert(partition_row, partition_batch_state);
452            }
453        }
454
455        if self.is_mode_linear() {
456            // In `Linear` mode, it is guaranteed that the first ORDER BY column
457            // is sorted across partitions. Note that only the first ORDER BY
458            // column is guaranteed to be ordered. As a counter example, consider
459            // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted
460            // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be
461            // `Linear`. However, we cannot guarantee that the last row of the
462            // input data will be the "last" data in terms of the ordering requirement
463            // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`.
464            // Hence, only column `a` should be used as a guarantee of the "last"
465            // data across partitions. For other modes (`Sorted`, `PartiallySorted`),
466            // we do not need to keep track of the most recent row guarantee across
467            // partitions. Since leading ordering separates partitions, guaranteed
468            // by the most recent row, already prune the previous partitions completely.
469            let last_row = get_last_row_batch(&record_batch)?;
470            for (_, partition_batch) in partition_buffers.iter_mut() {
471                partition_batch.set_most_recent_row(last_row.clone());
472            }
473        }
474        self.mark_partition_end(partition_buffers);
475
476        *input_buffer = if input_buffer.num_rows() == 0 {
477            record_batch
478        } else {
479            concat_batches(self.input_schema(), [input_buffer, &record_batch])?
480        };
481
482        Ok(())
483    }
484
485    fn input_schema(&self) -> &SchemaRef;
486}
487
488/// This object encapsulates the algorithm state for a simple linear scan
489/// algorithm for computing partitions.
490pub struct LinearSearch {
491    /// Keeps the hash of input buffer calculated from PARTITION BY columns.
492    /// Its length is equal to the `input_buffer` length.
493    input_buffer_hashes: VecDeque<u64>,
494    /// Used during hash value calculation.
495    random_state: RandomState,
496    /// Input ordering and partition by key ordering need not be the same, so
497    /// this vector stores the mapping between them. For instance, if the input
498    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
499    /// clause, this attribute stores [1, 0].
500    ordered_partition_by_indices: Vec<usize>,
501    /// We use this [`HashTable`] to calculate unique partitions for each new
502    /// RecordBatch. First entry in the tuple is the hash value, the second
503    /// entry is the unique ID for each partition (increments from 0 to n).
504    row_map_batch: HashTable<(u64, usize)>,
505    /// We use this [`HashTable`] to calculate the output columns that we can
506    /// produce at each cycle. First entry in the tuple is the hash value, the
507    /// second entry is the unique ID for each partition (increments from 0 to n).
508    /// The third entry stores how many new outputs are calculated for the
509    /// corresponding partition.
510    row_map_out: HashTable<(u64, usize, usize)>,
511    input_schema: SchemaRef,
512}
513
514impl PartitionSearcher for LinearSearch {
515    /// This method constructs output columns using the result of each window expression.
516    // Assume input buffer is         |      Partition Buffers would be (Where each partition and its data is separated)
517    // a, 2                           |      a, 2
518    // b, 2                           |      a, 2
519    // a, 2                           |      a, 2
520    // b, 2                           |
521    // a, 2                           |      b, 2
522    // b, 2                           |      b, 2
523    // b, 2                           |      b, 2
524    //                                |      b, 2
525    // Also assume we happen to calculate 2 new values for a, and 3 for b (To be calculate missing values we may need to consider future values).
526    // Partition buffers effectively will be
527    // a, 2, 1
528    // a, 2, 2
529    // a, 2, (missing)
530    //
531    // b, 2, 1
532    // b, 2, 2
533    // b, 2, 3
534    // b, 2, (missing)
535    // When partition buffers are mapped back to the original record batch. Result becomes
536    // a, 2, 1
537    // b, 2, 1
538    // a, 2, 2
539    // b, 2, 2
540    // a, 2, (missing)
541    // b, 2, 3
542    // b, 2, (missing)
543    // This function calculates the column result of window expression(s) (First 4 entry of 3rd column in the above section.)
544    // 1
545    // 1
546    // 2
547    // 2
548    // Above section corresponds to calculated result which can be emitted without breaking input buffer ordering.
549    fn calculate_out_columns(
550        &mut self,
551        input_buffer: &RecordBatch,
552        window_agg_states: &[PartitionWindowAggStates],
553        partition_buffers: &mut PartitionBatches,
554        window_expr: &[Arc<dyn WindowExpr>],
555    ) -> Result<Option<Vec<ArrayRef>>> {
556        let partition_output_indices = self.calc_partition_output_indices(
557            input_buffer,
558            window_agg_states,
559            window_expr,
560        )?;
561
562        let n_window_col = window_agg_states.len();
563        let mut new_columns = vec![vec![]; n_window_col];
564        // Size of all_indices can be at most input_buffer.num_rows():
565        let mut all_indices = UInt32Builder::with_capacity(input_buffer.num_rows());
566        for (row, indices) in partition_output_indices {
567            let length = indices.len();
568            for (idx, window_agg_state) in window_agg_states.iter().enumerate() {
569                let partition = &window_agg_state[&row];
570                let values = Arc::clone(&partition.state.out_col.slice(0, length));
571                new_columns[idx].push(values);
572            }
573            let partition_batch_state = &mut partition_buffers[&row];
574            // Store how many rows are generated for each partition
575            partition_batch_state.n_out_row = length;
576            // For each row keep corresponding index in the input record batch
577            all_indices.append_slice(&indices);
578        }
579        let all_indices = all_indices.finish();
580        if all_indices.is_empty() {
581            // We couldn't generate any new value, return early:
582            return Ok(None);
583        }
584
585        // Concatenate results for each column by converting `Vec<Vec<ArrayRef>>`
586        // to Vec<ArrayRef> where inner `Vec<ArrayRef>`s are converted to `ArrayRef`s.
587        let new_columns = new_columns
588            .iter()
589            .map(|items| {
590                concat(&items.iter().map(|e| e.as_ref()).collect::<Vec<_>>())
591                    .map_err(|e| arrow_datafusion_err!(e))
592            })
593            .collect::<Result<Vec<_>>>()?;
594        // We should emit columns according to row index ordering.
595        let sorted_indices = sort_to_indices(&all_indices, None, None)?;
596        // Construct new column according to row ordering. This fixes ordering
597        take_arrays(&new_columns, &sorted_indices, None)
598            .map(Some)
599            .map_err(|e| arrow_datafusion_err!(e))
600    }
601
602    fn evaluate_partition_batches(
603        &mut self,
604        record_batch: &RecordBatch,
605        window_expr: &[Arc<dyn WindowExpr>],
606    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
607        let partition_bys =
608            evaluate_partition_by_column_values(record_batch, window_expr)?;
609        // NOTE: In Linear or PartiallySorted modes, we are sure that
610        //       `partition_bys` are not empty.
611        // Calculate indices for each partition and construct a new record
612        // batch from the rows at these indices for each partition:
613        self.get_per_partition_indices(&partition_bys, record_batch)?
614            .into_iter()
615            .map(|(row, indices)| {
616                let mut new_indices = UInt32Builder::with_capacity(indices.len());
617                new_indices.append_slice(&indices);
618                let indices = new_indices.finish();
619                Ok((row, take_record_batch(record_batch, &indices)?))
620            })
621            .collect()
622    }
623
624    fn prune(&mut self, n_out: usize) {
625        // Delete hashes for the rows that are outputted.
626        self.input_buffer_hashes.drain(0..n_out);
627    }
628
629    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
630        // We should be in the `PartiallySorted` case, otherwise we can not
631        // tell when we are at the end of a given partition.
632        if !self.ordered_partition_by_indices.is_empty()
633            && let Some((last_row, _)) = partition_buffers.last()
634        {
635            let last_sorted_cols = self
636                .ordered_partition_by_indices
637                .iter()
638                .map(|idx| last_row[*idx].clone())
639                .collect::<Vec<_>>();
640            for (row, partition_batch_state) in partition_buffers.iter_mut() {
641                let sorted_cols = self
642                    .ordered_partition_by_indices
643                    .iter()
644                    .map(|idx| &row[*idx]);
645                // All the partitions other than `last_sorted_cols` are done.
646                // We are sure that we will no longer receive values for these
647                // partitions (arrival of a new value would violate ordering).
648                partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols);
649            }
650        }
651    }
652
653    fn is_mode_linear(&self) -> bool {
654        self.ordered_partition_by_indices.is_empty()
655    }
656
657    fn input_schema(&self) -> &SchemaRef {
658        &self.input_schema
659    }
660}
661
662impl LinearSearch {
663    /// Initialize a new [`LinearSearch`] partition searcher.
664    fn new(ordered_partition_by_indices: Vec<usize>, input_schema: SchemaRef) -> Self {
665        LinearSearch {
666            input_buffer_hashes: VecDeque::new(),
667            random_state: Default::default(),
668            ordered_partition_by_indices,
669            row_map_batch: HashTable::with_capacity(256),
670            row_map_out: HashTable::with_capacity(256),
671            input_schema,
672        }
673    }
674
675    /// Calculate indices of each partition (according to PARTITION BY expression)
676    /// `columns` contain partition by expression results.
677    fn get_per_partition_indices(
678        &mut self,
679        columns: &[ArrayRef],
680        batch: &RecordBatch,
681    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
682        let mut batch_hashes = vec![0; batch.num_rows()];
683        create_hashes(columns, &self.random_state, &mut batch_hashes)?;
684        self.input_buffer_hashes.extend(&batch_hashes);
685        // reset row_map for new calculation
686        self.row_map_batch.clear();
687        // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition.
688        let mut result: Vec<(PartitionKey, Vec<u32>)> = vec![];
689        for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) {
690            let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| {
691                // We can safely get the first index of the partition indices
692                // since partition indices has one element during initialization.
693                let row = get_row_at_idx(columns, row_idx as usize).unwrap();
694                // Handle hash collusions with an equality check:
695                row.eq(&result[*group_idx].0)
696            });
697            if let Some((_, group_idx)) = entry {
698                result[*group_idx].1.push(row_idx)
699            } else {
700                self.row_map_batch.insert_unique(
701                    hash,
702                    (hash, result.len()),
703                    |(hash, _)| *hash,
704                );
705                let row = get_row_at_idx(columns, row_idx as usize)?;
706                // This is a new partition its only index is row_idx for now.
707                result.push((row, vec![row_idx]));
708            }
709        }
710        Ok(result)
711    }
712
713    /// Calculates partition keys and result indices for each partition.
714    /// The return value is a vector of tuples where the first entry stores
715    /// the partition key (unique for each partition) and the second entry
716    /// stores indices of the rows for which the partition is constructed.
717    fn calc_partition_output_indices(
718        &mut self,
719        input_buffer: &RecordBatch,
720        window_agg_states: &[PartitionWindowAggStates],
721        window_expr: &[Arc<dyn WindowExpr>],
722    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
723        let partition_by_columns =
724            evaluate_partition_by_column_values(input_buffer, window_expr)?;
725        // Reset the row_map state:
726        self.row_map_out.clear();
727        let mut partition_indices: Vec<(PartitionKey, Vec<u32>)> = vec![];
728        for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) {
729            let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| {
730                let row =
731                    get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap();
732                row == partition_indices[*group_idx].0
733            });
734            if let Some((_, group_idx, n_out)) = entry {
735                let (_, indices) = &mut partition_indices[*group_idx];
736                if indices.len() >= *n_out {
737                    break;
738                }
739                indices.push(row_idx);
740            } else {
741                let row = get_row_at_idx(&partition_by_columns, row_idx as usize)?;
742                let min_out = window_agg_states
743                    .iter()
744                    .map(|window_agg_state| {
745                        window_agg_state
746                            .get(&row)
747                            .map(|partition| partition.state.out_col.len())
748                            .unwrap_or(0)
749                    })
750                    .min()
751                    .unwrap_or(0);
752                if min_out == 0 {
753                    break;
754                }
755                self.row_map_out.insert_unique(
756                    *hash,
757                    (*hash, partition_indices.len(), min_out),
758                    |(hash, _, _)| *hash,
759                );
760                partition_indices.push((row, vec![row_idx]));
761            }
762        }
763        Ok(partition_indices)
764    }
765}
766
767/// This object encapsulates the algorithm state for sorted searching
768/// when computing partitions.
769pub struct SortedSearch {
770    /// Stores partition by columns and their ordering information
771    partition_by_sort_keys: Vec<PhysicalSortExpr>,
772    /// Input ordering and partition by key ordering need not be the same, so
773    /// this vector stores the mapping between them. For instance, if the input
774    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
775    /// clause, this attribute stores [1, 0].
776    ordered_partition_by_indices: Vec<usize>,
777    input_schema: SchemaRef,
778}
779
780impl PartitionSearcher for SortedSearch {
781    /// This method constructs new output columns using the result of each window expression.
782    fn calculate_out_columns(
783        &mut self,
784        _input_buffer: &RecordBatch,
785        window_agg_states: &[PartitionWindowAggStates],
786        partition_buffers: &mut PartitionBatches,
787        _window_expr: &[Arc<dyn WindowExpr>],
788    ) -> Result<Option<Vec<ArrayRef>>> {
789        let n_out = self.calculate_n_out_row(window_agg_states, partition_buffers);
790        if n_out == 0 {
791            Ok(None)
792        } else {
793            window_agg_states
794                .iter()
795                .map(|map| get_aggregate_result_out_column(map, n_out).map(Some))
796                .collect()
797        }
798    }
799
800    fn evaluate_partition_batches(
801        &mut self,
802        record_batch: &RecordBatch,
803        _window_expr: &[Arc<dyn WindowExpr>],
804    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
805        let num_rows = record_batch.num_rows();
806        // Calculate result of partition by column expressions
807        let partition_columns = self
808            .partition_by_sort_keys
809            .iter()
810            .map(|elem| elem.evaluate_to_sort_column(record_batch))
811            .collect::<Result<Vec<_>>>()?;
812        // Reorder `partition_columns` such that its ordering matches input ordering.
813        let partition_columns_ordered =
814            get_at_indices(&partition_columns, &self.ordered_partition_by_indices)?;
815        let partition_points =
816            evaluate_partition_ranges(num_rows, &partition_columns_ordered)?;
817        let partition_bys = partition_columns
818            .into_iter()
819            .map(|arr| arr.values)
820            .collect::<Vec<ArrayRef>>();
821
822        partition_points
823            .iter()
824            .map(|range| {
825                let row = get_row_at_idx(&partition_bys, range.start)?;
826                let len = range.end - range.start;
827                let slice = record_batch.slice(range.start, len);
828                Ok((row, slice))
829            })
830            .collect::<Result<Vec<_>>>()
831    }
832
833    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
834        // In Sorted case. We can mark all partitions besides last partition as ended.
835        // We are sure that those partitions will never receive any values.
836        // (Otherwise ordering invariant is violated.)
837        let n_partitions = partition_buffers.len();
838        for (idx, (_, partition_batch_state)) in partition_buffers.iter_mut().enumerate()
839        {
840            partition_batch_state.is_end |= idx < n_partitions - 1;
841        }
842    }
843
844    fn input_schema(&self) -> &SchemaRef {
845        &self.input_schema
846    }
847}
848
849impl SortedSearch {
850    /// Calculates how many rows we can output.
851    fn calculate_n_out_row(
852        &mut self,
853        window_agg_states: &[PartitionWindowAggStates],
854        partition_buffers: &mut PartitionBatches,
855    ) -> usize {
856        // Different window aggregators may produce results at different rates.
857        // We produce the overall batch result only as fast as the slowest one.
858        let mut counts = vec![];
859        let out_col_counts = window_agg_states.iter().map(|window_agg_state| {
860            // Store how many elements are generated for the current
861            // window expression:
862            let mut cur_window_expr_out_result_len = 0;
863            // We iterate over `window_agg_state`, which is an IndexMap.
864            // Iterations follow the insertion order, hence we preserve
865            // sorting when partition columns are sorted.
866            let mut per_partition_out_results = HashMap::new();
867            for (row, WindowState { state, .. }) in window_agg_state.iter() {
868                cur_window_expr_out_result_len += state.out_col.len();
869                let count = per_partition_out_results.entry(row).or_insert(0);
870                if *count < state.out_col.len() {
871                    *count = state.out_col.len();
872                }
873                // If we do not generate all results for the current
874                // partition, we do not generate results for next
875                // partition --  otherwise we will lose input ordering.
876                if state.n_row_result_missing > 0 {
877                    break;
878                }
879            }
880            counts.push(per_partition_out_results);
881            cur_window_expr_out_result_len
882        });
883        argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
884            let mut slowest_partition = counts.swap_remove(min_idx);
885            for (partition_key, partition_batch) in partition_buffers.iter_mut() {
886                if let Some(count) = slowest_partition.remove(partition_key) {
887                    partition_batch.n_out_row = count;
888                }
889            }
890            minima
891        })
892    }
893}
894
895/// Calculates partition by expression results for each window expression
896/// on `record_batch`.
897fn evaluate_partition_by_column_values(
898    record_batch: &RecordBatch,
899    window_expr: &[Arc<dyn WindowExpr>],
900) -> Result<Vec<ArrayRef>> {
901    window_expr[0]
902        .partition_by()
903        .iter()
904        .map(|item| match item.evaluate(record_batch)? {
905            ColumnarValue::Array(array) => Ok(array),
906            ColumnarValue::Scalar(scalar) => {
907                scalar.to_array_of_size(record_batch.num_rows())
908            }
909        })
910        .collect()
911}
912
913/// Stream for the bounded window aggregation plan.
914pub struct BoundedWindowAggStream {
915    schema: SchemaRef,
916    input: SendableRecordBatchStream,
917    /// The record batch executor receives as input (i.e. the columns needed
918    /// while calculating aggregation results).
919    input_buffer: RecordBatch,
920    /// We separate `input_buffer` based on partitions (as
921    /// determined by PARTITION BY columns) and store them per partition
922    /// in `partition_batches`. We use this variable when calculating results
923    /// for each window expression. This enables us to use the same batch for
924    /// different window expressions without copying.
925    // Note that we could keep record batches for each window expression in
926    // `PartitionWindowAggStates`. However, this would use more memory (as
927    // many times as the number of window expressions).
928    partition_buffers: PartitionBatches,
929    /// An executor can run multiple window expressions if the PARTITION BY
930    /// and ORDER BY sections are same. We keep state of the each window
931    /// expression inside `window_agg_states`.
932    window_agg_states: Vec<PartitionWindowAggStates>,
933    finished: bool,
934    window_expr: Vec<Arc<dyn WindowExpr>>,
935    baseline_metrics: BaselineMetrics,
936    /// Search mode for partition columns. This determines the algorithm with
937    /// which we group each partition.
938    search_mode: Box<dyn PartitionSearcher>,
939}
940
941impl BoundedWindowAggStream {
942    /// Prunes sections of the state that are no longer needed when calculating
943    /// results (as determined by window frame boundaries and number of results generated).
944    // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to
945    // calculate window expression result (outside the window frame boundary) we retract first `n` elements
946    // from `self.partition_batches` in corresponding partition.
947    // For instance, if `n_out` number of rows are calculated, we can remove
948    // first `n_out` rows from `self.input_buffer`.
949    fn prune_state(&mut self, n_out: usize) -> Result<()> {
950        // Prune `self.window_agg_states`:
951        self.prune_out_columns();
952        // Prune `self.partition_batches`:
953        self.prune_partition_batches();
954        // Prune `self.input_buffer`:
955        self.prune_input_batch(n_out)?;
956        // Prune internal state of search algorithm.
957        self.search_mode.prune(n_out);
958        Ok(())
959    }
960}
961
962impl Stream for BoundedWindowAggStream {
963    type Item = Result<RecordBatch>;
964
965    fn poll_next(
966        mut self: Pin<&mut Self>,
967        cx: &mut Context<'_>,
968    ) -> Poll<Option<Self::Item>> {
969        let poll = self.poll_next_inner(cx);
970        self.baseline_metrics.record_poll(poll)
971    }
972}
973
974impl BoundedWindowAggStream {
975    /// Create a new BoundedWindowAggStream
976    fn new(
977        schema: SchemaRef,
978        window_expr: Vec<Arc<dyn WindowExpr>>,
979        input: SendableRecordBatchStream,
980        baseline_metrics: BaselineMetrics,
981        search_mode: Box<dyn PartitionSearcher>,
982    ) -> Result<Self> {
983        let state = window_expr.iter().map(|_| IndexMap::new()).collect();
984        let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
985        Ok(Self {
986            schema,
987            input,
988            input_buffer: empty_batch,
989            partition_buffers: IndexMap::new(),
990            window_agg_states: state,
991            finished: false,
992            window_expr,
993            baseline_metrics,
994            search_mode,
995        })
996    }
997
998    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
999        // calculate window cols
1000        for (cur_window_expr, state) in
1001            self.window_expr.iter().zip(&mut self.window_agg_states)
1002        {
1003            cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
1004        }
1005
1006        let schema = Arc::clone(&self.schema);
1007        let window_expr_out = self.search_mode.calculate_out_columns(
1008            &self.input_buffer,
1009            &self.window_agg_states,
1010            &mut self.partition_buffers,
1011            &self.window_expr,
1012        )?;
1013        if let Some(window_expr_out) = window_expr_out {
1014            let n_out = window_expr_out[0].len();
1015            // right append new columns to corresponding section in the original input buffer.
1016            let columns_to_show = self
1017                .input_buffer
1018                .columns()
1019                .iter()
1020                .map(|elem| elem.slice(0, n_out))
1021                .chain(window_expr_out)
1022                .collect::<Vec<_>>();
1023            let n_generated = columns_to_show[0].len();
1024            self.prune_state(n_generated)?;
1025            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
1026        } else {
1027            Ok(None)
1028        }
1029    }
1030
1031    #[inline]
1032    fn poll_next_inner(
1033        &mut self,
1034        cx: &mut Context<'_>,
1035    ) -> Poll<Option<Result<RecordBatch>>> {
1036        if self.finished {
1037            return Poll::Ready(None);
1038        }
1039
1040        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1041        match ready!(self.input.poll_next_unpin(cx)) {
1042            Some(Ok(batch)) => {
1043                // Start the timer for compute time within this operator. It will be
1044                // stopped when dropped.
1045                let _timer = elapsed_compute.timer();
1046
1047                self.search_mode.update_partition_batch(
1048                    &mut self.input_buffer,
1049                    batch,
1050                    &self.window_expr,
1051                    &mut self.partition_buffers,
1052                )?;
1053                if let Some(batch) = self.compute_aggregates()? {
1054                    return Poll::Ready(Some(Ok(batch)));
1055                }
1056                self.poll_next_inner(cx)
1057            }
1058            Some(Err(e)) => Poll::Ready(Some(Err(e))),
1059            None => {
1060                let _timer = elapsed_compute.timer();
1061
1062                self.finished = true;
1063                for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
1064                    partition_batch_state.is_end = true;
1065                }
1066                if let Some(batch) = self.compute_aggregates()? {
1067                    return Poll::Ready(Some(Ok(batch)));
1068                }
1069                Poll::Ready(None)
1070            }
1071        }
1072    }
1073
1074    /// Prunes the sections of the record batch (for each partition)
1075    /// that we no longer need to calculate the window function result.
1076    fn prune_partition_batches(&mut self) {
1077        // Remove partitions which we know already ended (is_end flag is true).
1078        // Since the retain method preserves insertion order, we still have
1079        // ordering in between partitions after removal.
1080        self.partition_buffers
1081            .retain(|_, partition_batch_state| !partition_batch_state.is_end);
1082
1083        // The data in `self.partition_batches` is used by all window expressions.
1084        // Therefore, when removing from `self.partition_batches`, we need to remove
1085        // from the earliest range boundary among all window expressions. Variable
1086        // `n_prune_each_partition` fill the earliest range boundary information for
1087        // each partition. This way, we can delete the no-longer-needed sections from
1088        // `self.partition_batches`.
1089        // For instance, if window frame one uses [10, 20] and window frame two uses
1090        // [5, 15]; we only prune the first 5 elements from the corresponding record
1091        // batch in `self.partition_batches`.
1092
1093        // Calculate how many elements to prune for each partition batch
1094        let mut n_prune_each_partition = HashMap::new();
1095        for window_agg_state in self.window_agg_states.iter_mut() {
1096            window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
1097            for (partition_row, WindowState { state: value, .. }) in window_agg_state {
1098                let n_prune =
1099                    min(value.window_frame_range.start, value.last_calculated_index);
1100                if let Some(current) = n_prune_each_partition.get_mut(partition_row) {
1101                    if n_prune < *current {
1102                        *current = n_prune;
1103                    }
1104                } else {
1105                    n_prune_each_partition.insert(partition_row.clone(), n_prune);
1106                }
1107            }
1108        }
1109
1110        // Retract no longer needed parts during window calculations from partition batch:
1111        for (partition_row, n_prune) in n_prune_each_partition.iter() {
1112            let pb_state = &mut self.partition_buffers[partition_row];
1113
1114            let batch = &pb_state.record_batch;
1115            pb_state.record_batch = batch.slice(*n_prune, batch.num_rows() - n_prune);
1116            pb_state.n_out_row = 0;
1117
1118            // Update state indices since we have pruned some rows from the beginning:
1119            for window_agg_state in self.window_agg_states.iter_mut() {
1120                window_agg_state[partition_row].state.prune_state(*n_prune);
1121            }
1122        }
1123    }
1124
1125    /// Prunes the section of the input batch whose aggregate results
1126    /// are calculated and emitted.
1127    fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
1128        // Prune first n_out rows from the input_buffer
1129        let n_to_keep = self.input_buffer.num_rows() - n_out;
1130        let batch_to_keep = self
1131            .input_buffer
1132            .columns()
1133            .iter()
1134            .map(|elem| elem.slice(n_out, n_to_keep))
1135            .collect::<Vec<_>>();
1136        self.input_buffer = RecordBatch::try_new_with_options(
1137            self.input_buffer.schema(),
1138            batch_to_keep,
1139            &RecordBatchOptions::new().with_row_count(Some(n_to_keep)),
1140        )?;
1141        Ok(())
1142    }
1143
1144    /// Prunes emitted parts from WindowAggState `out_col` field.
1145    fn prune_out_columns(&mut self) {
1146        // We store generated columns for each window expression in the `out_col`
1147        // field of `WindowAggState`. Given how many rows are emitted, we remove
1148        // these sections from state.
1149        for partition_window_agg_states in self.window_agg_states.iter_mut() {
1150            // Remove `n_out` entries from the `out_col` field of `WindowAggState`.
1151            // `n_out` is stored in `self.partition_buffers` for each partition.
1152            // If `is_end` is set, directly remove them; this shrinks the hash map.
1153            partition_window_agg_states
1154                .retain(|_, partition_batch_state| !partition_batch_state.state.is_end);
1155            for (
1156                partition_key,
1157                WindowState {
1158                    state: WindowAggState { out_col, .. },
1159                    ..
1160                },
1161            ) in partition_window_agg_states
1162            {
1163                let partition_batch = &mut self.partition_buffers[partition_key];
1164                let n_to_del = partition_batch.n_out_row;
1165                let n_to_keep = out_col.len() - n_to_del;
1166                *out_col = out_col.slice(n_to_del, n_to_keep);
1167            }
1168        }
1169    }
1170}
1171
1172impl RecordBatchStream for BoundedWindowAggStream {
1173    /// Get the schema
1174    fn schema(&self) -> SchemaRef {
1175        Arc::clone(&self.schema)
1176    }
1177}
1178
1179// Gets the index of minimum entry, returns None if empty.
1180fn argmin<T: PartialOrd>(data: impl Iterator<Item = T>) -> Option<(usize, T)> {
1181    data.enumerate()
1182        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1183}
1184
1185/// Calculates the section we can show results for expression
1186fn get_aggregate_result_out_column(
1187    partition_window_agg_states: &PartitionWindowAggStates,
1188    len_to_show: usize,
1189) -> Result<ArrayRef> {
1190    let mut result = None;
1191    let mut running_length = 0;
1192    let mut batches_to_concat = vec![];
1193    // We assume that iteration order is according to insertion order
1194    for (
1195        _,
1196        WindowState {
1197            state: WindowAggState { out_col, .. },
1198            ..
1199        },
1200    ) in partition_window_agg_states
1201    {
1202        if running_length < len_to_show {
1203            let n_to_use = min(len_to_show - running_length, out_col.len());
1204            let slice_to_use = if n_to_use == out_col.len() {
1205                // avoid slice when the entire column is used
1206                Arc::clone(out_col)
1207            } else {
1208                out_col.slice(0, n_to_use)
1209            };
1210            batches_to_concat.push(slice_to_use);
1211            running_length += n_to_use;
1212        } else {
1213            break;
1214        }
1215    }
1216
1217    if !batches_to_concat.is_empty() {
1218        let array_refs: Vec<&dyn Array> =
1219            batches_to_concat.iter().map(|a| a.as_ref()).collect();
1220        result = Some(concat(&array_refs)?);
1221    }
1222
1223    if running_length != len_to_show {
1224        return exec_err!(
1225            "Generated row number should be {len_to_show}, it is {running_length}"
1226        );
1227    }
1228    result.ok_or_else(|| exec_datafusion_err!("Should contain something"))
1229}
1230
1231/// Constructs a batch from the last row of batch in the argument.
1232pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result<RecordBatch> {
1233    if batch.num_rows() == 0 {
1234        return exec_err!("Latest batch should have at least 1 row");
1235    }
1236    Ok(batch.slice(batch.num_rows() - 1, 1))
1237}
1238
1239#[cfg(test)]
1240mod tests {
1241    use std::pin::Pin;
1242    use std::sync::Arc;
1243    use std::task::{Context, Poll};
1244    use std::time::Duration;
1245
1246    use crate::common::collect;
1247    use crate::expressions::PhysicalSortExpr;
1248    use crate::projection::{ProjectionExec, ProjectionExpr};
1249    use crate::streaming::{PartitionStream, StreamingTableExec};
1250    use crate::test::TestMemoryExec;
1251    use crate::windows::{
1252        BoundedWindowAggExec, InputOrderMode, create_udwf_window_expr, create_window_expr,
1253    };
1254    use crate::{ExecutionPlan, displayable, execute_stream};
1255
1256    use arrow::array::{
1257        RecordBatch,
1258        builder::{Int64Builder, UInt64Builder},
1259    };
1260    use arrow::compute::SortOptions;
1261    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1262    use datafusion_common::test_util::batches_to_string;
1263    use datafusion_common::{Result, ScalarValue, exec_datafusion_err};
1264    use datafusion_execution::config::SessionConfig;
1265    use datafusion_execution::{
1266        RecordBatchStream, SendableRecordBatchStream, TaskContext,
1267    };
1268    use datafusion_expr::{
1269        WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1270    };
1271    use datafusion_functions_aggregate::count::count_udaf;
1272    use datafusion_functions_window::nth_value::last_value_udwf;
1273    use datafusion_functions_window::nth_value::nth_value_udwf;
1274    use datafusion_physical_expr::expressions::{Column, Literal, col};
1275    use datafusion_physical_expr::window::StandardWindowExpr;
1276    use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
1277
1278    use futures::future::Shared;
1279    use futures::{FutureExt, Stream, StreamExt, pin_mut, ready};
1280    use insta::assert_snapshot;
1281    use itertools::Itertools;
1282    use tokio::time::timeout;
1283
1284    #[derive(Debug, Clone)]
1285    struct TestStreamPartition {
1286        schema: SchemaRef,
1287        batches: Vec<RecordBatch>,
1288        idx: usize,
1289        state: PolingState,
1290        sleep_duration: Duration,
1291        send_exit: bool,
1292    }
1293
1294    impl PartitionStream for TestStreamPartition {
1295        fn schema(&self) -> &SchemaRef {
1296            &self.schema
1297        }
1298
1299        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
1300            // We create an iterator from the record batches and map them into Ok values,
1301            // converting the iterator into a futures::stream::Stream
1302            Box::pin(self.clone())
1303        }
1304    }
1305
1306    impl Stream for TestStreamPartition {
1307        type Item = Result<RecordBatch>;
1308
1309        fn poll_next(
1310            mut self: Pin<&mut Self>,
1311            cx: &mut Context<'_>,
1312        ) -> Poll<Option<Self::Item>> {
1313            self.poll_next_inner(cx)
1314        }
1315    }
1316
1317    #[derive(Debug, Clone)]
1318    enum PolingState {
1319        Sleep(Shared<futures::future::BoxFuture<'static, ()>>),
1320        BatchReturn,
1321    }
1322
1323    impl TestStreamPartition {
1324        fn poll_next_inner(
1325            self: &mut Pin<&mut Self>,
1326            cx: &mut Context<'_>,
1327        ) -> Poll<Option<Result<RecordBatch>>> {
1328            loop {
1329                match &mut self.state {
1330                    PolingState::BatchReturn => {
1331                        // Wait for self.sleep_duration before sending any new data
1332                        let f = tokio::time::sleep(self.sleep_duration).boxed().shared();
1333                        self.state = PolingState::Sleep(f);
1334                        let input_batch = if let Some(batch) =
1335                            self.batches.clone().get(self.idx)
1336                        {
1337                            batch.clone()
1338                        } else if self.send_exit {
1339                            // Send None to signal end of data
1340                            return Poll::Ready(None);
1341                        } else {
1342                            // Go to sleep mode
1343                            let f =
1344                                tokio::time::sleep(self.sleep_duration).boxed().shared();
1345                            self.state = PolingState::Sleep(f);
1346                            continue;
1347                        };
1348                        self.idx += 1;
1349                        return Poll::Ready(Some(Ok(input_batch)));
1350                    }
1351                    PolingState::Sleep(future) => {
1352                        pin_mut!(future);
1353                        ready!(future.poll_unpin(cx));
1354                        self.state = PolingState::BatchReturn;
1355                    }
1356                }
1357            }
1358        }
1359    }
1360
1361    impl RecordBatchStream for TestStreamPartition {
1362        fn schema(&self) -> SchemaRef {
1363            Arc::clone(&self.schema)
1364        }
1365    }
1366
1367    fn bounded_window_exec_pb_latent_range(
1368        input: Arc<dyn ExecutionPlan>,
1369        n_future_range: usize,
1370        hash: &str,
1371        order_by: &str,
1372    ) -> Result<Arc<dyn ExecutionPlan>> {
1373        let schema = input.schema();
1374        let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
1375        let col_expr =
1376            Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1377        let args = vec![col_expr];
1378        let partitionby_exprs = vec![col(hash, &schema)?];
1379        let orderby_exprs = vec![PhysicalSortExpr {
1380            expr: col(order_by, &schema)?,
1381            options: SortOptions::default(),
1382        }];
1383        let window_frame = WindowFrame::new_bounds(
1384            WindowFrameUnits::Range,
1385            WindowFrameBound::CurrentRow,
1386            WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))),
1387        );
1388        let fn_name = format!(
1389            "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]"
1390        );
1391        let input_order_mode = InputOrderMode::Linear;
1392        Ok(Arc::new(BoundedWindowAggExec::try_new(
1393            vec![create_window_expr(
1394                &window_fn,
1395                fn_name,
1396                &args,
1397                &partitionby_exprs,
1398                &orderby_exprs,
1399                Arc::new(window_frame),
1400                input.schema(),
1401                false,
1402                false,
1403                None,
1404            )?],
1405            input,
1406            input_order_mode,
1407            true,
1408        )?))
1409    }
1410
1411    fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
1412        let schema = input.schema();
1413        let exprs = input
1414            .schema()
1415            .fields
1416            .iter()
1417            .enumerate()
1418            .map(|(idx, field)| {
1419                let name = if field.name().len() > 20 {
1420                    format!("col_{idx}")
1421                } else {
1422                    field.name().clone()
1423                };
1424                let expr = col(field.name(), &schema).unwrap();
1425                (expr, name)
1426            })
1427            .collect::<Vec<_>>();
1428        let proj_exprs: Vec<ProjectionExpr> = exprs
1429            .into_iter()
1430            .map(|(expr, alias)| ProjectionExpr { expr, alias })
1431            .collect();
1432        Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?))
1433    }
1434
1435    fn task_context_helper() -> TaskContext {
1436        let task_ctx = TaskContext::default();
1437        // Create session context with config
1438        let session_config = SessionConfig::new()
1439            .with_batch_size(1)
1440            .with_target_partitions(2)
1441            .with_round_robin_repartition(false);
1442        task_ctx.with_session_config(session_config)
1443    }
1444
1445    fn task_context() -> Arc<TaskContext> {
1446        Arc::new(task_context_helper())
1447    }
1448
1449    pub async fn collect_stream(
1450        mut stream: SendableRecordBatchStream,
1451        results: &mut Vec<RecordBatch>,
1452    ) -> Result<()> {
1453        while let Some(item) = stream.next().await {
1454            results.push(item?);
1455        }
1456        Ok(())
1457    }
1458
1459    /// Execute the [ExecutionPlan] and collect the results in memory
1460    pub async fn collect_with_timeout(
1461        plan: Arc<dyn ExecutionPlan>,
1462        context: Arc<TaskContext>,
1463        timeout_duration: Duration,
1464    ) -> Result<Vec<RecordBatch>> {
1465        let stream = execute_stream(plan, context)?;
1466        let mut results = vec![];
1467
1468        // Execute the asynchronous operation with a timeout
1469        if timeout(timeout_duration, collect_stream(stream, &mut results))
1470            .await
1471            .is_ok()
1472        {
1473            return Err(exec_datafusion_err!("shouldn't have completed"));
1474        };
1475
1476        Ok(results)
1477    }
1478
1479    fn test_schema() -> SchemaRef {
1480        Arc::new(Schema::new(vec![
1481            Field::new("sn", DataType::UInt64, true),
1482            Field::new("hash", DataType::Int64, true),
1483        ]))
1484    }
1485
1486    fn schema_orders(schema: &SchemaRef) -> Result<Vec<LexOrdering>> {
1487        let orderings = vec![
1488            [PhysicalSortExpr {
1489                expr: col("sn", schema)?,
1490                options: SortOptions {
1491                    descending: false,
1492                    nulls_first: false,
1493                },
1494            }]
1495            .into(),
1496        ];
1497        Ok(orderings)
1498    }
1499
1500    fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool {
1501        let res = lhs / rhs;
1502        res * rhs == lhs
1503    }
1504    fn generate_batches(
1505        schema: &SchemaRef,
1506        n_row: usize,
1507        n_chunk: usize,
1508    ) -> Result<Vec<RecordBatch>> {
1509        let mut batches = vec![];
1510        assert!(n_row > 0);
1511        assert!(n_chunk > 0);
1512        assert!(is_integer_division_safe(n_row, n_chunk));
1513        let hash_replicate = 4;
1514
1515        let chunks = (0..n_row)
1516            .chunks(n_chunk)
1517            .into_iter()
1518            .map(|elem| elem.into_iter().collect::<Vec<_>>())
1519            .collect::<Vec<_>>();
1520
1521        // Send 2 RecordBatches at the source
1522        for sn_values in chunks {
1523            let mut sn1_array = UInt64Builder::with_capacity(sn_values.len());
1524            let mut hash_array = Int64Builder::with_capacity(sn_values.len());
1525
1526            for sn in sn_values {
1527                sn1_array.append_value(sn as u64);
1528                let hash_value = (2 - (sn / hash_replicate)) as i64;
1529                hash_array.append_value(hash_value);
1530            }
1531
1532            let batch = RecordBatch::try_new(
1533                Arc::clone(schema),
1534                vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())],
1535            )?;
1536            batches.push(batch);
1537        }
1538        Ok(batches)
1539    }
1540
1541    fn generate_never_ending_source(
1542        n_rows: usize,
1543        chunk_length: usize,
1544        n_partition: usize,
1545        is_infinite: bool,
1546        send_exit: bool,
1547        per_batch_wait_duration_in_millis: u64,
1548    ) -> Result<Arc<dyn ExecutionPlan>> {
1549        assert!(n_partition > 0);
1550
1551        // We use same hash value in the table. This makes sure that
1552        // After hashing computation will continue in only in one of the output partitions
1553        // In this case, data flow should still continue
1554        let schema = test_schema();
1555        let orderings = schema_orders(&schema)?;
1556
1557        // Source waits per_batch_wait_duration_in_millis ms before sending other batch
1558        let per_batch_wait_duration =
1559            Duration::from_millis(per_batch_wait_duration_in_millis);
1560
1561        let batches = generate_batches(&schema, n_rows, chunk_length)?;
1562
1563        // Source has 2 partitions
1564        let partitions = vec![
1565            Arc::new(TestStreamPartition {
1566                schema: Arc::clone(&schema),
1567                batches,
1568                idx: 0,
1569                state: PolingState::BatchReturn,
1570                sleep_duration: per_batch_wait_duration,
1571                send_exit,
1572            }) as _;
1573            n_partition
1574        ];
1575        let source = Arc::new(StreamingTableExec::try_new(
1576            Arc::clone(&schema),
1577            partitions,
1578            None,
1579            orderings,
1580            is_infinite,
1581            None,
1582        )?) as _;
1583        Ok(source)
1584    }
1585
1586    // Tests NTH_VALUE(negative index) with memoize feature
1587    // To be able to trigger memoize feature for NTH_VALUE we need to
1588    // - feed BoundedWindowAggExec with batch stream data.
1589    // - Window frame should contain UNBOUNDED PRECEDING.
1590    // It hard to ensure these conditions are met, from the sql query.
1591    #[tokio::test]
1592    async fn test_window_nth_value_bounded_memoize() -> Result<()> {
1593        let config = SessionConfig::new().with_target_partitions(1);
1594        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
1595
1596        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1597        // Create a new batch of data to insert into the table
1598        let batch = RecordBatch::try_new(
1599            Arc::clone(&schema),
1600            vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
1601        )?;
1602
1603        let memory_exec = TestMemoryExec::try_new_exec(
1604            &[vec![batch.clone(), batch.clone(), batch.clone()]],
1605            Arc::clone(&schema),
1606            None,
1607        )?;
1608        let col_a = col("a", &schema)?;
1609        let nth_value_func1 = create_udwf_window_expr(
1610            &nth_value_udwf(),
1611            &[
1612                Arc::clone(&col_a),
1613                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1614            ],
1615            &schema,
1616            "nth_value(-1)".to_string(),
1617            false,
1618        )?
1619        .reverse_expr()
1620        .unwrap();
1621        let nth_value_func2 = create_udwf_window_expr(
1622            &nth_value_udwf(),
1623            &[
1624                Arc::clone(&col_a),
1625                Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
1626            ],
1627            &schema,
1628            "nth_value(-2)".to_string(),
1629            false,
1630        )?
1631        .reverse_expr()
1632        .unwrap();
1633
1634        let last_value_func = create_udwf_window_expr(
1635            &last_value_udwf(),
1636            &[Arc::clone(&col_a)],
1637            &schema,
1638            "last".to_string(),
1639            false,
1640        )?;
1641
1642        let window_exprs = vec![
1643            // LAST_VALUE(a)
1644            Arc::new(StandardWindowExpr::new(
1645                last_value_func,
1646                &[],
1647                &[],
1648                Arc::new(WindowFrame::new_bounds(
1649                    WindowFrameUnits::Rows,
1650                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1651                    WindowFrameBound::CurrentRow,
1652                )),
1653            )) as _,
1654            // NTH_VALUE(a, -1)
1655            Arc::new(StandardWindowExpr::new(
1656                nth_value_func1,
1657                &[],
1658                &[],
1659                Arc::new(WindowFrame::new_bounds(
1660                    WindowFrameUnits::Rows,
1661                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1662                    WindowFrameBound::CurrentRow,
1663                )),
1664            )) as _,
1665            // NTH_VALUE(a, -2)
1666            Arc::new(StandardWindowExpr::new(
1667                nth_value_func2,
1668                &[],
1669                &[],
1670                Arc::new(WindowFrame::new_bounds(
1671                    WindowFrameUnits::Rows,
1672                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1673                    WindowFrameBound::CurrentRow,
1674                )),
1675            )) as _,
1676        ];
1677        let physical_plan = BoundedWindowAggExec::try_new(
1678            window_exprs,
1679            memory_exec,
1680            InputOrderMode::Sorted,
1681            true,
1682        )
1683        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
1684
1685        let batches = collect(physical_plan.execute(0, task_ctx)?).await?;
1686
1687        // Get string representation of the plan
1688        assert_snapshot!(displayable(physical_plan.as_ref()).indent(true), @r#"
1689        BoundedWindowAggExec: wdw=[last: Field { "last": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-1): Field { "nth_value(-1)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-2): Field { "nth_value(-2)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
1690          DataSourceExec: partitions=1, partition_sizes=[3]
1691        "#);
1692
1693        assert_snapshot!(batches_to_string(&batches), @r"
1694        +---+------+---------------+---------------+
1695        | a | last | nth_value(-1) | nth_value(-2) |
1696        +---+------+---------------+---------------+
1697        | 1 | 1    | 1             |               |
1698        | 2 | 2    | 2             | 1             |
1699        | 3 | 3    | 3             | 2             |
1700        | 1 | 1    | 1             | 3             |
1701        | 2 | 2    | 2             | 1             |
1702        | 3 | 3    | 3             | 2             |
1703        | 1 | 1    | 1             | 3             |
1704        | 2 | 2    | 2             | 1             |
1705        | 3 | 3    | 3             | 2             |
1706        +---+------+---------------+---------------+
1707        ");
1708        Ok(())
1709    }
1710
1711    // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec`
1712    // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode.
1713    // Input data generated at the source is
1714    //       "+----+------+",
1715    //       "| sn | hash |",
1716    //       "+----+------+",
1717    //       "| 0  | 2    |",
1718    //       "| 1  | 2    |",
1719    //       "| 2  | 2    |",
1720    //       "| 3  | 2    |",
1721    //       "| 4  | 1    |",
1722    //       "| 5  | 1    |",
1723    //       "| 6  | 1    |",
1724    //       "| 7  | 1    |",
1725    //       "| 8  | 0    |",
1726    //       "| 9  | 0    |",
1727    //       "+----+------+",
1728    //
1729    // Effectively following query is run on this data
1730    //
1731    //   SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)
1732    //   FROM test;
1733    //
1734    // partition `duplicated_hash=2` receives following data from the input
1735    //
1736    //       "+----+------+",
1737    //       "| sn | hash |",
1738    //       "+----+------+",
1739    //       "| 0  | 2    |",
1740    //       "| 1  | 2    |",
1741    //       "| 2  | 2    |",
1742    //       "| 3  | 2    |",
1743    //       "+----+------+",
1744    // normally `BoundedWindowExec` can only generate following result from the input above
1745    //
1746    //       "+----+------+---------+",
1747    //       "| sn | hash |  count  |",
1748    //       "+----+------+---------+",
1749    //       "| 0  | 2    |  2      |",
1750    //       "| 1  | 2    |  2      |",
1751    //       "| 2  | 2    |<not yet>|",
1752    //       "| 3  | 2    |<not yet>|",
1753    //       "+----+------+---------+",
1754    // where result of last 2 row is missing. Since window frame end is not may change with future data
1755    // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we
1756    // need to received sn=4 to make sure window frame end bound won't change with future data).
1757    //
1758    // With the ability of different partitions to use global ordering at the input (where most up-to date
1759    //   row is
1760    //      "| 9  | 0    |",
1761    //   )
1762    //
1763    // `BoundedWindowExec` should be able to generate following result in the test
1764    //
1765    //       "+----+------+-------+",
1766    //       "| sn | hash | col_2 |",
1767    //       "+----+------+-------+",
1768    //       "| 0  | 2    | 2     |",
1769    //       "| 1  | 2    | 2     |",
1770    //       "| 2  | 2    | 2     |",
1771    //       "| 3  | 2    | 1     |",
1772    //       "| 4  | 1    | 2     |",
1773    //       "| 5  | 1    | 2     |",
1774    //       "| 6  | 1    | 2     |",
1775    //       "| 7  | 1    | 1     |",
1776    //       "+----+------+-------+",
1777    //
1778    // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8
1779    //   we need to receive sn=10 value to calculate it result.).
1780    // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate
1781    // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal
1782    //    is sent to output from source). After, row:
1783    //
1784    //       "| 9  | 0    |",
1785    //
1786    // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the
1787    // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise).
1788    #[tokio::test]
1789    async fn bounded_window_exec_linear_mode_range_information() -> Result<()> {
1790        let n_rows = 10;
1791        let chunk_length = 2;
1792        let n_future_range = 1;
1793
1794        let timeout_duration = Duration::from_millis(2000);
1795
1796        let source =
1797            generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?;
1798
1799        let window =
1800            bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?;
1801
1802        let plan = projection_exec(window)?;
1803
1804        // Get string representation of the plan
1805        assert_snapshot!(displayable(plan.as_ref()).indent(true), @r#"
1806        ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]
1807          BoundedWindowAggExec: wdw=[count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Field { "count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]": Int64 }, frame: RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Linear]
1808            StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]
1809        "#);
1810
1811        let task_ctx = task_context();
1812        let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?;
1813
1814        assert_snapshot!(batches_to_string(&batches), @r"
1815        +----+------+-------+
1816        | sn | hash | col_2 |
1817        +----+------+-------+
1818        | 0  | 2    | 2     |
1819        | 1  | 2    | 2     |
1820        | 2  | 2    | 2     |
1821        | 3  | 2    | 1     |
1822        | 4  | 1    | 2     |
1823        | 5  | 1    | 2     |
1824        | 6  | 1    | 2     |
1825        | 7  | 1    | 1     |
1826        +----+------+-------+
1827        ");
1828
1829        Ok(())
1830    }
1831}