datafusion_physical_expr/
async_scalar_function.rs1use crate::ScalarFunctionExpr;
19use arrow::array::{make_array, MutableArrayData, RecordBatch};
20use arrow::datatypes::{DataType, Field, FieldRef, Schema};
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::Result;
23use datafusion_common::{internal_err, not_impl_err};
24use datafusion_expr::async_udf::AsyncScalarUDF;
25use datafusion_expr::ScalarFunctionArgs;
26use datafusion_expr_common::columnar_value::ColumnarValue;
27use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
28use std::any::Any;
29use std::fmt::Display;
30use std::hash::{Hash, Hasher};
31use std::sync::Arc;
32
33#[derive(Debug, Clone, Eq)]
35pub struct AsyncFuncExpr {
36 pub name: String,
38 pub func: Arc<dyn PhysicalExpr>,
40 return_field: FieldRef,
42}
43
44impl Display for AsyncFuncExpr {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 write!(f, "async_expr(name={}, expr={})", self.name, self.func)
47 }
48}
49
50impl PartialEq for AsyncFuncExpr {
51 fn eq(&self, other: &Self) -> bool {
52 self.name == other.name && self.func == Arc::clone(&other.func)
53 }
54}
55
56impl Hash for AsyncFuncExpr {
57 fn hash<H: Hasher>(&self, state: &mut H) {
58 self.name.hash(state);
59 self.func.as_ref().hash(state);
60 }
61}
62
63impl AsyncFuncExpr {
64 pub fn try_new(
66 name: impl Into<String>,
67 func: Arc<dyn PhysicalExpr>,
68 schema: &Schema,
69 ) -> Result<Self> {
70 let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
71 return internal_err!(
72 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
73 func
74 );
75 };
76
77 let return_field = func.return_field(schema)?;
78 Ok(Self {
79 name: name.into(),
80 func,
81 return_field,
82 })
83 }
84
85 pub fn name(&self) -> &str {
87 &self.name
88 }
89
90 pub fn field(&self, input_schema: &Schema) -> Result<Field> {
92 Ok(Field::new(
93 &self.name,
94 self.func.data_type(input_schema)?,
95 self.func.nullable(input_schema)?,
96 ))
97 }
98
99 pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
101 if let Some(expr) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() {
102 if let Some(udf) =
103 expr.fun().inner().as_any().downcast_ref::<AsyncScalarUDF>()
104 {
105 return Ok(udf.ideal_batch_size());
106 }
107 }
108 not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
109 }
110
111 pub async fn invoke_with_args(
115 &self,
116 batch: &RecordBatch,
117 config_options: Arc<ConfigOptions>,
118 ) -> Result<ColumnarValue> {
119 let Some(scalar_function_expr) =
120 self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
121 else {
122 return internal_err!(
123 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
124 self.func
125 );
126 };
127
128 let Some(async_udf) = scalar_function_expr
129 .fun()
130 .inner()
131 .as_any()
132 .downcast_ref::<AsyncScalarUDF>()
133 else {
134 return not_impl_err!(
135 "Don't know how to evaluate async function: {:?}",
136 scalar_function_expr
137 );
138 };
139
140 let arg_fields = scalar_function_expr
141 .args()
142 .iter()
143 .map(|e| e.return_field(batch.schema_ref()))
144 .collect::<Result<Vec<_>>>()?;
145
146 let mut result_batches = vec![];
147 if let Some(ideal_batch_size) = self.ideal_batch_size()? {
148 let mut remainder = batch.clone();
149 while remainder.num_rows() > 0 {
150 let size = if ideal_batch_size > remainder.num_rows() {
151 remainder.num_rows()
152 } else {
153 ideal_batch_size
154 };
155
156 let current_batch = remainder.slice(0, size); remainder = remainder.slice(size, remainder.num_rows() - size);
158 let args = scalar_function_expr
159 .args()
160 .iter()
161 .map(|e| e.evaluate(¤t_batch))
162 .collect::<Result<Vec<_>>>()?;
163 result_batches.push(
164 async_udf
165 .invoke_async_with_args(ScalarFunctionArgs {
166 args,
167 arg_fields: arg_fields.clone(),
168 number_rows: current_batch.num_rows(),
169 return_field: Arc::clone(&self.return_field),
170 config_options: Arc::clone(&config_options),
171 })
172 .await?,
173 );
174 }
175 } else {
176 let args = scalar_function_expr
177 .args()
178 .iter()
179 .map(|e| e.evaluate(batch))
180 .collect::<Result<Vec<_>>>()?;
181
182 result_batches.push(
183 async_udf
184 .invoke_async_with_args(ScalarFunctionArgs {
185 args: args.to_vec(),
186 arg_fields,
187 number_rows: batch.num_rows(),
188 return_field: Arc::clone(&self.return_field),
189 config_options: Arc::clone(&config_options),
190 })
191 .await?,
192 );
193 }
194
195 let datas = ColumnarValue::values_to_arrays(&result_batches)?
196 .iter()
197 .map(|b| b.to_data())
198 .collect::<Vec<_>>();
199 let total_len = datas.iter().map(|d| d.len()).sum();
200 let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len);
201 datas.iter().enumerate().for_each(|(i, data)| {
202 mutable.extend(i, 0, data.len());
203 });
204 let array_ref = make_array(mutable.freeze());
205 Ok(ColumnarValue::Array(array_ref))
206 }
207}
208
209impl PhysicalExpr for AsyncFuncExpr {
210 fn as_any(&self) -> &dyn Any {
211 self
212 }
213
214 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
215 self.func.data_type(input_schema)
216 }
217
218 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
219 self.func.nullable(input_schema)
220 }
221
222 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
223 not_impl_err!("AsyncFuncExpr.evaluate")
225 }
226
227 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
228 self.func.children()
229 }
230
231 fn with_new_children(
232 self: Arc<Self>,
233 children: Vec<Arc<dyn PhysicalExpr>>,
234 ) -> Result<Arc<dyn PhysicalExpr>> {
235 let new_func = Arc::clone(&self.func).with_new_children(children)?;
236 Ok(Arc::new(AsyncFuncExpr {
237 name: self.name.clone(),
238 func: new_func,
239 return_field: Arc::clone(&self.return_field),
240 }))
241 }
242
243 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 write!(f, "{}", self.func)
245 }
246}