datafusion_physical_expr/
scalar_function.rs1use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37use crate::PhysicalExpr;
38use crate::expressions::Literal;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::config::{ConfigEntry, ConfigOptions};
43use datafusion_common::{Result, ScalarValue, internal_err};
44use datafusion_expr::interval_arithmetic::Interval;
45use datafusion_expr::sort_properties::ExprProperties;
46use datafusion_expr::type_coercion::functions::fields_with_udf;
47use datafusion_expr::{
48 ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
49 Volatility, expr_vec_fmt,
50};
51
52pub struct ScalarFunctionExpr {
54 fun: Arc<ScalarUDF>,
55 name: String,
56 args: Vec<Arc<dyn PhysicalExpr>>,
57 return_field: FieldRef,
58 config_options: Arc<ConfigOptions>,
59}
60
61impl Debug for ScalarFunctionExpr {
62 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63 f.debug_struct("ScalarFunctionExpr")
64 .field("fun", &"<FUNC>")
65 .field("name", &self.name)
66 .field("args", &self.args)
67 .field("return_field", &self.return_field)
68 .finish()
69 }
70}
71
72impl ScalarFunctionExpr {
73 pub fn new(
75 name: &str,
76 fun: Arc<ScalarUDF>,
77 args: Vec<Arc<dyn PhysicalExpr>>,
78 return_field: FieldRef,
79 config_options: Arc<ConfigOptions>,
80 ) -> Self {
81 Self {
82 fun,
83 name: name.to_owned(),
84 args,
85 return_field,
86 config_options,
87 }
88 }
89
90 pub fn try_new(
92 fun: Arc<ScalarUDF>,
93 args: Vec<Arc<dyn PhysicalExpr>>,
94 schema: &Schema,
95 config_options: Arc<ConfigOptions>,
96 ) -> Result<Self> {
97 let name = fun.name().to_string();
98 let arg_fields = args
99 .iter()
100 .map(|e| e.return_field(schema))
101 .collect::<Result<Vec<_>>>()?;
102
103 fields_with_udf(&arg_fields, fun.as_ref())?;
105
106 let arguments = args
107 .iter()
108 .map(|e| {
109 e.as_any()
110 .downcast_ref::<Literal>()
111 .map(|literal| literal.value())
112 })
113 .collect::<Vec<_>>();
114 let ret_args = ReturnFieldArgs {
115 arg_fields: &arg_fields,
116 scalar_arguments: &arguments,
117 };
118 let return_field = fun.return_field_from_args(ret_args)?;
119 Ok(Self {
120 fun,
121 name,
122 args,
123 return_field,
124 config_options,
125 })
126 }
127
128 pub fn fun(&self) -> &ScalarUDF {
130 &self.fun
131 }
132
133 pub fn name(&self) -> &str {
135 &self.name
136 }
137
138 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
140 &self.args
141 }
142
143 pub fn return_type(&self) -> &DataType {
145 self.return_field.data_type()
146 }
147
148 pub fn with_nullable(mut self, nullable: bool) -> Self {
149 self.return_field = self
150 .return_field
151 .as_ref()
152 .clone()
153 .with_nullable(nullable)
154 .into();
155 self
156 }
157
158 pub fn nullable(&self) -> bool {
159 self.return_field.is_nullable()
160 }
161
162 pub fn config_options(&self) -> &ConfigOptions {
163 &self.config_options
164 }
165
166 pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
171 where
172 T: 'static,
173 {
174 match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
175 Some(scalar_expr)
176 if scalar_expr
177 .fun()
178 .inner()
179 .as_any()
180 .downcast_ref::<T>()
181 .is_some() =>
182 {
183 Some(scalar_expr)
184 }
185 _ => None,
186 }
187 }
188}
189
190impl fmt::Display for ScalarFunctionExpr {
191 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
192 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
193 }
194}
195
196impl PartialEq for ScalarFunctionExpr {
197 fn eq(&self, o: &Self) -> bool {
198 if std::ptr::eq(self, o) {
199 return true;
201 }
202 let Self {
203 fun,
204 name,
205 args,
206 return_field,
207 config_options,
208 } = self;
209 fun.eq(&o.fun)
210 && name.eq(&o.name)
211 && args.eq(&o.args)
212 && return_field.eq(&o.return_field)
213 && (Arc::ptr_eq(config_options, &o.config_options)
214 || sorted_config_entries(config_options)
215 == sorted_config_entries(&o.config_options))
216 }
217}
218impl Eq for ScalarFunctionExpr {}
219impl Hash for ScalarFunctionExpr {
220 fn hash<H: Hasher>(&self, state: &mut H) {
221 let Self {
222 fun,
223 name,
224 args,
225 return_field,
226 config_options: _, } = self;
228 fun.hash(state);
229 name.hash(state);
230 args.hash(state);
231 return_field.hash(state);
232 }
233}
234
235fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
236 let mut entries = config_options.entries();
237 entries.sort_by(|l, r| l.key.cmp(&r.key));
238 entries
239}
240
241impl PhysicalExpr for ScalarFunctionExpr {
242 fn as_any(&self) -> &dyn Any {
244 self
245 }
246
247 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
248 Ok(self.return_field.data_type().clone())
249 }
250
251 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
252 Ok(self.return_field.is_nullable())
253 }
254
255 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
256 let args = self
257 .args
258 .iter()
259 .map(|e| e.evaluate(batch))
260 .collect::<Result<Vec<_>>>()?;
261
262 let arg_fields = self
263 .args
264 .iter()
265 .map(|e| e.return_field(batch.schema_ref()))
266 .collect::<Result<Vec<_>>>()?;
267
268 let input_empty = args.is_empty();
269 let input_all_scalar = args
270 .iter()
271 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
272
273 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
275 args,
276 arg_fields,
277 number_rows: batch.num_rows(),
278 return_field: Arc::clone(&self.return_field),
279 config_options: Arc::clone(&self.config_options),
280 })?;
281
282 if let ColumnarValue::Array(array) = &output
283 && array.len() != batch.num_rows()
284 {
285 let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
288 return if preserve_scalar {
289 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
290 } else {
291 internal_err!(
292 "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
293 self.name,
294 batch.num_rows(),
295 array.len()
296 )
297 };
298 }
299 Ok(output)
300 }
301
302 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
303 Ok(Arc::clone(&self.return_field))
304 }
305
306 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
307 self.args.iter().collect()
308 }
309
310 fn with_new_children(
311 self: Arc<Self>,
312 children: Vec<Arc<dyn PhysicalExpr>>,
313 ) -> Result<Arc<dyn PhysicalExpr>> {
314 Ok(Arc::new(ScalarFunctionExpr::new(
315 &self.name,
316 Arc::clone(&self.fun),
317 children,
318 Arc::clone(&self.return_field),
319 Arc::clone(&self.config_options),
320 )))
321 }
322
323 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
324 self.fun.evaluate_bounds(children)
325 }
326
327 fn propagate_constraints(
328 &self,
329 interval: &Interval,
330 children: &[&Interval],
331 ) -> Result<Option<Vec<Interval>>> {
332 self.fun.propagate_constraints(interval, children)
333 }
334
335 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
336 let sort_properties = self.fun.output_ordering(children)?;
337 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
338 let children_range = children
339 .iter()
340 .map(|props| &props.range)
341 .collect::<Vec<_>>();
342 let range = self.fun().evaluate_bounds(&children_range)?;
343
344 Ok(ExprProperties {
345 sort_properties,
346 range,
347 preserves_lex_ordering,
348 })
349 }
350
351 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
352 write!(f, "{}(", self.name)?;
353 for (i, expr) in self.args.iter().enumerate() {
354 if i > 0 {
355 write!(f, ", ")?;
356 }
357 expr.fmt_sql(f)?;
358 }
359 write!(f, ")")
360 }
361
362 fn is_volatile_node(&self) -> bool {
363 self.fun.signature().volatility == Volatility::Volatile
364 }
365
366 fn placement(&self) -> ExpressionPlacement {
367 let arg_placements: Vec<_> =
368 self.args.iter().map(|arg| arg.placement()).collect();
369 self.fun.placement(&arg_placements)
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::expressions::Column;
377 use arrow::datatypes::{DataType, Field, Schema};
378 use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
379 use datafusion_physical_expr_common::physical_expr::is_volatile;
380 use std::any::Any;
381
382 #[derive(Debug, PartialEq, Eq, Hash)]
384 struct MockScalarUDF {
385 signature: Signature,
386 }
387
388 impl ScalarUDFImpl for MockScalarUDF {
389 fn as_any(&self) -> &dyn Any {
390 self
391 }
392
393 fn name(&self) -> &str {
394 "mock_function"
395 }
396
397 fn signature(&self) -> &Signature {
398 &self.signature
399 }
400
401 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
402 Ok(DataType::Int32)
403 }
404
405 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
406 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
407 }
408 }
409
410 #[test]
411 fn test_scalar_function_volatile_node() {
412 let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
414 signature: Signature::uniform(
415 1,
416 vec![DataType::Float32],
417 Volatility::Volatile,
418 ),
419 }));
420
421 let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
423 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
424 }));
425
426 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
427 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
428 let config_options = Arc::new(ConfigOptions::new());
429
430 let volatile_expr = ScalarFunctionExpr::try_new(
432 volatile_udf,
433 args.clone(),
434 &schema,
435 Arc::clone(&config_options),
436 )
437 .unwrap();
438
439 assert!(volatile_expr.is_volatile_node());
440 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
441 assert!(is_volatile(&volatile_arc));
442
443 let stable_expr =
445 ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
446 .unwrap();
447
448 assert!(!stable_expr.is_volatile_node());
449 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
450 assert!(!is_volatile(&stable_arc));
451 }
452}