datafusion_physical_expr/
scalar_function.rs1use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37use crate::PhysicalExpr;
38use crate::expressions::Literal;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::config::{ConfigEntry, ConfigOptions};
43use datafusion_common::{Result, ScalarValue, internal_err};
44use datafusion_expr::interval_arithmetic::Interval;
45use datafusion_expr::sort_properties::ExprProperties;
46use datafusion_expr::type_coercion::functions::fields_with_udf;
47use datafusion_expr::{
48 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility,
49 expr_vec_fmt,
50};
51
52pub struct ScalarFunctionExpr {
54 fun: Arc<ScalarUDF>,
55 name: String,
56 args: Vec<Arc<dyn PhysicalExpr>>,
57 return_field: FieldRef,
58 config_options: Arc<ConfigOptions>,
59}
60
61impl Debug for ScalarFunctionExpr {
62 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63 f.debug_struct("ScalarFunctionExpr")
64 .field("fun", &"<FUNC>")
65 .field("name", &self.name)
66 .field("args", &self.args)
67 .field("return_field", &self.return_field)
68 .finish()
69 }
70}
71
72impl ScalarFunctionExpr {
73 pub fn new(
75 name: &str,
76 fun: Arc<ScalarUDF>,
77 args: Vec<Arc<dyn PhysicalExpr>>,
78 return_field: FieldRef,
79 config_options: Arc<ConfigOptions>,
80 ) -> Self {
81 Self {
82 fun,
83 name: name.to_owned(),
84 args,
85 return_field,
86 config_options,
87 }
88 }
89
90 pub fn try_new(
92 fun: Arc<ScalarUDF>,
93 args: Vec<Arc<dyn PhysicalExpr>>,
94 schema: &Schema,
95 config_options: Arc<ConfigOptions>,
96 ) -> Result<Self> {
97 let name = fun.name().to_string();
98 let arg_fields = args
99 .iter()
100 .map(|e| e.return_field(schema))
101 .collect::<Result<Vec<_>>>()?;
102
103 fields_with_udf(&arg_fields, fun.as_ref())?;
105
106 let arguments = args
107 .iter()
108 .map(|e| {
109 e.as_any()
110 .downcast_ref::<Literal>()
111 .map(|literal| literal.value())
112 })
113 .collect::<Vec<_>>();
114 let ret_args = ReturnFieldArgs {
115 arg_fields: &arg_fields,
116 scalar_arguments: &arguments,
117 };
118 let return_field = fun.return_field_from_args(ret_args)?;
119 Ok(Self {
120 fun,
121 name,
122 args,
123 return_field,
124 config_options,
125 })
126 }
127
128 pub fn fun(&self) -> &ScalarUDF {
130 &self.fun
131 }
132
133 pub fn name(&self) -> &str {
135 &self.name
136 }
137
138 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
140 &self.args
141 }
142
143 pub fn return_type(&self) -> &DataType {
145 self.return_field.data_type()
146 }
147
148 pub fn with_nullable(mut self, nullable: bool) -> Self {
149 self.return_field = self
150 .return_field
151 .as_ref()
152 .clone()
153 .with_nullable(nullable)
154 .into();
155 self
156 }
157
158 pub fn nullable(&self) -> bool {
159 self.return_field.is_nullable()
160 }
161
162 pub fn config_options(&self) -> &ConfigOptions {
163 &self.config_options
164 }
165
166 pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
171 where
172 T: 'static,
173 {
174 match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
175 Some(scalar_expr)
176 if scalar_expr
177 .fun()
178 .inner()
179 .as_any()
180 .downcast_ref::<T>()
181 .is_some() =>
182 {
183 Some(scalar_expr)
184 }
185 _ => None,
186 }
187 }
188}
189
190impl fmt::Display for ScalarFunctionExpr {
191 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
192 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
193 }
194}
195
196impl PartialEq for ScalarFunctionExpr {
197 fn eq(&self, o: &Self) -> bool {
198 if std::ptr::eq(self, o) {
199 return true;
201 }
202 let Self {
203 fun,
204 name,
205 args,
206 return_field,
207 config_options,
208 } = self;
209 fun.eq(&o.fun)
210 && name.eq(&o.name)
211 && args.eq(&o.args)
212 && return_field.eq(&o.return_field)
213 && (Arc::ptr_eq(config_options, &o.config_options)
214 || sorted_config_entries(config_options)
215 == sorted_config_entries(&o.config_options))
216 }
217}
218impl Eq for ScalarFunctionExpr {}
219impl Hash for ScalarFunctionExpr {
220 fn hash<H: Hasher>(&self, state: &mut H) {
221 let Self {
222 fun,
223 name,
224 args,
225 return_field,
226 config_options: _, } = self;
228 fun.hash(state);
229 name.hash(state);
230 args.hash(state);
231 return_field.hash(state);
232 }
233}
234
235fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
236 let mut entries = config_options.entries();
237 entries.sort_by(|l, r| l.key.cmp(&r.key));
238 entries
239}
240
241impl PhysicalExpr for ScalarFunctionExpr {
242 fn as_any(&self) -> &dyn Any {
244 self
245 }
246
247 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
248 Ok(self.return_field.data_type().clone())
249 }
250
251 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
252 Ok(self.return_field.is_nullable())
253 }
254
255 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
256 let args = self
257 .args
258 .iter()
259 .map(|e| e.evaluate(batch))
260 .collect::<Result<Vec<_>>>()?;
261
262 let arg_fields = self
263 .args
264 .iter()
265 .map(|e| e.return_field(batch.schema_ref()))
266 .collect::<Result<Vec<_>>>()?;
267
268 let input_empty = args.is_empty();
269 let input_all_scalar = args
270 .iter()
271 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
272
273 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
275 args,
276 arg_fields,
277 number_rows: batch.num_rows(),
278 return_field: Arc::clone(&self.return_field),
279 config_options: Arc::clone(&self.config_options),
280 })?;
281
282 if let ColumnarValue::Array(array) = &output
283 && array.len() != batch.num_rows()
284 {
285 let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
288 return if preserve_scalar {
289 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
290 } else {
291 internal_err!(
292 "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
293 self.name,
294 batch.num_rows(),
295 array.len()
296 )
297 };
298 }
299 Ok(output)
300 }
301
302 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
303 Ok(Arc::clone(&self.return_field))
304 }
305
306 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
307 self.args.iter().collect()
308 }
309
310 fn with_new_children(
311 self: Arc<Self>,
312 children: Vec<Arc<dyn PhysicalExpr>>,
313 ) -> Result<Arc<dyn PhysicalExpr>> {
314 Ok(Arc::new(ScalarFunctionExpr::new(
315 &self.name,
316 Arc::clone(&self.fun),
317 children,
318 Arc::clone(&self.return_field),
319 Arc::clone(&self.config_options),
320 )))
321 }
322
323 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
324 self.fun.evaluate_bounds(children)
325 }
326
327 fn propagate_constraints(
328 &self,
329 interval: &Interval,
330 children: &[&Interval],
331 ) -> Result<Option<Vec<Interval>>> {
332 self.fun.propagate_constraints(interval, children)
333 }
334
335 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
336 let sort_properties = self.fun.output_ordering(children)?;
337 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
338 let children_range = children
339 .iter()
340 .map(|props| &props.range)
341 .collect::<Vec<_>>();
342 let range = self.fun().evaluate_bounds(&children_range)?;
343
344 Ok(ExprProperties {
345 sort_properties,
346 range,
347 preserves_lex_ordering,
348 })
349 }
350
351 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
352 write!(f, "{}(", self.name)?;
353 for (i, expr) in self.args.iter().enumerate() {
354 if i > 0 {
355 write!(f, ", ")?;
356 }
357 expr.fmt_sql(f)?;
358 }
359 write!(f, ")")
360 }
361
362 fn is_volatile_node(&self) -> bool {
363 self.fun.signature().volatility == Volatility::Volatile
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::expressions::Column;
371 use arrow::datatypes::{DataType, Field, Schema};
372 use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
373 use datafusion_physical_expr_common::physical_expr::is_volatile;
374 use std::any::Any;
375
376 #[derive(Debug, PartialEq, Eq, Hash)]
378 struct MockScalarUDF {
379 signature: Signature,
380 }
381
382 impl ScalarUDFImpl for MockScalarUDF {
383 fn as_any(&self) -> &dyn Any {
384 self
385 }
386
387 fn name(&self) -> &str {
388 "mock_function"
389 }
390
391 fn signature(&self) -> &Signature {
392 &self.signature
393 }
394
395 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
396 Ok(DataType::Int32)
397 }
398
399 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
400 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
401 }
402 }
403
404 #[test]
405 fn test_scalar_function_volatile_node() {
406 let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
408 signature: Signature::uniform(
409 1,
410 vec![DataType::Float32],
411 Volatility::Volatile,
412 ),
413 }));
414
415 let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
417 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
418 }));
419
420 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
421 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
422 let config_options = Arc::new(ConfigOptions::new());
423
424 let volatile_expr = ScalarFunctionExpr::try_new(
426 volatile_udf,
427 args.clone(),
428 &schema,
429 Arc::clone(&config_options),
430 )
431 .unwrap();
432
433 assert!(volatile_expr.is_volatile_node());
434 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
435 assert!(is_volatile(&volatile_arc));
436
437 let stable_expr =
439 ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
440 .unwrap();
441
442 assert!(!stable_expr.is_volatile_node());
443 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
444 assert!(!is_volatile(&stable_arc));
445 }
446}