datafusion_physical_plan/
async_func.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
19use crate::stream::RecordBatchStreamAdapter;
20use crate::{
21    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
22};
23use arrow::array::RecordBatch;
24use arrow_schema::{Fields, Schema, SchemaRef};
25use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
26use datafusion_common::{internal_err, Result};
27use datafusion_execution::{SendableRecordBatchStream, TaskContext};
28use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
29use datafusion_physical_expr::equivalence::ProjectionMapping;
30use datafusion_physical_expr::expressions::Column;
31use datafusion_physical_expr::ScalarFunctionExpr;
32use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
33use futures::stream::StreamExt;
34use log::trace;
35use std::any::Any;
36use std::sync::Arc;
37
38/// This structure evaluates a set of async expressions on a record
39/// batch producing a new record batch
40///
41/// The schema of the output of the AsyncFuncExec is:
42/// Input columns followed by one column for each async expression
43#[derive(Debug)]
44pub struct AsyncFuncExec {
45    /// The async expressions to evaluate
46    async_exprs: Vec<Arc<AsyncFuncExpr>>,
47    input: Arc<dyn ExecutionPlan>,
48    cache: PlanProperties,
49    metrics: ExecutionPlanMetricsSet,
50}
51
52impl AsyncFuncExec {
53    pub fn try_new(
54        async_exprs: Vec<Arc<AsyncFuncExpr>>,
55        input: Arc<dyn ExecutionPlan>,
56    ) -> Result<Self> {
57        let async_fields = async_exprs
58            .iter()
59            .map(|async_expr| async_expr.field(input.schema().as_ref()))
60            .collect::<Result<Vec<_>>>()?;
61
62        // compute the output schema: input schema then async expressions
63        let fields: Fields = input
64            .schema()
65            .fields()
66            .iter()
67            .cloned()
68            .chain(async_fields.into_iter().map(Arc::new))
69            .collect();
70
71        let schema = Arc::new(Schema::new(fields));
72        let tuples = async_exprs
73            .iter()
74            .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
75            .collect::<Vec<_>>();
76        let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
77        let cache =
78            AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
79        Ok(Self {
80            input,
81            async_exprs,
82            cache,
83            metrics: ExecutionPlanMetricsSet::new(),
84        })
85    }
86
87    /// This function creates the cache object that stores the plan properties
88    /// such as schema, equivalence properties, ordering, partitioning, etc.
89    fn compute_properties(
90        input: &Arc<dyn ExecutionPlan>,
91        schema: SchemaRef,
92        async_expr_mapping: &ProjectionMapping,
93    ) -> Result<PlanProperties> {
94        Ok(PlanProperties::new(
95            input
96                .equivalence_properties()
97                .project(async_expr_mapping, schema),
98            input.output_partitioning().clone(),
99            input.pipeline_behavior(),
100            input.boundedness(),
101        ))
102    }
103}
104
105impl DisplayAs for AsyncFuncExec {
106    fn fmt_as(
107        &self,
108        t: DisplayFormatType,
109        f: &mut std::fmt::Formatter,
110    ) -> std::fmt::Result {
111        let expr: Vec<String> = self
112            .async_exprs
113            .iter()
114            .map(|async_expr| async_expr.to_string())
115            .collect();
116        let exprs = expr.join(", ");
117        match t {
118            DisplayFormatType::Default | DisplayFormatType::Verbose => {
119                write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
120            }
121            DisplayFormatType::TreeRender => {
122                writeln!(f, "format=async_expr")?;
123                writeln!(f, "async_expr={exprs}")?;
124                Ok(())
125            }
126        }
127    }
128}
129
130impl ExecutionPlan for AsyncFuncExec {
131    fn name(&self) -> &str {
132        "async_func"
133    }
134
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn properties(&self) -> &PlanProperties {
140        &self.cache
141    }
142
143    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
144        vec![&self.input]
145    }
146
147    fn with_new_children(
148        self: Arc<Self>,
149        children: Vec<Arc<dyn ExecutionPlan>>,
150    ) -> Result<Arc<dyn ExecutionPlan>> {
151        if children.len() != 1 {
152            return internal_err!("AsyncFuncExec wrong number of children");
153        }
154        Ok(Arc::new(AsyncFuncExec::try_new(
155            self.async_exprs.clone(),
156            Arc::clone(&children[0]),
157        )?))
158    }
159
160    fn execute(
161        &self,
162        partition: usize,
163        context: Arc<TaskContext>,
164    ) -> Result<SendableRecordBatchStream> {
165        trace!(
166            "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
167            partition,
168            context.session_id(),
169            context.task_id()
170        );
171        // TODO figure out how to record metrics
172
173        // first execute the input stream
174        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
175
176        // now, for each record batch, evaluate the async expressions and add the columns to the result
177        let async_exprs_captured = Arc::new(self.async_exprs.clone());
178        let schema_captured = self.schema();
179        let config_options_ref = Arc::clone(context.session_config().options());
180
181        let stream_with_async_functions = input_stream.then(move |batch| {
182            // need to clone *again* to capture the async_exprs and schema in the
183            // stream and satisfy lifetime requirements.
184            let async_exprs_captured = Arc::clone(&async_exprs_captured);
185            let schema_captured = Arc::clone(&schema_captured);
186            let config_options = Arc::clone(&config_options_ref);
187
188            async move {
189                let batch = batch?;
190                // append the result of evaluating the async expressions to the output
191                let mut output_arrays = batch.columns().to_vec();
192                for async_expr in async_exprs_captured.iter() {
193                    let output = async_expr
194                        .invoke_with_args(&batch, Arc::clone(&config_options))
195                        .await?;
196                    output_arrays.push(output.to_array(batch.num_rows())?);
197                }
198                let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
199                Ok(batch)
200            }
201        });
202
203        // Adapt the stream with the output schema
204        let adapter =
205            RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
206        Ok(Box::pin(adapter))
207    }
208
209    fn metrics(&self) -> Option<MetricsSet> {
210        Some(self.metrics.clone_inner())
211    }
212}
213
214const ASYNC_FN_PREFIX: &str = "__async_fn_";
215
216/// Maps async_expressions to new columns
217///
218/// The output of the async functions are appended, in order, to the end of the input schema
219#[derive(Debug)]
220pub struct AsyncMapper {
221    /// the number of columns in the input plan
222    /// used to generate the output column names.
223    /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc
224    num_input_columns: usize,
225    /// the expressions to map
226    pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
227}
228
229impl AsyncMapper {
230    pub fn new(num_input_columns: usize) -> Self {
231        Self {
232            num_input_columns,
233            async_exprs: Vec::new(),
234        }
235    }
236
237    pub fn is_empty(&self) -> bool {
238        self.async_exprs.is_empty()
239    }
240
241    pub fn next_column_name(&self) -> String {
242        format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
243    }
244
245    /// Finds any references to async functions in the expression and adds them to the map
246    pub fn find_references(
247        &mut self,
248        physical_expr: &Arc<dyn PhysicalExpr>,
249        schema: &Schema,
250    ) -> Result<()> {
251        // recursively look for references to async functions
252        physical_expr.apply(|expr| {
253            if let Some(scalar_func_expr) =
254                expr.as_any().downcast_ref::<ScalarFunctionExpr>()
255            {
256                if scalar_func_expr.fun().as_async().is_some() {
257                    let next_name = self.next_column_name();
258                    self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
259                        next_name,
260                        Arc::clone(expr),
261                        schema,
262                    )?));
263                }
264            }
265            Ok(TreeNodeRecursion::Continue)
266        })?;
267        Ok(())
268    }
269
270    /// If the expression matches any of the async functions, return the new column
271    pub fn map_expr(
272        &self,
273        expr: Arc<dyn PhysicalExpr>,
274    ) -> Transformed<Arc<dyn PhysicalExpr>> {
275        // find the first matching async function if any
276        let Some(idx) =
277            self.async_exprs
278                .iter()
279                .enumerate()
280                .find_map(|(idx, async_expr)| {
281                    if async_expr.func == Arc::clone(&expr) {
282                        Some(idx)
283                    } else {
284                        None
285                    }
286                })
287        else {
288            return Transformed::no(expr);
289        };
290        // rewrite in terms of the output column
291        Transformed::yes(self.output_column(idx))
292    }
293
294    /// return the output column for the async function at index idx
295    pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
296        let async_expr = &self.async_exprs[idx];
297        let output_idx = self.num_input_columns + idx;
298        Arc::new(Column::new(async_expr.name(), output_idx))
299    }
300}