Skip to main content

datafusion_physical_plan/
async_func.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::RecordBatchStreamAdapter;
21use crate::{
22    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23    check_if_same_properties,
24};
25use arrow::array::RecordBatch;
26use arrow_schema::{Fields, Schema, SchemaRef};
27use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
28use datafusion_common::{Result, assert_eq_or_internal_err};
29use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
30use datafusion_physical_expr::ScalarFunctionExpr;
31use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
32use datafusion_physical_expr::equivalence::ProjectionMapping;
33use datafusion_physical_expr::expressions::Column;
34use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput};
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use futures::Stream;
37use futures::stream::StreamExt;
38use log::trace;
39use std::any::Any;
40use std::pin::Pin;
41use std::sync::Arc;
42use std::task::{Context, Poll, ready};
43
44/// This structure evaluates a set of async expressions on a record
45/// batch producing a new record batch
46///
47/// The schema of the output of the AsyncFuncExec is:
48/// Input columns followed by one column for each async expression
49#[derive(Debug, Clone)]
50pub struct AsyncFuncExec {
51    /// The async expressions to evaluate
52    async_exprs: Vec<Arc<AsyncFuncExpr>>,
53    input: Arc<dyn ExecutionPlan>,
54    cache: Arc<PlanProperties>,
55    metrics: ExecutionPlanMetricsSet,
56}
57
58impl AsyncFuncExec {
59    pub fn try_new(
60        async_exprs: Vec<Arc<AsyncFuncExpr>>,
61        input: Arc<dyn ExecutionPlan>,
62    ) -> Result<Self> {
63        let async_fields = async_exprs
64            .iter()
65            .map(|async_expr| async_expr.field(input.schema().as_ref()))
66            .collect::<Result<Vec<_>>>()?;
67
68        // compute the output schema: input schema then async expressions
69        let fields: Fields = input
70            .schema()
71            .fields()
72            .iter()
73            .cloned()
74            .chain(async_fields.into_iter().map(Arc::new))
75            .collect();
76
77        let schema = Arc::new(Schema::new(fields));
78        let tuples = async_exprs
79            .iter()
80            .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
81            .collect::<Vec<_>>();
82        let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
83        let cache =
84            AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
85        Ok(Self {
86            input,
87            async_exprs,
88            cache: Arc::new(cache),
89            metrics: ExecutionPlanMetricsSet::new(),
90        })
91    }
92
93    /// This function creates the cache object that stores the plan properties
94    /// such as schema, equivalence properties, ordering, partitioning, etc.
95    fn compute_properties(
96        input: &Arc<dyn ExecutionPlan>,
97        schema: SchemaRef,
98        async_expr_mapping: &ProjectionMapping,
99    ) -> Result<PlanProperties> {
100        Ok(PlanProperties::new(
101            input
102                .equivalence_properties()
103                .project(async_expr_mapping, schema),
104            input.output_partitioning().clone(),
105            input.pipeline_behavior(),
106            input.boundedness(),
107        ))
108    }
109
110    pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
111        &self.async_exprs
112    }
113
114    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
115        &self.input
116    }
117
118    fn with_new_children_and_same_properties(
119        &self,
120        mut children: Vec<Arc<dyn ExecutionPlan>>,
121    ) -> Self {
122        Self {
123            input: children.swap_remove(0),
124            metrics: ExecutionPlanMetricsSet::new(),
125            ..Self::clone(self)
126        }
127    }
128}
129
130impl DisplayAs for AsyncFuncExec {
131    fn fmt_as(
132        &self,
133        t: DisplayFormatType,
134        f: &mut std::fmt::Formatter,
135    ) -> std::fmt::Result {
136        let expr: Vec<String> = self
137            .async_exprs
138            .iter()
139            .map(|async_expr| async_expr.to_string())
140            .collect();
141        let exprs = expr.join(", ");
142        match t {
143            DisplayFormatType::Default | DisplayFormatType::Verbose => {
144                write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
145            }
146            DisplayFormatType::TreeRender => {
147                writeln!(f, "format=async_expr")?;
148                writeln!(f, "async_expr={exprs}")?;
149                Ok(())
150            }
151        }
152    }
153}
154
155impl ExecutionPlan for AsyncFuncExec {
156    fn name(&self) -> &str {
157        "async_func"
158    }
159
160    fn as_any(&self) -> &dyn Any {
161        self
162    }
163
164    fn properties(&self) -> &Arc<PlanProperties> {
165        &self.cache
166    }
167
168    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
169        vec![&self.input]
170    }
171
172    fn with_new_children(
173        self: Arc<Self>,
174        mut children: Vec<Arc<dyn ExecutionPlan>>,
175    ) -> Result<Arc<dyn ExecutionPlan>> {
176        assert_eq_or_internal_err!(
177            children.len(),
178            1,
179            "AsyncFuncExec wrong number of children"
180        );
181        check_if_same_properties!(self, children);
182        Ok(Arc::new(AsyncFuncExec::try_new(
183            self.async_exprs.clone(),
184            children.swap_remove(0),
185        )?))
186    }
187
188    fn execute(
189        &self,
190        partition: usize,
191        context: Arc<TaskContext>,
192    ) -> Result<SendableRecordBatchStream> {
193        trace!(
194            "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
195            partition,
196            context.session_id(),
197            context.task_id()
198        );
199
200        // first execute the input stream
201        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
202
203        // TODO: Track `elapsed_compute` in `BaselineMetrics`
204        // Issue: <https://github.com/apache/datafusion/issues/19658>
205        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
206
207        // now, for each record batch, evaluate the async expressions and add the columns to the result
208        let async_exprs_captured = Arc::new(self.async_exprs.clone());
209        let schema_captured = self.schema();
210        let config_options_ref = Arc::clone(context.session_config().options());
211
212        let coalesced_input_stream = CoalesceInputStream {
213            input_stream,
214            batch_coalescer: LimitedBatchCoalescer::new(
215                Arc::clone(&self.input.schema()),
216                config_options_ref.execution.batch_size,
217                None,
218            ),
219        };
220
221        let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
222            // need to clone *again* to capture the async_exprs and schema in the
223            // stream and satisfy lifetime requirements.
224            let async_exprs_captured = Arc::clone(&async_exprs_captured);
225            let schema_captured = Arc::clone(&schema_captured);
226            let config_options = Arc::clone(&config_options_ref);
227            let baseline_metrics_captured = baseline_metrics.clone();
228
229            async move {
230                let batch = batch?;
231                // append the result of evaluating the async expressions to the output
232                let mut output_arrays = batch.columns().to_vec();
233                for async_expr in async_exprs_captured.iter() {
234                    let output = async_expr
235                        .invoke_with_args(&batch, Arc::clone(&config_options))
236                        .await?;
237                    output_arrays.push(output.to_array(batch.num_rows())?);
238                }
239                let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
240
241                Ok(batch.record_output(&baseline_metrics_captured))
242            }
243        });
244
245        // Adapt the stream with the output schema
246        let adapter =
247            RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
248        Ok(Box::pin(adapter))
249    }
250
251    fn metrics(&self) -> Option<MetricsSet> {
252        Some(self.metrics.clone_inner())
253    }
254}
255
256struct CoalesceInputStream {
257    input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
258    batch_coalescer: LimitedBatchCoalescer,
259}
260
261impl Stream for CoalesceInputStream {
262    type Item = Result<RecordBatch>;
263
264    fn poll_next(
265        mut self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267    ) -> Poll<Option<Self::Item>> {
268        let mut completed = false;
269
270        loop {
271            if let Some(batch) = self.batch_coalescer.next_completed_batch() {
272                return Poll::Ready(Some(Ok(batch)));
273            }
274
275            if completed {
276                return Poll::Ready(None);
277            }
278
279            match ready!(self.input_stream.poll_next_unpin(cx)) {
280                Some(Ok(batch)) => {
281                    if let Err(err) = self.batch_coalescer.push_batch(batch) {
282                        return Poll::Ready(Some(Err(err)));
283                    }
284                }
285                Some(err) => {
286                    return Poll::Ready(Some(err));
287                }
288                None => {
289                    completed = true;
290                    if let Err(err) = self.batch_coalescer.finish() {
291                        return Poll::Ready(Some(Err(err)));
292                    }
293                }
294            }
295        }
296    }
297}
298
299const ASYNC_FN_PREFIX: &str = "__async_fn_";
300
301/// Maps async_expressions to new columns
302///
303/// The output of the async functions are appended, in order, to the end of the input schema
304#[derive(Debug)]
305pub struct AsyncMapper {
306    /// the number of columns in the input plan
307    /// used to generate the output column names.
308    /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc
309    num_input_columns: usize,
310    /// the expressions to map
311    pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
312}
313
314impl AsyncMapper {
315    pub fn new(num_input_columns: usize) -> Self {
316        Self {
317            num_input_columns,
318            async_exprs: Vec::new(),
319        }
320    }
321
322    pub fn is_empty(&self) -> bool {
323        self.async_exprs.is_empty()
324    }
325
326    pub fn next_column_name(&self) -> String {
327        format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
328    }
329
330    /// Finds any references to async functions in the expression and adds them to the map
331    pub fn find_references(
332        &mut self,
333        physical_expr: &Arc<dyn PhysicalExpr>,
334        schema: &Schema,
335    ) -> Result<()> {
336        // recursively look for references to async functions
337        physical_expr.apply(|expr| {
338            if let Some(scalar_func_expr) =
339                expr.as_any().downcast_ref::<ScalarFunctionExpr>()
340                && scalar_func_expr.fun().as_async().is_some()
341            {
342                let next_name = self.next_column_name();
343                self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
344                    next_name,
345                    Arc::clone(expr),
346                    schema,
347                )?));
348            }
349            Ok(TreeNodeRecursion::Continue)
350        })?;
351        Ok(())
352    }
353
354    /// If the expression matches any of the async functions, return the new column
355    pub fn map_expr(
356        &self,
357        expr: Arc<dyn PhysicalExpr>,
358    ) -> Transformed<Arc<dyn PhysicalExpr>> {
359        // find the first matching async function if any
360        let Some(idx) =
361            self.async_exprs
362                .iter()
363                .enumerate()
364                .find_map(|(idx, async_expr)| {
365                    if async_expr.func == Arc::clone(&expr) {
366                        Some(idx)
367                    } else {
368                        None
369                    }
370                })
371        else {
372            return Transformed::no(expr);
373        };
374        // rewrite in terms of the output column
375        Transformed::yes(self.output_column(idx))
376    }
377
378    /// return the output column for the async function at index idx
379    pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
380        let async_expr = &self.async_exprs[idx];
381        let output_idx = self.num_input_columns + idx;
382        Arc::new(Column::new(async_expr.name(), output_idx))
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use std::sync::Arc;
389
390    use arrow::array::{RecordBatch, UInt32Array};
391    use arrow_schema::{DataType, Field, Schema};
392    use datafusion_common::Result;
393    use datafusion_execution::{TaskContext, config::SessionConfig};
394    use futures::StreamExt;
395
396    use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
397
398    #[tokio::test]
399    async fn test_async_fn_with_coalescing() -> Result<()> {
400        let schema =
401            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
402
403        let batch = RecordBatch::try_new(
404            Arc::clone(&schema),
405            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
406        )?;
407
408        let batches: Vec<RecordBatch> = std::iter::repeat_n(batch, 50).collect();
409
410        let session_config = SessionConfig::new().with_batch_size(200);
411        let task_ctx = TaskContext::default().with_session_config(session_config);
412        let task_ctx = Arc::new(task_ctx);
413
414        let test_exec =
415            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
416        let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
417
418        let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
419        let batch = stream
420            .next()
421            .await
422            .expect("expected to get a record batch")?;
423        assert_eq!(200, batch.num_rows());
424        let batch = stream
425            .next()
426            .await
427            .expect("expected to get a record batch")?;
428        assert_eq!(100, batch.num_rows());
429
430        Ok(())
431    }
432}