datafusion_physical_plan/windows/
window_agg_exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream and channel implementations for window function expressions.
19
20use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::utils::create_schema;
26use crate::execution_plan::EmissionType;
27use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use crate::windows::{
29    calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs,
30    window_equivalence_properties,
31};
32use crate::{
33    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
34    ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream,
35    SendableRecordBatchStream, Statistics, WindowExpr,
36};
37
38use arrow::array::ArrayRef;
39use arrow::compute::{concat, concat_batches};
40use arrow::datatypes::SchemaRef;
41use arrow::error::ArrowError;
42use arrow::record_batch::RecordBatch;
43use datafusion_common::stats::Precision;
44use datafusion_common::utils::{evaluate_partition_ranges, transpose};
45use datafusion_common::{internal_err, Result};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
48
49use futures::{ready, Stream, StreamExt};
50
51/// Window execution plan
52#[derive(Debug, Clone)]
53pub struct WindowAggExec {
54    /// Input plan
55    pub(crate) input: Arc<dyn ExecutionPlan>,
56    /// Window function expression
57    window_expr: Vec<Arc<dyn WindowExpr>>,
58    /// Schema after the window is run
59    schema: SchemaRef,
60    /// Execution metrics
61    metrics: ExecutionPlanMetricsSet,
62    /// Partition by indices that defines preset for existing ordering
63    // see `get_ordered_partition_by_indices` for more details.
64    ordered_partition_by_indices: Vec<usize>,
65    /// Cache holding plan properties like equivalences, output partitioning etc.
66    cache: PlanProperties,
67    /// If `can_partition` is false, partition_keys is always empty.
68    can_repartition: bool,
69}
70
71impl WindowAggExec {
72    /// Create a new execution plan for window aggregates
73    pub fn try_new(
74        window_expr: Vec<Arc<dyn WindowExpr>>,
75        input: Arc<dyn ExecutionPlan>,
76        can_repartition: bool,
77    ) -> Result<Self> {
78        let schema = create_schema(&input.schema(), &window_expr)?;
79        let schema = Arc::new(schema);
80
81        let ordered_partition_by_indices =
82            get_ordered_partition_by_indices(window_expr[0].partition_by(), &input);
83        let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr);
84        Ok(Self {
85            input,
86            window_expr,
87            schema,
88            metrics: ExecutionPlanMetricsSet::new(),
89            ordered_partition_by_indices,
90            cache,
91            can_repartition,
92        })
93    }
94
95    /// Window expressions
96    pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
97        &self.window_expr
98    }
99
100    /// Input plan
101    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
102        &self.input
103    }
104
105    /// Return the output sort order of partition keys: For example
106    /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
107    // We are sure that partition by columns are always at the beginning of sort_keys
108    // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
109    // to calculate partition separation points
110    pub fn partition_by_sort_keys(&self) -> Result<LexOrdering> {
111        let partition_by = self.window_expr()[0].partition_by();
112        get_partition_by_sort_exprs(
113            &self.input,
114            partition_by,
115            &self.ordered_partition_by_indices,
116        )
117    }
118
119    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
120    fn compute_properties(
121        schema: SchemaRef,
122        input: &Arc<dyn ExecutionPlan>,
123        window_exprs: &[Arc<dyn WindowExpr>],
124    ) -> PlanProperties {
125        // Calculate equivalence properties:
126        let eq_properties = window_equivalence_properties(&schema, input, window_exprs);
127
128        // Get output partitioning:
129        // Because we can have repartitioning using the partition keys this
130        // would be either 1 or more than 1 depending on the presence of repartitioning.
131        let output_partitioning = input.output_partitioning().clone();
132
133        // Construct properties cache:
134        PlanProperties::new(
135            eq_properties,
136            output_partitioning,
137            // TODO: Emission type and boundedness information can be enhanced here
138            EmissionType::Final,
139            input.boundedness(),
140        )
141    }
142
143    pub fn partition_keys(&self) -> Vec<Arc<dyn PhysicalExpr>> {
144        if !self.can_repartition {
145            vec![]
146        } else {
147            let all_partition_keys = self
148                .window_expr()
149                .iter()
150                .map(|expr| expr.partition_by().to_vec())
151                .collect::<Vec<_>>();
152
153            all_partition_keys
154                .into_iter()
155                .min_by_key(|s| s.len())
156                .unwrap_or_else(Vec::new)
157        }
158    }
159
160    fn statistics_inner(&self) -> Result<Statistics> {
161        let input_stat = self.input.partition_statistics(None)?;
162        let win_cols = self.window_expr.len();
163        let input_cols = self.input.schema().fields().len();
164        // TODO stats: some windowing function will maintain invariants such as min, max...
165        let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
166        // copy stats of the input to the beginning of the schema.
167        column_statistics.extend(input_stat.column_statistics);
168        for _ in 0..win_cols {
169            column_statistics.push(ColumnStatistics::new_unknown())
170        }
171        Ok(Statistics {
172            num_rows: input_stat.num_rows,
173            column_statistics,
174            total_byte_size: Precision::Absent,
175        })
176    }
177}
178
179impl DisplayAs for WindowAggExec {
180    fn fmt_as(
181        &self,
182        t: DisplayFormatType,
183        f: &mut std::fmt::Formatter,
184    ) -> std::fmt::Result {
185        match t {
186            DisplayFormatType::Default | DisplayFormatType::Verbose => {
187                write!(f, "WindowAggExec: ")?;
188                let g: Vec<String> = self
189                    .window_expr
190                    .iter()
191                    .map(|e| {
192                        format!(
193                            "{}: {:?}, frame: {:?}",
194                            e.name().to_owned(),
195                            e.field(),
196                            e.get_window_frame()
197                        )
198                    })
199                    .collect();
200                write!(f, "wdw=[{}]", g.join(", "))?;
201            }
202            DisplayFormatType::TreeRender => {
203                let g: Vec<String> = self
204                    .window_expr
205                    .iter()
206                    .map(|e| e.name().to_owned().to_string())
207                    .collect();
208                writeln!(f, "select_list={}", g.join(", "))?;
209            }
210        }
211        Ok(())
212    }
213}
214
215impl ExecutionPlan for WindowAggExec {
216    fn name(&self) -> &'static str {
217        "WindowAggExec"
218    }
219
220    /// Return a reference to Any that can be used for downcasting
221    fn as_any(&self) -> &dyn Any {
222        self
223    }
224
225    fn properties(&self) -> &PlanProperties {
226        &self.cache
227    }
228
229    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
230        vec![&self.input]
231    }
232
233    fn maintains_input_order(&self) -> Vec<bool> {
234        vec![true]
235    }
236
237    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
238        let partition_bys = self.window_expr()[0].partition_by();
239        let order_keys = self.window_expr()[0].order_by();
240        if self.ordered_partition_by_indices.len() < partition_bys.len() {
241            vec![calc_requirements(partition_bys, order_keys.iter())]
242        } else {
243            let partition_bys = self
244                .ordered_partition_by_indices
245                .iter()
246                .map(|idx| &partition_bys[*idx]);
247            vec![calc_requirements(partition_bys, order_keys.iter())]
248        }
249    }
250
251    fn required_input_distribution(&self) -> Vec<Distribution> {
252        if self.partition_keys().is_empty() {
253            vec![Distribution::SinglePartition]
254        } else {
255            vec![Distribution::HashPartitioned(self.partition_keys())]
256        }
257    }
258
259    fn with_new_children(
260        self: Arc<Self>,
261        children: Vec<Arc<dyn ExecutionPlan>>,
262    ) -> Result<Arc<dyn ExecutionPlan>> {
263        Ok(Arc::new(WindowAggExec::try_new(
264            self.window_expr.clone(),
265            Arc::clone(&children[0]),
266            true,
267        )?))
268    }
269
270    fn execute(
271        &self,
272        partition: usize,
273        context: Arc<TaskContext>,
274    ) -> Result<SendableRecordBatchStream> {
275        let input = self.input.execute(partition, context)?;
276        let stream = Box::pin(WindowAggStream::new(
277            Arc::clone(&self.schema),
278            self.window_expr.clone(),
279            input,
280            BaselineMetrics::new(&self.metrics, partition),
281            self.partition_by_sort_keys()?,
282            self.ordered_partition_by_indices.clone(),
283        )?);
284        Ok(stream)
285    }
286
287    fn metrics(&self) -> Option<MetricsSet> {
288        Some(self.metrics.clone_inner())
289    }
290
291    fn statistics(&self) -> Result<Statistics> {
292        self.statistics_inner()
293    }
294
295    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
296        if partition.is_none() {
297            self.statistics_inner()
298        } else {
299            Ok(Statistics::new_unknown(&self.schema()))
300        }
301    }
302}
303
304/// Compute the window aggregate columns
305fn compute_window_aggregates(
306    window_expr: &[Arc<dyn WindowExpr>],
307    batch: &RecordBatch,
308) -> Result<Vec<ArrayRef>> {
309    window_expr
310        .iter()
311        .map(|window_expr| window_expr.evaluate(batch))
312        .collect()
313}
314
315/// stream for window aggregation plan
316pub struct WindowAggStream {
317    schema: SchemaRef,
318    input: SendableRecordBatchStream,
319    batches: Vec<RecordBatch>,
320    finished: bool,
321    window_expr: Vec<Arc<dyn WindowExpr>>,
322    partition_by_sort_keys: LexOrdering,
323    baseline_metrics: BaselineMetrics,
324    ordered_partition_by_indices: Vec<usize>,
325}
326
327impl WindowAggStream {
328    /// Create a new WindowAggStream
329    pub fn new(
330        schema: SchemaRef,
331        window_expr: Vec<Arc<dyn WindowExpr>>,
332        input: SendableRecordBatchStream,
333        baseline_metrics: BaselineMetrics,
334        partition_by_sort_keys: LexOrdering,
335        ordered_partition_by_indices: Vec<usize>,
336    ) -> Result<Self> {
337        // In WindowAggExec all partition by columns should be ordered.
338        if window_expr[0].partition_by().len() != ordered_partition_by_indices.len() {
339            return internal_err!("All partition by columns should have an ordering");
340        }
341        Ok(Self {
342            schema,
343            input,
344            batches: vec![],
345            finished: false,
346            window_expr,
347            baseline_metrics,
348            partition_by_sort_keys,
349            ordered_partition_by_indices,
350        })
351    }
352
353    fn compute_aggregates(&self) -> Result<Option<RecordBatch>> {
354        // record compute time on drop
355        let _timer = self.baseline_metrics.elapsed_compute().timer();
356
357        let batch = concat_batches(&self.input.schema(), &self.batches)?;
358        if batch.num_rows() == 0 {
359            return Ok(None);
360        }
361
362        let partition_by_sort_keys = self
363            .ordered_partition_by_indices
364            .iter()
365            .map(|idx| self.partition_by_sort_keys[*idx].evaluate_to_sort_column(&batch))
366            .collect::<Result<Vec<_>>>()?;
367        let partition_points =
368            evaluate_partition_ranges(batch.num_rows(), &partition_by_sort_keys)?;
369
370        let mut partition_results = vec![];
371        // Calculate window cols
372        for partition_point in partition_points {
373            let length = partition_point.end - partition_point.start;
374            partition_results.push(compute_window_aggregates(
375                &self.window_expr,
376                &batch.slice(partition_point.start, length),
377            )?)
378        }
379        let columns = transpose(partition_results)
380            .iter()
381            .map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::<Vec<_>>()))
382            .collect::<Vec<_>>()
383            .into_iter()
384            .collect::<Result<Vec<ArrayRef>, ArrowError>>()?;
385
386        // combine with the original cols
387        // note the setup of window aggregates is that they newly calculated window
388        // expression results are always appended to the columns
389        let mut batch_columns = batch.columns().to_vec();
390        // calculate window cols
391        batch_columns.extend_from_slice(&columns);
392        Ok(Some(RecordBatch::try_new(
393            Arc::clone(&self.schema),
394            batch_columns,
395        )?))
396    }
397}
398
399impl Stream for WindowAggStream {
400    type Item = Result<RecordBatch>;
401
402    fn poll_next(
403        mut self: Pin<&mut Self>,
404        cx: &mut Context<'_>,
405    ) -> Poll<Option<Self::Item>> {
406        let poll = self.poll_next_inner(cx);
407        self.baseline_metrics.record_poll(poll)
408    }
409}
410
411impl WindowAggStream {
412    #[inline]
413    fn poll_next_inner(
414        &mut self,
415        cx: &mut Context<'_>,
416    ) -> Poll<Option<Result<RecordBatch>>> {
417        if self.finished {
418            return Poll::Ready(None);
419        }
420
421        loop {
422            return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) {
423                Some(Ok(batch)) => {
424                    self.batches.push(batch);
425                    continue;
426                }
427                Some(Err(e)) => Err(e),
428                None => {
429                    let Some(result) = self.compute_aggregates()? else {
430                        return Poll::Ready(None);
431                    };
432                    self.finished = true;
433                    // Empty record batches should not be emitted.
434                    // They need to be treated as  [`Option<RecordBatch>`]es and handled separately
435                    debug_assert!(result.num_rows() > 0);
436                    Ok(result)
437                }
438            }));
439        }
440    }
441}
442
443impl RecordBatchStream for WindowAggStream {
444    /// Get the schema
445    fn schema(&self) -> SchemaRef {
446        Arc::clone(&self.schema)
447    }
448}