datafusion_physical_plan/
async_func.rs1use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::RecordBatchStreamAdapter;
21use crate::{
22 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23};
24use arrow::array::RecordBatch;
25use arrow_schema::{Fields, Schema, SchemaRef};
26use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
27use datafusion_common::{Result, assert_eq_or_internal_err};
28use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion_physical_expr::ScalarFunctionExpr;
30use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
31use datafusion_physical_expr::equivalence::ProjectionMapping;
32use datafusion_physical_expr::expressions::Column;
33use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
34use futures::Stream;
35use futures::stream::StreamExt;
36use log::trace;
37use std::any::Any;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll, ready};
41
42#[derive(Debug)]
48pub struct AsyncFuncExec {
49 async_exprs: Vec<Arc<AsyncFuncExpr>>,
51 input: Arc<dyn ExecutionPlan>,
52 cache: PlanProperties,
53 metrics: ExecutionPlanMetricsSet,
54}
55
56impl AsyncFuncExec {
57 pub fn try_new(
58 async_exprs: Vec<Arc<AsyncFuncExpr>>,
59 input: Arc<dyn ExecutionPlan>,
60 ) -> Result<Self> {
61 let async_fields = async_exprs
62 .iter()
63 .map(|async_expr| async_expr.field(input.schema().as_ref()))
64 .collect::<Result<Vec<_>>>()?;
65
66 let fields: Fields = input
68 .schema()
69 .fields()
70 .iter()
71 .cloned()
72 .chain(async_fields.into_iter().map(Arc::new))
73 .collect();
74
75 let schema = Arc::new(Schema::new(fields));
76 let tuples = async_exprs
77 .iter()
78 .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
79 .collect::<Vec<_>>();
80 let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
81 let cache =
82 AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
83 Ok(Self {
84 input,
85 async_exprs,
86 cache,
87 metrics: ExecutionPlanMetricsSet::new(),
88 })
89 }
90
91 fn compute_properties(
94 input: &Arc<dyn ExecutionPlan>,
95 schema: SchemaRef,
96 async_expr_mapping: &ProjectionMapping,
97 ) -> Result<PlanProperties> {
98 Ok(PlanProperties::new(
99 input
100 .equivalence_properties()
101 .project(async_expr_mapping, schema),
102 input.output_partitioning().clone(),
103 input.pipeline_behavior(),
104 input.boundedness(),
105 ))
106 }
107
108 pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
109 &self.async_exprs
110 }
111
112 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
113 &self.input
114 }
115}
116
117impl DisplayAs for AsyncFuncExec {
118 fn fmt_as(
119 &self,
120 t: DisplayFormatType,
121 f: &mut std::fmt::Formatter,
122 ) -> std::fmt::Result {
123 let expr: Vec<String> = self
124 .async_exprs
125 .iter()
126 .map(|async_expr| async_expr.to_string())
127 .collect();
128 let exprs = expr.join(", ");
129 match t {
130 DisplayFormatType::Default | DisplayFormatType::Verbose => {
131 write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
132 }
133 DisplayFormatType::TreeRender => {
134 writeln!(f, "format=async_expr")?;
135 writeln!(f, "async_expr={exprs}")?;
136 Ok(())
137 }
138 }
139 }
140}
141
142impl ExecutionPlan for AsyncFuncExec {
143 fn name(&self) -> &str {
144 "async_func"
145 }
146
147 fn as_any(&self) -> &dyn Any {
148 self
149 }
150
151 fn properties(&self) -> &PlanProperties {
152 &self.cache
153 }
154
155 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
156 vec![&self.input]
157 }
158
159 fn with_new_children(
160 self: Arc<Self>,
161 children: Vec<Arc<dyn ExecutionPlan>>,
162 ) -> Result<Arc<dyn ExecutionPlan>> {
163 assert_eq_or_internal_err!(
164 children.len(),
165 1,
166 "AsyncFuncExec wrong number of children"
167 );
168 Ok(Arc::new(AsyncFuncExec::try_new(
169 self.async_exprs.clone(),
170 Arc::clone(&children[0]),
171 )?))
172 }
173
174 fn execute(
175 &self,
176 partition: usize,
177 context: Arc<TaskContext>,
178 ) -> Result<SendableRecordBatchStream> {
179 trace!(
180 "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
181 partition,
182 context.session_id(),
183 context.task_id()
184 );
185 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
189
190 let async_exprs_captured = Arc::new(self.async_exprs.clone());
192 let schema_captured = self.schema();
193 let config_options_ref = Arc::clone(context.session_config().options());
194
195 let coalesced_input_stream = CoalesceInputStream {
196 input_stream,
197 batch_coalescer: LimitedBatchCoalescer::new(
198 Arc::clone(&self.input.schema()),
199 config_options_ref.execution.batch_size,
200 None,
201 ),
202 };
203
204 let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
205 let async_exprs_captured = Arc::clone(&async_exprs_captured);
208 let schema_captured = Arc::clone(&schema_captured);
209 let config_options = Arc::clone(&config_options_ref);
210
211 async move {
212 let batch = batch?;
213 let mut output_arrays = batch.columns().to_vec();
215 for async_expr in async_exprs_captured.iter() {
216 let output = async_expr
217 .invoke_with_args(&batch, Arc::clone(&config_options))
218 .await?;
219 output_arrays.push(output.to_array(batch.num_rows())?);
220 }
221 let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
222 Ok(batch)
223 }
224 });
225
226 let adapter =
228 RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
229 Ok(Box::pin(adapter))
230 }
231
232 fn metrics(&self) -> Option<MetricsSet> {
233 Some(self.metrics.clone_inner())
234 }
235}
236
237struct CoalesceInputStream {
238 input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
239 batch_coalescer: LimitedBatchCoalescer,
240}
241
242impl Stream for CoalesceInputStream {
243 type Item = Result<RecordBatch>;
244
245 fn poll_next(
246 mut self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 ) -> Poll<Option<Self::Item>> {
249 let mut completed = false;
250
251 loop {
252 if let Some(batch) = self.batch_coalescer.next_completed_batch() {
253 return Poll::Ready(Some(Ok(batch)));
254 }
255
256 if completed {
257 return Poll::Ready(None);
258 }
259
260 match ready!(self.input_stream.poll_next_unpin(cx)) {
261 Some(Ok(batch)) => {
262 if let Err(err) = self.batch_coalescer.push_batch(batch) {
263 return Poll::Ready(Some(Err(err)));
264 }
265 }
266 Some(err) => {
267 return Poll::Ready(Some(err));
268 }
269 None => {
270 completed = true;
271 if let Err(err) = self.batch_coalescer.finish() {
272 return Poll::Ready(Some(Err(err)));
273 }
274 }
275 }
276 }
277 }
278}
279
280const ASYNC_FN_PREFIX: &str = "__async_fn_";
281
282#[derive(Debug)]
286pub struct AsyncMapper {
287 num_input_columns: usize,
291 pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
293}
294
295impl AsyncMapper {
296 pub fn new(num_input_columns: usize) -> Self {
297 Self {
298 num_input_columns,
299 async_exprs: Vec::new(),
300 }
301 }
302
303 pub fn is_empty(&self) -> bool {
304 self.async_exprs.is_empty()
305 }
306
307 pub fn next_column_name(&self) -> String {
308 format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
309 }
310
311 pub fn find_references(
313 &mut self,
314 physical_expr: &Arc<dyn PhysicalExpr>,
315 schema: &Schema,
316 ) -> Result<()> {
317 physical_expr.apply(|expr| {
319 if let Some(scalar_func_expr) =
320 expr.as_any().downcast_ref::<ScalarFunctionExpr>()
321 && scalar_func_expr.fun().as_async().is_some()
322 {
323 let next_name = self.next_column_name();
324 self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
325 next_name,
326 Arc::clone(expr),
327 schema,
328 )?));
329 }
330 Ok(TreeNodeRecursion::Continue)
331 })?;
332 Ok(())
333 }
334
335 pub fn map_expr(
337 &self,
338 expr: Arc<dyn PhysicalExpr>,
339 ) -> Transformed<Arc<dyn PhysicalExpr>> {
340 let Some(idx) =
342 self.async_exprs
343 .iter()
344 .enumerate()
345 .find_map(|(idx, async_expr)| {
346 if async_expr.func == Arc::clone(&expr) {
347 Some(idx)
348 } else {
349 None
350 }
351 })
352 else {
353 return Transformed::no(expr);
354 };
355 Transformed::yes(self.output_column(idx))
357 }
358
359 pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
361 let async_expr = &self.async_exprs[idx];
362 let output_idx = self.num_input_columns + idx;
363 Arc::new(Column::new(async_expr.name(), output_idx))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use std::sync::Arc;
370
371 use arrow::array::{RecordBatch, UInt32Array};
372 use arrow_schema::{DataType, Field, Schema};
373 use datafusion_common::Result;
374 use datafusion_execution::{TaskContext, config::SessionConfig};
375 use futures::StreamExt;
376
377 use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
378
379 #[tokio::test]
380 async fn test_async_fn_with_coalescing() -> Result<()> {
381 let schema =
382 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
383
384 let batch = RecordBatch::try_new(
385 Arc::clone(&schema),
386 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
387 )?;
388
389 let batches: Vec<RecordBatch> = (0..50).map(|_| batch.clone()).collect();
390
391 let session_config = SessionConfig::new().with_batch_size(200);
392 let task_ctx = TaskContext::default().with_session_config(session_config);
393 let task_ctx = Arc::new(task_ctx);
394
395 let test_exec =
396 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
397 let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
398
399 let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
400 let batch = stream
401 .next()
402 .await
403 .expect("expected to get a record batch")?;
404 assert_eq!(200, batch.num_rows());
405 let batch = stream
406 .next()
407 .await
408 .expect("expected to get a record batch")?;
409 assert_eq!(100, batch.num_rows());
410
411 Ok(())
412 }
413}