Skip to main content

datafusion_physical_plan/
async_func.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::coalesce::LimitedBatchCoalescer;
19use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter};
21use crate::{
22    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
23    check_if_same_properties,
24};
25use arrow::array::RecordBatch;
26use arrow_schema::{Fields, Schema, SchemaRef};
27use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
28use datafusion_common::{Result, assert_eq_or_internal_err};
29use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
30use datafusion_physical_expr::ScalarFunctionExpr;
31use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
32use datafusion_physical_expr::equivalence::ProjectionMapping;
33use datafusion_physical_expr::expressions::Column;
34use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput};
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use futures::Stream;
37use futures::stream::StreamExt;
38use log::trace;
39use std::pin::Pin;
40use std::sync::Arc;
41use std::task::{Context, Poll, ready};
42
43/// This structure evaluates a set of async expressions on a record
44/// batch producing a new record batch
45///
46/// The schema of the output of the AsyncFuncExec is:
47/// Input columns followed by one column for each async expression
48#[derive(Debug, Clone)]
49pub struct AsyncFuncExec {
50    /// The async expressions to evaluate
51    async_exprs: Vec<Arc<AsyncFuncExpr>>,
52    input: Arc<dyn ExecutionPlan>,
53    cache: Arc<PlanProperties>,
54    metrics: ExecutionPlanMetricsSet,
55}
56
57impl AsyncFuncExec {
58    pub fn try_new(
59        async_exprs: Vec<Arc<AsyncFuncExpr>>,
60        input: Arc<dyn ExecutionPlan>,
61    ) -> Result<Self> {
62        let async_fields = async_exprs
63            .iter()
64            .map(|async_expr| async_expr.field(input.schema().as_ref()))
65            .collect::<Result<Vec<_>>>()?;
66
67        // compute the output schema: input schema then async expressions
68        let fields: Fields = input
69            .schema()
70            .fields()
71            .iter()
72            .cloned()
73            .chain(async_fields.into_iter().map(Arc::new))
74            .collect();
75
76        let schema = Arc::new(Schema::new(fields));
77        let tuples = async_exprs
78            .iter()
79            .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
80            .collect::<Vec<_>>();
81        let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?;
82        let cache =
83            AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?;
84        Ok(Self {
85            input,
86            async_exprs,
87            cache: Arc::new(cache),
88            metrics: ExecutionPlanMetricsSet::new(),
89        })
90    }
91
92    /// This function creates the cache object that stores the plan properties
93    /// such as schema, equivalence properties, ordering, partitioning, etc.
94    fn compute_properties(
95        input: &Arc<dyn ExecutionPlan>,
96        schema: SchemaRef,
97        async_expr_mapping: &ProjectionMapping,
98    ) -> Result<PlanProperties> {
99        Ok(PlanProperties::new(
100            input
101                .equivalence_properties()
102                .project(async_expr_mapping, schema),
103            input.output_partitioning().clone(),
104            input.pipeline_behavior(),
105            input.boundedness(),
106        ))
107    }
108
109    pub fn async_exprs(&self) -> &[Arc<AsyncFuncExpr>] {
110        &self.async_exprs
111    }
112
113    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
114        &self.input
115    }
116
117    fn with_new_children_and_same_properties(
118        &self,
119        mut children: Vec<Arc<dyn ExecutionPlan>>,
120    ) -> Self {
121        Self {
122            input: children.swap_remove(0),
123            metrics: ExecutionPlanMetricsSet::new(),
124            ..Self::clone(self)
125        }
126    }
127}
128
129impl DisplayAs for AsyncFuncExec {
130    fn fmt_as(
131        &self,
132        t: DisplayFormatType,
133        f: &mut std::fmt::Formatter,
134    ) -> std::fmt::Result {
135        let expr: Vec<String> = self
136            .async_exprs
137            .iter()
138            .map(|async_expr| async_expr.to_string())
139            .collect();
140        let exprs = expr.join(", ");
141        match t {
142            DisplayFormatType::Default | DisplayFormatType::Verbose => {
143                write!(f, "AsyncFuncExec: async_expr=[{exprs}]")
144            }
145            DisplayFormatType::TreeRender => {
146                writeln!(f, "format=async_expr")?;
147                writeln!(f, "async_expr={exprs}")?;
148                Ok(())
149            }
150        }
151    }
152}
153
154impl ExecutionPlan for AsyncFuncExec {
155    fn name(&self) -> &str {
156        "async_func"
157    }
158
159    fn properties(&self) -> &Arc<PlanProperties> {
160        &self.cache
161    }
162
163    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164        vec![&self.input]
165    }
166
167    fn with_new_children(
168        self: Arc<Self>,
169        mut children: Vec<Arc<dyn ExecutionPlan>>,
170    ) -> Result<Arc<dyn ExecutionPlan>> {
171        assert_eq_or_internal_err!(
172            children.len(),
173            1,
174            "AsyncFuncExec wrong number of children"
175        );
176        check_if_same_properties!(self, children);
177        Ok(Arc::new(AsyncFuncExec::try_new(
178            self.async_exprs.clone(),
179            children.swap_remove(0),
180        )?))
181    }
182
183    fn execute(
184        &self,
185        partition: usize,
186        context: Arc<TaskContext>,
187    ) -> Result<SendableRecordBatchStream> {
188        trace!(
189            "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
190            partition,
191            context.session_id(),
192            context.task_id()
193        );
194
195        // first execute the input stream
196        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
197
198        // TODO: Track `elapsed_compute` in `BaselineMetrics`
199        // Issue: <https://github.com/apache/datafusion/issues/19658>
200        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
201
202        // now, for each record batch, evaluate the async expressions and add the columns to the result
203        let async_exprs_captured = Arc::new(self.async_exprs.clone());
204        let schema_captured = self.schema();
205        let config_options_ref = Arc::clone(context.session_config().options());
206
207        let coalesced_input_stream = CoalesceInputStream {
208            input_stream,
209            batch_coalescer: LimitedBatchCoalescer::new(
210                Arc::clone(&self.input.schema()),
211                config_options_ref.execution.batch_size,
212                None,
213            ),
214        };
215
216        let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
217            // need to clone *again* to capture the async_exprs and schema in the
218            // stream and satisfy lifetime requirements.
219            let async_exprs_captured = Arc::clone(&async_exprs_captured);
220            let schema_captured = Arc::clone(&schema_captured);
221            let config_options = Arc::clone(&config_options_ref);
222            let baseline_metrics_captured = baseline_metrics.clone();
223
224            async move {
225                let batch = batch?;
226                // append the result of evaluating the async expressions to the output
227                let mut output_arrays = batch.columns().to_vec();
228                for async_expr in async_exprs_captured.iter() {
229                    let output = async_expr
230                        .invoke_with_args(&batch, Arc::clone(&config_options))
231                        .await?;
232                    output_arrays.push(output.to_array(batch.num_rows())?);
233                }
234                let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
235
236                Ok(batch.record_output(&baseline_metrics_captured))
237            }
238        });
239
240        // Adapt the stream with the output schema
241        let adapter =
242            RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
243        Ok(Box::pin(adapter))
244    }
245
246    fn metrics(&self) -> Option<MetricsSet> {
247        Some(self.metrics.clone_inner())
248    }
249}
250
251struct CoalesceInputStream {
252    input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
253    batch_coalescer: LimitedBatchCoalescer,
254}
255
256impl Stream for CoalesceInputStream {
257    type Item = Result<RecordBatch>;
258
259    fn poll_next(
260        mut self: Pin<&mut Self>,
261        cx: &mut Context<'_>,
262    ) -> Poll<Option<Self::Item>> {
263        let mut completed = false;
264
265        loop {
266            if let Some(batch) = self.batch_coalescer.next_completed_batch() {
267                return Poll::Ready(Some(Ok(batch)));
268            }
269
270            if completed {
271                return Poll::Ready(None);
272            }
273
274            match ready!(self.input_stream.poll_next_unpin(cx)) {
275                Some(Ok(batch)) => {
276                    if let Err(err) = self.batch_coalescer.push_batch(batch) {
277                        return Poll::Ready(Some(Err(err)));
278                    }
279                }
280                Some(err) => {
281                    return Poll::Ready(Some(err));
282                }
283                None => {
284                    completed = true;
285                    // Release the input pipeline's resources.
286                    let input_schema = self.input_stream.schema();
287                    self.input_stream =
288                        Box::pin(EmptyRecordBatchStream::new(input_schema));
289                    if let Err(err) = self.batch_coalescer.finish() {
290                        return Poll::Ready(Some(Err(err)));
291                    }
292                }
293            }
294        }
295    }
296}
297
298const ASYNC_FN_PREFIX: &str = "__async_fn_";
299
300/// Maps async_expressions to new columns
301///
302/// The output of the async functions are appended, in order, to the end of the input schema
303#[derive(Debug)]
304pub struct AsyncMapper {
305    /// the number of columns in the input plan
306    /// used to generate the output column names.
307    /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc
308    num_input_columns: usize,
309    /// the expressions to map
310    pub async_exprs: Vec<Arc<AsyncFuncExpr>>,
311}
312
313impl AsyncMapper {
314    pub fn new(num_input_columns: usize) -> Self {
315        Self {
316            num_input_columns,
317            async_exprs: Vec::new(),
318        }
319    }
320
321    pub fn is_empty(&self) -> bool {
322        self.async_exprs.is_empty()
323    }
324
325    pub fn next_column_name(&self) -> String {
326        format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len())
327    }
328
329    /// Finds any references to async functions in the expression and adds them to the map
330    pub fn find_references(
331        &mut self,
332        physical_expr: &Arc<dyn PhysicalExpr>,
333        schema: &Schema,
334    ) -> Result<()> {
335        // recursively look for references to async functions
336        physical_expr.apply(|expr| {
337            if let Some(scalar_func_expr) = expr.downcast_ref::<ScalarFunctionExpr>()
338                && scalar_func_expr.fun().as_async().is_some()
339            {
340                let next_name = self.next_column_name();
341                self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
342                    next_name,
343                    Arc::clone(expr),
344                    schema,
345                )?));
346            }
347            Ok(TreeNodeRecursion::Continue)
348        })?;
349        Ok(())
350    }
351
352    /// If the expression matches any of the async functions, return the new column
353    pub fn map_expr(
354        &self,
355        expr: Arc<dyn PhysicalExpr>,
356    ) -> Transformed<Arc<dyn PhysicalExpr>> {
357        // find the first matching async function if any
358        let Some(idx) =
359            self.async_exprs
360                .iter()
361                .enumerate()
362                .find_map(|(idx, async_expr)| {
363                    if async_expr.func == Arc::clone(&expr) {
364                        Some(idx)
365                    } else {
366                        None
367                    }
368                })
369        else {
370            return Transformed::no(expr);
371        };
372        // rewrite in terms of the output column
373        Transformed::yes(self.output_column(idx))
374    }
375
376    /// return the output column for the async function at index idx
377    pub fn output_column(&self, idx: usize) -> Arc<dyn PhysicalExpr> {
378        let async_expr = &self.async_exprs[idx];
379        let output_idx = self.num_input_columns + idx;
380        Arc::new(Column::new(async_expr.name(), output_idx))
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use std::sync::Arc;
387
388    use arrow::array::{RecordBatch, UInt32Array};
389    use arrow_schema::{DataType, Field, Schema};
390    use datafusion_common::Result;
391    use datafusion_execution::{TaskContext, config::SessionConfig};
392    use futures::StreamExt;
393
394    use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};
395
396    #[tokio::test]
397    async fn test_async_fn_with_coalescing() -> Result<()> {
398        let schema =
399            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
400
401        let batch = RecordBatch::try_new(
402            Arc::clone(&schema),
403            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
404        )?;
405
406        let batches: Vec<RecordBatch> = std::iter::repeat_n(batch, 50).collect();
407
408        let session_config = SessionConfig::new().with_batch_size(200);
409        let task_ctx = TaskContext::default().with_session_config(session_config);
410        let task_ctx = Arc::new(task_ctx);
411
412        let test_exec =
413            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
414        let exec = AsyncFuncExec::try_new(vec![], test_exec)?;
415
416        let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
417        let batch = stream
418            .next()
419            .await
420            .expect("expected to get a record batch")?;
421        assert_eq!(200, batch.num_rows());
422        let batch = stream
423            .next()
424            .await
425            .expect("expected to get a record batch")?;
426        assert_eq!(100, batch.num_rows());
427
428        Ok(())
429    }
430}