datafusion_physical_plan/aggregates/
topk_stream.rs1use crate::aggregates::topk::priority_map::PriorityMap;
21use crate::aggregates::{
22 aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
23 PhysicalGroupBy,
24};
25use crate::{RecordBatchStream, SendableRecordBatchStream};
26use arrow::array::{Array, ArrayRef, RecordBatch};
27use arrow::datatypes::SchemaRef;
28use arrow::util::pretty::print_batches;
29use datafusion_common::DataFusionError;
30use datafusion_common::Result;
31use datafusion_execution::TaskContext;
32use datafusion_physical_expr::PhysicalExpr;
33use futures::stream::{Stream, StreamExt};
34use log::{trace, Level};
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38
39pub struct GroupedTopKAggregateStream {
40 partition: usize,
41 row_count: usize,
42 started: bool,
43 schema: SchemaRef,
44 input: SendableRecordBatchStream,
45 aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
46 group_by: PhysicalGroupBy,
47 priority_map: PriorityMap,
48}
49
50impl GroupedTopKAggregateStream {
51 pub fn new(
52 aggr: &AggregateExec,
53 context: Arc<TaskContext>,
54 partition: usize,
55 limit: usize,
56 ) -> Result<Self> {
57 let agg_schema = Arc::clone(&aggr.schema);
58 let group_by = aggr.group_by.clone();
59 let input = aggr.input.execute(partition, Arc::clone(&context))?;
60 let aggregate_arguments =
61 aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?;
62 let (val_field, desc) = aggr
63 .get_minmax_desc()
64 .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?;
65
66 let (expr, _) = &aggr.group_expr().expr()[0];
67 let kt = expr.data_type(&aggr.input().schema())?;
68 let vt = val_field.data_type().clone();
69
70 let priority_map = PriorityMap::new(kt, vt, limit, desc)?;
71
72 Ok(GroupedTopKAggregateStream {
73 partition,
74 started: false,
75 row_count: 0,
76 schema: agg_schema,
77 input,
78 aggregate_arguments,
79 group_by,
80 priority_map,
81 })
82 }
83}
84
85impl RecordBatchStream for GroupedTopKAggregateStream {
86 fn schema(&self) -> SchemaRef {
87 Arc::clone(&self.schema)
88 }
89}
90
91impl GroupedTopKAggregateStream {
92 fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
93 let len = ids.len();
94 self.priority_map.set_batch(ids, Arc::clone(&vals));
95
96 let has_nulls = vals.null_count() > 0;
97 for row_idx in 0..len {
98 if has_nulls && vals.is_null(row_idx) {
99 continue;
100 }
101 self.priority_map.insert(row_idx)?;
102 }
103 Ok(())
104 }
105}
106
107impl Stream for GroupedTopKAggregateStream {
108 type Item = Result<RecordBatch>;
109
110 fn poll_next(
111 mut self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 ) -> Poll<Option<Self::Item>> {
114 while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
115 match res {
116 Some(Ok(batch)) => {
118 self.started = true;
119 trace!(
120 "partition {} has {} rows and got batch with {} rows",
121 self.partition,
122 self.row_count,
123 batch.num_rows()
124 );
125 if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 {
126 print_batches(std::slice::from_ref(&batch))?;
127 }
128 self.row_count += batch.num_rows();
129 let batches = &[batch];
130 let group_by_values =
131 evaluate_group_by(&self.group_by, batches.first().unwrap())?;
132 assert_eq!(
133 group_by_values.len(),
134 1,
135 "Exactly 1 group value required"
136 );
137 assert_eq!(
138 group_by_values[0].len(),
139 1,
140 "Exactly 1 group value required"
141 );
142 let group_by_values = Arc::clone(&group_by_values[0][0]);
143 let input_values = evaluate_many(
144 &self.aggregate_arguments,
145 batches.first().unwrap(),
146 )?;
147 assert_eq!(input_values.len(), 1, "Exactly 1 input required");
148 assert_eq!(input_values[0].len(), 1, "Exactly 1 input required");
149 let input_values = Arc::clone(&input_values[0][0]);
150
151 (*self).intern(group_by_values, input_values)?;
153 }
154 None => {
156 if self.priority_map.is_empty() {
157 trace!("partition {} emit None", self.partition);
158 return Poll::Ready(None);
159 }
160 let cols = self.priority_map.emit()?;
161 let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?;
162 trace!(
163 "partition {} emit batch with {} rows",
164 self.partition,
165 batch.num_rows()
166 );
167 if log::log_enabled!(Level::Trace) {
168 print_batches(std::slice::from_ref(&batch))?;
169 }
170 return Poll::Ready(Some(Ok(batch)));
171 }
172 Some(Err(e)) => {
174 return Poll::Ready(Some(Err(e)));
175 }
176 }
177 }
178 Poll::Pending
179 }
180}