datafusion_physical_expr/
async_scalar_function.rs1use crate::ScalarFunctionExpr;
19use arrow::array::{make_array, MutableArrayData, RecordBatch};
20use arrow::datatypes::{DataType, Field, FieldRef, Schema};
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::Result;
23use datafusion_common::{internal_err, not_impl_err};
24use datafusion_expr::async_udf::AsyncScalarUDF;
25use datafusion_expr::ScalarFunctionArgs;
26use datafusion_expr_common::columnar_value::ColumnarValue;
27use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
28use std::any::Any;
29use std::fmt::Display;
30use std::hash::{Hash, Hasher};
31use std::sync::Arc;
32
33#[derive(Debug, Clone, Eq)]
35pub struct AsyncFuncExpr {
36 pub name: String,
38 pub func: Arc<dyn PhysicalExpr>,
40 return_field: FieldRef,
42}
43
44impl Display for AsyncFuncExpr {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 write!(f, "async_expr(name={}, expr={})", self.name, self.func)
47 }
48}
49
50impl PartialEq for AsyncFuncExpr {
51 fn eq(&self, other: &Self) -> bool {
52 self.name == other.name && self.func == Arc::clone(&other.func)
53 }
54}
55
56impl Hash for AsyncFuncExpr {
57 fn hash<H: Hasher>(&self, state: &mut H) {
58 self.name.hash(state);
59 self.func.as_ref().hash(state);
60 }
61}
62
63impl AsyncFuncExpr {
64 pub fn try_new(
66 name: impl Into<String>,
67 func: Arc<dyn PhysicalExpr>,
68 schema: &Schema,
69 ) -> Result<Self> {
70 let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
71 return internal_err!(
72 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
73 func
74 );
75 };
76
77 let return_field = func.return_field(schema)?;
78 Ok(Self {
79 name: name.into(),
80 func,
81 return_field,
82 })
83 }
84
85 pub fn name(&self) -> &str {
87 &self.name
88 }
89
90 pub fn field(&self, input_schema: &Schema) -> Result<Field> {
92 Ok(Field::new(
93 &self.name,
94 self.func.data_type(input_schema)?,
95 self.func.nullable(input_schema)?,
96 ))
97 }
98
99 pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
101 if let Some(expr) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() {
102 if let Some(udf) =
103 expr.fun().inner().as_any().downcast_ref::<AsyncScalarUDF>()
104 {
105 return Ok(udf.ideal_batch_size());
106 }
107 }
108 not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
109 }
110
111 pub async fn invoke_with_args(
115 &self,
116 batch: &RecordBatch,
117 option: &ConfigOptions,
118 ) -> Result<ColumnarValue> {
119 let Some(scalar_function_expr) =
120 self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
121 else {
122 return internal_err!(
123 "unexpected function type, expected ScalarFunctionExpr, got: {:?}",
124 self.func
125 );
126 };
127
128 let Some(async_udf) = scalar_function_expr
129 .fun()
130 .inner()
131 .as_any()
132 .downcast_ref::<AsyncScalarUDF>()
133 else {
134 return not_impl_err!(
135 "Don't know how to evaluate async function: {:?}",
136 scalar_function_expr
137 );
138 };
139
140 let arg_fields = scalar_function_expr
141 .args()
142 .iter()
143 .map(|e| e.return_field(batch.schema_ref()))
144 .collect::<Result<Vec<_>>>()?;
145
146 let mut result_batches = vec![];
147 if let Some(ideal_batch_size) = self.ideal_batch_size()? {
148 let mut remainder = batch.clone();
149 while remainder.num_rows() > 0 {
150 let size = if ideal_batch_size > remainder.num_rows() {
151 remainder.num_rows()
152 } else {
153 ideal_batch_size
154 };
155
156 let current_batch = remainder.slice(0, size); remainder = remainder.slice(size, remainder.num_rows() - size);
158 let args = scalar_function_expr
159 .args()
160 .iter()
161 .map(|e| e.evaluate(¤t_batch))
162 .collect::<Result<Vec<_>>>()?;
163 result_batches.push(
164 async_udf
165 .invoke_async_with_args(
166 ScalarFunctionArgs {
167 args,
168 arg_fields: arg_fields.clone(),
169 number_rows: current_batch.num_rows(),
170 return_field: Arc::clone(&self.return_field),
171 },
172 option,
173 )
174 .await?,
175 );
176 }
177 } else {
178 let args = scalar_function_expr
179 .args()
180 .iter()
181 .map(|e| e.evaluate(batch))
182 .collect::<Result<Vec<_>>>()?;
183
184 result_batches.push(
185 async_udf
186 .invoke_async_with_args(
187 ScalarFunctionArgs {
188 args: args.to_vec(),
189 arg_fields,
190 number_rows: batch.num_rows(),
191 return_field: Arc::clone(&self.return_field),
192 },
193 option,
194 )
195 .await?,
196 );
197 }
198
199 let datas = result_batches
200 .iter()
201 .map(|b| b.to_data())
202 .collect::<Vec<_>>();
203 let total_len = datas.iter().map(|d| d.len()).sum();
204 let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len);
205 datas.iter().enumerate().for_each(|(i, data)| {
206 mutable.extend(i, 0, data.len());
207 });
208 let array_ref = make_array(mutable.freeze());
209 Ok(ColumnarValue::Array(array_ref))
210 }
211}
212
213impl PhysicalExpr for AsyncFuncExpr {
214 fn as_any(&self) -> &dyn Any {
215 self
216 }
217
218 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
219 self.func.data_type(input_schema)
220 }
221
222 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
223 self.func.nullable(input_schema)
224 }
225
226 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
227 not_impl_err!("AsyncFuncExpr.evaluate")
229 }
230
231 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
232 self.func.children()
233 }
234
235 fn with_new_children(
236 self: Arc<Self>,
237 children: Vec<Arc<dyn PhysicalExpr>>,
238 ) -> Result<Arc<dyn PhysicalExpr>> {
239 let new_func = Arc::clone(&self.func).with_new_children(children)?;
240 Ok(Arc::new(AsyncFuncExpr {
241 name: self.name.clone(),
242 func: new_func,
243 return_field: Arc::clone(&self.return_field),
244 }))
245 }
246
247 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(f, "{}", self.func)
249 }
250}