Skip to main content

datafusion_physical_plan/
async_func.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::RecordBatchStreamAdapter;
21use crate::{
22    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23};
24use arrow::array::RecordBatch;
25use arrow_schema::{Fields, Schema, SchemaRef};
26use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
27use datafusion_common::{Result, assert_eq_or_internal_err};
28use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion_physical_expr::ScalarFunctionExpr;
30use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
31use datafusion_physical_expr::equivalence::ProjectionMapping;
32use datafusion_physical_expr::expressions::Column;
33use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
34use futures::Stream;
35use futures::stream::StreamExt;
36use log::trace;
37use std::any::Any;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll, ready};
41
42/// This structure evaluates a set of async expressions on a record
43/// batch producing a new record batch
44///
45/// The schema of the output of the AsyncFuncExec is:
46/// Input columns followed by one column for each async expression
47#[derive(Debug)]
48pub struct AsyncFuncExec {
49    /// The async expressions to evaluate
50    async_exprs: Vec<Arc<AsyncFuncExpr>>,
51    input: Arc<dyn ExecutionPlan>,
52    cache: PlanProperties,
53    metrics: ExecutionPlanMetricsSet,
54}
55
56impl AsyncFuncExec {
57    pub fn try_new(
58        async_exprs: Vec<Arc<AsyncFuncExpr>>,
59        input: Arc<dyn ExecutionPlan>,
60    ) -> Result<Self> {
61        let async_fields = async_exprs
62            .iter()
63            .map(|async_expr| async_expr.field(input.schema().as_ref()))
64            .collect::<Result<Vec<_>>>()?;
65
66        // compute the output schema: input schema then async expressions
67        let fields: Fields = input
68            .schema()
69            .fields()
70            .iter()
71            .cloned()
72            .chain(async_fields.into_iter().map(Arc::new))
73            .collect();
74
75        let schema = Arc::new(Schema::new(fields));
76        let tuples = async_exprs
77            .iter()
78            .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
79            .collect::<Vec<_>>();
80        let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
81        let cache =
82            AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
83        Ok(Self {
84            input,
85            async_exprs,
86            cache,
87            metrics: ExecutionPlanMetricsSet::new(),
88        })
89    }
90
91    /// This function creates the cache object that stores the plan properties
92    /// such as schema, equivalence properties, ordering, partitioning, etc.
93    fn compute_properties(
94        input: &Arc<dyn ExecutionPlan>,
95        schema: SchemaRef,
96        async_expr_mapping: &ProjectionMapping,
97    ) -> Result<PlanProperties> {
98        Ok(PlanProperties::new(
99            input
100                .equivalence_properties()
101                .project(async_expr_mapping, schema),
102            input.output_partitioning().clone(),
103            input.pipeline_behavior(),
104            input.boundedness(),
105        ))
106    }
107
108    pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
109        &self.async_exprs
110    }
111
112    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
113        &self.input
114    }
115}
116
117impl DisplayAs for AsyncFuncExec {
118    fn fmt_as(
119        &self,
120        t: DisplayFormatType,
121        f: &mut std::fmt::Formatter,
122    ) -> std::fmt::Result {
123        let expr: Vec<String> = self
124            .async_exprs
125            .iter()
126            .map(|async_expr| async_expr.to_string())
127            .collect();
128        let exprs = expr.join(", ");
129        match t {
130            DisplayFormatType::Default | DisplayFormatType::Verbose => {
131                write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
132            }
133            DisplayFormatType::TreeRender => {
134                writeln!(f, "format=async_expr")?;
135                writeln!(f, "async_expr={exprs}")?;
136                Ok(())
137            }
138        }
139    }
140}
141
142impl ExecutionPlan for AsyncFuncExec {
143    fn name(&self) -> &str {
144        "async_func"
145    }
146
147    fn as_any(&self) -> &dyn Any {
148        self
149    }
150
151    fn properties(&self) -> &PlanProperties {
152        &self.cache
153    }
154
155    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
156        vec![&self.input]
157    }
158
159    fn with_new_children(
160        self: Arc<Self>,
161        children: Vec<Arc<dyn ExecutionPlan>>,
162    ) -> Result<Arc<dyn ExecutionPlan>> {
163        assert_eq_or_internal_err!(
164            children.len(),
165            1,
166            "AsyncFuncExec wrong number of children"
167        );
168        Ok(Arc::new(AsyncFuncExec::try_new(
169            self.async_exprs.clone(),
170            Arc::clone(&children[0]),
171        )?))
172    }
173
174    fn execute(
175        &self,
176        partition: usize,
177        context: Arc<TaskContext>,
178    ) -> Result<SendableRecordBatchStream> {
179        trace!(
180            "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
181            partition,
182            context.session_id(),
183            context.task_id()
184        );
185        // TODO figure out how to record metrics
186
187        // first execute the input stream
188        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
189
190        // now, for each record batch, evaluate the async expressions and add the columns to the result
191        let async_exprs_captured = Arc::new(self.async_exprs.clone());
192        let schema_captured = self.schema();
193        let config_options_ref = Arc::clone(context.session_config().options());
194
195        let coalesced_input_stream = CoalesceInputStream {
196            input_stream,
197            batch_coalescer: LimitedBatchCoalescer::new(
198                Arc::clone(&self.input.schema()),
199                config_options_ref.execution.batch_size,
200                None,
201            ),
202        };
203
204        let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
205            // need to clone *again* to capture the async_exprs and schema in the
206            // stream and satisfy lifetime requirements.
207            let async_exprs_captured = Arc::clone(&async_exprs_captured);
208            let schema_captured = Arc::clone(&schema_captured);
209            let config_options = Arc::clone(&config_options_ref);
210
211            async move {
212                let batch = batch?;
213                // append the result of evaluating the async expressions to the output
214                let mut output_arrays = batch.columns().to_vec();
215                for async_expr in async_exprs_captured.iter() {
216                    let output = async_expr
217                        .invoke_with_args(&batch, Arc::clone(&config_options))
218                        .await?;
219                    output_arrays.push(output.to_array(batch.num_rows())?);
220                }
221                let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
222                Ok(batch)
223            }
224        });
225
226        // Adapt the stream with the output schema
227        let adapter =
228            RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
229        Ok(Box::pin(adapter))
230    }
231
232    fn metrics(&self) -> Option<MetricsSet> {
233        Some(self.metrics.clone_inner())
234    }
235}
236
237struct CoalesceInputStream {
238    input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
239    batch_coalescer: LimitedBatchCoalescer,
240}
241
242impl Stream for CoalesceInputStream {
243    type Item = Result<RecordBatch>;
244
245    fn poll_next(
246        mut self: Pin<&mut Self>,
247        cx: &mut Context<'_>,
248    ) -> Poll<Option<Self::Item>> {
249        let mut completed = false;
250
251        loop {
252            if let Some(batch) = self.batch_coalescer.next_completed_batch() {
253                return Poll::Ready(Some(Ok(batch)));
254            }
255
256            if completed {
257                return Poll::Ready(None);
258            }
259
260            match ready!(self.input_stream.poll_next_unpin(cx)) {
261                Some(Ok(batch)) => {
262                    if let Err(err) = self.batch_coalescer.push_batch(batch) {
263                        return Poll::Ready(Some(Err(err)));
264                    }
265                }
266                Some(err) => {
267                    return Poll::Ready(Some(err));
268                }
269                None => {
270                    completed = true;
271                    if let Err(err) = self.batch_coalescer.finish() {
272                        return Poll::Ready(Some(Err(err)));
273                    }
274                }
275            }
276        }
277    }
278}
279
280const ASYNC_FN_PREFIX: &str = "__async_fn_";
281
282/// Maps async_expressions to new columns
283///
284/// The output of the async functions are appended, in order, to the end of the input schema
285#[derive(Debug)]
286pub struct AsyncMapper {
287    /// the number of columns in the input plan
288    /// used to generate the output column names.
289    /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc
290    num_input_columns: usize,
291    /// the expressions to map
292    pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
293}
294
295impl AsyncMapper {
296    pub fn new(num_input_columns: usize) -> Self {
297        Self {
298            num_input_columns,
299            async_exprs: Vec::new(),
300        }
301    }
302
303    pub fn is_empty(&self) -> bool {
304        self.async_exprs.is_empty()
305    }
306
307    pub fn next_column_name(&self) -> String {
308        format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
309    }
310
311    /// Finds any references to async functions in the expression and adds them to the map
312    pub fn find_references(
313        &mut self,
314        physical_expr: &Arc<dyn PhysicalExpr>,
315        schema: &Schema,
316    ) -> Result<()> {
317        // recursively look for references to async functions
318        physical_expr.apply(|expr| {
319            if let Some(scalar_func_expr) =
320                expr.as_any().downcast_ref::<ScalarFunctionExpr>()
321                && scalar_func_expr.fun().as_async().is_some()
322            {
323                let next_name = self.next_column_name();
324                self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
325                    next_name,
326                    Arc::clone(expr),
327                    schema,
328                )?));
329            }
330            Ok(TreeNodeRecursion::Continue)
331        })?;
332        Ok(())
333    }
334
335    /// If the expression matches any of the async functions, return the new column
336    pub fn map_expr(
337        &self,
338        expr: Arc<dyn PhysicalExpr>,
339    ) -> Transformed<Arc<dyn PhysicalExpr>> {
340        // find the first matching async function if any
341        let Some(idx) =
342            self.async_exprs
343                .iter()
344                .enumerate()
345                .find_map(|(idx, async_expr)| {
346                    if async_expr.func == Arc::clone(&expr) {
347                        Some(idx)
348                    } else {
349                        None
350                    }
351                })
352        else {
353            return Transformed::no(expr);
354        };
355        // rewrite in terms of the output column
356        Transformed::yes(self.output_column(idx))
357    }
358
359    /// return the output column for the async function at index idx
360    pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
361        let async_expr = &self.async_exprs[idx];
362        let output_idx = self.num_input_columns + idx;
363        Arc::new(Column::new(async_expr.name(), output_idx))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use std::sync::Arc;
370
371    use arrow::array::{RecordBatch, UInt32Array};
372    use arrow_schema::{DataType, Field, Schema};
373    use datafusion_common::Result;
374    use datafusion_execution::{TaskContext, config::SessionConfig};
375    use futures::StreamExt;
376
377    use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
378
379    #[tokio::test]
380    async fn test_async_fn_with_coalescing() -> Result<()> {
381        let schema =
382            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
383
384        let batch = RecordBatch::try_new(
385            Arc::clone(&schema),
386            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
387        )?;
388
389        let batches: Vec<RecordBatch> = (0..50).map(|_| batch.clone()).collect();
390
391        let session_config = SessionConfig::new().with_batch_size(200);
392        let task_ctx = TaskContext::default().with_session_config(session_config);
393        let task_ctx = Arc::new(task_ctx);
394
395        let test_exec =
396            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
397        let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
398
399        let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
400        let batch = stream
401            .next()
402            .await
403            .expect("expected to get a record batch")?;
404        assert_eq!(200, batch.num_rows());
405        let batch = stream
406            .next()
407            .await
408            .expect("expected to get a record batch")?;
409        assert_eq!(100, batch.num_rows());
410
411        Ok(())
412    }
413}