datafusion_physical_expr/
async_scalar_function.rs1use crate::ScalarFunctionExpr;
19use arrow::array::RecordBatch;
20use arrow::compute::concat;
21use arrow::datatypes::{DataType, Field, FieldRef, Schema};
22use datafusion_common::Result;
23use datafusion_common::config::ConfigOptions;
24use datafusion_common::{internal_err, not_impl_err};
25use datafusion_expr::ScalarFunctionArgs;
26use datafusion_expr::async_udf::AsyncScalarUDF;
27use datafusion_expr_common::columnar_value::ColumnarValue;
28use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
29use std::any::Any;
30use std::fmt::Display;
31use std::hash::{Hash, Hasher};
32use std::sync::Arc;
33
34#[derive(Debug, Clone, Eq)]
36pub struct AsyncFuncExpr {
37 pub name: String,
39 pub func: Arc<dyn PhysicalExpr>,
41 return_field: FieldRef,
43}
44
45impl Display for AsyncFuncExpr {
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 write!(f, "async_expr(name={}, expr={})", self.name, self.func)
48 }
49}
50
51impl PartialEq for AsyncFuncExpr {
52 fn eq(&self, other: &Self) -> bool {
53 self.name == other.name && self.func == Arc::clone(&other.func)
54 }
55}
56
57impl Hash for AsyncFuncExpr {
58 fn hash<H: Hasher>(&self, state: &mut H) {
59 self.name.hash(state);
60 self.func.as_ref().hash(state);
61 }
62}
63
64impl AsyncFuncExpr {
65 pub fn try_new(
67 name: impl Into<String>,
68 func: Arc<dyn PhysicalExpr>,
69 schema: &Schema,
70 ) -> Result<Self> {
71 let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
72 return internal_err!(
73 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
74 func
75 );
76 };
77
78 let return_field = func.return_field(schema)?;
79 Ok(Self {
80 name: name.into(),
81 func,
82 return_field,
83 })
84 }
85
86 pub fn name(&self) -> &str {
88 &self.name
89 }
90
91 pub fn field(&self, input_schema: &Schema) -> Result<Field> {
93 Ok(Field::new(
94 &self.name,
95 self.func.data_type(input_schema)?,
96 self.func.nullable(input_schema)?,
97 ))
98 }
99
100 pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
102 if let Some(expr) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
103 && let Some(udf) =
104 expr.fun().inner().as_any().downcast_ref::<AsyncScalarUDF>()
105 {
106 return Ok(udf.ideal_batch_size());
107 }
108 not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
109 }
110
111 pub async fn invoke_with_args(
115 &self,
116 batch: &RecordBatch,
117 config_options: Arc<ConfigOptions>,
118 ) -> Result<ColumnarValue> {
119 let Some(scalar_function_expr) =
120 self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
121 else {
122 return internal_err!(
123 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
124 self.func
125 );
126 };
127
128 let Some(async_udf) = scalar_function_expr
129 .fun()
130 .inner()
131 .as_any()
132 .downcast_ref::<AsyncScalarUDF>()
133 else {
134 return not_impl_err!(
135 "Don't know how to evaluate async function: {:?}",
136 scalar_function_expr
137 );
138 };
139
140 let arg_fields = scalar_function_expr
141 .args()
142 .iter()
143 .map(|e| e.return_field(batch.schema_ref()))
144 .collect::<Result<Vec<_>>>()?;
145
146 let mut result_batches = vec![];
147 if let Some(ideal_batch_size) = self.ideal_batch_size()? {
148 let mut remainder = batch.clone();
149 while remainder.num_rows() > 0 {
150 let size = if ideal_batch_size > remainder.num_rows() {
151 remainder.num_rows()
152 } else {
153 ideal_batch_size
154 };
155
156 let current_batch = remainder.slice(0, size); remainder = remainder.slice(size, remainder.num_rows() - size);
158 let args = scalar_function_expr
159 .args()
160 .iter()
161 .map(|e| e.evaluate(¤t_batch))
162 .collect::<Result<Vec<_>>>()?;
163 result_batches.push(
164 async_udf
165 .invoke_async_with_args(ScalarFunctionArgs {
166 args,
167 arg_fields: arg_fields.clone(),
168 number_rows: current_batch.num_rows(),
169 return_field: Arc::clone(&self.return_field),
170 config_options: Arc::clone(&config_options),
171 })
172 .await?,
173 );
174 }
175 } else {
176 let args = scalar_function_expr
177 .args()
178 .iter()
179 .map(|e| e.evaluate(batch))
180 .collect::<Result<Vec<_>>>()?;
181
182 result_batches.push(
183 async_udf
184 .invoke_async_with_args(ScalarFunctionArgs {
185 args: args.to_vec(),
186 arg_fields,
187 number_rows: batch.num_rows(),
188 return_field: Arc::clone(&self.return_field),
189 config_options: Arc::clone(&config_options),
190 })
191 .await?,
192 );
193 }
194
195 let datas = result_batches
196 .into_iter()
197 .map(|cv| match cv {
198 ColumnarValue::Array(arr) => Ok(arr),
199 ColumnarValue::Scalar(scalar) => Ok(scalar.to_array_of_size(1)?),
200 })
201 .collect::<Result<Vec<_>>>()?;
202
203 let dyn_arrays = datas
205 .iter()
206 .map(|arr| arr as &dyn arrow::array::Array)
207 .collect::<Vec<_>>();
208 let result_array = concat(&dyn_arrays)?;
209 Ok(ColumnarValue::Array(result_array))
210 }
211}
212
213impl PhysicalExpr for AsyncFuncExpr {
214 fn as_any(&self) -> &dyn Any {
215 self
216 }
217
218 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
219 self.func.data_type(input_schema)
220 }
221
222 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
223 self.func.nullable(input_schema)
224 }
225
226 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
227 not_impl_err!("AsyncFuncExpr.evaluate")
229 }
230
231 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
232 self.func.children()
233 }
234
235 fn with_new_children(
236 self: Arc<Self>,
237 children: Vec<Arc<dyn PhysicalExpr>>,
238 ) -> Result<Arc<dyn PhysicalExpr>> {
239 let new_func = Arc::clone(&self.func).with_new_children(children)?;
240 Ok(Arc::new(AsyncFuncExpr {
241 name: self.name.clone(),
242 func: new_func,
243 return_field: Arc::clone(&self.return_field),
244 }))
245 }
246
247 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(f, "{}", self.func)
249 }
250}