datafusion_physical_plan/
async_func.rs1use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::RecordBatchStreamAdapter;
21use crate::{
22 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23 check_if_same_properties,
24};
25use arrow::array::RecordBatch;
26use arrow_schema::{Fields, Schema, SchemaRef};
27use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
28use datafusion_common::{Result, assert_eq_or_internal_err};
29use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
30use datafusion_physical_expr::ScalarFunctionExpr;
31use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
32use datafusion_physical_expr::equivalence::ProjectionMapping;
33use datafusion_physical_expr::expressions::Column;
34use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput};
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use futures::Stream;
37use futures::stream::StreamExt;
38use log::trace;
39use std::any::Any;
40use std::pin::Pin;
41use std::sync::Arc;
42use std::task::{Context, Poll, ready};
43
44#[derive(Debug, Clone)]
50pub struct AsyncFuncExec {
51 async_exprs: Vec<Arc<AsyncFuncExpr>>,
53 input: Arc<dyn ExecutionPlan>,
54 cache: Arc<PlanProperties>,
55 metrics: ExecutionPlanMetricsSet,
56}
57
58impl AsyncFuncExec {
59 pub fn try_new(
60 async_exprs: Vec<Arc<AsyncFuncExpr>>,
61 input: Arc<dyn ExecutionPlan>,
62 ) -> Result<Self> {
63 let async_fields = async_exprs
64 .iter()
65 .map(|async_expr| async_expr.field(input.schema().as_ref()))
66 .collect::<Result<Vec<_>>>()?;
67
68 let fields: Fields = input
70 .schema()
71 .fields()
72 .iter()
73 .cloned()
74 .chain(async_fields.into_iter().map(Arc::new))
75 .collect();
76
77 let schema = Arc::new(Schema::new(fields));
78 let tuples = async_exprs
79 .iter()
80 .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
81 .collect::<Vec<_>>();
82 let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
83 let cache =
84 AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
85 Ok(Self {
86 input,
87 async_exprs,
88 cache: Arc::new(cache),
89 metrics: ExecutionPlanMetricsSet::new(),
90 })
91 }
92
93 fn compute_properties(
96 input: &Arc<dyn ExecutionPlan>,
97 schema: SchemaRef,
98 async_expr_mapping: &ProjectionMapping,
99 ) -> Result<PlanProperties> {
100 Ok(PlanProperties::new(
101 input
102 .equivalence_properties()
103 .project(async_expr_mapping, schema),
104 input.output_partitioning().clone(),
105 input.pipeline_behavior(),
106 input.boundedness(),
107 ))
108 }
109
110 pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
111 &self.async_exprs
112 }
113
114 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
115 &self.input
116 }
117
118 fn with_new_children_and_same_properties(
119 &self,
120 mut children: Vec<Arc<dyn ExecutionPlan>>,
121 ) -> Self {
122 Self {
123 input: children.swap_remove(0),
124 metrics: ExecutionPlanMetricsSet::new(),
125 ..Self::clone(self)
126 }
127 }
128}
129
130impl DisplayAs for AsyncFuncExec {
131 fn fmt_as(
132 &self,
133 t: DisplayFormatType,
134 f: &mut std::fmt::Formatter,
135 ) -> std::fmt::Result {
136 let expr: Vec<String> = self
137 .async_exprs
138 .iter()
139 .map(|async_expr| async_expr.to_string())
140 .collect();
141 let exprs = expr.join(", ");
142 match t {
143 DisplayFormatType::Default | DisplayFormatType::Verbose => {
144 write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
145 }
146 DisplayFormatType::TreeRender => {
147 writeln!(f, "format=async_expr")?;
148 writeln!(f, "async_expr={exprs}")?;
149 Ok(())
150 }
151 }
152 }
153}
154
155impl ExecutionPlan for AsyncFuncExec {
156 fn name(&self) -> &str {
157 "async_func"
158 }
159
160 fn as_any(&self) -> &dyn Any {
161 self
162 }
163
164 fn properties(&self) -> &Arc<PlanProperties> {
165 &self.cache
166 }
167
168 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
169 vec![&self.input]
170 }
171
172 fn with_new_children(
173 self: Arc<Self>,
174 mut children: Vec<Arc<dyn ExecutionPlan>>,
175 ) -> Result<Arc<dyn ExecutionPlan>> {
176 assert_eq_or_internal_err!(
177 children.len(),
178 1,
179 "AsyncFuncExec wrong number of children"
180 );
181 check_if_same_properties!(self, children);
182 Ok(Arc::new(AsyncFuncExec::try_new(
183 self.async_exprs.clone(),
184 children.swap_remove(0),
185 )?))
186 }
187
188 fn execute(
189 &self,
190 partition: usize,
191 context: Arc<TaskContext>,
192 ) -> Result<SendableRecordBatchStream> {
193 trace!(
194 "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
195 partition,
196 context.session_id(),
197 context.task_id()
198 );
199
200 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
202
203 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
206
207 let async_exprs_captured = Arc::new(self.async_exprs.clone());
209 let schema_captured = self.schema();
210 let config_options_ref = Arc::clone(context.session_config().options());
211
212 let coalesced_input_stream = CoalesceInputStream {
213 input_stream,
214 batch_coalescer: LimitedBatchCoalescer::new(
215 Arc::clone(&self.input.schema()),
216 config_options_ref.execution.batch_size,
217 None,
218 ),
219 };
220
221 let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
222 let async_exprs_captured = Arc::clone(&async_exprs_captured);
225 let schema_captured = Arc::clone(&schema_captured);
226 let config_options = Arc::clone(&config_options_ref);
227 let baseline_metrics_captured = baseline_metrics.clone();
228
229 async move {
230 let batch = batch?;
231 let mut output_arrays = batch.columns().to_vec();
233 for async_expr in async_exprs_captured.iter() {
234 let output = async_expr
235 .invoke_with_args(&batch, Arc::clone(&config_options))
236 .await?;
237 output_arrays.push(output.to_array(batch.num_rows())?);
238 }
239 let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
240
241 Ok(batch.record_output(&baseline_metrics_captured))
242 }
243 });
244
245 let adapter =
247 RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
248 Ok(Box::pin(adapter))
249 }
250
251 fn metrics(&self) -> Option<MetricsSet> {
252 Some(self.metrics.clone_inner())
253 }
254}
255
256struct CoalesceInputStream {
257 input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
258 batch_coalescer: LimitedBatchCoalescer,
259}
260
261impl Stream for CoalesceInputStream {
262 type Item = Result<RecordBatch>;
263
264 fn poll_next(
265 mut self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 ) -> Poll<Option<Self::Item>> {
268 let mut completed = false;
269
270 loop {
271 if let Some(batch) = self.batch_coalescer.next_completed_batch() {
272 return Poll::Ready(Some(Ok(batch)));
273 }
274
275 if completed {
276 return Poll::Ready(None);
277 }
278
279 match ready!(self.input_stream.poll_next_unpin(cx)) {
280 Some(Ok(batch)) => {
281 if let Err(err) = self.batch_coalescer.push_batch(batch) {
282 return Poll::Ready(Some(Err(err)));
283 }
284 }
285 Some(err) => {
286 return Poll::Ready(Some(err));
287 }
288 None => {
289 completed = true;
290 if let Err(err) = self.batch_coalescer.finish() {
291 return Poll::Ready(Some(Err(err)));
292 }
293 }
294 }
295 }
296 }
297}
298
299const ASYNC_FN_PREFIX: &str = "__async_fn_";
300
301#[derive(Debug)]
305pub struct AsyncMapper {
306 num_input_columns: usize,
310 pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
312}
313
314impl AsyncMapper {
315 pub fn new(num_input_columns: usize) -> Self {
316 Self {
317 num_input_columns,
318 async_exprs: Vec::new(),
319 }
320 }
321
322 pub fn is_empty(&self) -> bool {
323 self.async_exprs.is_empty()
324 }
325
326 pub fn next_column_name(&self) -> String {
327 format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
328 }
329
330 pub fn find_references(
332 &mut self,
333 physical_expr: &Arc<dyn PhysicalExpr>,
334 schema: &Schema,
335 ) -> Result<()> {
336 physical_expr.apply(|expr| {
338 if let Some(scalar_func_expr) =
339 expr.as_any().downcast_ref::<ScalarFunctionExpr>()
340 && scalar_func_expr.fun().as_async().is_some()
341 {
342 let next_name = self.next_column_name();
343 self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
344 next_name,
345 Arc::clone(expr),
346 schema,
347 )?));
348 }
349 Ok(TreeNodeRecursion::Continue)
350 })?;
351 Ok(())
352 }
353
354 pub fn map_expr(
356 &self,
357 expr: Arc<dyn PhysicalExpr>,
358 ) -> Transformed<Arc<dyn PhysicalExpr>> {
359 let Some(idx) =
361 self.async_exprs
362 .iter()
363 .enumerate()
364 .find_map(|(idx, async_expr)| {
365 if async_expr.func == Arc::clone(&expr) {
366 Some(idx)
367 } else {
368 None
369 }
370 })
371 else {
372 return Transformed::no(expr);
373 };
374 Transformed::yes(self.output_column(idx))
376 }
377
378 pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
380 let async_expr = &self.async_exprs[idx];
381 let output_idx = self.num_input_columns + idx;
382 Arc::new(Column::new(async_expr.name(), output_idx))
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use std::sync::Arc;
389
390 use arrow::array::{RecordBatch, UInt32Array};
391 use arrow_schema::{DataType, Field, Schema};
392 use datafusion_common::Result;
393 use datafusion_execution::{TaskContext, config::SessionConfig};
394 use futures::StreamExt;
395
396 use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
397
398 #[tokio::test]
399 async fn test_async_fn_with_coalescing() -> Result<()> {
400 let schema =
401 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
402
403 let batch = RecordBatch::try_new(
404 Arc::clone(&schema),
405 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
406 )?;
407
408 let batches: Vec<RecordBatch> = std::iter::repeat_n(batch, 50).collect();
409
410 let session_config = SessionConfig::new().with_batch_size(200);
411 let task_ctx = TaskContext::default().with_session_config(session_config);
412 let task_ctx = Arc::new(task_ctx);
413
414 let test_exec =
415 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
416 let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
417
418 let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
419 let batch = stream
420 .next()
421 .await
422 .expect("expected to get a record batch")?;
423 assert_eq!(200, batch.num_rows());
424 let batch = stream
425 .next()
426 .await
427 .expect("expected to get a record batch")?;
428 assert_eq!(100, batch.num_rows());
429
430 Ok(())
431 }
432}