datafusion_physical_plan/windows/
bounded_window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19//! The executor given here uses bounded memory (does not maintain all
20//! the input data seen so far), which makes it appropriate when processing
21//! infinite inputs.
22
23use std::any::Any;
24use std::cmp::{min, Ordering};
25use std::collections::VecDeque;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use super::utils::create_schema;
31use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use crate::windows::{
33    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
34    window_equivalence_properties,
35};
36use crate::{
37    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
38    ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream,
39    SendableRecordBatchStream, Statistics, WindowExpr,
40};
41use ahash::RandomState;
42use arrow::compute::take_record_batch;
43use arrow::{
44    array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
45    compute::{concat, concat_batches, sort_to_indices, take_arrays},
46    datatypes::SchemaRef,
47    record_batch::RecordBatch,
48};
49use datafusion_common::hash_utils::create_hashes;
50use datafusion_common::stats::Precision;
51use datafusion_common::utils::{
52    evaluate_partition_ranges, get_at_indices, get_row_at_idx,
53};
54use datafusion_common::{
55    arrow_datafusion_err, exec_err, DataFusionError, HashMap, Result,
56};
57use datafusion_execution::TaskContext;
58use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
59use datafusion_expr::ColumnarValue;
60use datafusion_physical_expr::window::{
61    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
62};
63use datafusion_physical_expr::PhysicalExpr;
64use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
65
66use futures::stream::Stream;
67use futures::{ready, StreamExt};
68use hashbrown::hash_table::HashTable;
69use indexmap::IndexMap;
70use log::debug;
71
72/// Window execution plan
73#[derive(Debug, Clone)]
74pub struct BoundedWindowAggExec {
75    /// Input plan
76    input: Arc<dyn ExecutionPlan>,
77    /// Window function expression
78    window_expr: Vec<Arc<dyn WindowExpr>>,
79    /// Schema after the window is run
80    schema: SchemaRef,
81    /// Execution metrics
82    metrics: ExecutionPlanMetricsSet,
83    /// Describes how the input is ordered relative to the partition keys
84    pub input_order_mode: InputOrderMode,
85    /// Partition by indices that define ordering
86    // For example, if input ordering is ORDER BY a, b and window expression
87    // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0.
88    // Similarly, if window expression contains PARTITION BY a, b; then
89    // `ordered_partition_by_indices` would be 0, 1.
90    // See `get_ordered_partition_by_indices` for more details.
91    ordered_partition_by_indices: Vec<usize>,
92    /// Cache holding plan properties like equivalences, output partitioning etc.
93    cache: PlanProperties,
94    /// If `can_rerepartition` is false, partition_keys is always empty.
95    can_repartition: bool,
96}
97
98impl BoundedWindowAggExec {
99    /// Create a new execution plan for window aggregates
100    pub fn try_new(
101        window_expr: Vec<Arc<dyn WindowExpr>>,
102        input: Arc<dyn ExecutionPlan>,
103        input_order_mode: InputOrderMode,
104        can_repartition: bool,
105    ) -> Result<Self> {
106        let schema = create_schema(&input.schema(), &window_expr)?;
107        let schema = Arc::new(schema);
108        let partition_by_exprs = window_expr[0].partition_by();
109        let ordered_partition_by_indices = match &input_order_mode {
110            InputOrderMode::Sorted => {
111                let indices = get_ordered_partition_by_indices(
112                    window_expr[0].partition_by(),
113                    &input,
114                );
115                if indices.len() == partition_by_exprs.len() {
116                    indices
117                } else {
118                    (0..partition_by_exprs.len()).collect::<Vec<_>>()
119                }
120            }
121            InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(),
122            InputOrderMode::Linear => {
123                vec![]
124            }
125        };
126        let cache = Self::compute_properties(&input, &schema, &window_expr);
127        Ok(Self {
128            input,
129            window_expr,
130            schema,
131            metrics: ExecutionPlanMetricsSet::new(),
132            input_order_mode,
133            ordered_partition_by_indices,
134            cache,
135            can_repartition,
136        })
137    }
138
139    /// Window expressions
140    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
141        &self.window_expr
142    }
143
144    /// Input plan
145    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
146        &self.input
147    }
148
149    /// Return the output sort order of partition keys: For example
150    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
151    // We are sure that partition by columns are always at the beginning of sort_keys
152    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
153    // to calculate partition separation points
154    pub fn partition_by_sort_keys(&self) -> Result<LexOrdering> {
155        let partition_by = self.window_expr()[0].partition_by();
156        get_partition_by_sort_exprs(
157            &self.input,
158            partition_by,
159            &self.ordered_partition_by_indices,
160        )
161    }
162
163    /// Initializes the appropriate [`PartitionSearcher`] implementation from
164    /// the state.
165    fn get_search_algo(&self) -> Result<Box<dyn PartitionSearcher>> {
166        let partition_by_sort_keys = self.partition_by_sort_keys()?;
167        let ordered_partition_by_indices = self.ordered_partition_by_indices.clone();
168        let input_schema = self.input().schema();
169        Ok(match &self.input_order_mode {
170            InputOrderMode::Sorted => {
171                // In Sorted mode, all partition by columns should be ordered.
172                if self.window_expr()[0].partition_by().len()
173                    != ordered_partition_by_indices.len()
174                {
175                    return exec_err!("All partition by columns should have an ordering in Sorted mode.");
176                }
177                Box::new(SortedSearch {
178                    partition_by_sort_keys,
179                    ordered_partition_by_indices,
180                    input_schema,
181                })
182            }
183            InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => Box::new(
184                LinearSearch::new(ordered_partition_by_indices, input_schema),
185            ),
186        })
187    }
188
189    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
190    fn compute_properties(
191        input: &Arc<dyn ExecutionPlan>,
192        schema: &SchemaRef,
193        window_exprs: &[Arc<dyn WindowExpr>],
194    ) -> PlanProperties {
195        // Calculate equivalence properties:
196        let eq_properties = window_equivalence_properties(schema, input, window_exprs);
197
198        // As we can have repartitioning using the partition keys, this can
199        // be either one or more than one, depending on the presence of
200        // repartitioning.
201        let output_partitioning = input.output_partitioning().clone();
202
203        // Construct properties cache
204        PlanProperties::new(
205            eq_properties,
206            output_partitioning,
207            // TODO: Emission type and boundedness information can be enhanced here
208            input.pipeline_behavior(),
209            input.boundedness(),
210        )
211    }
212
213    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
214        if !self.can_repartition {
215            vec![]
216        } else {
217            let all_partition_keys = self
218                .window_expr()
219                .iter()
220                .map(|expr| expr.partition_by().to_vec())
221                .collect::<Vec<_>>();
222
223            all_partition_keys
224                .into_iter()
225                .min_by_key(|s| s.len())
226                .unwrap_or_else(Vec::new)
227        }
228    }
229
230    fn statistics_helper(&self, statistics: Statistics) -> Result<Statistics> {
231        let win_cols = self.window_expr.len();
232        let input_cols = self.input.schema().fields().len();
233        // TODO stats: some windowing function will maintain invariants such as min, max...
234        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
235        // copy stats of the input to the beginning of the schema.
236        column_statistics.extend(statistics.column_statistics);
237        for _ in 0..win_cols {
238            column_statistics.push(ColumnStatistics::new_unknown())
239        }
240        Ok(Statistics {
241            num_rows: statistics.num_rows,
242            column_statistics,
243            total_byte_size: Precision::Absent,
244        })
245    }
246}
247
248impl DisplayAs for BoundedWindowAggExec {
249    fn fmt_as(
250        &self,
251        t: DisplayFormatType,
252        f: &mut std::fmt::Formatter,
253    ) -> std::fmt::Result {
254        match t {
255            DisplayFormatType::Default | DisplayFormatType::Verbose => {
256                write!(f, "BoundedWindowAggExec: ")?;
257                let g: Vec<String> = self
258                    .window_expr
259                    .iter()
260                    .map(|e| {
261                        format!(
262                            "{}: {:?}, frame: {:?}",
263                            e.name().to_owned(),
264                            e.field(),
265                            e.get_window_frame()
266                        )
267                    })
268                    .collect();
269                let mode = &self.input_order_mode;
270                write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?;
271            }
272            DisplayFormatType::TreeRender => {
273                let g: Vec<String> = self
274                    .window_expr
275                    .iter()
276                    .map(|e| e.name().to_owned().to_string())
277                    .collect();
278                writeln!(f, "select_list={}", g.join(", "))?;
279
280                let mode = &self.input_order_mode;
281                writeln!(f, "mode={mode:?}")?;
282            }
283        }
284        Ok(())
285    }
286}
287
288impl ExecutionPlan for BoundedWindowAggExec {
289    fn name(&self) -> &'static str {
290        "BoundedWindowAggExec"
291    }
292
293    /// Return a reference to Any that can be used for downcasting
294    fn as_any(&self) -> &dyn Any {
295        self
296    }
297
298    fn properties(&self) -> &PlanProperties {
299        &self.cache
300    }
301
302    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
303        vec![&self.input]
304    }
305
306    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
307        let partition_bys = self.window_expr()[0].partition_by();
308        let order_keys = self.window_expr()[0].order_by();
309        let partition_bys = self
310            .ordered_partition_by_indices
311            .iter()
312            .map(|idx| &partition_bys[*idx]);
313        vec![calc_requirements(partition_bys, order_keys.iter())]
314    }
315
316    fn required_input_distribution(&self) -> Vec<Distribution> {
317        if self.partition_keys().is_empty() {
318            debug!("No partition defined for BoundedWindowAggExec!!!");
319            vec![Distribution::SinglePartition]
320        } else {
321            vec![Distribution::HashPartitioned(self.partition_keys().clone())]
322        }
323    }
324
325    fn maintains_input_order(&self) -> Vec<bool> {
326        vec![true]
327    }
328
329    fn with_new_children(
330        self: Arc<Self>,
331        children: Vec<Arc<dyn ExecutionPlan>>,
332    ) -> Result<Arc<dyn ExecutionPlan>> {
333        Ok(Arc::new(BoundedWindowAggExec::try_new(
334            self.window_expr.clone(),
335            Arc::clone(&children[0]),
336            self.input_order_mode.clone(),
337            self.can_repartition,
338        )?))
339    }
340
341    fn execute(
342        &self,
343        partition: usize,
344        context: Arc<TaskContext>,
345    ) -> Result<SendableRecordBatchStream> {
346        let input = self.input.execute(partition, context)?;
347        let search_mode = self.get_search_algo()?;
348        let stream = Box::pin(BoundedWindowAggStream::new(
349            Arc::clone(&self.schema),
350            self.window_expr.clone(),
351            input,
352            BaselineMetrics::new(&self.metrics, partition),
353            search_mode,
354        )?);
355        Ok(stream)
356    }
357
358    fn metrics(&self) -> Option<MetricsSet> {
359        Some(self.metrics.clone_inner())
360    }
361
362    fn statistics(&self) -> Result<Statistics> {
363        self.partition_statistics(None)
364    }
365
366    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
367        let input_stat = self.input.partition_statistics(partition)?;
368        self.statistics_helper(input_stat)
369    }
370}
371
372/// Trait that specifies how we search for (or calculate) partitions. It has two
373/// implementations: [`SortedSearch`] and [`LinearSearch`].
374trait PartitionSearcher: Send {
375    /// This method constructs output columns using the result of each window expression
376    /// (each entry in the output vector comes from a window expression).
377    /// Executor when producing output concatenates `input_buffer` (corresponding section), and
378    /// result of this function to generate output `RecordBatch`. `input_buffer` is used to determine
379    /// which sections of the window expression results should be used to generate output.
380    /// `partition_buffers` contains corresponding section of the `RecordBatch` for each partition.
381    /// `window_agg_states` stores per partition state for each window expression.
382    /// None case means that no result is generated
383    /// `Some(Vec<ArrayRef>)` is the result of each window expression.
384    fn calculate_out_columns(
385        &mut self,
386        input_buffer: &RecordBatch,
387        window_agg_states: &[PartitionWindowAggStates],
388        partition_buffers: &mut PartitionBatches,
389        window_expr: &[Arc<dyn WindowExpr>],
390    ) -> Result<Option<Vec<ArrayRef>>>;
391
392    /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not.
393    fn is_mode_linear(&self) -> bool {
394        false
395    }
396
397    // Constructs corresponding batches for each partition for the record_batch.
398    fn evaluate_partition_batches(
399        &mut self,
400        record_batch: &RecordBatch,
401        window_expr: &[Arc<dyn WindowExpr>],
402    ) -> Result<Vec<(PartitionKey, RecordBatch)>>;
403
404    /// Prunes the state.
405    fn prune(&mut self, _n_out: usize) {}
406
407    /// Marks the partition as done if we are sure that corresponding partition
408    /// cannot receive any more values.
409    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches);
410
411    /// Updates `input_buffer` and `partition_buffers` with the new `record_batch`.
412    fn update_partition_batch(
413        &mut self,
414        input_buffer: &mut RecordBatch,
415        record_batch: RecordBatch,
416        window_expr: &[Arc<dyn WindowExpr>],
417        partition_buffers: &mut PartitionBatches,
418    ) -> Result<()> {
419        if record_batch.num_rows() == 0 {
420            return Ok(());
421        }
422        let partition_batches =
423            self.evaluate_partition_batches(&record_batch, window_expr)?;
424        for (partition_row, partition_batch) in partition_batches {
425            let partition_batch_state = partition_buffers
426                .entry(partition_row)
427                // Use input_schema for the buffer schema, not `record_batch.schema()`
428                // as it may not have the "correct" schema in terms of output
429                // nullability constraints. For details, see the following issue:
430                // https://github.com/apache/datafusion/issues/9320
431                .or_insert_with(|| {
432                    PartitionBatchState::new(Arc::clone(self.input_schema()))
433                });
434            partition_batch_state.extend(&partition_batch)?;
435        }
436
437        if self.is_mode_linear() {
438            // In `Linear` mode, it is guaranteed that the first ORDER BY column
439            // is sorted across partitions. Note that only the first ORDER BY
440            // column is guaranteed to be ordered. As a counter example, consider
441            // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted
442            // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be
443            // `Linear`. However, we cannot guarantee that the last row of the
444            // input data will be the "last" data in terms of the ordering requirement
445            // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`.
446            // Hence, only column `a` should be used as a guarantee of the "last"
447            // data across partitions. For other modes (`Sorted`, `PartiallySorted`),
448            // we do not need to keep track of the most recent row guarantee across
449            // partitions. Since leading ordering separates partitions, guaranteed
450            // by the most recent row, already prune the previous partitions completely.
451            let last_row = get_last_row_batch(&record_batch)?;
452            for (_, partition_batch) in partition_buffers.iter_mut() {
453                partition_batch.set_most_recent_row(last_row.clone());
454            }
455        }
456        self.mark_partition_end(partition_buffers);
457
458        *input_buffer = if input_buffer.num_rows() == 0 {
459            record_batch
460        } else {
461            concat_batches(self.input_schema(), [input_buffer, &record_batch])?
462        };
463
464        Ok(())
465    }
466
467    fn input_schema(&self) -> &SchemaRef;
468}
469
470/// This object encapsulates the algorithm state for a simple linear scan
471/// algorithm for computing partitions.
472pub struct LinearSearch {
473    /// Keeps the hash of input buffer calculated from PARTITION BY columns.
474    /// Its length is equal to the `input_buffer` length.
475    input_buffer_hashes: VecDeque<u64>,
476    /// Used during hash value calculation.
477    random_state: RandomState,
478    /// Input ordering and partition by key ordering need not be the same, so
479    /// this vector stores the mapping between them. For instance, if the input
480    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
481    /// clause, this attribute stores [1, 0].
482    ordered_partition_by_indices: Vec<usize>,
483    /// We use this [`HashTable`] to calculate unique partitions for each new
484    /// RecordBatch. First entry in the tuple is the hash value, the second
485    /// entry is the unique ID for each partition (increments from 0 to n).
486    row_map_batch: HashTable<(u64, usize)>,
487    /// We use this [`HashTable`] to calculate the output columns that we can
488    /// produce at each cycle. First entry in the tuple is the hash value, the
489    /// second entry is the unique ID for each partition (increments from 0 to n).
490    /// The third entry stores how many new outputs are calculated for the
491    /// corresponding partition.
492    row_map_out: HashTable<(u64, usize, usize)>,
493    input_schema: SchemaRef,
494}
495
496impl PartitionSearcher for LinearSearch {
497    /// This method constructs output columns using the result of each window expression.
498    // Assume input buffer is         |      Partition Buffers would be (Where each partition and its data is separated)
499    // a, 2                           |      a, 2
500    // b, 2                           |      a, 2
501    // a, 2                           |      a, 2
502    // b, 2                           |
503    // a, 2                           |      b, 2
504    // b, 2                           |      b, 2
505    // b, 2                           |      b, 2
506    //                                |      b, 2
507    // Also assume we happen to calculate 2 new values for a, and 3 for b (To be calculate missing values we may need to consider future values).
508    // Partition buffers effectively will be
509    // a, 2, 1
510    // a, 2, 2
511    // a, 2, (missing)
512    //
513    // b, 2, 1
514    // b, 2, 2
515    // b, 2, 3
516    // b, 2, (missing)
517    // When partition buffers are mapped back to the original record batch. Result becomes
518    // a, 2, 1
519    // b, 2, 1
520    // a, 2, 2
521    // b, 2, 2
522    // a, 2, (missing)
523    // b, 2, 3
524    // b, 2, (missing)
525    // This function calculates the column result of window expression(s) (First 4 entry of 3rd column in the above section.)
526    // 1
527    // 1
528    // 2
529    // 2
530    // Above section corresponds to calculated result which can be emitted without breaking input buffer ordering.
531    fn calculate_out_columns(
532        &mut self,
533        input_buffer: &RecordBatch,
534        window_agg_states: &[PartitionWindowAggStates],
535        partition_buffers: &mut PartitionBatches,
536        window_expr: &[Arc<dyn WindowExpr>],
537    ) -> Result<Option<Vec<ArrayRef>>> {
538        let partition_output_indices = self.calc_partition_output_indices(
539            input_buffer,
540            window_agg_states,
541            window_expr,
542        )?;
543
544        let n_window_col = window_agg_states.len();
545        let mut new_columns = vec![vec![]; n_window_col];
546        // Size of all_indices can be at most input_buffer.num_rows():
547        let mut all_indices = UInt32Builder::with_capacity(input_buffer.num_rows());
548        for (row, indices) in partition_output_indices {
549            let length = indices.len();
550            for (idx, window_agg_state) in window_agg_states.iter().enumerate() {
551                let partition = &window_agg_state[&row];
552                let values = Arc::clone(&partition.state.out_col.slice(0, length));
553                new_columns[idx].push(values);
554            }
555            let partition_batch_state = &mut partition_buffers[&row];
556            // Store how many rows are generated for each partition
557            partition_batch_state.n_out_row = length;
558            // For each row keep corresponding index in the input record batch
559            all_indices.append_slice(&indices);
560        }
561        let all_indices = all_indices.finish();
562        if all_indices.is_empty() {
563            // We couldn't generate any new value, return early:
564            return Ok(None);
565        }
566
567        // Concatenate results for each column by converting `Vec<Vec<ArrayRef>>`
568        // to Vec<ArrayRef> where inner `Vec<ArrayRef>`s are converted to `ArrayRef`s.
569        let new_columns = new_columns
570            .iter()
571            .map(|items| {
572                concat(&items.iter().map(|e| e.as_ref()).collect::<Vec<_>>())
573                    .map_err(|e| arrow_datafusion_err!(e))
574            })
575            .collect::<Result<Vec<_>>>()?;
576        // We should emit columns according to row index ordering.
577        let sorted_indices = sort_to_indices(&all_indices, None, None)?;
578        // Construct new column according to row ordering. This fixes ordering
579        take_arrays(&new_columns, &sorted_indices, None)
580            .map(Some)
581            .map_err(|e| arrow_datafusion_err!(e))
582    }
583
584    fn evaluate_partition_batches(
585        &mut self,
586        record_batch: &RecordBatch,
587        window_expr: &[Arc<dyn WindowExpr>],
588    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
589        let partition_bys =
590            evaluate_partition_by_column_values(record_batch, window_expr)?;
591        // NOTE: In Linear or PartiallySorted modes, we are sure that
592        //       `partition_bys` are not empty.
593        // Calculate indices for each partition and construct a new record
594        // batch from the rows at these indices for each partition:
595        self.get_per_partition_indices(&partition_bys, record_batch)?
596            .into_iter()
597            .map(|(row, indices)| {
598                let mut new_indices = UInt32Builder::with_capacity(indices.len());
599                new_indices.append_slice(&indices);
600                let indices = new_indices.finish();
601                Ok((row, take_record_batch(record_batch, &indices)?))
602            })
603            .collect()
604    }
605
606    fn prune(&mut self, n_out: usize) {
607        // Delete hashes for the rows that are outputted.
608        self.input_buffer_hashes.drain(0..n_out);
609    }
610
611    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
612        // We should be in the `PartiallySorted` case, otherwise we can not
613        // tell when we are at the end of a given partition.
614        if !self.ordered_partition_by_indices.is_empty() {
615            if let Some((last_row, _)) = partition_buffers.last() {
616                let last_sorted_cols = self
617                    .ordered_partition_by_indices
618                    .iter()
619                    .map(|idx| last_row[*idx].clone())
620                    .collect::<Vec<_>>();
621                for (row, partition_batch_state) in partition_buffers.iter_mut() {
622                    let sorted_cols = self
623                        .ordered_partition_by_indices
624                        .iter()
625                        .map(|idx| &row[*idx]);
626                    // All the partitions other than `last_sorted_cols` are done.
627                    // We are sure that we will no longer receive values for these
628                    // partitions (arrival of a new value would violate ordering).
629                    partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols);
630                }
631            }
632        }
633    }
634
635    fn is_mode_linear(&self) -> bool {
636        self.ordered_partition_by_indices.is_empty()
637    }
638
639    fn input_schema(&self) -> &SchemaRef {
640        &self.input_schema
641    }
642}
643
644impl LinearSearch {
645    /// Initialize a new [`LinearSearch`] partition searcher.
646    fn new(ordered_partition_by_indices: Vec<usize>, input_schema: SchemaRef) -> Self {
647        LinearSearch {
648            input_buffer_hashes: VecDeque::new(),
649            random_state: Default::default(),
650            ordered_partition_by_indices,
651            row_map_batch: HashTable::with_capacity(256),
652            row_map_out: HashTable::with_capacity(256),
653            input_schema,
654        }
655    }
656
657    /// Calculate indices of each partition (according to PARTITION BY expression)
658    /// `columns` contain partition by expression results.
659    fn get_per_partition_indices(
660        &mut self,
661        columns: &[ArrayRef],
662        batch: &RecordBatch,
663    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
664        let mut batch_hashes = vec![0; batch.num_rows()];
665        create_hashes(columns, &self.random_state, &mut batch_hashes)?;
666        self.input_buffer_hashes.extend(&batch_hashes);
667        // reset row_map for new calculation
668        self.row_map_batch.clear();
669        // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition.
670        let mut result: Vec<(PartitionKey, Vec<u32>)> = vec![];
671        for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) {
672            let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| {
673                // We can safely get the first index of the partition indices
674                // since partition indices has one element during initialization.
675                let row = get_row_at_idx(columns, row_idx as usize).unwrap();
676                // Handle hash collusions with an equality check:
677                row.eq(&result[*group_idx].0)
678            });
679            if let Some((_, group_idx)) = entry {
680                result[*group_idx].1.push(row_idx)
681            } else {
682                self.row_map_batch.insert_unique(
683                    hash,
684                    (hash, result.len()),
685                    |(hash, _)| *hash,
686                );
687                let row = get_row_at_idx(columns, row_idx as usize)?;
688                // This is a new partition its only index is row_idx for now.
689                result.push((row, vec![row_idx]));
690            }
691        }
692        Ok(result)
693    }
694
695    /// Calculates partition keys and result indices for each partition.
696    /// The return value is a vector of tuples where the first entry stores
697    /// the partition key (unique for each partition) and the second entry
698    /// stores indices of the rows for which the partition is constructed.
699    fn calc_partition_output_indices(
700        &mut self,
701        input_buffer: &RecordBatch,
702        window_agg_states: &[PartitionWindowAggStates],
703        window_expr: &[Arc<dyn WindowExpr>],
704    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
705        let partition_by_columns =
706            evaluate_partition_by_column_values(input_buffer, window_expr)?;
707        // Reset the row_map state:
708        self.row_map_out.clear();
709        let mut partition_indices: Vec<(PartitionKey, Vec<u32>)> = vec![];
710        for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) {
711            let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| {
712                let row =
713                    get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap();
714                row == partition_indices[*group_idx].0
715            });
716            if let Some((_, group_idx, n_out)) = entry {
717                let (_, indices) = &mut partition_indices[*group_idx];
718                if indices.len() >= *n_out {
719                    break;
720                }
721                indices.push(row_idx);
722            } else {
723                let row = get_row_at_idx(&partition_by_columns, row_idx as usize)?;
724                let min_out = window_agg_states
725                    .iter()
726                    .map(|window_agg_state| {
727                        window_agg_state
728                            .get(&row)
729                            .map(|partition| partition.state.out_col.len())
730                            .unwrap_or(0)
731                    })
732                    .min()
733                    .unwrap_or(0);
734                if min_out == 0 {
735                    break;
736                }
737                self.row_map_out.insert_unique(
738                    *hash,
739                    (*hash, partition_indices.len(), min_out),
740                    |(hash, _, _)| *hash,
741                );
742                partition_indices.push((row, vec![row_idx]));
743            }
744        }
745        Ok(partition_indices)
746    }
747}
748
749/// This object encapsulates the algorithm state for sorted searching
750/// when computing partitions.
751pub struct SortedSearch {
752    /// Stores partition by columns and their ordering information
753    partition_by_sort_keys: LexOrdering,
754    /// Input ordering and partition by key ordering need not be the same, so
755    /// this vector stores the mapping between them. For instance, if the input
756    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
757    /// clause, this attribute stores [1, 0].
758    ordered_partition_by_indices: Vec<usize>,
759    input_schema: SchemaRef,
760}
761
762impl PartitionSearcher for SortedSearch {
763    /// This method constructs new output columns using the result of each window expression.
764    fn calculate_out_columns(
765        &mut self,
766        _input_buffer: &RecordBatch,
767        window_agg_states: &[PartitionWindowAggStates],
768        partition_buffers: &mut PartitionBatches,
769        _window_expr: &[Arc<dyn WindowExpr>],
770    ) -> Result<Option<Vec<ArrayRef>>> {
771        let n_out = self.calculate_n_out_row(window_agg_states, partition_buffers);
772        if n_out == 0 {
773            Ok(None)
774        } else {
775            window_agg_states
776                .iter()
777                .map(|map| get_aggregate_result_out_column(map, n_out).map(Some))
778                .collect()
779        }
780    }
781
782    fn evaluate_partition_batches(
783        &mut self,
784        record_batch: &RecordBatch,
785        _window_expr: &[Arc<dyn WindowExpr>],
786    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
787        let num_rows = record_batch.num_rows();
788        // Calculate result of partition by column expressions
789        let partition_columns = self
790            .partition_by_sort_keys
791            .iter()
792            .map(|elem| elem.evaluate_to_sort_column(record_batch))
793            .collect::<Result<Vec<_>>>()?;
794        // Reorder `partition_columns` such that its ordering matches input ordering.
795        let partition_columns_ordered =
796            get_at_indices(&partition_columns, &self.ordered_partition_by_indices)?;
797        let partition_points =
798            evaluate_partition_ranges(num_rows, &partition_columns_ordered)?;
799        let partition_bys = partition_columns
800            .into_iter()
801            .map(|arr| arr.values)
802            .collect::<Vec<ArrayRef>>();
803
804        partition_points
805            .iter()
806            .map(|range| {
807                let row = get_row_at_idx(&partition_bys, range.start)?;
808                let len = range.end - range.start;
809                let slice = record_batch.slice(range.start, len);
810                Ok((row, slice))
811            })
812            .collect::<Result<Vec<_>>>()
813    }
814
815    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
816        // In Sorted case. We can mark all partitions besides last partition as ended.
817        // We are sure that those partitions will never receive any values.
818        // (Otherwise ordering invariant is violated.)
819        let n_partitions = partition_buffers.len();
820        for (idx, (_, partition_batch_state)) in partition_buffers.iter_mut().enumerate()
821        {
822            partition_batch_state.is_end |= idx < n_partitions - 1;
823        }
824    }
825
826    fn input_schema(&self) -> &SchemaRef {
827        &self.input_schema
828    }
829}
830
831impl SortedSearch {
832    /// Calculates how many rows we can output.
833    fn calculate_n_out_row(
834        &mut self,
835        window_agg_states: &[PartitionWindowAggStates],
836        partition_buffers: &mut PartitionBatches,
837    ) -> usize {
838        // Different window aggregators may produce results at different rates.
839        // We produce the overall batch result only as fast as the slowest one.
840        let mut counts = vec![];
841        let out_col_counts = window_agg_states.iter().map(|window_agg_state| {
842            // Store how many elements are generated for the current
843            // window expression:
844            let mut cur_window_expr_out_result_len = 0;
845            // We iterate over `window_agg_state`, which is an IndexMap.
846            // Iterations follow the insertion order, hence we preserve
847            // sorting when partition columns are sorted.
848            let mut per_partition_out_results = HashMap::new();
849            for (row, WindowState { state, .. }) in window_agg_state.iter() {
850                cur_window_expr_out_result_len += state.out_col.len();
851                let count = per_partition_out_results.entry(row).or_insert(0);
852                if *count < state.out_col.len() {
853                    *count = state.out_col.len();
854                }
855                // If we do not generate all results for the current
856                // partition, we do not generate results for next
857                // partition --  otherwise we will lose input ordering.
858                if state.n_row_result_missing > 0 {
859                    break;
860                }
861            }
862            counts.push(per_partition_out_results);
863            cur_window_expr_out_result_len
864        });
865        argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
866            for (row, count) in counts.swap_remove(min_idx).into_iter() {
867                let partition_batch = &mut partition_buffers[row];
868                partition_batch.n_out_row = count;
869            }
870            minima
871        })
872    }
873}
874
875/// Calculates partition by expression results for each window expression
876/// on `record_batch`.
877fn evaluate_partition_by_column_values(
878    record_batch: &RecordBatch,
879    window_expr: &[Arc<dyn WindowExpr>],
880) -> Result<Vec<ArrayRef>> {
881    window_expr[0]
882        .partition_by()
883        .iter()
884        .map(|item| match item.evaluate(record_batch)? {
885            ColumnarValue::Array(array) => Ok(array),
886            ColumnarValue::Scalar(scalar) => {
887                scalar.to_array_of_size(record_batch.num_rows())
888            }
889        })
890        .collect()
891}
892
893/// Stream for the bounded window aggregation plan.
894pub struct BoundedWindowAggStream {
895    schema: SchemaRef,
896    input: SendableRecordBatchStream,
897    /// The record batch executor receives as input (i.e. the columns needed
898    /// while calculating aggregation results).
899    input_buffer: RecordBatch,
900    /// We separate `input_buffer` based on partitions (as
901    /// determined by PARTITION BY columns) and store them per partition
902    /// in `partition_batches`. We use this variable when calculating results
903    /// for each window expression. This enables us to use the same batch for
904    /// different window expressions without copying.
905    // Note that we could keep record batches for each window expression in
906    // `PartitionWindowAggStates`. However, this would use more memory (as
907    // many times as the number of window expressions).
908    partition_buffers: PartitionBatches,
909    /// An executor can run multiple window expressions if the PARTITION BY
910    /// and ORDER BY sections are same. We keep state of the each window
911    /// expression inside `window_agg_states`.
912    window_agg_states: Vec<PartitionWindowAggStates>,
913    finished: bool,
914    window_expr: Vec<Arc<dyn WindowExpr>>,
915    baseline_metrics: BaselineMetrics,
916    /// Search mode for partition columns. This determines the algorithm with
917    /// which we group each partition.
918    search_mode: Box<dyn PartitionSearcher>,
919}
920
921impl BoundedWindowAggStream {
922    /// Prunes sections of the state that are no longer needed when calculating
923    /// results (as determined by window frame boundaries and number of results generated).
924    // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to
925    // calculate window expression result (outside the window frame boundary) we retract first `n` elements
926    // from `self.partition_batches` in corresponding partition.
927    // For instance, if `n_out` number of rows are calculated, we can remove
928    // first `n_out` rows from `self.input_buffer`.
929    fn prune_state(&mut self, n_out: usize) -> Result<()> {
930        // Prune `self.window_agg_states`:
931        self.prune_out_columns();
932        // Prune `self.partition_batches`:
933        self.prune_partition_batches();
934        // Prune `self.input_buffer`:
935        self.prune_input_batch(n_out)?;
936        // Prune internal state of search algorithm.
937        self.search_mode.prune(n_out);
938        Ok(())
939    }
940}
941
942impl Stream for BoundedWindowAggStream {
943    type Item = Result<RecordBatch>;
944
945    fn poll_next(
946        mut self: Pin<&mut Self>,
947        cx: &mut Context<'_>,
948    ) -> Poll<Option<Self::Item>> {
949        let poll = self.poll_next_inner(cx);
950        self.baseline_metrics.record_poll(poll)
951    }
952}
953
954impl BoundedWindowAggStream {
955    /// Create a new BoundedWindowAggStream
956    fn new(
957        schema: SchemaRef,
958        window_expr: Vec<Arc<dyn WindowExpr>>,
959        input: SendableRecordBatchStream,
960        baseline_metrics: BaselineMetrics,
961        search_mode: Box<dyn PartitionSearcher>,
962    ) -> Result<Self> {
963        let state = window_expr.iter().map(|_| IndexMap::new()).collect();
964        let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
965        Ok(Self {
966            schema,
967            input,
968            input_buffer: empty_batch,
969            partition_buffers: IndexMap::new(),
970            window_agg_states: state,
971            finished: false,
972            window_expr,
973            baseline_metrics,
974            search_mode,
975        })
976    }
977
978    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
979        // calculate window cols
980        for (cur_window_expr, state) in
981            self.window_expr.iter().zip(&mut self.window_agg_states)
982        {
983            cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
984        }
985
986        let schema = Arc::clone(&self.schema);
987        let window_expr_out = self.search_mode.calculate_out_columns(
988            &self.input_buffer,
989            &self.window_agg_states,
990            &mut self.partition_buffers,
991            &self.window_expr,
992        )?;
993        if let Some(window_expr_out) = window_expr_out {
994            let n_out = window_expr_out[0].len();
995            // right append new columns to corresponding section in the original input buffer.
996            let columns_to_show = self
997                .input_buffer
998                .columns()
999                .iter()
1000                .map(|elem| elem.slice(0, n_out))
1001                .chain(window_expr_out)
1002                .collect::<Vec<_>>();
1003            let n_generated = columns_to_show[0].len();
1004            self.prune_state(n_generated)?;
1005            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
1006        } else {
1007            Ok(None)
1008        }
1009    }
1010
1011    #[inline]
1012    fn poll_next_inner(
1013        &mut self,
1014        cx: &mut Context<'_>,
1015    ) -> Poll<Option<Result<RecordBatch>>> {
1016        if self.finished {
1017            return Poll::Ready(None);
1018        }
1019
1020        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1021        match ready!(self.input.poll_next_unpin(cx)) {
1022            Some(Ok(batch)) => {
1023                // Start the timer for compute time within this operator. It will be
1024                // stopped when dropped.
1025                let _timer = elapsed_compute.timer();
1026
1027                self.search_mode.update_partition_batch(
1028                    &mut self.input_buffer,
1029                    batch,
1030                    &self.window_expr,
1031                    &mut self.partition_buffers,
1032                )?;
1033                if let Some(batch) = self.compute_aggregates()? {
1034                    return Poll::Ready(Some(Ok(batch)));
1035                }
1036                self.poll_next_inner(cx)
1037            }
1038            Some(Err(e)) => Poll::Ready(Some(Err(e))),
1039            None => {
1040                let _timer = elapsed_compute.timer();
1041
1042                self.finished = true;
1043                for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
1044                    partition_batch_state.is_end = true;
1045                }
1046                if let Some(batch) = self.compute_aggregates()? {
1047                    return Poll::Ready(Some(Ok(batch)));
1048                }
1049                Poll::Ready(None)
1050            }
1051        }
1052    }
1053
1054    /// Prunes the sections of the record batch (for each partition)
1055    /// that we no longer need to calculate the window function result.
1056    fn prune_partition_batches(&mut self) {
1057        // Remove partitions which we know already ended (is_end flag is true).
1058        // Since the retain method preserves insertion order, we still have
1059        // ordering in between partitions after removal.
1060        self.partition_buffers
1061            .retain(|_, partition_batch_state| !partition_batch_state.is_end);
1062
1063        // The data in `self.partition_batches` is used by all window expressions.
1064        // Therefore, when removing from `self.partition_batches`, we need to remove
1065        // from the earliest range boundary among all window expressions. Variable
1066        // `n_prune_each_partition` fill the earliest range boundary information for
1067        // each partition. This way, we can delete the no-longer-needed sections from
1068        // `self.partition_batches`.
1069        // For instance, if window frame one uses [10, 20] and window frame two uses
1070        // [5, 15]; we only prune the first 5 elements from the corresponding record
1071        // batch in `self.partition_batches`.
1072
1073        // Calculate how many elements to prune for each partition batch
1074        let mut n_prune_each_partition = HashMap::new();
1075        for window_agg_state in self.window_agg_states.iter_mut() {
1076            window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
1077            for (partition_row, WindowState { state: value, .. }) in window_agg_state {
1078                let n_prune =
1079                    min(value.window_frame_range.start, value.last_calculated_index);
1080                if let Some(current) = n_prune_each_partition.get_mut(partition_row) {
1081                    if n_prune < *current {
1082                        *current = n_prune;
1083                    }
1084                } else {
1085                    n_prune_each_partition.insert(partition_row.clone(), n_prune);
1086                }
1087            }
1088        }
1089
1090        // Retract no longer needed parts during window calculations from partition batch:
1091        for (partition_row, n_prune) in n_prune_each_partition.iter() {
1092            let pb_state = &mut self.partition_buffers[partition_row];
1093
1094            let batch = &pb_state.record_batch;
1095            pb_state.record_batch = batch.slice(*n_prune, batch.num_rows() - n_prune);
1096            pb_state.n_out_row = 0;
1097
1098            // Update state indices since we have pruned some rows from the beginning:
1099            for window_agg_state in self.window_agg_states.iter_mut() {
1100                window_agg_state[partition_row].state.prune_state(*n_prune);
1101            }
1102        }
1103    }
1104
1105    /// Prunes the section of the input batch whose aggregate results
1106    /// are calculated and emitted.
1107    fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
1108        // Prune first n_out rows from the input_buffer
1109        let n_to_keep = self.input_buffer.num_rows() - n_out;
1110        let batch_to_keep = self
1111            .input_buffer
1112            .columns()
1113            .iter()
1114            .map(|elem| elem.slice(n_out, n_to_keep))
1115            .collect::<Vec<_>>();
1116        self.input_buffer = RecordBatch::try_new_with_options(
1117            self.input_buffer.schema(),
1118            batch_to_keep,
1119            &RecordBatchOptions::new().with_row_count(Some(n_to_keep)),
1120        )?;
1121        Ok(())
1122    }
1123
1124    /// Prunes emitted parts from WindowAggState `out_col` field.
1125    fn prune_out_columns(&mut self) {
1126        // We store generated columns for each window expression in the `out_col`
1127        // field of `WindowAggState`. Given how many rows are emitted, we remove
1128        // these sections from state.
1129        for partition_window_agg_states in self.window_agg_states.iter_mut() {
1130            // Remove `n_out` entries from the `out_col` field of `WindowAggState`.
1131            // `n_out` is stored in `self.partition_buffers` for each partition.
1132            // If `is_end` is set, directly remove them; this shrinks the hash map.
1133            partition_window_agg_states
1134                .retain(|_, partition_batch_state| !partition_batch_state.state.is_end);
1135            for (
1136                partition_key,
1137                WindowState {
1138                    state: WindowAggState { out_col, .. },
1139                    ..
1140                },
1141            ) in partition_window_agg_states
1142            {
1143                let partition_batch = &mut self.partition_buffers[partition_key];
1144                let n_to_del = partition_batch.n_out_row;
1145                let n_to_keep = out_col.len() - n_to_del;
1146                *out_col = out_col.slice(n_to_del, n_to_keep);
1147            }
1148        }
1149    }
1150}
1151
1152impl RecordBatchStream for BoundedWindowAggStream {
1153    /// Get the schema
1154    fn schema(&self) -> SchemaRef {
1155        Arc::clone(&self.schema)
1156    }
1157}
1158
1159// Gets the index of minimum entry, returns None if empty.
1160fn argmin<T: PartialOrd>(data: impl Iterator<Item = T>) -> Option<(usize, T)> {
1161    data.enumerate()
1162        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1163}
1164
1165/// Calculates the section we can show results for expression
1166fn get_aggregate_result_out_column(
1167    partition_window_agg_states: &PartitionWindowAggStates,
1168    len_to_show: usize,
1169) -> Result<ArrayRef> {
1170    let mut result = None;
1171    let mut running_length = 0;
1172    // We assume that iteration order is according to insertion order
1173    for (
1174        _,
1175        WindowState {
1176            state: WindowAggState { out_col, .. },
1177            ..
1178        },
1179    ) in partition_window_agg_states
1180    {
1181        if running_length < len_to_show {
1182            let n_to_use = min(len_to_show - running_length, out_col.len());
1183            let slice_to_use = out_col.slice(0, n_to_use);
1184            result = Some(match result {
1185                Some(arr) => concat(&[&arr, &slice_to_use])?,
1186                None => slice_to_use,
1187            });
1188            running_length += n_to_use;
1189        } else {
1190            break;
1191        }
1192    }
1193    if running_length != len_to_show {
1194        return exec_err!(
1195            "Generated row number should be {len_to_show}, it is {running_length}"
1196        );
1197    }
1198    result
1199        .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string()))
1200}
1201
1202/// Constructs a batch from the last row of batch in the argument.
1203pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result<RecordBatch> {
1204    if batch.num_rows() == 0 {
1205        return exec_err!("Latest batch should have at least 1 row");
1206    }
1207    Ok(batch.slice(batch.num_rows() - 1, 1))
1208}
1209
1210#[cfg(test)]
1211mod tests {
1212    use std::pin::Pin;
1213    use std::sync::Arc;
1214    use std::task::{Context, Poll};
1215    use std::time::Duration;
1216
1217    use crate::common::collect;
1218    use crate::expressions::PhysicalSortExpr;
1219    use crate::projection::ProjectionExec;
1220    use crate::streaming::{PartitionStream, StreamingTableExec};
1221    use crate::test::TestMemoryExec;
1222    use crate::windows::{
1223        create_udwf_window_expr, create_window_expr, BoundedWindowAggExec, InputOrderMode,
1224    };
1225    use crate::{execute_stream, get_plan_string, ExecutionPlan};
1226
1227    use arrow::array::{
1228        builder::{Int64Builder, UInt64Builder},
1229        RecordBatch,
1230    };
1231    use arrow::compute::SortOptions;
1232    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1233    use datafusion_common::test_util::batches_to_string;
1234    use datafusion_common::{exec_datafusion_err, Result, ScalarValue};
1235    use datafusion_execution::config::SessionConfig;
1236    use datafusion_execution::{
1237        RecordBatchStream, SendableRecordBatchStream, TaskContext,
1238    };
1239    use datafusion_expr::{
1240        WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1241    };
1242    use datafusion_functions_aggregate::count::count_udaf;
1243    use datafusion_functions_window::nth_value::last_value_udwf;
1244    use datafusion_functions_window::nth_value::nth_value_udwf;
1245    use datafusion_physical_expr::expressions::{col, Column, Literal};
1246    use datafusion_physical_expr::window::StandardWindowExpr;
1247    use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
1248
1249    use futures::future::Shared;
1250    use futures::{pin_mut, ready, FutureExt, Stream, StreamExt};
1251    use insta::assert_snapshot;
1252    use itertools::Itertools;
1253    use tokio::time::timeout;
1254
1255    #[derive(Debug, Clone)]
1256    struct TestStreamPartition {
1257        schema: SchemaRef,
1258        batches: Vec<RecordBatch>,
1259        idx: usize,
1260        state: PolingState,
1261        sleep_duration: Duration,
1262        send_exit: bool,
1263    }
1264
1265    impl PartitionStream for TestStreamPartition {
1266        fn schema(&self) -> &SchemaRef {
1267            &self.schema
1268        }
1269
1270        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
1271            // We create an iterator from the record batches and map them into Ok values,
1272            // converting the iterator into a futures::stream::Stream
1273            Box::pin(self.clone())
1274        }
1275    }
1276
1277    impl Stream for TestStreamPartition {
1278        type Item = Result<RecordBatch>;
1279
1280        fn poll_next(
1281            mut self: Pin<&mut Self>,
1282            cx: &mut Context<'_>,
1283        ) -> Poll<Option<Self::Item>> {
1284            self.poll_next_inner(cx)
1285        }
1286    }
1287
1288    #[derive(Debug, Clone)]
1289    enum PolingState {
1290        Sleep(Shared<futures::future::BoxFuture<'static, ()>>),
1291        BatchReturn,
1292    }
1293
1294    impl TestStreamPartition {
1295        fn poll_next_inner(
1296            self: &mut Pin<&mut Self>,
1297            cx: &mut Context<'_>,
1298        ) -> Poll<Option<Result<RecordBatch>>> {
1299            loop {
1300                match &mut self.state {
1301                    PolingState::BatchReturn => {
1302                        // Wait for self.sleep_duration before sending any new data
1303                        let f = tokio::time::sleep(self.sleep_duration).boxed().shared();
1304                        self.state = PolingState::Sleep(f);
1305                        let input_batch = if let Some(batch) =
1306                            self.batches.clone().get(self.idx)
1307                        {
1308                            batch.clone()
1309                        } else if self.send_exit {
1310                            // Send None to signal end of data
1311                            return Poll::Ready(None);
1312                        } else {
1313                            // Go to sleep mode
1314                            let f =
1315                                tokio::time::sleep(self.sleep_duration).boxed().shared();
1316                            self.state = PolingState::Sleep(f);
1317                            continue;
1318                        };
1319                        self.idx += 1;
1320                        return Poll::Ready(Some(Ok(input_batch)));
1321                    }
1322                    PolingState::Sleep(future) => {
1323                        pin_mut!(future);
1324                        ready!(future.poll_unpin(cx));
1325                        self.state = PolingState::BatchReturn;
1326                    }
1327                }
1328            }
1329        }
1330    }
1331
1332    impl RecordBatchStream for TestStreamPartition {
1333        fn schema(&self) -> SchemaRef {
1334            Arc::clone(&self.schema)
1335        }
1336    }
1337
1338    fn bounded_window_exec_pb_latent_range(
1339        input: Arc<dyn ExecutionPlan>,
1340        n_future_range: usize,
1341        hash: &str,
1342        order_by: &str,
1343    ) -> Result<Arc<dyn ExecutionPlan>> {
1344        let schema = input.schema();
1345        let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
1346        let col_expr =
1347            Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1348        let args = vec![col_expr];
1349        let partitionby_exprs = vec![col(hash, &schema)?];
1350        let orderby_exprs = LexOrdering::new(vec![PhysicalSortExpr {
1351            expr: col(order_by, &schema)?,
1352            options: SortOptions::default(),
1353        }]);
1354        let window_frame = WindowFrame::new_bounds(
1355            WindowFrameUnits::Range,
1356            WindowFrameBound::CurrentRow,
1357            WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))),
1358        );
1359        let fn_name = format!(
1360            "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]"
1361        );
1362        let input_order_mode = InputOrderMode::Linear;
1363        Ok(Arc::new(BoundedWindowAggExec::try_new(
1364            vec![create_window_expr(
1365                &window_fn,
1366                fn_name,
1367                &args,
1368                &partitionby_exprs,
1369                orderby_exprs.as_ref(),
1370                Arc::new(window_frame),
1371                &input.schema(),
1372                false,
1373            )?],
1374            input,
1375            input_order_mode,
1376            true,
1377        )?))
1378    }
1379
1380    fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
1381        let schema = input.schema();
1382        let exprs = input
1383            .schema()
1384            .fields
1385            .iter()
1386            .enumerate()
1387            .map(|(idx, field)| {
1388                let name = if field.name().len() > 20 {
1389                    format!("col_{idx}")
1390                } else {
1391                    field.name().clone()
1392                };
1393                let expr = col(field.name(), &schema).unwrap();
1394                (expr, name)
1395            })
1396            .collect::<Vec<_>>();
1397        Ok(Arc::new(ProjectionExec::try_new(exprs, input)?))
1398    }
1399
1400    fn task_context_helper() -> TaskContext {
1401        let task_ctx = TaskContext::default();
1402        // Create session context with config
1403        let session_config = SessionConfig::new()
1404            .with_batch_size(1)
1405            .with_target_partitions(2)
1406            .with_round_robin_repartition(false);
1407        task_ctx.with_session_config(session_config)
1408    }
1409
1410    fn task_context() -> Arc<TaskContext> {
1411        Arc::new(task_context_helper())
1412    }
1413
1414    pub async fn collect_stream(
1415        mut stream: SendableRecordBatchStream,
1416        results: &mut Vec<RecordBatch>,
1417    ) -> Result<()> {
1418        while let Some(item) = stream.next().await {
1419            results.push(item?);
1420        }
1421        Ok(())
1422    }
1423
1424    /// Execute the [ExecutionPlan] and collect the results in memory
1425    pub async fn collect_with_timeout(
1426        plan: Arc<dyn ExecutionPlan>,
1427        context: Arc<TaskContext>,
1428        timeout_duration: Duration,
1429    ) -> Result<Vec<RecordBatch>> {
1430        let stream = execute_stream(plan, context)?;
1431        let mut results = vec![];
1432
1433        // Execute the asynchronous operation with a timeout
1434        if timeout(timeout_duration, collect_stream(stream, &mut results))
1435            .await
1436            .is_ok()
1437        {
1438            return Err(exec_datafusion_err!("shouldn't have completed"));
1439        };
1440
1441        Ok(results)
1442    }
1443
1444    /// Execute the [ExecutionPlan] and collect the results in memory
1445    #[allow(dead_code)]
1446    pub async fn collect_bonafide(
1447        plan: Arc<dyn ExecutionPlan>,
1448        context: Arc<TaskContext>,
1449    ) -> Result<Vec<RecordBatch>> {
1450        let stream = execute_stream(plan, context)?;
1451        let mut results = vec![];
1452
1453        collect_stream(stream, &mut results).await?;
1454
1455        Ok(results)
1456    }
1457
1458    fn test_schema() -> SchemaRef {
1459        Arc::new(Schema::new(vec![
1460            Field::new("sn", DataType::UInt64, true),
1461            Field::new("hash", DataType::Int64, true),
1462        ]))
1463    }
1464
1465    fn schema_orders(schema: &SchemaRef) -> Result<Vec<LexOrdering>> {
1466        let orderings = vec![LexOrdering::new(vec![PhysicalSortExpr {
1467            expr: col("sn", schema)?,
1468            options: SortOptions {
1469                descending: false,
1470                nulls_first: false,
1471            },
1472        }])];
1473        Ok(orderings)
1474    }
1475
1476    fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool {
1477        let res = lhs / rhs;
1478        res * rhs == lhs
1479    }
1480    fn generate_batches(
1481        schema: &SchemaRef,
1482        n_row: usize,
1483        n_chunk: usize,
1484    ) -> Result<Vec<RecordBatch>> {
1485        let mut batches = vec![];
1486        assert!(n_row > 0);
1487        assert!(n_chunk > 0);
1488        assert!(is_integer_division_safe(n_row, n_chunk));
1489        let hash_replicate = 4;
1490
1491        let chunks = (0..n_row)
1492            .chunks(n_chunk)
1493            .into_iter()
1494            .map(|elem| elem.into_iter().collect::<Vec<_>>())
1495            .collect::<Vec<_>>();
1496
1497        // Send 2 RecordBatches at the source
1498        for sn_values in chunks {
1499            let mut sn1_array = UInt64Builder::with_capacity(sn_values.len());
1500            let mut hash_array = Int64Builder::with_capacity(sn_values.len());
1501
1502            for sn in sn_values {
1503                sn1_array.append_value(sn as u64);
1504                let hash_value = (2 - (sn / hash_replicate)) as i64;
1505                hash_array.append_value(hash_value);
1506            }
1507
1508            let batch = RecordBatch::try_new(
1509                Arc::clone(schema),
1510                vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())],
1511            )?;
1512            batches.push(batch);
1513        }
1514        Ok(batches)
1515    }
1516
1517    fn generate_never_ending_source(
1518        n_rows: usize,
1519        chunk_length: usize,
1520        n_partition: usize,
1521        is_infinite: bool,
1522        send_exit: bool,
1523        per_batch_wait_duration_in_millis: u64,
1524    ) -> Result<Arc<dyn ExecutionPlan>> {
1525        assert!(n_partition > 0);
1526
1527        // We use same hash value in the table. This makes sure that
1528        // After hashing computation will continue in only in one of the output partitions
1529        // In this case, data flow should still continue
1530        let schema = test_schema();
1531        let orderings = schema_orders(&schema)?;
1532
1533        // Source waits per_batch_wait_duration_in_millis ms before sending other batch
1534        let per_batch_wait_duration =
1535            Duration::from_millis(per_batch_wait_duration_in_millis);
1536
1537        let batches = generate_batches(&schema, n_rows, chunk_length)?;
1538
1539        // Source has 2 partitions
1540        let partitions = vec![
1541            Arc::new(TestStreamPartition {
1542                schema: Arc::clone(&schema),
1543                batches,
1544                idx: 0,
1545                state: PolingState::BatchReturn,
1546                sleep_duration: per_batch_wait_duration,
1547                send_exit,
1548            }) as _;
1549            n_partition
1550        ];
1551        let source = Arc::new(StreamingTableExec::try_new(
1552            Arc::clone(&schema),
1553            partitions,
1554            None,
1555            orderings,
1556            is_infinite,
1557            None,
1558        )?) as _;
1559        Ok(source)
1560    }
1561
1562    // Tests NTH_VALUE(negative index) with memoize feature
1563    // To be able to trigger memoize feature for NTH_VALUE we need to
1564    // - feed BoundedWindowAggExec with batch stream data.
1565    // - Window frame should contain UNBOUNDED PRECEDING.
1566    // It hard to ensure these conditions are met, from the sql query.
1567    #[tokio::test]
1568    async fn test_window_nth_value_bounded_memoize() -> Result<()> {
1569        let config = SessionConfig::new().with_target_partitions(1);
1570        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
1571
1572        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1573        // Create a new batch of data to insert into the table
1574        let batch = RecordBatch::try_new(
1575            Arc::clone(&schema),
1576            vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
1577        )?;
1578
1579        let memory_exec = TestMemoryExec::try_new_exec(
1580            &[vec![batch.clone(), batch.clone(), batch.clone()]],
1581            Arc::clone(&schema),
1582            None,
1583        )?;
1584        let col_a = col("a", &schema)?;
1585        let nth_value_func1 = create_udwf_window_expr(
1586            &nth_value_udwf(),
1587            &[
1588                Arc::clone(&col_a),
1589                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1590            ],
1591            &schema,
1592            "nth_value(-1)".to_string(),
1593            false,
1594        )?
1595        .reverse_expr()
1596        .unwrap();
1597        let nth_value_func2 = create_udwf_window_expr(
1598            &nth_value_udwf(),
1599            &[
1600                Arc::clone(&col_a),
1601                Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
1602            ],
1603            &schema,
1604            "nth_value(-2)".to_string(),
1605            false,
1606        )?
1607        .reverse_expr()
1608        .unwrap();
1609
1610        let last_value_func = create_udwf_window_expr(
1611            &last_value_udwf(),
1612            &[Arc::clone(&col_a)],
1613            &schema,
1614            "last".to_string(),
1615            false,
1616        )?;
1617
1618        let window_exprs = vec![
1619            // LAST_VALUE(a)
1620            Arc::new(StandardWindowExpr::new(
1621                last_value_func,
1622                &[],
1623                &LexOrdering::default(),
1624                Arc::new(WindowFrame::new_bounds(
1625                    WindowFrameUnits::Rows,
1626                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1627                    WindowFrameBound::CurrentRow,
1628                )),
1629            )) as _,
1630            // NTH_VALUE(a, -1)
1631            Arc::new(StandardWindowExpr::new(
1632                nth_value_func1,
1633                &[],
1634                &LexOrdering::default(),
1635                Arc::new(WindowFrame::new_bounds(
1636                    WindowFrameUnits::Rows,
1637                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1638                    WindowFrameBound::CurrentRow,
1639                )),
1640            )) as _,
1641            // NTH_VALUE(a, -2)
1642            Arc::new(StandardWindowExpr::new(
1643                nth_value_func2,
1644                &[],
1645                &LexOrdering::default(),
1646                Arc::new(WindowFrame::new_bounds(
1647                    WindowFrameUnits::Rows,
1648                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1649                    WindowFrameBound::CurrentRow,
1650                )),
1651            )) as _,
1652        ];
1653        let physical_plan = BoundedWindowAggExec::try_new(
1654            window_exprs,
1655            memory_exec,
1656            InputOrderMode::Sorted,
1657            true,
1658        )
1659        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
1660
1661        let batches = collect(physical_plan.execute(0, task_ctx)?).await?;
1662
1663        let expected = vec![
1664            "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]",
1665            "  DataSourceExec: partitions=1, partition_sizes=[3]",
1666        ];
1667        // Get string representation of the plan
1668        let actual = get_plan_string(&physical_plan);
1669        assert_eq!(
1670            expected, actual,
1671            "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
1672        );
1673
1674        assert_snapshot!(batches_to_string(&batches), @r#"
1675            +---+------+---------------+---------------+
1676            | a | last | nth_value(-1) | nth_value(-2) |
1677            +---+------+---------------+---------------+
1678            | 1 | 1    | 1             |               |
1679            | 2 | 2    | 2             | 1             |
1680            | 3 | 3    | 3             | 2             |
1681            | 1 | 1    | 1             | 3             |
1682            | 2 | 2    | 2             | 1             |
1683            | 3 | 3    | 3             | 2             |
1684            | 1 | 1    | 1             | 3             |
1685            | 2 | 2    | 2             | 1             |
1686            | 3 | 3    | 3             | 2             |
1687            +---+------+---------------+---------------+
1688            "#);
1689        Ok(())
1690    }
1691
1692    // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec`
1693    // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode.
1694    // Input data generated at the source is
1695    //       "+----+------+",
1696    //       "| sn | hash |",
1697    //       "+----+------+",
1698    //       "| 0  | 2    |",
1699    //       "| 1  | 2    |",
1700    //       "| 2  | 2    |",
1701    //       "| 3  | 2    |",
1702    //       "| 4  | 1    |",
1703    //       "| 5  | 1    |",
1704    //       "| 6  | 1    |",
1705    //       "| 7  | 1    |",
1706    //       "| 8  | 0    |",
1707    //       "| 9  | 0    |",
1708    //       "+----+------+",
1709    //
1710    // Effectively following query is run on this data
1711    //
1712    //   SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)
1713    //   FROM test;
1714    //
1715    // partition `duplicated_hash=2` receives following data from the input
1716    //
1717    //       "+----+------+",
1718    //       "| sn | hash |",
1719    //       "+----+------+",
1720    //       "| 0  | 2    |",
1721    //       "| 1  | 2    |",
1722    //       "| 2  | 2    |",
1723    //       "| 3  | 2    |",
1724    //       "+----+------+",
1725    // normally `BoundedWindowExec` can only generate following result from the input above
1726    //
1727    //       "+----+------+---------+",
1728    //       "| sn | hash |  count  |",
1729    //       "+----+------+---------+",
1730    //       "| 0  | 2    |  2      |",
1731    //       "| 1  | 2    |  2      |",
1732    //       "| 2  | 2    |<not yet>|",
1733    //       "| 3  | 2    |<not yet>|",
1734    //       "+----+------+---------+",
1735    // where result of last 2 row is missing. Since window frame end is not may change with future data
1736    // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we
1737    // need to received sn=4 to make sure window frame end bound won't change with future data).
1738    //
1739    // With the ability of different partitions to use global ordering at the input (where most up-to date
1740    //   row is
1741    //      "| 9  | 0    |",
1742    //   )
1743    //
1744    // `BoundedWindowExec` should be able to generate following result in the test
1745    //
1746    //       "+----+------+-------+",
1747    //       "| sn | hash | col_2 |",
1748    //       "+----+------+-------+",
1749    //       "| 0  | 2    | 2     |",
1750    //       "| 1  | 2    | 2     |",
1751    //       "| 2  | 2    | 2     |",
1752    //       "| 3  | 2    | 1     |",
1753    //       "| 4  | 1    | 2     |",
1754    //       "| 5  | 1    | 2     |",
1755    //       "| 6  | 1    | 2     |",
1756    //       "| 7  | 1    | 1     |",
1757    //       "+----+------+-------+",
1758    //
1759    // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8
1760    //   we need to receive sn=10 value to calculate it result.).
1761    // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate
1762    // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal
1763    //    is sent to output from source). After, row:
1764    //
1765    //       "| 9  | 0    |",
1766    //
1767    // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the
1768    // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise).
1769    #[tokio::test]
1770    async fn bounded_window_exec_linear_mode_range_information() -> Result<()> {
1771        let n_rows = 10;
1772        let chunk_length = 2;
1773        let n_future_range = 1;
1774
1775        let timeout_duration = Duration::from_millis(2000);
1776
1777        let source =
1778            generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?;
1779
1780        let window =
1781            bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?;
1782
1783        let plan = projection_exec(window)?;
1784
1785        let expected_plan = vec![
1786            "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]@2 as col_2]",
1787            "  BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]",
1788            "    StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]",
1789        ];
1790
1791        // Get string representation of the plan
1792        let actual = get_plan_string(&plan);
1793        assert_eq!(
1794            expected_plan, actual,
1795            "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_plan:#?}\nactual:\n\n{actual:#?}\n\n"
1796        );
1797
1798        let task_ctx = task_context();
1799        let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?;
1800
1801        assert_snapshot!(batches_to_string(&batches), @r#"
1802            +----+------+-------+
1803            | sn | hash | col_2 |
1804            +----+------+-------+
1805            | 0  | 2    | 2     |
1806            | 1  | 2    | 2     |
1807            | 2  | 2    | 2     |
1808            | 3  | 2    | 1     |
1809            | 4  | 1    | 2     |
1810            | 5  | 1    | 2     |
1811            | 6  | 1    | 2     |
1812            | 7  | 1    | 1     |
1813            +----+------+-------+
1814            "#);
1815
1816        Ok(())
1817    }
1818}