datafusion_physical_expr/
scalar_function.rs1use std::fmt::{self, Debug, Formatter};
33use std::hash::{Hash, Hasher};
34use std::sync::Arc;
35
36use crate::PhysicalExpr;
37use crate::expressions::Literal;
38
39use arrow::array::{Array, RecordBatch};
40use arrow::datatypes::{DataType, FieldRef, Schema};
41use datafusion_common::config::{ConfigEntry, ConfigOptions};
42use datafusion_common::{Result, ScalarValue, internal_err};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::fields_with_udf;
46use datafusion_expr::{
47 ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
48 ScalarUDFImpl, Volatility, expr_vec_fmt,
49};
50
51pub struct ScalarFunctionExpr {
53 fun: Arc<ScalarUDF>,
54 name: String,
55 args: Vec<Arc<dyn PhysicalExpr>>,
56 return_field: FieldRef,
57 config_options: Arc<ConfigOptions>,
58}
59
60impl Debug for ScalarFunctionExpr {
61 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
62 f.debug_struct("ScalarFunctionExpr")
63 .field("fun", &"<FUNC>")
64 .field("name", &self.name)
65 .field("args", &self.args)
66 .field("return_field", &self.return_field)
67 .finish()
68 }
69}
70
71impl ScalarFunctionExpr {
72 pub fn new(
74 name: &str,
75 fun: Arc<ScalarUDF>,
76 args: Vec<Arc<dyn PhysicalExpr>>,
77 return_field: FieldRef,
78 config_options: Arc<ConfigOptions>,
79 ) -> Self {
80 Self {
81 fun,
82 name: name.to_owned(),
83 args,
84 return_field,
85 config_options,
86 }
87 }
88
89 pub fn try_new(
91 fun: Arc<ScalarUDF>,
92 args: Vec<Arc<dyn PhysicalExpr>>,
93 schema: &Schema,
94 config_options: Arc<ConfigOptions>,
95 ) -> Result<Self> {
96 let name = fun.name().to_string();
97 let arg_fields = args
98 .iter()
99 .map(|e| e.return_field(schema))
100 .collect::<Result<Vec<_>>>()?;
101
102 fields_with_udf(&arg_fields, fun.as_ref())?;
104
105 let arguments = args
106 .iter()
107 .map(|e| e.downcast_ref::<Literal>().map(|literal| literal.value()))
108 .collect::<Vec<_>>();
109 let ret_args = ReturnFieldArgs {
110 arg_fields: &arg_fields,
111 scalar_arguments: &arguments,
112 };
113 let return_field = fun.return_field_from_args(ret_args)?;
114 Ok(Self {
115 fun,
116 name,
117 args,
118 return_field,
119 config_options,
120 })
121 }
122
123 pub fn fun(&self) -> &ScalarUDF {
125 &self.fun
126 }
127
128 pub fn name(&self) -> &str {
130 &self.name
131 }
132
133 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
135 &self.args
136 }
137
138 pub fn return_type(&self) -> &DataType {
140 self.return_field.data_type()
141 }
142
143 pub fn with_nullable(mut self, nullable: bool) -> Self {
144 self.return_field = self
145 .return_field
146 .as_ref()
147 .clone()
148 .with_nullable(nullable)
149 .into();
150 self
151 }
152
153 pub fn nullable(&self) -> bool {
154 self.return_field.is_nullable()
155 }
156
157 pub fn config_options(&self) -> &ConfigOptions {
158 &self.config_options
159 }
160
161 pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
166 where
167 T: ScalarUDFImpl,
168 {
169 match expr.downcast_ref::<ScalarFunctionExpr>() {
170 Some(scalar_expr) if scalar_expr.fun().inner().is::<T>() => Some(scalar_expr),
171 _ => None,
172 }
173 }
174}
175
176impl fmt::Display for ScalarFunctionExpr {
177 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
178 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
179 }
180}
181
182impl PartialEq for ScalarFunctionExpr {
183 fn eq(&self, o: &Self) -> bool {
184 if std::ptr::eq(self, o) {
185 return true;
187 }
188 let Self {
189 fun,
190 name,
191 args,
192 return_field,
193 config_options,
194 } = self;
195 fun.eq(&o.fun)
196 && name.eq(&o.name)
197 && args.eq(&o.args)
198 && return_field.eq(&o.return_field)
199 && (Arc::ptr_eq(config_options, &o.config_options)
200 || sorted_config_entries(config_options)
201 == sorted_config_entries(&o.config_options))
202 }
203}
204impl Eq for ScalarFunctionExpr {}
205impl Hash for ScalarFunctionExpr {
206 fn hash<H: Hasher>(&self, state: &mut H) {
207 let Self {
208 fun,
209 name,
210 args,
211 return_field,
212 config_options: _, } = self;
214 fun.hash(state);
215 name.hash(state);
216 args.hash(state);
217 return_field.hash(state);
218 }
219}
220
221fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
222 let mut entries = config_options.entries();
223 entries.sort_by(|l, r| l.key.cmp(&r.key));
224 entries
225}
226
227impl PhysicalExpr for ScalarFunctionExpr {
228 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
229 Ok(self.return_field.data_type().clone())
230 }
231
232 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
233 Ok(self.return_field.is_nullable())
234 }
235
236 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
237 let args = self
238 .args
239 .iter()
240 .map(|e| e.evaluate(batch))
241 .collect::<Result<Vec<_>>>()?;
242
243 let arg_fields = self
244 .args
245 .iter()
246 .map(|e| e.return_field(batch.schema_ref()))
247 .collect::<Result<Vec<_>>>()?;
248
249 let input_empty = args.is_empty();
250 let input_all_scalar = args
251 .iter()
252 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
253
254 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
256 args,
257 arg_fields,
258 number_rows: batch.num_rows(),
259 return_field: Arc::clone(&self.return_field),
260 config_options: Arc::clone(&self.config_options),
261 })?;
262
263 if let ColumnarValue::Array(array) = &output
264 && array.len() != batch.num_rows()
265 {
266 let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
269 return if preserve_scalar {
270 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
271 } else {
272 internal_err!(
273 "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
274 self.name,
275 batch.num_rows(),
276 array.len()
277 )
278 };
279 }
280 Ok(output)
281 }
282
283 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
284 Ok(Arc::clone(&self.return_field))
285 }
286
287 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
288 self.args.iter().collect()
289 }
290
291 fn with_new_children(
292 self: Arc<Self>,
293 children: Vec<Arc<dyn PhysicalExpr>>,
294 ) -> Result<Arc<dyn PhysicalExpr>> {
295 Ok(Arc::new(ScalarFunctionExpr::new(
296 &self.name,
297 Arc::clone(&self.fun),
298 children,
299 Arc::clone(&self.return_field),
300 Arc::clone(&self.config_options),
301 )))
302 }
303
304 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
305 self.fun.evaluate_bounds(children)
306 }
307
308 fn propagate_constraints(
309 &self,
310 interval: &Interval,
311 children: &[&Interval],
312 ) -> Result<Option<Vec<Interval>>> {
313 self.fun.propagate_constraints(interval, children)
314 }
315
316 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
317 let sort_properties = self.fun.output_ordering(children)?;
318 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
319 let children_range = children
320 .iter()
321 .map(|props| &props.range)
322 .collect::<Vec<_>>();
323 let range = self.fun().evaluate_bounds(&children_range)?;
324
325 Ok(ExprProperties {
326 sort_properties,
327 range,
328 preserves_lex_ordering,
329 })
330 }
331
332 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
333 write!(f, "{}(", self.name)?;
334 for (i, expr) in self.args.iter().enumerate() {
335 if i > 0 {
336 write!(f, ", ")?;
337 }
338 expr.fmt_sql(f)?;
339 }
340 write!(f, ")")
341 }
342
343 fn is_volatile_node(&self) -> bool {
344 self.fun.signature().volatility == Volatility::Volatile
345 }
346
347 fn placement(&self) -> ExpressionPlacement {
348 let arg_placements: Vec<_> =
349 self.args.iter().map(|arg| arg.placement()).collect();
350 self.fun.placement(&arg_placements)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::expressions::Column;
358 use arrow::datatypes::Field;
359 use datafusion_expr::{ScalarUDFImpl, Signature};
360 use datafusion_physical_expr_common::physical_expr::is_volatile;
361
362 #[derive(Debug, PartialEq, Eq, Hash)]
364 struct MockScalarUDF {
365 signature: Signature,
366 }
367
368 impl ScalarUDFImpl for MockScalarUDF {
369 fn name(&self) -> &str {
370 "mock_function"
371 }
372
373 fn signature(&self) -> &Signature {
374 &self.signature
375 }
376
377 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
378 Ok(DataType::Int32)
379 }
380
381 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
382 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
383 }
384 }
385
386 #[test]
387 fn test_scalar_function_volatile_node() {
388 let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
390 signature: Signature::uniform(
391 1,
392 vec![DataType::Float32],
393 Volatility::Volatile,
394 ),
395 }));
396
397 let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
399 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
400 }));
401
402 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
403 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
404 let config_options = Arc::new(ConfigOptions::new());
405
406 let volatile_expr = ScalarFunctionExpr::try_new(
408 volatile_udf,
409 args.clone(),
410 &schema,
411 Arc::clone(&config_options),
412 )
413 .unwrap();
414
415 assert!(volatile_expr.is_volatile_node());
416 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
417 assert!(is_volatile(&volatile_arc));
418
419 let stable_expr =
421 ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
422 .unwrap();
423
424 assert!(!stable_expr.is_volatile_node());
425 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
426 assert!(!is_volatile(&stable_arc));
427 }
428}