datafusion_physical_expr/
async_scalar_function.rs1use crate::ScalarFunctionExpr;
19use arrow::array::RecordBatch;
20use arrow::compute::concat;
21use arrow::datatypes::{DataType, Field, FieldRef, Schema};
22use datafusion_common::Result;
23use datafusion_common::config::ConfigOptions;
24use datafusion_common::{internal_err, not_impl_err};
25use datafusion_expr::ScalarFunctionArgs;
26use datafusion_expr::async_udf::AsyncScalarUDF;
27use datafusion_expr_common::columnar_value::ColumnarValue;
28use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
29use std::fmt::Display;
30use std::hash::{Hash, Hasher};
31use std::sync::Arc;
32
33#[derive(Debug, Clone, Eq)]
35pub struct AsyncFuncExpr {
36 pub name: String,
38 pub func: Arc<dyn PhysicalExpr>,
40 return_field: FieldRef,
42}
43
44impl Display for AsyncFuncExpr {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 write!(f, "async_expr(name={}, expr={})", self.name, self.func)
47 }
48}
49
50impl PartialEq for AsyncFuncExpr {
51 fn eq(&self, other: &Self) -> bool {
52 self.name == other.name && self.func == Arc::clone(&other.func)
53 }
54}
55
56impl Hash for AsyncFuncExpr {
57 fn hash<H: Hasher>(&self, state: &mut H) {
58 self.name.hash(state);
59 self.func.as_ref().hash(state);
60 }
61}
62
63impl AsyncFuncExpr {
64 pub fn try_new(
66 name: impl Into<String>,
67 func: Arc<dyn PhysicalExpr>,
68 schema: &Schema,
69 ) -> Result<Self> {
70 let Some(_) = func.downcast_ref::<ScalarFunctionExpr>() else {
71 return internal_err!(
72 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
73 func
74 );
75 };
76
77 let return_field = func.return_field(schema)?;
78 Ok(Self {
79 name: name.into(),
80 func,
81 return_field,
82 })
83 }
84
85 pub fn name(&self) -> &str {
87 &self.name
88 }
89
90 pub fn field(&self, input_schema: &Schema) -> Result<Field> {
92 Ok(Field::new(
93 &self.name,
94 self.func.data_type(input_schema)?,
95 self.func.nullable(input_schema)?,
96 ))
97 }
98
99 pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
101 if let Some(expr) = self.func.downcast_ref::<ScalarFunctionExpr>()
102 && let Some(udf) = expr.fun().inner().downcast_ref::<AsyncScalarUDF>()
103 {
104 return Ok(udf.ideal_batch_size());
105 }
106 not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
107 }
108
109 pub async fn invoke_with_args(
113 &self,
114 batch: &RecordBatch,
115 config_options: Arc<ConfigOptions>,
116 ) -> Result<ColumnarValue> {
117 let Some(scalar_function_expr) = self.func.downcast_ref::<ScalarFunctionExpr>()
118 else {
119 return internal_err!(
120 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
121 self.func
122 );
123 };
124
125 let Some(async_udf) = scalar_function_expr
126 .fun()
127 .inner()
128 .downcast_ref::<AsyncScalarUDF>()
129 else {
130 return not_impl_err!(
131 "Don't know how to evaluate async function: {:?}",
132 scalar_function_expr
133 );
134 };
135
136 let arg_fields = scalar_function_expr
137 .args()
138 .iter()
139 .map(|e| e.return_field(batch.schema_ref()))
140 .collect::<Result<Vec<_>>>()?;
141
142 let mut result_batches = vec![];
143 if let Some(ideal_batch_size) = self.ideal_batch_size()? {
144 let mut remainder = batch.clone();
145 while remainder.num_rows() > 0 {
146 let size = if ideal_batch_size > remainder.num_rows() {
147 remainder.num_rows()
148 } else {
149 ideal_batch_size
150 };
151
152 let current_batch = remainder.slice(0, size); remainder = remainder.slice(size, remainder.num_rows() - size);
154 let args = scalar_function_expr
155 .args()
156 .iter()
157 .map(|e| e.evaluate(¤t_batch))
158 .collect::<Result<Vec<_>>>()?;
159 result_batches.push(
160 async_udf
161 .invoke_async_with_args(ScalarFunctionArgs {
162 args,
163 arg_fields: arg_fields.clone(),
164 number_rows: current_batch.num_rows(),
165 return_field: Arc::clone(&self.return_field),
166 config_options: Arc::clone(&config_options),
167 })
168 .await?,
169 );
170 }
171 } else {
172 let args = scalar_function_expr
173 .args()
174 .iter()
175 .map(|e| e.evaluate(batch))
176 .collect::<Result<Vec<_>>>()?;
177
178 result_batches.push(
179 async_udf
180 .invoke_async_with_args(ScalarFunctionArgs {
181 args: args.to_vec(),
182 arg_fields,
183 number_rows: batch.num_rows(),
184 return_field: Arc::clone(&self.return_field),
185 config_options: Arc::clone(&config_options),
186 })
187 .await?,
188 );
189 }
190
191 let datas = result_batches
192 .into_iter()
193 .map(|cv| match cv {
194 ColumnarValue::Array(arr) => Ok(arr),
195 ColumnarValue::Scalar(scalar) => Ok(scalar.to_array_of_size(1)?),
196 })
197 .collect::<Result<Vec<_>>>()?;
198
199 let dyn_arrays = datas
201 .iter()
202 .map(|arr| arr as &dyn arrow::array::Array)
203 .collect::<Vec<_>>();
204 let result_array = concat(&dyn_arrays)?;
205 Ok(ColumnarValue::Array(result_array))
206 }
207}
208
209impl PhysicalExpr for AsyncFuncExpr {
210 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
211 self.func.data_type(input_schema)
212 }
213
214 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
215 self.func.nullable(input_schema)
216 }
217
218 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
219 not_impl_err!("AsyncFuncExpr.evaluate")
221 }
222
223 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
224 self.func.children()
225 }
226
227 fn with_new_children(
228 self: Arc<Self>,
229 children: Vec<Arc<dyn PhysicalExpr>>,
230 ) -> Result<Arc<dyn PhysicalExpr>> {
231 let new_func = Arc::clone(&self.func).with_new_children(children)?;
232 Ok(Arc::new(AsyncFuncExpr {
233 name: self.name.clone(),
234 func: new_func,
235 return_field: Arc::clone(&self.return_field),
236 }))
237 }
238
239 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 write!(f, "{}", self.func)
241 }
242}