datafusion_physical_plan/
async_func.rs1use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter};
21use crate::{
22 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23 check_if_same_properties,
24};
25use arrow::array::RecordBatch;
26use arrow_schema::{Fields, Schema, SchemaRef};
27use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
28use datafusion_common::{Result, assert_eq_or_internal_err};
29use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
30use datafusion_physical_expr::ScalarFunctionExpr;
31use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
32use datafusion_physical_expr::equivalence::ProjectionMapping;
33use datafusion_physical_expr::expressions::Column;
34use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput};
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use futures::Stream;
37use futures::stream::StreamExt;
38use log::trace;
39use std::pin::Pin;
40use std::sync::Arc;
41use std::task::{Context, Poll, ready};
42
43#[derive(Debug, Clone)]
49pub struct AsyncFuncExec {
50 async_exprs: Vec<Arc<AsyncFuncExpr>>,
52 input: Arc<dyn ExecutionPlan>,
53 cache: Arc<PlanProperties>,
54 metrics: ExecutionPlanMetricsSet,
55}
56
57impl AsyncFuncExec {
58 pub fn try_new(
59 async_exprs: Vec<Arc<AsyncFuncExpr>>,
60 input: Arc<dyn ExecutionPlan>,
61 ) -> Result<Self> {
62 let async_fields = async_exprs
63 .iter()
64 .map(|async_expr| async_expr.field(input.schema().as_ref()))
65 .collect::<Result<Vec<_>>>()?;
66
67 let fields: Fields = input
69 .schema()
70 .fields()
71 .iter()
72 .cloned()
73 .chain(async_fields.into_iter().map(Arc::new))
74 .collect();
75
76 let schema = Arc::new(Schema::new(fields));
77 let tuples = async_exprs
78 .iter()
79 .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
80 .collect::<Vec<_>>();
81 let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
82 let cache =
83 AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
84 Ok(Self {
85 input,
86 async_exprs,
87 cache: Arc::new(cache),
88 metrics: ExecutionPlanMetricsSet::new(),
89 })
90 }
91
92 fn compute_properties(
95 input: &Arc<dyn ExecutionPlan>,
96 schema: SchemaRef,
97 async_expr_mapping: &ProjectionMapping,
98 ) -> Result<PlanProperties> {
99 Ok(PlanProperties::new(
100 input
101 .equivalence_properties()
102 .project(async_expr_mapping, schema),
103 input.output_partitioning().clone(),
104 input.pipeline_behavior(),
105 input.boundedness(),
106 ))
107 }
108
109 pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
110 &self.async_exprs
111 }
112
113 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
114 &self.input
115 }
116
117 fn with_new_children_and_same_properties(
118 &self,
119 mut children: Vec<Arc<dyn ExecutionPlan>>,
120 ) -> Self {
121 Self {
122 input: children.swap_remove(0),
123 metrics: ExecutionPlanMetricsSet::new(),
124 ..Self::clone(self)
125 }
126 }
127}
128
129impl DisplayAs for AsyncFuncExec {
130 fn fmt_as(
131 &self,
132 t: DisplayFormatType,
133 f: &mut std::fmt::Formatter,
134 ) -> std::fmt::Result {
135 let expr: Vec<String> = self
136 .async_exprs
137 .iter()
138 .map(|async_expr| async_expr.to_string())
139 .collect();
140 let exprs = expr.join(", ");
141 match t {
142 DisplayFormatType::Default | DisplayFormatType::Verbose => {
143 write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
144 }
145 DisplayFormatType::TreeRender => {
146 writeln!(f, "format=async_expr")?;
147 writeln!(f, "async_expr={exprs}")?;
148 Ok(())
149 }
150 }
151 }
152}
153
154impl ExecutionPlan for AsyncFuncExec {
155 fn name(&self) -> &str {
156 "async_func"
157 }
158
159 fn properties(&self) -> &Arc<PlanProperties> {
160 &self.cache
161 }
162
163 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164 vec![&self.input]
165 }
166
167 fn with_new_children(
168 self: Arc<Self>,
169 mut children: Vec<Arc<dyn ExecutionPlan>>,
170 ) -> Result<Arc<dyn ExecutionPlan>> {
171 assert_eq_or_internal_err!(
172 children.len(),
173 1,
174 "AsyncFuncExec wrong number of children"
175 );
176 check_if_same_properties!(self, children);
177 Ok(Arc::new(AsyncFuncExec::try_new(
178 self.async_exprs.clone(),
179 children.swap_remove(0),
180 )?))
181 }
182
183 fn execute(
184 &self,
185 partition: usize,
186 context: Arc<TaskContext>,
187 ) -> Result<SendableRecordBatchStream> {
188 trace!(
189 "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
190 partition,
191 context.session_id(),
192 context.task_id()
193 );
194
195 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
197
198 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
201
202 let async_exprs_captured = Arc::new(self.async_exprs.clone());
204 let schema_captured = self.schema();
205 let config_options_ref = Arc::clone(context.session_config().options());
206
207 let coalesced_input_stream = CoalesceInputStream {
208 input_stream,
209 batch_coalescer: LimitedBatchCoalescer::new(
210 Arc::clone(&self.input.schema()),
211 config_options_ref.execution.batch_size,
212 None,
213 ),
214 };
215
216 let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
217 let async_exprs_captured = Arc::clone(&async_exprs_captured);
220 let schema_captured = Arc::clone(&schema_captured);
221 let config_options = Arc::clone(&config_options_ref);
222 let baseline_metrics_captured = baseline_metrics.clone();
223
224 async move {
225 let batch = batch?;
226 let mut output_arrays = batch.columns().to_vec();
228 for async_expr in async_exprs_captured.iter() {
229 let output = async_expr
230 .invoke_with_args(&batch, Arc::clone(&config_options))
231 .await?;
232 output_arrays.push(output.to_array(batch.num_rows())?);
233 }
234 let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
235
236 Ok(batch.record_output(&baseline_metrics_captured))
237 }
238 });
239
240 let adapter =
242 RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
243 Ok(Box::pin(adapter))
244 }
245
246 fn metrics(&self) -> Option<MetricsSet> {
247 Some(self.metrics.clone_inner())
248 }
249}
250
251struct CoalesceInputStream {
252 input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
253 batch_coalescer: LimitedBatchCoalescer,
254}
255
256impl Stream for CoalesceInputStream {
257 type Item = Result<RecordBatch>;
258
259 fn poll_next(
260 mut self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 ) -> Poll<Option<Self::Item>> {
263 let mut completed = false;
264
265 loop {
266 if let Some(batch) = self.batch_coalescer.next_completed_batch() {
267 return Poll::Ready(Some(Ok(batch)));
268 }
269
270 if completed {
271 return Poll::Ready(None);
272 }
273
274 match ready!(self.input_stream.poll_next_unpin(cx)) {
275 Some(Ok(batch)) => {
276 if let Err(err) = self.batch_coalescer.push_batch(batch) {
277 return Poll::Ready(Some(Err(err)));
278 }
279 }
280 Some(err) => {
281 return Poll::Ready(Some(err));
282 }
283 None => {
284 completed = true;
285 let input_schema = self.input_stream.schema();
287 self.input_stream =
288 Box::pin(EmptyRecordBatchStream::new(input_schema));
289 if let Err(err) = self.batch_coalescer.finish() {
290 return Poll::Ready(Some(Err(err)));
291 }
292 }
293 }
294 }
295 }
296}
297
298const ASYNC_FN_PREFIX: &str = "__async_fn_";
299
300#[derive(Debug)]
304pub struct AsyncMapper {
305 num_input_columns: usize,
309 pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
311}
312
313impl AsyncMapper {
314 pub fn new(num_input_columns: usize) -> Self {
315 Self {
316 num_input_columns,
317 async_exprs: Vec::new(),
318 }
319 }
320
321 pub fn is_empty(&self) -> bool {
322 self.async_exprs.is_empty()
323 }
324
325 pub fn next_column_name(&self) -> String {
326 format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
327 }
328
329 pub fn find_references(
331 &mut self,
332 physical_expr: &Arc<dyn PhysicalExpr>,
333 schema: &Schema,
334 ) -> Result<()> {
335 physical_expr.apply(|expr| {
337 if let Some(scalar_func_expr) = expr.downcast_ref::<ScalarFunctionExpr>()
338 && scalar_func_expr.fun().as_async().is_some()
339 {
340 let next_name = self.next_column_name();
341 self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
342 next_name,
343 Arc::clone(expr),
344 schema,
345 )?));
346 }
347 Ok(TreeNodeRecursion::Continue)
348 })?;
349 Ok(())
350 }
351
352 pub fn map_expr(
354 &self,
355 expr: Arc<dyn PhysicalExpr>,
356 ) -> Transformed<Arc<dyn PhysicalExpr>> {
357 let Some(idx) =
359 self.async_exprs
360 .iter()
361 .enumerate()
362 .find_map(|(idx, async_expr)| {
363 if async_expr.func == Arc::clone(&expr) {
364 Some(idx)
365 } else {
366 None
367 }
368 })
369 else {
370 return Transformed::no(expr);
371 };
372 Transformed::yes(self.output_column(idx))
374 }
375
376 pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
378 let async_expr = &self.async_exprs[idx];
379 let output_idx = self.num_input_columns + idx;
380 Arc::new(Column::new(async_expr.name(), output_idx))
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use std::sync::Arc;
387
388 use arrow::array::{RecordBatch, UInt32Array};
389 use arrow_schema::{DataType, Field, Schema};
390 use datafusion_common::Result;
391 use datafusion_execution::{TaskContext, config::SessionConfig};
392 use futures::StreamExt;
393
394 use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
395
396 #[tokio::test]
397 async fn test_async_fn_with_coalescing() -> Result<()> {
398 let schema =
399 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
400
401 let batch = RecordBatch::try_new(
402 Arc::clone(&schema),
403 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
404 )?;
405
406 let batches: Vec<RecordBatch> = std::iter::repeat_n(batch, 50).collect();
407
408 let session_config = SessionConfig::new().with_batch_size(200);
409 let task_ctx = TaskContext::default().with_session_config(session_config);
410 let task_ctx = Arc::new(task_ctx);
411
412 let test_exec =
413 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
414 let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
415
416 let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
417 let batch = stream
418 .next()
419 .await
420 .expect("expected to get a record batch")?;
421 assert_eq!(200, batch.num_rows());
422 let batch = stream
423 .next()
424 .await
425 .expect("expected to get a record batch")?;
426 assert_eq!(100, batch.num_rows());
427
428 Ok(())
429 }
430}