datafusion_physical_expr/
scalar_function.rs1use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::config::{ConfigEntry, ConfigOptions};
43use datafusion_common::{internal_err, Result, ScalarValue};
44use datafusion_expr::interval_arithmetic::Interval;
45use datafusion_expr::sort_properties::ExprProperties;
46use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
47use datafusion_expr::{
48 expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
49 Volatility,
50};
51
52pub struct ScalarFunctionExpr {
54 fun: Arc<ScalarUDF>,
55 name: String,
56 args: Vec<Arc<dyn PhysicalExpr>>,
57 return_field: FieldRef,
58 config_options: Arc<ConfigOptions>,
59}
60
61impl Debug for ScalarFunctionExpr {
62 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63 f.debug_struct("ScalarFunctionExpr")
64 .field("fun", &"<FUNC>")
65 .field("name", &self.name)
66 .field("args", &self.args)
67 .field("return_field", &self.return_field)
68 .finish()
69 }
70}
71
72impl ScalarFunctionExpr {
73 pub fn new(
75 name: &str,
76 fun: Arc<ScalarUDF>,
77 args: Vec<Arc<dyn PhysicalExpr>>,
78 return_field: FieldRef,
79 config_options: Arc<ConfigOptions>,
80 ) -> Self {
81 Self {
82 fun,
83 name: name.to_owned(),
84 args,
85 return_field,
86 config_options,
87 }
88 }
89
90 pub fn try_new(
92 fun: Arc<ScalarUDF>,
93 args: Vec<Arc<dyn PhysicalExpr>>,
94 schema: &Schema,
95 config_options: Arc<ConfigOptions>,
96 ) -> Result<Self> {
97 let name = fun.name().to_string();
98 let arg_fields = args
99 .iter()
100 .map(|e| e.return_field(schema))
101 .collect::<Result<Vec<_>>>()?;
102
103 let arg_types = arg_fields
105 .iter()
106 .map(|f| f.data_type().clone())
107 .collect::<Vec<_>>();
108 data_types_with_scalar_udf(&arg_types, &fun)?;
109
110 let arguments = args
111 .iter()
112 .map(|e| {
113 e.as_any()
114 .downcast_ref::<Literal>()
115 .map(|literal| literal.value())
116 })
117 .collect::<Vec<_>>();
118 let ret_args = ReturnFieldArgs {
119 arg_fields: &arg_fields,
120 scalar_arguments: &arguments,
121 };
122 let return_field = fun.return_field_from_args(ret_args)?;
123 Ok(Self {
124 fun,
125 name,
126 args,
127 return_field,
128 config_options,
129 })
130 }
131
132 pub fn fun(&self) -> &ScalarUDF {
134 &self.fun
135 }
136
137 pub fn name(&self) -> &str {
139 &self.name
140 }
141
142 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
144 &self.args
145 }
146
147 pub fn return_type(&self) -> &DataType {
149 self.return_field.data_type()
150 }
151
152 pub fn with_nullable(mut self, nullable: bool) -> Self {
153 self.return_field = self
154 .return_field
155 .as_ref()
156 .clone()
157 .with_nullable(nullable)
158 .into();
159 self
160 }
161
162 pub fn nullable(&self) -> bool {
163 self.return_field.is_nullable()
164 }
165
166 pub fn config_options(&self) -> &ConfigOptions {
167 &self.config_options
168 }
169
170 pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
175 where
176 T: 'static,
177 {
178 match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
179 Some(scalar_expr)
180 if scalar_expr
181 .fun()
182 .inner()
183 .as_any()
184 .downcast_ref::<T>()
185 .is_some() =>
186 {
187 Some(scalar_expr)
188 }
189 _ => None,
190 }
191 }
192}
193
194impl fmt::Display for ScalarFunctionExpr {
195 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
196 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
197 }
198}
199
200impl PartialEq for ScalarFunctionExpr {
201 fn eq(&self, o: &Self) -> bool {
202 if std::ptr::eq(self, o) {
203 return true;
205 }
206 let Self {
207 fun,
208 name,
209 args,
210 return_field,
211 config_options,
212 } = self;
213 fun.eq(&o.fun)
214 && name.eq(&o.name)
215 && args.eq(&o.args)
216 && return_field.eq(&o.return_field)
217 && (Arc::ptr_eq(config_options, &o.config_options)
218 || sorted_config_entries(config_options)
219 == sorted_config_entries(&o.config_options))
220 }
221}
222impl Eq for ScalarFunctionExpr {}
223impl Hash for ScalarFunctionExpr {
224 fn hash<H: Hasher>(&self, state: &mut H) {
225 let Self {
226 fun,
227 name,
228 args,
229 return_field,
230 config_options: _, } = self;
232 fun.hash(state);
233 name.hash(state);
234 args.hash(state);
235 return_field.hash(state);
236 }
237}
238
239fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
240 let mut entries = config_options.entries();
241 entries.sort_by(|l, r| l.key.cmp(&r.key));
242 entries
243}
244
245impl PhysicalExpr for ScalarFunctionExpr {
246 fn as_any(&self) -> &dyn Any {
248 self
249 }
250
251 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
252 Ok(self.return_field.data_type().clone())
253 }
254
255 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
256 Ok(self.return_field.is_nullable())
257 }
258
259 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
260 let args = self
261 .args
262 .iter()
263 .map(|e| e.evaluate(batch))
264 .collect::<Result<Vec<_>>>()?;
265
266 let arg_fields = self
267 .args
268 .iter()
269 .map(|e| e.return_field(batch.schema_ref()))
270 .collect::<Result<Vec<_>>>()?;
271
272 let input_empty = args.is_empty();
273 let input_all_scalar = args
274 .iter()
275 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
276
277 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
279 args,
280 arg_fields,
281 number_rows: batch.num_rows(),
282 return_field: Arc::clone(&self.return_field),
283 config_options: Arc::clone(&self.config_options),
284 })?;
285
286 if let ColumnarValue::Array(array) = &output {
287 if array.len() != batch.num_rows() {
288 let preserve_scalar =
291 array.len() == 1 && !input_empty && input_all_scalar;
292 return if preserve_scalar {
293 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
294 } else {
295 internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
296 self.name, batch.num_rows(), array.len())
297 };
298 }
299 }
300 Ok(output)
301 }
302
303 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
304 Ok(Arc::clone(&self.return_field))
305 }
306
307 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
308 self.args.iter().collect()
309 }
310
311 fn with_new_children(
312 self: Arc<Self>,
313 children: Vec<Arc<dyn PhysicalExpr>>,
314 ) -> Result<Arc<dyn PhysicalExpr>> {
315 Ok(Arc::new(ScalarFunctionExpr::new(
316 &self.name,
317 Arc::clone(&self.fun),
318 children,
319 Arc::clone(&self.return_field),
320 Arc::clone(&self.config_options),
321 )))
322 }
323
324 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
325 self.fun.evaluate_bounds(children)
326 }
327
328 fn propagate_constraints(
329 &self,
330 interval: &Interval,
331 children: &[&Interval],
332 ) -> Result<Option<Vec<Interval>>> {
333 self.fun.propagate_constraints(interval, children)
334 }
335
336 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
337 let sort_properties = self.fun.output_ordering(children)?;
338 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
339 let children_range = children
340 .iter()
341 .map(|props| &props.range)
342 .collect::<Vec<_>>();
343 let range = self.fun().evaluate_bounds(&children_range)?;
344
345 Ok(ExprProperties {
346 sort_properties,
347 range,
348 preserves_lex_ordering,
349 })
350 }
351
352 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
353 write!(f, "{}(", self.name)?;
354 for (i, expr) in self.args.iter().enumerate() {
355 if i > 0 {
356 write!(f, ", ")?;
357 }
358 expr.fmt_sql(f)?;
359 }
360 write!(f, ")")
361 }
362
363 fn is_volatile_node(&self) -> bool {
364 self.fun.signature().volatility == Volatility::Volatile
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::expressions::Column;
372 use arrow::datatypes::{DataType, Field, Schema};
373 use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
374 use datafusion_physical_expr_common::physical_expr::is_volatile;
375 use std::any::Any;
376
377 #[derive(Debug, PartialEq, Eq, Hash)]
379 struct MockScalarUDF {
380 signature: Signature,
381 }
382
383 impl ScalarUDFImpl for MockScalarUDF {
384 fn as_any(&self) -> &dyn Any {
385 self
386 }
387
388 fn name(&self) -> &str {
389 "mock_function"
390 }
391
392 fn signature(&self) -> &Signature {
393 &self.signature
394 }
395
396 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
397 Ok(DataType::Int32)
398 }
399
400 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
401 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
402 }
403 }
404
405 #[test]
406 fn test_scalar_function_volatile_node() {
407 let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
409 signature: Signature::uniform(
410 1,
411 vec![DataType::Float32],
412 Volatility::Volatile,
413 ),
414 }));
415
416 let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
418 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
419 }));
420
421 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
422 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
423 let config_options = Arc::new(ConfigOptions::new());
424
425 let volatile_expr = ScalarFunctionExpr::try_new(
427 volatile_udf,
428 args.clone(),
429 &schema,
430 Arc::clone(&config_options),
431 )
432 .unwrap();
433
434 assert!(volatile_expr.is_volatile_node());
435 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
436 assert!(is_volatile(&volatile_arc));
437
438 let stable_expr =
440 ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
441 .unwrap();
442
443 assert!(!stable_expr.is_volatile_node());
444 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
445 assert!(!is_volatile(&stable_arc));
446 }
447}