Skip to main content

datafusion_physical_plan/windows/
bounded_window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19//! The executor given here uses bounded memory (does not maintain all
20//! the input data seen so far), which makes it appropriate when processing
21//! infinite inputs.
22
23use std::cmp::{Ordering, min};
24use std::collections::VecDeque;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use super::utils::create_schema;
30use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use crate::stream::EmptyRecordBatchStream;
32use crate::windows::{
33    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
34    window_equivalence_properties,
35};
36use crate::{
37    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
38    ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream,
39    SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties,
40};
41
42use arrow::compute::take_record_batch;
43use arrow::{
44    array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
45    compute::{concat, concat_batches, sort_to_indices, take_arrays},
46    datatypes::SchemaRef,
47    record_batch::RecordBatch,
48};
49use datafusion_common::hash_utils::create_hashes;
50use datafusion_common::stats::Precision;
51use datafusion_common::utils::{
52    evaluate_partition_ranges, get_at_indices, get_row_at_idx,
53};
54use datafusion_common::{
55    HashMap, Result, arrow_datafusion_err, exec_datafusion_err, exec_err,
56};
57use datafusion_execution::TaskContext;
58use datafusion_expr::ColumnarValue;
59use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
60use datafusion_physical_expr::window::{
61    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
62};
63use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
64use datafusion_physical_expr_common::sort_expr::{
65    OrderingRequirements, PhysicalSortExpr,
66};
67
68use crate::execution_plan::CardinalityEffect;
69use datafusion_common::hash_utils::RandomState;
70use futures::stream::Stream;
71use futures::{StreamExt, ready};
72use hashbrown::hash_table::HashTable;
73use indexmap::IndexMap;
74use log::debug;
75
76/// Window execution plan
77#[derive(Debug, Clone)]
78pub struct BoundedWindowAggExec {
79    /// Input plan
80    input: Arc<dyn ExecutionPlan>,
81    /// Window function expression
82    window_expr: Vec<Arc<dyn WindowExpr>>,
83    /// Schema after the window is run
84    schema: SchemaRef,
85    /// Execution metrics
86    metrics: ExecutionPlanMetricsSet,
87    /// Describes how the input is ordered relative to the partition keys
88    pub input_order_mode: InputOrderMode,
89    /// Partition by indices that define ordering
90    // For example, if input ordering is ORDER BY a, b and window expression
91    // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0.
92    // Similarly, if window expression contains PARTITION BY a, b; then
93    // `ordered_partition_by_indices` would be 0, 1.
94    // See `get_ordered_partition_by_indices` for more details.
95    ordered_partition_by_indices: Vec<usize>,
96    /// Cache holding plan properties like equivalences, output partitioning etc.
97    cache: Arc<PlanProperties>,
98    /// If `can_rerepartition` is false, partition_keys is always empty.
99    can_repartition: bool,
100}
101
102impl BoundedWindowAggExec {
103    /// Create a new execution plan for window aggregates
104    pub fn try_new(
105        window_expr: Vec<Arc<dyn WindowExpr>>,
106        input: Arc<dyn ExecutionPlan>,
107        input_order_mode: InputOrderMode,
108        can_repartition: bool,
109    ) -> Result<Self> {
110        let schema = create_schema(&input.schema(), &window_expr)?;
111        let schema = Arc::new(schema);
112        let partition_by_exprs = window_expr[0].partition_by();
113        let ordered_partition_by_indices = match &input_order_mode {
114            InputOrderMode::Sorted => {
115                let indices = get_ordered_partition_by_indices(
116                    window_expr[0].partition_by(),
117                    &input,
118                )?;
119                if indices.len() == partition_by_exprs.len() {
120                    indices
121                } else {
122                    (0..partition_by_exprs.len()).collect::<Vec<_>>()
123                }
124            }
125            InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(),
126            InputOrderMode::Linear => {
127                vec![]
128            }
129        };
130        let cache = Self::compute_properties(&input, &schema, &window_expr)?;
131        Ok(Self {
132            input,
133            window_expr,
134            schema,
135            metrics: ExecutionPlanMetricsSet::new(),
136            input_order_mode,
137            ordered_partition_by_indices,
138            cache: Arc::new(cache),
139            can_repartition,
140        })
141    }
142
143    /// Window expressions
144    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
145        &self.window_expr
146    }
147
148    /// Input plan
149    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
150        &self.input
151    }
152
153    /// Return the output sort order of partition keys: For example
154    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
155    // We are sure that partition by columns are always at the beginning of sort_keys
156    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
157    // to calculate partition separation points
158    pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
159        let partition_by = self.window_expr()[0].partition_by();
160        get_partition_by_sort_exprs(
161            &self.input,
162            partition_by,
163            &self.ordered_partition_by_indices,
164        )
165    }
166
167    /// Initializes the appropriate [`PartitionSearcher`] implementation from
168    /// the state.
169    fn get_search_algo(&self) -> Result<Box<dyn PartitionSearcher>> {
170        let partition_by_sort_keys = self.partition_by_sort_keys()?;
171        let ordered_partition_by_indices = self.ordered_partition_by_indices.clone();
172        let input_schema = self.input().schema();
173        Ok(match &self.input_order_mode {
174            InputOrderMode::Sorted => {
175                // In Sorted mode, all partition by columns should be ordered.
176                if self.window_expr()[0].partition_by().len()
177                    != ordered_partition_by_indices.len()
178                {
179                    return exec_err!(
180                        "All partition by columns should have an ordering in Sorted mode."
181                    );
182                }
183                Box::new(SortedSearch {
184                    partition_by_sort_keys,
185                    ordered_partition_by_indices,
186                    input_schema,
187                })
188            }
189            InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => Box::new(
190                LinearSearch::new(ordered_partition_by_indices, input_schema),
191            ),
192        })
193    }
194
195    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
196    fn compute_properties(
197        input: &Arc<dyn ExecutionPlan>,
198        schema: &SchemaRef,
199        window_exprs: &[Arc<dyn WindowExpr>],
200    ) -> Result<PlanProperties> {
201        // Calculate equivalence properties:
202        let eq_properties = window_equivalence_properties(schema, input, window_exprs)?;
203
204        // As we can have repartitioning using the partition keys, this can
205        // be either one or more than one, depending on the presence of
206        // repartitioning.
207        let output_partitioning = input.output_partitioning().clone();
208
209        // Construct properties cache
210        Ok(PlanProperties::new(
211            eq_properties,
212            output_partitioning,
213            // TODO: Emission type and boundedness information can be enhanced here
214            input.pipeline_behavior(),
215            input.boundedness(),
216        ))
217    }
218
219    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
220        if !self.can_repartition {
221            vec![]
222        } else {
223            let all_partition_keys = self
224                .window_expr()
225                .iter()
226                .map(|expr| expr.partition_by().to_vec())
227                .collect::<Vec<_>>();
228
229            all_partition_keys
230                .into_iter()
231                .min_by_key(|s| s.len())
232                .unwrap_or_else(Vec::new)
233        }
234    }
235
236    fn statistics_helper(&self, statistics: Statistics) -> Result<Statistics> {
237        let win_cols = self.window_expr.len();
238        let input_cols = self.input.schema().fields().len();
239        // TODO stats: some windowing function will maintain invariants such as min, max...
240        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
241        // copy stats of the input to the beginning of the schema.
242        column_statistics.extend(statistics.column_statistics);
243        for _ in 0..win_cols {
244            column_statistics.push(ColumnStatistics::new_unknown())
245        }
246        Ok(Statistics {
247            num_rows: statistics.num_rows,
248            column_statistics,
249            total_byte_size: Precision::Absent,
250        })
251    }
252
253    fn with_new_children_and_same_properties(
254        &self,
255        mut children: Vec<Arc<dyn ExecutionPlan>>,
256    ) -> Self {
257        Self {
258            input: children.swap_remove(0),
259            metrics: ExecutionPlanMetricsSet::new(),
260            ..Self::clone(self)
261        }
262    }
263}
264
265impl DisplayAs for BoundedWindowAggExec {
266    fn fmt_as(
267        &self,
268        t: DisplayFormatType,
269        f: &mut std::fmt::Formatter,
270    ) -> std::fmt::Result {
271        match t {
272            DisplayFormatType::Default | DisplayFormatType::Verbose => {
273                write!(f, "BoundedWindowAggExec: ")?;
274                let g: Vec<String> = self
275                    .window_expr
276                    .iter()
277                    .map(|e| {
278                        let field = match e.field() {
279                            Ok(f) => f.to_string(),
280                            Err(e) => format!("{e:?}"),
281                        };
282                        format!(
283                            "{}: {}, frame: {}",
284                            e.name().to_owned(),
285                            field,
286                            e.get_window_frame()
287                        )
288                    })
289                    .collect();
290                let mode = &self.input_order_mode;
291                write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?;
292            }
293            DisplayFormatType::TreeRender => {
294                let g: Vec<String> = self
295                    .window_expr
296                    .iter()
297                    .map(|e| e.name().to_owned().to_string())
298                    .collect();
299                writeln!(f, "select_list={}", g.join(", "))?;
300
301                let mode = &self.input_order_mode;
302                writeln!(f, "mode={mode:?}")?;
303            }
304        }
305        Ok(())
306    }
307}
308
309impl ExecutionPlan for BoundedWindowAggExec {
310    fn name(&self) -> &'static str {
311        "BoundedWindowAggExec"
312    }
313
314    /// Return a reference to Any that can be used for downcasting
315    fn properties(&self) -> &Arc<PlanProperties> {
316        &self.cache
317    }
318
319    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
320        vec![&self.input]
321    }
322
323    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
324        let partition_bys = self.window_expr()[0].partition_by();
325        let order_keys = self.window_expr()[0].order_by();
326        let partition_bys = self
327            .ordered_partition_by_indices
328            .iter()
329            .map(|idx| &partition_bys[*idx]);
330        vec![calc_requirements(partition_bys, order_keys)]
331    }
332
333    fn required_input_distribution(&self) -> Vec<Distribution> {
334        if self.partition_keys().is_empty() {
335            debug!("No partition defined for BoundedWindowAggExec!!!");
336            vec![Distribution::SinglePartition]
337        } else {
338            vec![Distribution::HashPartitioned(self.partition_keys().clone())]
339        }
340    }
341
342    fn maintains_input_order(&self) -> Vec<bool> {
343        vec![true]
344    }
345
346    fn with_new_children(
347        self: Arc<Self>,
348        children: Vec<Arc<dyn ExecutionPlan>>,
349    ) -> Result<Arc<dyn ExecutionPlan>> {
350        check_if_same_properties!(self, children);
351        Ok(Arc::new(BoundedWindowAggExec::try_new(
352            self.window_expr.clone(),
353            Arc::clone(&children[0]),
354            self.input_order_mode.clone(),
355            self.can_repartition,
356        )?))
357    }
358
359    fn execute(
360        &self,
361        partition: usize,
362        context: Arc<TaskContext>,
363    ) -> Result<SendableRecordBatchStream> {
364        let input = self.input.execute(partition, context)?;
365        let search_mode = self.get_search_algo()?;
366        let stream = Box::pin(BoundedWindowAggStream::new(
367            Arc::clone(&self.schema),
368            self.window_expr.clone(),
369            input,
370            BaselineMetrics::new(&self.metrics, partition),
371            search_mode,
372        )?);
373        Ok(stream)
374    }
375
376    fn metrics(&self) -> Option<MetricsSet> {
377        Some(self.metrics.clone_inner())
378    }
379
380    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
381        let input_stat =
382            Arc::unwrap_or_clone(self.input.partition_statistics(partition)?);
383        Ok(Arc::new(self.statistics_helper(input_stat)?))
384    }
385
386    fn cardinality_effect(&self) -> CardinalityEffect {
387        CardinalityEffect::Equal
388    }
389}
390
391/// Trait that specifies how we search for (or calculate) partitions. It has two
392/// implementations: [`SortedSearch`] and [`LinearSearch`].
393trait PartitionSearcher: Send {
394    /// This method constructs output columns using the result of each window expression
395    /// (each entry in the output vector comes from a window expression).
396    /// Executor when producing output concatenates `input_buffer` (corresponding section), and
397    /// result of this function to generate output `RecordBatch`. `input_buffer` is used to determine
398    /// which sections of the window expression results should be used to generate output.
399    /// `partition_buffers` contains corresponding section of the `RecordBatch` for each partition.
400    /// `window_agg_states` stores per partition state for each window expression.
401    /// None case means that no result is generated
402    /// `Some(Vec<ArrayRef>)` is the result of each window expression.
403    fn calculate_out_columns(
404        &mut self,
405        input_buffer: &RecordBatch,
406        window_agg_states: &[PartitionWindowAggStates],
407        partition_buffers: &mut PartitionBatches,
408        window_expr: &[Arc<dyn WindowExpr>],
409    ) -> Result<Option<Vec<ArrayRef>>>;
410
411    /// Determine whether `[InputOrderMode]` is `[InputOrderMode::Linear]` or not.
412    fn is_mode_linear(&self) -> bool {
413        false
414    }
415
416    // Constructs corresponding batches for each partition for the record_batch.
417    fn evaluate_partition_batches(
418        &mut self,
419        record_batch: &RecordBatch,
420        window_expr: &[Arc<dyn WindowExpr>],
421    ) -> Result<Vec<(PartitionKey, RecordBatch)>>;
422
423    /// Prunes the state.
424    fn prune(&mut self, _n_out: usize) {}
425
426    /// Marks the partition as done if we are sure that corresponding partition
427    /// cannot receive any more values.
428    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches);
429
430    /// Updates `input_buffer` and `partition_buffers` with the new `record_batch`.
431    fn update_partition_batch(
432        &mut self,
433        input_buffer: &mut RecordBatch,
434        record_batch: RecordBatch,
435        window_expr: &[Arc<dyn WindowExpr>],
436        partition_buffers: &mut PartitionBatches,
437    ) -> Result<()> {
438        if record_batch.num_rows() == 0 {
439            return Ok(());
440        }
441        let partition_batches =
442            self.evaluate_partition_batches(&record_batch, window_expr)?;
443        for (partition_row, partition_batch) in partition_batches {
444            if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row)
445            {
446                partition_batch_state.extend(&partition_batch)?
447            } else {
448                let options = RecordBatchOptions::new()
449                    .with_row_count(Some(partition_batch.num_rows()));
450                // Use input_schema for the buffer schema, not `record_batch.schema()`
451                // as it may not have the "correct" schema in terms of output
452                // nullability constraints. For details, see the following issue:
453                // https://github.com/apache/datafusion/issues/9320
454                let partition_batch = RecordBatch::try_new_with_options(
455                    Arc::clone(self.input_schema()),
456                    partition_batch.columns().to_vec(),
457                    &options,
458                )?;
459                let partition_batch_state =
460                    PartitionBatchState::new_with_batch(partition_batch);
461                partition_buffers.insert(partition_row, partition_batch_state);
462            }
463        }
464
465        if self.is_mode_linear() {
466            // In `Linear` mode, it is guaranteed that the first ORDER BY column
467            // is sorted across partitions. Note that only the first ORDER BY
468            // column is guaranteed to be ordered. As a counter example, consider
469            // the case, `PARTITION BY b, ORDER BY a, c` when the input is sorted
470            // by `[a, b, c]`. In this case, `BoundedWindowAggExec` mode will be
471            // `Linear`. However, we cannot guarantee that the last row of the
472            // input data will be the "last" data in terms of the ordering requirement
473            // `[a, c]` -- it will be the "last" data in terms of `[a, b, c]`.
474            // Hence, only column `a` should be used as a guarantee of the "last"
475            // data across partitions. For other modes (`Sorted`, `PartiallySorted`),
476            // we do not need to keep track of the most recent row guarantee across
477            // partitions. Since leading ordering separates partitions, guaranteed
478            // by the most recent row, already prune the previous partitions completely.
479            let last_row = get_last_row_batch(&record_batch)?;
480            for (_, partition_batch) in partition_buffers.iter_mut() {
481                partition_batch.set_most_recent_row(last_row.clone());
482            }
483        }
484        self.mark_partition_end(partition_buffers);
485
486        *input_buffer = if input_buffer.num_rows() == 0 {
487            record_batch
488        } else {
489            concat_batches(self.input_schema(), [input_buffer, &record_batch])?
490        };
491
492        Ok(())
493    }
494
495    fn input_schema(&self) -> &SchemaRef;
496}
497
498/// This object encapsulates the algorithm state for a simple linear scan
499/// algorithm for computing partitions.
500pub struct LinearSearch {
501    /// Keeps the hash of input buffer calculated from PARTITION BY columns.
502    /// Its length is equal to the `input_buffer` length.
503    input_buffer_hashes: VecDeque<u64>,
504    /// Used during hash value calculation.
505    random_state: RandomState,
506    /// Input ordering and partition by key ordering need not be the same, so
507    /// this vector stores the mapping between them. For instance, if the input
508    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
509    /// clause, this attribute stores [1, 0].
510    ordered_partition_by_indices: Vec<usize>,
511    /// We use this [`HashTable`] to calculate unique partitions for each new
512    /// RecordBatch. First entry in the tuple is the hash value, the second
513    /// entry is the unique ID for each partition (increments from 0 to n).
514    row_map_batch: HashTable<(u64, usize)>,
515    /// We use this [`HashTable`] to calculate the output columns that we can
516    /// produce at each cycle. First entry in the tuple is the hash value, the
517    /// second entry is the unique ID for each partition (increments from 0 to n).
518    /// The third entry stores how many new outputs are calculated for the
519    /// corresponding partition.
520    row_map_out: HashTable<(u64, usize, usize)>,
521    input_schema: SchemaRef,
522}
523
524impl PartitionSearcher for LinearSearch {
525    /// This method constructs output columns using the result of each window expression.
526    // Assume input buffer is         |      Partition Buffers would be (Where each partition and its data is separated)
527    // a, 2                           |      a, 2
528    // b, 2                           |      a, 2
529    // a, 2                           |      a, 2
530    // b, 2                           |
531    // a, 2                           |      b, 2
532    // b, 2                           |      b, 2
533    // b, 2                           |      b, 2
534    //                                |      b, 2
535    // Also assume we happen to calculate 2 new values for a, and 3 for b (To be calculate missing values we may need to consider future values).
536    // Partition buffers effectively will be
537    // a, 2, 1
538    // a, 2, 2
539    // a, 2, (missing)
540    //
541    // b, 2, 1
542    // b, 2, 2
543    // b, 2, 3
544    // b, 2, (missing)
545    // When partition buffers are mapped back to the original record batch. Result becomes
546    // a, 2, 1
547    // b, 2, 1
548    // a, 2, 2
549    // b, 2, 2
550    // a, 2, (missing)
551    // b, 2, 3
552    // b, 2, (missing)
553    // This function calculates the column result of window expression(s) (First 4 entry of 3rd column in the above section.)
554    // 1
555    // 1
556    // 2
557    // 2
558    // Above section corresponds to calculated result which can be emitted without breaking input buffer ordering.
559    fn calculate_out_columns(
560        &mut self,
561        input_buffer: &RecordBatch,
562        window_agg_states: &[PartitionWindowAggStates],
563        partition_buffers: &mut PartitionBatches,
564        window_expr: &[Arc<dyn WindowExpr>],
565    ) -> Result<Option<Vec<ArrayRef>>> {
566        let partition_output_indices = self.calc_partition_output_indices(
567            input_buffer,
568            window_agg_states,
569            window_expr,
570        )?;
571
572        let n_window_col = window_agg_states.len();
573        let mut new_columns = vec![vec![]; n_window_col];
574        // Size of all_indices can be at most input_buffer.num_rows():
575        let mut all_indices = UInt32Builder::with_capacity(input_buffer.num_rows());
576        for (row, indices) in partition_output_indices {
577            let length = indices.len();
578            for (idx, window_agg_state) in window_agg_states.iter().enumerate() {
579                let partition = &window_agg_state[&row];
580                let values = Arc::clone(&partition.state.out_col.slice(0, length));
581                new_columns[idx].push(values);
582            }
583            let partition_batch_state = &mut partition_buffers[&row];
584            // Store how many rows are generated for each partition
585            partition_batch_state.n_out_row = length;
586            // For each row keep corresponding index in the input record batch
587            all_indices.append_slice(&indices);
588        }
589        let all_indices = all_indices.finish();
590        if all_indices.is_empty() {
591            // We couldn't generate any new value, return early:
592            return Ok(None);
593        }
594
595        // Concatenate results for each column by converting `Vec<Vec<ArrayRef>>`
596        // to Vec<ArrayRef> where inner `Vec<ArrayRef>`s are converted to `ArrayRef`s.
597        let new_columns = new_columns
598            .iter()
599            .map(|items| {
600                concat(&items.iter().map(|e| e.as_ref()).collect::<Vec<_>>())
601                    .map_err(|e| arrow_datafusion_err!(e))
602            })
603            .collect::<Result<Vec<_>>>()?;
604        // We should emit columns according to row index ordering.
605        let sorted_indices = sort_to_indices(&all_indices, None, None)?;
606        // Construct new column according to row ordering. This fixes ordering
607        take_arrays(&new_columns, &sorted_indices, None)
608            .map(Some)
609            .map_err(|e| arrow_datafusion_err!(e))
610    }
611
612    fn evaluate_partition_batches(
613        &mut self,
614        record_batch: &RecordBatch,
615        window_expr: &[Arc<dyn WindowExpr>],
616    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
617        let partition_bys =
618            evaluate_partition_by_column_values(record_batch, window_expr)?;
619        // NOTE: In Linear or PartiallySorted modes, we are sure that
620        //       `partition_bys` are not empty.
621        // Calculate indices for each partition and construct a new record
622        // batch from the rows at these indices for each partition:
623        self.get_per_partition_indices(&partition_bys, record_batch)?
624            .into_iter()
625            .map(|(row, indices)| {
626                let mut new_indices = UInt32Builder::with_capacity(indices.len());
627                new_indices.append_slice(&indices);
628                let indices = new_indices.finish();
629                Ok((row, take_record_batch(record_batch, &indices)?))
630            })
631            .collect()
632    }
633
634    fn prune(&mut self, n_out: usize) {
635        // Delete hashes for the rows that are outputted.
636        self.input_buffer_hashes.drain(0..n_out);
637    }
638
639    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
640        // We should be in the `PartiallySorted` case, otherwise we can not
641        // tell when we are at the end of a given partition.
642        if !self.ordered_partition_by_indices.is_empty()
643            && let Some((last_row, _)) = partition_buffers.last()
644        {
645            let last_sorted_cols = self
646                .ordered_partition_by_indices
647                .iter()
648                .map(|idx| last_row[*idx].clone())
649                .collect::<Vec<_>>();
650            for (row, partition_batch_state) in partition_buffers.iter_mut() {
651                let sorted_cols = self
652                    .ordered_partition_by_indices
653                    .iter()
654                    .map(|idx| &row[*idx]);
655                // All the partitions other than `last_sorted_cols` are done.
656                // We are sure that we will no longer receive values for these
657                // partitions (arrival of a new value would violate ordering).
658                partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols);
659            }
660        }
661    }
662
663    fn is_mode_linear(&self) -> bool {
664        self.ordered_partition_by_indices.is_empty()
665    }
666
667    fn input_schema(&self) -> &SchemaRef {
668        &self.input_schema
669    }
670}
671
672impl LinearSearch {
673    /// Initialize a new [`LinearSearch`] partition searcher.
674    fn new(ordered_partition_by_indices: Vec<usize>, input_schema: SchemaRef) -> Self {
675        LinearSearch {
676            input_buffer_hashes: VecDeque::new(),
677            random_state: Default::default(),
678            ordered_partition_by_indices,
679            row_map_batch: HashTable::with_capacity(256),
680            row_map_out: HashTable::with_capacity(256),
681            input_schema,
682        }
683    }
684
685    /// Calculate indices of each partition (according to PARTITION BY expression)
686    /// `columns` contain partition by expression results.
687    fn get_per_partition_indices(
688        &mut self,
689        columns: &[ArrayRef],
690        batch: &RecordBatch,
691    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
692        let mut batch_hashes = vec![0; batch.num_rows()];
693        create_hashes(columns, &self.random_state, &mut batch_hashes)?;
694        self.input_buffer_hashes.extend(&batch_hashes);
695        // reset row_map for new calculation
696        self.row_map_batch.clear();
697        // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition.
698        let mut result: Vec<(PartitionKey, Vec<u32>)> = vec![];
699        for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) {
700            let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| {
701                // We can safely get the first index of the partition indices
702                // since partition indices has one element during initialization.
703                let row = get_row_at_idx(columns, row_idx as usize).unwrap();
704                // Handle hash collusions with an equality check:
705                row.eq(&result[*group_idx].0)
706            });
707            if let Some((_, group_idx)) = entry {
708                result[*group_idx].1.push(row_idx)
709            } else {
710                self.row_map_batch.insert_unique(
711                    hash,
712                    (hash, result.len()),
713                    |(hash, _)| *hash,
714                );
715                let row = get_row_at_idx(columns, row_idx as usize)?;
716                // This is a new partition its only index is row_idx for now.
717                result.push((row, vec![row_idx]));
718            }
719        }
720        Ok(result)
721    }
722
723    /// Calculates partition keys and result indices for each partition.
724    /// The return value is a vector of tuples where the first entry stores
725    /// the partition key (unique for each partition) and the second entry
726    /// stores indices of the rows for which the partition is constructed.
727    fn calc_partition_output_indices(
728        &mut self,
729        input_buffer: &RecordBatch,
730        window_agg_states: &[PartitionWindowAggStates],
731        window_expr: &[Arc<dyn WindowExpr>],
732    ) -> Result<Vec<(PartitionKey, Vec<u32>)>> {
733        let partition_by_columns =
734            evaluate_partition_by_column_values(input_buffer, window_expr)?;
735        // Reset the row_map state:
736        self.row_map_out.clear();
737        let mut partition_indices: Vec<(PartitionKey, Vec<u32>)> = vec![];
738        for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) {
739            let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| {
740                let row =
741                    get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap();
742                row == partition_indices[*group_idx].0
743            });
744            if let Some((_, group_idx, n_out)) = entry {
745                let (_, indices) = &mut partition_indices[*group_idx];
746                if indices.len() >= *n_out {
747                    break;
748                }
749                indices.push(row_idx);
750            } else {
751                let row = get_row_at_idx(&partition_by_columns, row_idx as usize)?;
752                let min_out = window_agg_states
753                    .iter()
754                    .map(|window_agg_state| {
755                        window_agg_state
756                            .get(&row)
757                            .map(|partition| partition.state.out_col.len())
758                            .unwrap_or(0)
759                    })
760                    .min()
761                    .unwrap_or(0);
762                if min_out == 0 {
763                    break;
764                }
765                self.row_map_out.insert_unique(
766                    *hash,
767                    (*hash, partition_indices.len(), min_out),
768                    |(hash, _, _)| *hash,
769                );
770                partition_indices.push((row, vec![row_idx]));
771            }
772        }
773        Ok(partition_indices)
774    }
775}
776
777/// This object encapsulates the algorithm state for sorted searching
778/// when computing partitions.
779pub struct SortedSearch {
780    /// Stores partition by columns and their ordering information
781    partition_by_sort_keys: Vec<PhysicalSortExpr>,
782    /// Input ordering and partition by key ordering need not be the same, so
783    /// this vector stores the mapping between them. For instance, if the input
784    /// is ordered by a, b and the window expression contains a PARTITION BY b, a
785    /// clause, this attribute stores [1, 0].
786    ordered_partition_by_indices: Vec<usize>,
787    input_schema: SchemaRef,
788}
789
790impl PartitionSearcher for SortedSearch {
791    /// This method constructs new output columns using the result of each window expression.
792    fn calculate_out_columns(
793        &mut self,
794        _input_buffer: &RecordBatch,
795        window_agg_states: &[PartitionWindowAggStates],
796        partition_buffers: &mut PartitionBatches,
797        _window_expr: &[Arc<dyn WindowExpr>],
798    ) -> Result<Option<Vec<ArrayRef>>> {
799        let n_out = self.calculate_n_out_row(window_agg_states, partition_buffers);
800        if n_out == 0 {
801            Ok(None)
802        } else {
803            window_agg_states
804                .iter()
805                .map(|map| get_aggregate_result_out_column(map, n_out).map(Some))
806                .collect()
807        }
808    }
809
810    fn evaluate_partition_batches(
811        &mut self,
812        record_batch: &RecordBatch,
813        _window_expr: &[Arc<dyn WindowExpr>],
814    ) -> Result<Vec<(PartitionKey, RecordBatch)>> {
815        let num_rows = record_batch.num_rows();
816        // Calculate result of partition by column expressions
817        let partition_columns = self
818            .partition_by_sort_keys
819            .iter()
820            .map(|elem| elem.evaluate_to_sort_column(record_batch))
821            .collect::<Result<Vec<_>>>()?;
822        // Reorder `partition_columns` such that its ordering matches input ordering.
823        let partition_columns_ordered =
824            get_at_indices(&partition_columns, &self.ordered_partition_by_indices)?;
825        let partition_points =
826            evaluate_partition_ranges(num_rows, &partition_columns_ordered)?;
827        let partition_bys = partition_columns
828            .into_iter()
829            .map(|arr| arr.values)
830            .collect::<Vec<ArrayRef>>();
831
832        partition_points
833            .iter()
834            .map(|range| {
835                let row = get_row_at_idx(&partition_bys, range.start)?;
836                let len = range.end - range.start;
837                let slice = record_batch.slice(range.start, len);
838                Ok((row, slice))
839            })
840            .collect::<Result<Vec<_>>>()
841    }
842
843    fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) {
844        // In Sorted case. We can mark all partitions besides last partition as ended.
845        // We are sure that those partitions will never receive any values.
846        // (Otherwise ordering invariant is violated.)
847        let n_partitions = partition_buffers.len();
848        for (idx, (_, partition_batch_state)) in partition_buffers.iter_mut().enumerate()
849        {
850            partition_batch_state.is_end |= idx < n_partitions - 1;
851        }
852    }
853
854    fn input_schema(&self) -> &SchemaRef {
855        &self.input_schema
856    }
857}
858
859impl SortedSearch {
860    /// Calculates how many rows we can output.
861    fn calculate_n_out_row(
862        &mut self,
863        window_agg_states: &[PartitionWindowAggStates],
864        partition_buffers: &mut PartitionBatches,
865    ) -> usize {
866        // Different window aggregators may produce results at different rates.
867        // We produce the overall batch result only as fast as the slowest one.
868        let mut counts = vec![];
869        let out_col_counts = window_agg_states.iter().map(|window_agg_state| {
870            // Store how many elements are generated for the current
871            // window expression:
872            let mut cur_window_expr_out_result_len = 0;
873            // We iterate over `window_agg_state`, which is an IndexMap.
874            // Iterations follow the insertion order, hence we preserve
875            // sorting when partition columns are sorted.
876            let mut per_partition_out_results = HashMap::new();
877            for (row, WindowState { state, .. }) in window_agg_state.iter() {
878                cur_window_expr_out_result_len += state.out_col.len();
879                let count = per_partition_out_results.entry(row).or_insert(0);
880                if *count < state.out_col.len() {
881                    *count = state.out_col.len();
882                }
883                // If we do not generate all results for the current
884                // partition, we do not generate results for next
885                // partition --  otherwise we will lose input ordering.
886                if state.n_row_result_missing > 0 {
887                    break;
888                }
889            }
890            counts.push(per_partition_out_results);
891            cur_window_expr_out_result_len
892        });
893        argmin(out_col_counts).map_or(0, |(min_idx, minima)| {
894            let mut slowest_partition = counts.swap_remove(min_idx);
895            for (partition_key, partition_batch) in partition_buffers.iter_mut() {
896                if let Some(count) = slowest_partition.remove(partition_key) {
897                    partition_batch.n_out_row = count;
898                }
899            }
900            minima
901        })
902    }
903}
904
905/// Calculates partition by expression results for each window expression
906/// on `record_batch`.
907fn evaluate_partition_by_column_values(
908    record_batch: &RecordBatch,
909    window_expr: &[Arc<dyn WindowExpr>],
910) -> Result<Vec<ArrayRef>> {
911    window_expr[0]
912        .partition_by()
913        .iter()
914        .map(|item| match item.evaluate(record_batch)? {
915            ColumnarValue::Array(array) => Ok(array),
916            ColumnarValue::Scalar(scalar) => {
917                scalar.to_array_of_size(record_batch.num_rows())
918            }
919        })
920        .collect()
921}
922
923/// Stream for the bounded window aggregation plan.
924pub struct BoundedWindowAggStream {
925    schema: SchemaRef,
926    input: SendableRecordBatchStream,
927    /// The record batch executor receives as input (i.e. the columns needed
928    /// while calculating aggregation results).
929    input_buffer: RecordBatch,
930    /// We separate `input_buffer` based on partitions (as
931    /// determined by PARTITION BY columns) and store them per partition
932    /// in `partition_batches`. We use this variable when calculating results
933    /// for each window expression. This enables us to use the same batch for
934    /// different window expressions without copying.
935    // Note that we could keep record batches for each window expression in
936    // `PartitionWindowAggStates`. However, this would use more memory (as
937    // many times as the number of window expressions).
938    partition_buffers: PartitionBatches,
939    /// An executor can run multiple window expressions if the PARTITION BY
940    /// and ORDER BY sections are same. We keep state of the each window
941    /// expression inside `window_agg_states`.
942    window_agg_states: Vec<PartitionWindowAggStates>,
943    finished: bool,
944    window_expr: Vec<Arc<dyn WindowExpr>>,
945    baseline_metrics: BaselineMetrics,
946    /// Search mode for partition columns. This determines the algorithm with
947    /// which we group each partition.
948    search_mode: Box<dyn PartitionSearcher>,
949}
950
951impl BoundedWindowAggStream {
952    /// Prunes sections of the state that are no longer needed when calculating
953    /// results (as determined by window frame boundaries and number of results generated).
954    // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to
955    // calculate window expression result (outside the window frame boundary) we retract first `n` elements
956    // from `self.partition_batches` in corresponding partition.
957    // For instance, if `n_out` number of rows are calculated, we can remove
958    // first `n_out` rows from `self.input_buffer`.
959    fn prune_state(&mut self, n_out: usize) -> Result<()> {
960        // Prune `self.window_agg_states`:
961        self.prune_out_columns();
962        // Prune `self.partition_batches`:
963        self.prune_partition_batches();
964        // Prune `self.input_buffer`:
965        self.prune_input_batch(n_out)?;
966        // Prune internal state of search algorithm.
967        self.search_mode.prune(n_out);
968        Ok(())
969    }
970}
971
972impl Stream for BoundedWindowAggStream {
973    type Item = Result<RecordBatch>;
974
975    fn poll_next(
976        mut self: Pin<&mut Self>,
977        cx: &mut Context<'_>,
978    ) -> Poll<Option<Self::Item>> {
979        let poll = self.poll_next_inner(cx);
980        self.baseline_metrics.record_poll(poll)
981    }
982}
983
984impl BoundedWindowAggStream {
985    /// Create a new BoundedWindowAggStream
986    fn new(
987        schema: SchemaRef,
988        window_expr: Vec<Arc<dyn WindowExpr>>,
989        input: SendableRecordBatchStream,
990        baseline_metrics: BaselineMetrics,
991        search_mode: Box<dyn PartitionSearcher>,
992    ) -> Result<Self> {
993        let state = window_expr.iter().map(|_| IndexMap::new()).collect();
994        let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
995        Ok(Self {
996            schema,
997            input,
998            input_buffer: empty_batch,
999            partition_buffers: IndexMap::new(),
1000            window_agg_states: state,
1001            finished: false,
1002            window_expr,
1003            baseline_metrics,
1004            search_mode,
1005        })
1006    }
1007
1008    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
1009        // calculate window cols
1010        for (cur_window_expr, state) in
1011            self.window_expr.iter().zip(&mut self.window_agg_states)
1012        {
1013            cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
1014        }
1015
1016        let schema = Arc::clone(&self.schema);
1017        let window_expr_out = self.search_mode.calculate_out_columns(
1018            &self.input_buffer,
1019            &self.window_agg_states,
1020            &mut self.partition_buffers,
1021            &self.window_expr,
1022        )?;
1023        if let Some(window_expr_out) = window_expr_out {
1024            let n_out = window_expr_out[0].len();
1025            // right append new columns to corresponding section in the original input buffer.
1026            let columns_to_show = self
1027                .input_buffer
1028                .columns()
1029                .iter()
1030                .map(|elem| elem.slice(0, n_out))
1031                .chain(window_expr_out)
1032                .collect::<Vec<_>>();
1033            let n_generated = columns_to_show[0].len();
1034            self.prune_state(n_generated)?;
1035            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
1036        } else {
1037            Ok(None)
1038        }
1039    }
1040
1041    #[inline]
1042    fn poll_next_inner(
1043        &mut self,
1044        cx: &mut Context<'_>,
1045    ) -> Poll<Option<Result<RecordBatch>>> {
1046        if self.finished {
1047            return Poll::Ready(None);
1048        }
1049
1050        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1051        match ready!(self.input.poll_next_unpin(cx)) {
1052            Some(Ok(batch)) => {
1053                // Start the timer for compute time within this operator. It will be
1054                // stopped when dropped.
1055                let _timer = elapsed_compute.timer();
1056
1057                self.search_mode.update_partition_batch(
1058                    &mut self.input_buffer,
1059                    batch,
1060                    &self.window_expr,
1061                    &mut self.partition_buffers,
1062                )?;
1063                if let Some(batch) = self.compute_aggregates()? {
1064                    return Poll::Ready(Some(Ok(batch)));
1065                }
1066                self.poll_next_inner(cx)
1067            }
1068            Some(Err(e)) => Poll::Ready(Some(Err(e))),
1069            None => {
1070                let _timer = elapsed_compute.timer();
1071
1072                self.finished = true;
1073                // Release the input pipeline's resources before computing the
1074                // final aggregates.
1075                let input_schema = self.input.schema();
1076                self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
1077                for (_, partition_batch_state) in self.partition_buffers.iter_mut() {
1078                    partition_batch_state.is_end = true;
1079                }
1080                if let Some(batch) = self.compute_aggregates()? {
1081                    return Poll::Ready(Some(Ok(batch)));
1082                }
1083                Poll::Ready(None)
1084            }
1085        }
1086    }
1087
1088    /// Prunes the sections of the record batch (for each partition)
1089    /// that we no longer need to calculate the window function result.
1090    fn prune_partition_batches(&mut self) {
1091        // Remove partitions which we know already ended (is_end flag is true).
1092        // Since the retain method preserves insertion order, we still have
1093        // ordering in between partitions after removal.
1094        self.partition_buffers
1095            .retain(|_, partition_batch_state| !partition_batch_state.is_end);
1096
1097        // The data in `self.partition_batches` is used by all window expressions.
1098        // Therefore, when removing from `self.partition_batches`, we need to remove
1099        // from the earliest range boundary among all window expressions. Variable
1100        // `n_prune_each_partition` fill the earliest range boundary information for
1101        // each partition. This way, we can delete the no-longer-needed sections from
1102        // `self.partition_batches`.
1103        // For instance, if window frame one uses [10, 20] and window frame two uses
1104        // [5, 15]; we only prune the first 5 elements from the corresponding record
1105        // batch in `self.partition_batches`.
1106
1107        // Calculate how many elements to prune for each partition batch
1108        let mut n_prune_each_partition = HashMap::new();
1109        for window_agg_state in self.window_agg_states.iter_mut() {
1110            window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
1111            for (partition_row, WindowState { state: value, .. }) in window_agg_state {
1112                let n_prune =
1113                    min(value.window_frame_range.start, value.last_calculated_index);
1114                if let Some(current) = n_prune_each_partition.get_mut(partition_row) {
1115                    if n_prune < *current {
1116                        *current = n_prune;
1117                    }
1118                } else {
1119                    n_prune_each_partition.insert(partition_row.clone(), n_prune);
1120                }
1121            }
1122        }
1123
1124        // Retract no longer needed parts during window calculations from partition batch:
1125        for (partition_row, n_prune) in n_prune_each_partition.iter() {
1126            let pb_state = &mut self.partition_buffers[partition_row];
1127
1128            let batch = &pb_state.record_batch;
1129            pb_state.record_batch = batch.slice(*n_prune, batch.num_rows() - n_prune);
1130            pb_state.n_out_row = 0;
1131
1132            // Update state indices since we have pruned some rows from the beginning:
1133            for window_agg_state in self.window_agg_states.iter_mut() {
1134                window_agg_state[partition_row].state.prune_state(*n_prune);
1135            }
1136        }
1137    }
1138
1139    /// Prunes the section of the input batch whose aggregate results
1140    /// are calculated and emitted.
1141    fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
1142        // Prune first n_out rows from the input_buffer
1143        let n_to_keep = self.input_buffer.num_rows() - n_out;
1144        let batch_to_keep = self
1145            .input_buffer
1146            .columns()
1147            .iter()
1148            .map(|elem| elem.slice(n_out, n_to_keep))
1149            .collect::<Vec<_>>();
1150        self.input_buffer = RecordBatch::try_new_with_options(
1151            self.input_buffer.schema(),
1152            batch_to_keep,
1153            &RecordBatchOptions::new().with_row_count(Some(n_to_keep)),
1154        )?;
1155        Ok(())
1156    }
1157
1158    /// Prunes emitted parts from WindowAggState `out_col` field.
1159    fn prune_out_columns(&mut self) {
1160        // We store generated columns for each window expression in the `out_col`
1161        // field of `WindowAggState`. Given how many rows are emitted, we remove
1162        // these sections from state.
1163        for partition_window_agg_states in self.window_agg_states.iter_mut() {
1164            // Remove `n_out` entries from the `out_col` field of `WindowAggState`.
1165            // `n_out` is stored in `self.partition_buffers` for each partition.
1166            // If `is_end` is set, directly remove them; this shrinks the hash map.
1167            partition_window_agg_states
1168                .retain(|_, partition_batch_state| !partition_batch_state.state.is_end);
1169            for (
1170                partition_key,
1171                WindowState {
1172                    state: WindowAggState { out_col, .. },
1173                    ..
1174                },
1175            ) in partition_window_agg_states
1176            {
1177                let partition_batch = &mut self.partition_buffers[partition_key];
1178                let n_to_del = partition_batch.n_out_row;
1179                let n_to_keep = out_col.len() - n_to_del;
1180                *out_col = out_col.slice(n_to_del, n_to_keep);
1181            }
1182        }
1183    }
1184}
1185
1186impl RecordBatchStream for BoundedWindowAggStream {
1187    /// Get the schema
1188    fn schema(&self) -> SchemaRef {
1189        Arc::clone(&self.schema)
1190    }
1191}
1192
1193// Gets the index of minimum entry, returns None if empty.
1194fn argmin<T: PartialOrd>(data: impl Iterator<Item = T>) -> Option<(usize, T)> {
1195    data.enumerate()
1196        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1197}
1198
1199/// Calculates the section we can show results for expression
1200fn get_aggregate_result_out_column(
1201    partition_window_agg_states: &PartitionWindowAggStates,
1202    len_to_show: usize,
1203) -> Result<ArrayRef> {
1204    let mut result = None;
1205    let mut running_length = 0;
1206    let mut batches_to_concat = vec![];
1207    // We assume that iteration order is according to insertion order
1208    for (
1209        _,
1210        WindowState {
1211            state: WindowAggState { out_col, .. },
1212            ..
1213        },
1214    ) in partition_window_agg_states
1215    {
1216        if running_length < len_to_show {
1217            let n_to_use = min(len_to_show - running_length, out_col.len());
1218            let slice_to_use = if n_to_use == out_col.len() {
1219                // avoid slice when the entire column is used
1220                Arc::clone(out_col)
1221            } else {
1222                out_col.slice(0, n_to_use)
1223            };
1224            batches_to_concat.push(slice_to_use);
1225            running_length += n_to_use;
1226        } else {
1227            break;
1228        }
1229    }
1230
1231    if !batches_to_concat.is_empty() {
1232        let array_refs: Vec<&dyn Array> =
1233            batches_to_concat.iter().map(|a| a.as_ref()).collect();
1234        result = Some(concat(&array_refs)?);
1235    }
1236
1237    if running_length != len_to_show {
1238        return exec_err!(
1239            "Generated row number should be {len_to_show}, it is {running_length}"
1240        );
1241    }
1242    result.ok_or_else(|| exec_datafusion_err!("Should contain something"))
1243}
1244
1245/// Constructs a batch from the last row of batch in the argument.
1246pub(crate) fn get_last_row_batch(batch: &RecordBatch) -> Result<RecordBatch> {
1247    if batch.num_rows() == 0 {
1248        return exec_err!("Latest batch should have at least 1 row");
1249    }
1250    Ok(batch.slice(batch.num_rows() - 1, 1))
1251}
1252
1253#[cfg(test)]
1254mod tests {
1255    use std::pin::Pin;
1256    use std::sync::Arc;
1257    use std::task::{Context, Poll};
1258    use std::time::Duration;
1259
1260    use crate::common::collect;
1261    use crate::execution_plan::CardinalityEffect;
1262    use crate::expressions::PhysicalSortExpr;
1263    use crate::projection::{ProjectionExec, ProjectionExpr};
1264    use crate::streaming::{PartitionStream, StreamingTableExec};
1265    use crate::test::TestMemoryExec;
1266    use crate::windows::{
1267        BoundedWindowAggExec, InputOrderMode, create_udwf_window_expr, create_window_expr,
1268    };
1269    use crate::{ExecutionPlan, displayable, execute_stream};
1270
1271    use arrow::array::{
1272        RecordBatch,
1273        builder::{Int64Builder, UInt64Builder},
1274    };
1275    use arrow::compute::SortOptions;
1276    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1277    use datafusion_common::test_util::batches_to_string;
1278    use datafusion_common::{Result, ScalarValue, exec_datafusion_err};
1279    use datafusion_execution::config::SessionConfig;
1280    use datafusion_execution::{
1281        RecordBatchStream, SendableRecordBatchStream, TaskContext,
1282    };
1283    use datafusion_expr::{
1284        WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1285    };
1286    use datafusion_functions_aggregate::count::count_udaf;
1287    use datafusion_functions_window::nth_value::last_value_udwf;
1288    use datafusion_functions_window::nth_value::nth_value_udwf;
1289    use datafusion_physical_expr::expressions::{Column, Literal, col};
1290    use datafusion_physical_expr::window::StandardWindowExpr;
1291    use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
1292
1293    use futures::future::Shared;
1294    use futures::{FutureExt, Stream, StreamExt, pin_mut, ready};
1295    use insta::assert_snapshot;
1296    use itertools::Itertools;
1297    use tokio::time::timeout;
1298
1299    #[derive(Debug, Clone)]
1300    struct TestStreamPartition {
1301        schema: SchemaRef,
1302        batches: Vec<RecordBatch>,
1303        idx: usize,
1304        state: PolingState,
1305        sleep_duration: Duration,
1306        send_exit: bool,
1307    }
1308
1309    impl PartitionStream for TestStreamPartition {
1310        fn schema(&self) -> &SchemaRef {
1311            &self.schema
1312        }
1313
1314        fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
1315            // We create an iterator from the record batches and map them into Ok values,
1316            // converting the iterator into a futures::stream::Stream
1317            Box::pin(self.clone())
1318        }
1319    }
1320
1321    impl Stream for TestStreamPartition {
1322        type Item = Result<RecordBatch>;
1323
1324        fn poll_next(
1325            mut self: Pin<&mut Self>,
1326            cx: &mut Context<'_>,
1327        ) -> Poll<Option<Self::Item>> {
1328            self.poll_next_inner(cx)
1329        }
1330    }
1331
1332    #[derive(Debug, Clone)]
1333    enum PolingState {
1334        Sleep(Shared<futures::future::BoxFuture<'static, ()>>),
1335        BatchReturn,
1336    }
1337
1338    impl TestStreamPartition {
1339        fn poll_next_inner(
1340            self: &mut Pin<&mut Self>,
1341            cx: &mut Context<'_>,
1342        ) -> Poll<Option<Result<RecordBatch>>> {
1343            loop {
1344                match &mut self.state {
1345                    PolingState::BatchReturn => {
1346                        // Wait for self.sleep_duration before sending any new data
1347                        let f = tokio::time::sleep(self.sleep_duration).boxed().shared();
1348                        self.state = PolingState::Sleep(f);
1349                        let input_batch = if let Some(batch) =
1350                            self.batches.clone().get(self.idx)
1351                        {
1352                            batch.clone()
1353                        } else if self.send_exit {
1354                            // Send None to signal end of data
1355                            return Poll::Ready(None);
1356                        } else {
1357                            // Go to sleep mode
1358                            let f =
1359                                tokio::time::sleep(self.sleep_duration).boxed().shared();
1360                            self.state = PolingState::Sleep(f);
1361                            continue;
1362                        };
1363                        self.idx += 1;
1364                        return Poll::Ready(Some(Ok(input_batch)));
1365                    }
1366                    PolingState::Sleep(future) => {
1367                        pin_mut!(future);
1368                        ready!(future.poll_unpin(cx));
1369                        self.state = PolingState::BatchReturn;
1370                    }
1371                }
1372            }
1373        }
1374    }
1375
1376    impl RecordBatchStream for TestStreamPartition {
1377        fn schema(&self) -> SchemaRef {
1378            Arc::clone(&self.schema)
1379        }
1380    }
1381
1382    fn bounded_window_exec_pb_latent_range(
1383        input: Arc<dyn ExecutionPlan>,
1384        n_future_range: usize,
1385        hash: &str,
1386        order_by: &str,
1387    ) -> Result<Arc<dyn ExecutionPlan>> {
1388        let schema = input.schema();
1389        let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
1390        let col_expr =
1391            Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1392        let args = vec![col_expr];
1393        let partitionby_exprs = vec![col(hash, &schema)?];
1394        let orderby_exprs = vec![PhysicalSortExpr {
1395            expr: col(order_by, &schema)?,
1396            options: SortOptions::default(),
1397        }];
1398        let window_frame = WindowFrame::new_bounds(
1399            WindowFrameUnits::Range,
1400            WindowFrameBound::CurrentRow,
1401            WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))),
1402        );
1403        let fn_name = format!(
1404            "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]"
1405        );
1406        let input_order_mode = InputOrderMode::Linear;
1407        Ok(Arc::new(BoundedWindowAggExec::try_new(
1408            vec![create_window_expr(
1409                &window_fn,
1410                fn_name,
1411                &args,
1412                &partitionby_exprs,
1413                &orderby_exprs,
1414                Arc::new(window_frame),
1415                input.schema(),
1416                false,
1417                false,
1418                None,
1419            )?],
1420            input,
1421            input_order_mode,
1422            true,
1423        )?))
1424    }
1425
1426    fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
1427        let schema = input.schema();
1428        let exprs = input
1429            .schema()
1430            .fields
1431            .iter()
1432            .enumerate()
1433            .map(|(idx, field)| {
1434                let name = if field.name().len() > 20 {
1435                    format!("col_{idx}")
1436                } else {
1437                    field.name().clone()
1438                };
1439                let expr = col(field.name(), &schema).unwrap();
1440                (expr, name)
1441            })
1442            .collect::<Vec<_>>();
1443        let proj_exprs: Vec<ProjectionExpr> = exprs
1444            .into_iter()
1445            .map(|(expr, alias)| ProjectionExpr { expr, alias })
1446            .collect();
1447        Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?))
1448    }
1449
1450    fn task_context_helper() -> TaskContext {
1451        let task_ctx = TaskContext::default();
1452        // Create session context with config
1453        let session_config = SessionConfig::new()
1454            .with_batch_size(1)
1455            .with_target_partitions(2)
1456            .with_round_robin_repartition(false);
1457        task_ctx.with_session_config(session_config)
1458    }
1459
1460    fn task_context() -> Arc<TaskContext> {
1461        Arc::new(task_context_helper())
1462    }
1463
1464    pub async fn collect_stream(
1465        mut stream: SendableRecordBatchStream,
1466        results: &mut Vec<RecordBatch>,
1467    ) -> Result<()> {
1468        while let Some(item) = stream.next().await {
1469            results.push(item?);
1470        }
1471        Ok(())
1472    }
1473
1474    /// Execute the [ExecutionPlan] and collect the results in memory
1475    pub async fn collect_with_timeout(
1476        plan: Arc<dyn ExecutionPlan>,
1477        context: Arc<TaskContext>,
1478        timeout_duration: Duration,
1479    ) -> Result<Vec<RecordBatch>> {
1480        let stream = execute_stream(plan, context)?;
1481        let mut results = vec![];
1482
1483        // Execute the asynchronous operation with a timeout
1484        if timeout(timeout_duration, collect_stream(stream, &mut results))
1485            .await
1486            .is_ok()
1487        {
1488            return Err(exec_datafusion_err!("shouldn't have completed"));
1489        };
1490
1491        Ok(results)
1492    }
1493
1494    fn test_schema() -> SchemaRef {
1495        Arc::new(Schema::new(vec![
1496            Field::new("sn", DataType::UInt64, true),
1497            Field::new("hash", DataType::Int64, true),
1498        ]))
1499    }
1500
1501    fn schema_orders(schema: &SchemaRef) -> Result<Vec<LexOrdering>> {
1502        let orderings = vec![
1503            [PhysicalSortExpr {
1504                expr: col("sn", schema)?,
1505                options: SortOptions {
1506                    descending: false,
1507                    nulls_first: false,
1508                },
1509            }]
1510            .into(),
1511        ];
1512        Ok(orderings)
1513    }
1514
1515    fn is_integer_division_safe(lhs: usize, rhs: usize) -> bool {
1516        let res = lhs / rhs;
1517        res * rhs == lhs
1518    }
1519    fn generate_batches(
1520        schema: &SchemaRef,
1521        n_row: usize,
1522        n_chunk: usize,
1523    ) -> Result<Vec<RecordBatch>> {
1524        let mut batches = vec![];
1525        assert!(n_row > 0);
1526        assert!(n_chunk > 0);
1527        assert!(is_integer_division_safe(n_row, n_chunk));
1528        let hash_replicate = 4;
1529
1530        let chunks = (0..n_row)
1531            .chunks(n_chunk)
1532            .into_iter()
1533            .map(|elem| elem.into_iter().collect::<Vec<_>>())
1534            .collect::<Vec<_>>();
1535
1536        // Send 2 RecordBatches at the source
1537        for sn_values in chunks {
1538            let mut sn1_array = UInt64Builder::with_capacity(sn_values.len());
1539            let mut hash_array = Int64Builder::with_capacity(sn_values.len());
1540
1541            for sn in sn_values {
1542                sn1_array.append_value(sn as u64);
1543                let hash_value = (2 - (sn / hash_replicate)) as i64;
1544                hash_array.append_value(hash_value);
1545            }
1546
1547            let batch = RecordBatch::try_new(
1548                Arc::clone(schema),
1549                vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())],
1550            )?;
1551            batches.push(batch);
1552        }
1553        Ok(batches)
1554    }
1555
1556    fn generate_never_ending_source(
1557        n_rows: usize,
1558        chunk_length: usize,
1559        n_partition: usize,
1560        is_infinite: bool,
1561        send_exit: bool,
1562        per_batch_wait_duration_in_millis: u64,
1563    ) -> Result<Arc<dyn ExecutionPlan>> {
1564        assert!(n_partition > 0);
1565
1566        // We use same hash value in the table. This makes sure that
1567        // After hashing computation will continue in only in one of the output partitions
1568        // In this case, data flow should still continue
1569        let schema = test_schema();
1570        let orderings = schema_orders(&schema)?;
1571
1572        // Source waits per_batch_wait_duration_in_millis ms before sending other batch
1573        let per_batch_wait_duration =
1574            Duration::from_millis(per_batch_wait_duration_in_millis);
1575
1576        let batches = generate_batches(&schema, n_rows, chunk_length)?;
1577
1578        // Source has 2 partitions
1579        let partitions = vec![
1580            Arc::new(TestStreamPartition {
1581                schema: Arc::clone(&schema),
1582                batches,
1583                idx: 0,
1584                state: PolingState::BatchReturn,
1585                sleep_duration: per_batch_wait_duration,
1586                send_exit,
1587            }) as _;
1588            n_partition
1589        ];
1590        let source = Arc::new(StreamingTableExec::try_new(
1591            Arc::clone(&schema),
1592            partitions,
1593            None,
1594            orderings,
1595            is_infinite,
1596            None,
1597        )?) as _;
1598        Ok(source)
1599    }
1600
1601    // Tests NTH_VALUE(negative index) with memoize feature
1602    // To be able to trigger memoize feature for NTH_VALUE we need to
1603    // - feed BoundedWindowAggExec with batch stream data.
1604    // - Window frame should contain UNBOUNDED PRECEDING.
1605    // It hard to ensure these conditions are met, from the sql query.
1606    #[tokio::test]
1607    async fn test_window_nth_value_bounded_memoize() -> Result<()> {
1608        let config = SessionConfig::new().with_target_partitions(1);
1609        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
1610
1611        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1612        // Create a new batch of data to insert into the table
1613        let batch = RecordBatch::try_new(
1614            Arc::clone(&schema),
1615            vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
1616        )?;
1617
1618        let memory_exec = TestMemoryExec::try_new_exec(
1619            &[vec![batch.clone(), batch.clone(), batch.clone()]],
1620            Arc::clone(&schema),
1621            None,
1622        )?;
1623        let col_a = col("a", &schema)?;
1624        let nth_value_func1 = create_udwf_window_expr(
1625            &nth_value_udwf(),
1626            &[
1627                Arc::clone(&col_a),
1628                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1629            ],
1630            &schema,
1631            "nth_value(-1)".to_string(),
1632            false,
1633        )?
1634        .reverse_expr()
1635        .unwrap();
1636        let nth_value_func2 = create_udwf_window_expr(
1637            &nth_value_udwf(),
1638            &[
1639                Arc::clone(&col_a),
1640                Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
1641            ],
1642            &schema,
1643            "nth_value(-2)".to_string(),
1644            false,
1645        )?
1646        .reverse_expr()
1647        .unwrap();
1648
1649        let last_value_func = create_udwf_window_expr(
1650            &last_value_udwf(),
1651            &[Arc::clone(&col_a)],
1652            &schema,
1653            "last".to_string(),
1654            false,
1655        )?;
1656
1657        let window_exprs = vec![
1658            // LAST_VALUE(a)
1659            Arc::new(StandardWindowExpr::new(
1660                last_value_func,
1661                &[],
1662                &[],
1663                Arc::new(WindowFrame::new_bounds(
1664                    WindowFrameUnits::Rows,
1665                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1666                    WindowFrameBound::CurrentRow,
1667                )),
1668            )) as _,
1669            // NTH_VALUE(a, -1)
1670            Arc::new(StandardWindowExpr::new(
1671                nth_value_func1,
1672                &[],
1673                &[],
1674                Arc::new(WindowFrame::new_bounds(
1675                    WindowFrameUnits::Rows,
1676                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1677                    WindowFrameBound::CurrentRow,
1678                )),
1679            )) as _,
1680            // NTH_VALUE(a, -2)
1681            Arc::new(StandardWindowExpr::new(
1682                nth_value_func2,
1683                &[],
1684                &[],
1685                Arc::new(WindowFrame::new_bounds(
1686                    WindowFrameUnits::Rows,
1687                    WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1688                    WindowFrameBound::CurrentRow,
1689                )),
1690            )) as _,
1691        ];
1692        let physical_plan = BoundedWindowAggExec::try_new(
1693            window_exprs,
1694            memory_exec,
1695            InputOrderMode::Sorted,
1696            true,
1697        )
1698        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
1699
1700        let batches = collect(physical_plan.execute(0, task_ctx)?).await?;
1701
1702        // Get string representation of the plan
1703        assert_snapshot!(displayable(physical_plan.as_ref()).indent(true), @r#"
1704        BoundedWindowAggExec: wdw=[last: Field { "last": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-1): Field { "nth_value(-1)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-2): Field { "nth_value(-2)": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
1705          DataSourceExec: partitions=1, partition_sizes=[3]
1706        "#);
1707
1708        assert_snapshot!(batches_to_string(&batches), @r"
1709        +---+------+---------------+---------------+
1710        | a | last | nth_value(-1) | nth_value(-2) |
1711        +---+------+---------------+---------------+
1712        | 1 | 1    | 1             |               |
1713        | 2 | 2    | 2             | 1             |
1714        | 3 | 3    | 3             | 2             |
1715        | 1 | 1    | 1             | 3             |
1716        | 2 | 2    | 2             | 1             |
1717        | 3 | 3    | 3             | 2             |
1718        | 1 | 1    | 1             | 3             |
1719        | 2 | 2    | 2             | 1             |
1720        | 3 | 3    | 3             | 2             |
1721        +---+------+---------------+---------------+
1722        ");
1723        Ok(())
1724    }
1725
1726    // This test, tests whether most recent row guarantee by the input batch of the `BoundedWindowAggExec`
1727    // helps `BoundedWindowAggExec` to generate low latency result in the `Linear` mode.
1728    // Input data generated at the source is
1729    //       "+----+------+",
1730    //       "| sn | hash |",
1731    //       "+----+------+",
1732    //       "| 0  | 2    |",
1733    //       "| 1  | 2    |",
1734    //       "| 2  | 2    |",
1735    //       "| 3  | 2    |",
1736    //       "| 4  | 1    |",
1737    //       "| 5  | 1    |",
1738    //       "| 6  | 1    |",
1739    //       "| 7  | 1    |",
1740    //       "| 8  | 0    |",
1741    //       "| 9  | 0    |",
1742    //       "+----+------+",
1743    //
1744    // Effectively following query is run on this data
1745    //
1746    //   SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)
1747    //   FROM test;
1748    //
1749    // partition `duplicated_hash=2` receives following data from the input
1750    //
1751    //       "+----+------+",
1752    //       "| sn | hash |",
1753    //       "+----+------+",
1754    //       "| 0  | 2    |",
1755    //       "| 1  | 2    |",
1756    //       "| 2  | 2    |",
1757    //       "| 3  | 2    |",
1758    //       "+----+------+",
1759    // normally `BoundedWindowExec` can only generate following result from the input above
1760    //
1761    //       "+----+------+---------+",
1762    //       "| sn | hash |  count  |",
1763    //       "+----+------+---------+",
1764    //       "| 0  | 2    |  2      |",
1765    //       "| 1  | 2    |  2      |",
1766    //       "| 2  | 2    |<not yet>|",
1767    //       "| 3  | 2    |<not yet>|",
1768    //       "+----+------+---------+",
1769    // where result of last 2 row is missing. Since window frame end is not may change with future data
1770    // since window frame end is determined by 1 following (To generate result for row=3[where sn=2] we
1771    // need to received sn=4 to make sure window frame end bound won't change with future data).
1772    //
1773    // With the ability of different partitions to use global ordering at the input (where most up-to date
1774    //   row is
1775    //      "| 9  | 0    |",
1776    //   )
1777    //
1778    // `BoundedWindowExec` should be able to generate following result in the test
1779    //
1780    //       "+----+------+-------+",
1781    //       "| sn | hash | col_2 |",
1782    //       "+----+------+-------+",
1783    //       "| 0  | 2    | 2     |",
1784    //       "| 1  | 2    | 2     |",
1785    //       "| 2  | 2    | 2     |",
1786    //       "| 3  | 2    | 1     |",
1787    //       "| 4  | 1    | 2     |",
1788    //       "| 5  | 1    | 2     |",
1789    //       "| 6  | 1    | 2     |",
1790    //       "| 7  | 1    | 1     |",
1791    //       "+----+------+-------+",
1792    //
1793    // where result for all rows except last 2 is calculated (To calculate result for row 9 where sn=8
1794    //   we need to receive sn=10 value to calculate it result.).
1795    // In this test, out aim is to test for which portion of the input data `BoundedWindowExec` can generate
1796    // a result. To test this behaviour, we generated the data at the source infinitely (no `None` signal
1797    //    is sent to output from source). After, row:
1798    //
1799    //       "| 9  | 0    |",
1800    //
1801    // is sent. Source stops sending data to output. We collect, result emitted by the `BoundedWindowExec` at the
1802    // end of the pipeline with a timeout (Since no `None` is sent from source. Collection never ends otherwise).
1803    #[tokio::test]
1804    async fn bounded_window_exec_linear_mode_range_information() -> Result<()> {
1805        let n_rows = 10;
1806        let chunk_length = 2;
1807        let n_future_range = 1;
1808
1809        let timeout_duration = Duration::from_millis(2000);
1810
1811        let source =
1812            generate_never_ending_source(n_rows, chunk_length, 1, true, false, 5)?;
1813
1814        let window =
1815            bounded_window_exec_pb_latent_range(source, n_future_range, "hash", "sn")?;
1816
1817        let plan = projection_exec(window)?;
1818
1819        // Get string representation of the plan
1820        assert_snapshot!(displayable(plan.as_ref()).indent(true), @r#"
1821        ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]
1822          BoundedWindowAggExec: wdw=[count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Field { "count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]": Int64 }, frame: RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Linear]
1823            StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]
1824        "#);
1825
1826        let task_ctx = task_context();
1827        let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?;
1828
1829        assert_snapshot!(batches_to_string(&batches), @r"
1830        +----+------+-------+
1831        | sn | hash | col_2 |
1832        +----+------+-------+
1833        | 0  | 2    | 2     |
1834        | 1  | 2    | 2     |
1835        | 2  | 2    | 2     |
1836        | 3  | 2    | 1     |
1837        | 4  | 1    | 2     |
1838        | 5  | 1    | 2     |
1839        | 6  | 1    | 2     |
1840        | 7  | 1    | 1     |
1841        +----+------+-------+
1842        ");
1843
1844        Ok(())
1845    }
1846
1847    #[test]
1848    fn test_bounded_window_agg_cardinality_effect() -> Result<()> {
1849        let schema = test_schema();
1850        let input: Arc<dyn ExecutionPlan> =
1851            Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?);
1852        let plan = bounded_window_exec_pb_latent_range(input, 1, "hash", "sn")?;
1853        let plan = plan
1854            .downcast_ref::<BoundedWindowAggExec>()
1855            .expect("expected BoundedWindowAggExec");
1856
1857        assert!(matches!(
1858            plan.cardinality_effect(),
1859            CardinalityEffect::Equal
1860        ));
1861        Ok(())
1862    }
1863}