datafusion_physical_plan/aggregates/
topk_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A memory-conscious aggregation implementation that limits group buckets to a fixed number
19
20use crate::aggregates::topk::priority_map::PriorityMap;
21use crate::aggregates::{
22    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
23    PhysicalGroupBy,
24};
25use crate::{RecordBatchStream, SendableRecordBatchStream};
26use arrow::array::{Array, ArrayRef, RecordBatch};
27use arrow::datatypes::SchemaRef;
28use arrow::util::pretty::print_batches;
29use datafusion_common::DataFusionError;
30use datafusion_common::Result;
31use datafusion_execution::TaskContext;
32use datafusion_physical_expr::PhysicalExpr;
33use futures::stream::{Stream, StreamExt};
34use log::{trace, Level};
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38
39pub struct GroupedTopKAggregateStream {
40    partition: usize,
41    row_count: usize,
42    started: bool,
43    schema: SchemaRef,
44    input: SendableRecordBatchStream,
45    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
46    group_by: PhysicalGroupBy,
47    priority_map: PriorityMap,
48}
49
50impl GroupedTopKAggregateStream {
51    pub fn new(
52        aggr: &AggregateExec,
53        context: Arc<TaskContext>,
54        partition: usize,
55        limit: usize,
56    ) -> Result<Self> {
57        let agg_schema = Arc::clone(&aggr.schema);
58        let group_by = aggr.group_by.clone();
59        let input = aggr.input.execute(partition, Arc::clone(&context))?;
60        let aggregate_arguments =
61            aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?;
62        let (val_field, desc) = aggr
63            .get_minmax_desc()
64            .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?;
65
66        let (expr, _) = &aggr.group_expr().expr()[0];
67        let kt = expr.data_type(&aggr.input().schema())?;
68        let vt = val_field.data_type().clone();
69
70        let priority_map = PriorityMap::new(kt, vt, limit, desc)?;
71
72        Ok(GroupedTopKAggregateStream {
73            partition,
74            started: false,
75            row_count: 0,
76            schema: agg_schema,
77            input,
78            aggregate_arguments,
79            group_by,
80            priority_map,
81        })
82    }
83}
84
85impl RecordBatchStream for GroupedTopKAggregateStream {
86    fn schema(&self) -> SchemaRef {
87        Arc::clone(&self.schema)
88    }
89}
90
91impl GroupedTopKAggregateStream {
92    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
93        let len = ids.len();
94        self.priority_map.set_batch(ids, Arc::clone(&vals));
95
96        let has_nulls = vals.null_count() > 0;
97        for row_idx in 0..len {
98            if has_nulls && vals.is_null(row_idx) {
99                continue;
100            }
101            self.priority_map.insert(row_idx)?;
102        }
103        Ok(())
104    }
105}
106
107impl Stream for GroupedTopKAggregateStream {
108    type Item = Result<RecordBatch>;
109
110    fn poll_next(
111        mut self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113    ) -> Poll<Option<Self::Item>> {
114        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
115            match res {
116                // got a batch, convert to rows and append to our TreeMap
117                Some(Ok(batch)) => {
118                    self.started = true;
119                    trace!(
120                        "partition {} has {} rows and got batch with {} rows",
121                        self.partition,
122                        self.row_count,
123                        batch.num_rows()
124                    );
125                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 {
126                        print_batches(std::slice::from_ref(&batch))?;
127                    }
128                    self.row_count += batch.num_rows();
129                    let batches = &[batch];
130                    let group_by_values =
131                        evaluate_group_by(&self.group_by, batches.first().unwrap())?;
132                    assert_eq!(
133                        group_by_values.len(),
134                        1,
135                        "Exactly 1 group value required"
136                    );
137                    assert_eq!(
138                        group_by_values[0].len(),
139                        1,
140                        "Exactly 1 group value required"
141                    );
142                    let group_by_values = Arc::clone(&group_by_values[0][0]);
143                    let input_values = evaluate_many(
144                        &self.aggregate_arguments,
145                        batches.first().unwrap(),
146                    )?;
147                    assert_eq!(input_values.len(), 1, "Exactly 1 input required");
148                    assert_eq!(input_values[0].len(), 1, "Exactly 1 input required");
149                    let input_values = Arc::clone(&input_values[0][0]);
150
151                    // iterate over each column of group_by values
152                    (*self).intern(group_by_values, input_values)?;
153                }
154                // inner is done, emit all rows and switch to producing output
155                None => {
156                    if self.priority_map.is_empty() {
157                        trace!("partition {} emit None", self.partition);
158                        return Poll::Ready(None);
159                    }
160                    let cols = self.priority_map.emit()?;
161                    let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?;
162                    trace!(
163                        "partition {} emit batch with {} rows",
164                        self.partition,
165                        batch.num_rows()
166                    );
167                    if log::log_enabled!(Level::Trace) {
168                        print_batches(std::slice::from_ref(&batch))?;
169                    }
170                    return Poll::Ready(Some(Ok(batch)));
171                }
172                // inner had error, return to caller
173                Some(Err(e)) => {
174                    return Poll::Ready(Some(Err(e)));
175                }
176            }
177        }
178        Poll::Pending
179    }
180}