datafusion_physical_expr/
async_scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::ScalarFunctionExpr;
19use arrow::array::{make_array, MutableArrayData, RecordBatch};
20use arrow::datatypes::{DataType, Field, FieldRef, Schema};
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::Result;
23use datafusion_common::{internal_err, not_impl_err};
24use datafusion_expr::async_udf::AsyncScalarUDF;
25use datafusion_expr::ScalarFunctionArgs;
26use datafusion_expr_common::columnar_value::ColumnarValue;
27use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
28use std::any::Any;
29use std::fmt::Display;
30use std::hash::{Hash, Hasher};
31use std::sync::Arc;
32
33/// Wrapper around a scalar function that can be evaluated asynchronously
34#[derive(Debug, Clone, Eq)]
35pub struct AsyncFuncExpr {
36    /// The name of the output column this function will generate
37    pub name: String,
38    /// The actual function (always `ScalarFunctionExpr`)
39    pub func: Arc<dyn PhysicalExpr>,
40    /// The field that this function will return
41    return_field: FieldRef,
42}
43
44impl Display for AsyncFuncExpr {
45    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46        write!(f, "async_expr(name={}, expr={})", self.name, self.func)
47    }
48}
49
50impl PartialEq for AsyncFuncExpr {
51    fn eq(&self, other: &Self) -> bool {
52        self.name == other.name && self.func == Arc::clone(&other.func)
53    }
54}
55
56impl Hash for AsyncFuncExpr {
57    fn hash<H: Hasher>(&self, state: &mut H) {
58        self.name.hash(state);
59        self.func.as_ref().hash(state);
60    }
61}
62
63impl AsyncFuncExpr {
64    /// create a new AsyncFuncExpr
65    pub fn try_new(
66        name: impl Into<String>,
67        func: Arc<dyn PhysicalExpr>,
68        schema: &Schema,
69    ) -> Result<Self> {
70        let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
71            return internal_err!(
72                "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
73                func
74            );
75        };
76
77        let return_field = func.return_field(schema)?;
78        Ok(Self {
79            name: name.into(),
80            func,
81            return_field,
82        })
83    }
84
85    /// return the name of the output column
86    pub fn name(&self) -> &str {
87        &self.name
88    }
89
90    /// Return the output field generated by evaluating this function
91    pub fn field(&self, input_schema: &Schema) -> Result<Field> {
92        Ok(Field::new(
93            &self.name,
94            self.func.data_type(input_schema)?,
95            self.func.nullable(input_schema)?,
96        ))
97    }
98
99    /// Return the ideal batch size for this function
100    pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
101        if let Some(expr) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() {
102            if let Some(udf) =
103                expr.fun().inner().as_any().downcast_ref::<AsyncScalarUDF>()
104            {
105                return Ok(udf.ideal_batch_size());
106            }
107        }
108        not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
109    }
110
111    /// This (async) function is called for each record batch to evaluate the LLM expressions
112    ///
113    /// The output is the output of evaluating the async expression and the input record batch
114    pub async fn invoke_with_args(
115        &self,
116        batch: &RecordBatch,
117        config_options: Arc<ConfigOptions>,
118    ) -> Result<ColumnarValue> {
119        let Some(scalar_function_expr) =
120            self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
121        else {
122            return internal_err!(
123                "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
124                self.func
125            );
126        };
127
128        let Some(async_udf) = scalar_function_expr
129            .fun()
130            .inner()
131            .as_any()
132            .downcast_ref::<AsyncScalarUDF>()
133        else {
134            return not_impl_err!(
135                "Don't know how to evaluate async function: {:?}",
136                scalar_function_expr
137            );
138        };
139
140        let arg_fields = scalar_function_expr
141            .args()
142            .iter()
143            .map(|e| e.return_field(batch.schema_ref()))
144            .collect::<Result<Vec<_>>>()?;
145
146        let mut result_batches = vec![];
147        if let Some(ideal_batch_size) = self.ideal_batch_size()? {
148            let mut remainder = batch.clone();
149            while remainder.num_rows() > 0 {
150                let size = if ideal_batch_size > remainder.num_rows() {
151                    remainder.num_rows()
152                } else {
153                    ideal_batch_size
154                };
155
156                let current_batch = remainder.slice(0, size); // get next 10 rows
157                remainder = remainder.slice(size, remainder.num_rows() - size);
158                let args = scalar_function_expr
159                    .args()
160                    .iter()
161                    .map(|e| e.evaluate(&current_batch))
162                    .collect::<Result<Vec<_>>>()?;
163                result_batches.push(
164                    async_udf
165                        .invoke_async_with_args(ScalarFunctionArgs {
166                            args,
167                            arg_fields: arg_fields.clone(),
168                            number_rows: current_batch.num_rows(),
169                            return_field: Arc::clone(&self.return_field),
170                            config_options: Arc::clone(&config_options),
171                        })
172                        .await?,
173                );
174            }
175        } else {
176            let args = scalar_function_expr
177                .args()
178                .iter()
179                .map(|e| e.evaluate(batch))
180                .collect::<Result<Vec<_>>>()?;
181
182            result_batches.push(
183                async_udf
184                    .invoke_async_with_args(ScalarFunctionArgs {
185                        args: args.to_vec(),
186                        arg_fields,
187                        number_rows: batch.num_rows(),
188                        return_field: Arc::clone(&self.return_field),
189                        config_options: Arc::clone(&config_options),
190                    })
191                    .await?,
192            );
193        }
194
195        let datas = ColumnarValue::values_to_arrays(&result_batches)?
196            .iter()
197            .map(|b| b.to_data())
198            .collect::<Vec<_>>();
199        let total_len = datas.iter().map(|d| d.len()).sum();
200        let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len);
201        datas.iter().enumerate().for_each(|(i, data)| {
202            mutable.extend(i, 0, data.len());
203        });
204        let array_ref = make_array(mutable.freeze());
205        Ok(ColumnarValue::Array(array_ref))
206    }
207}
208
209impl PhysicalExpr for AsyncFuncExpr {
210    fn as_any(&self) -> &dyn Any {
211        self
212    }
213
214    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
215        self.func.data_type(input_schema)
216    }
217
218    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
219        self.func.nullable(input_schema)
220    }
221
222    fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
223        // TODO: implement this for scalar value input
224        not_impl_err!("AsyncFuncExpr.evaluate")
225    }
226
227    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
228        self.func.children()
229    }
230
231    fn with_new_children(
232        self: Arc<Self>,
233        children: Vec<Arc<dyn PhysicalExpr>>,
234    ) -> Result<Arc<dyn PhysicalExpr>> {
235        let new_func = Arc::clone(&self.func).with_new_children(children)?;
236        Ok(Arc::new(AsyncFuncExpr {
237            name: self.name.clone(),
238            func: new_func,
239            return_field: Arc::clone(&self.return_field),
240        }))
241    }
242
243    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        write!(f, "{}", self.func)
245    }
246}