datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::config::{ConfigEntry, ConfigOptions};
43use datafusion_common::{internal_err, Result, ScalarValue};
44use datafusion_expr::interval_arithmetic::Interval;
45use datafusion_expr::sort_properties::ExprProperties;
46use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
47use datafusion_expr::{
48    expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
49    Volatility,
50};
51
52/// Physical expression of a scalar function
53pub struct ScalarFunctionExpr {
54    fun: Arc<ScalarUDF>,
55    name: String,
56    args: Vec<Arc<dyn PhysicalExpr>>,
57    return_field: FieldRef,
58    config_options: Arc<ConfigOptions>,
59}
60
61impl Debug for ScalarFunctionExpr {
62    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63        f.debug_struct("ScalarFunctionExpr")
64            .field("fun", &"<FUNC>")
65            .field("name", &self.name)
66            .field("args", &self.args)
67            .field("return_field", &self.return_field)
68            .finish()
69    }
70}
71
72impl ScalarFunctionExpr {
73    /// Create a new Scalar function
74    pub fn new(
75        name: &str,
76        fun: Arc<ScalarUDF>,
77        args: Vec<Arc<dyn PhysicalExpr>>,
78        return_field: FieldRef,
79        config_options: Arc<ConfigOptions>,
80    ) -> Self {
81        Self {
82            fun,
83            name: name.to_owned(),
84            args,
85            return_field,
86            config_options,
87        }
88    }
89
90    /// Create a new Scalar function
91    pub fn try_new(
92        fun: Arc<ScalarUDF>,
93        args: Vec<Arc<dyn PhysicalExpr>>,
94        schema: &Schema,
95        config_options: Arc<ConfigOptions>,
96    ) -> Result<Self> {
97        let name = fun.name().to_string();
98        let arg_fields = args
99            .iter()
100            .map(|e| e.return_field(schema))
101            .collect::<Result<Vec<_>>>()?;
102
103        // verify that input data types is consistent with function's `TypeSignature`
104        let arg_types = arg_fields
105            .iter()
106            .map(|f| f.data_type().clone())
107            .collect::<Vec<_>>();
108        data_types_with_scalar_udf(&arg_types, &fun)?;
109
110        let arguments = args
111            .iter()
112            .map(|e| {
113                e.as_any()
114                    .downcast_ref::<Literal>()
115                    .map(|literal| literal.value())
116            })
117            .collect::<Vec<_>>();
118        let ret_args = ReturnFieldArgs {
119            arg_fields: &arg_fields,
120            scalar_arguments: &arguments,
121        };
122        let return_field = fun.return_field_from_args(ret_args)?;
123        Ok(Self {
124            fun,
125            name,
126            args,
127            return_field,
128            config_options,
129        })
130    }
131
132    /// Get the scalar function implementation
133    pub fn fun(&self) -> &ScalarUDF {
134        &self.fun
135    }
136
137    /// The name for this expression
138    pub fn name(&self) -> &str {
139        &self.name
140    }
141
142    /// Input arguments
143    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
144        &self.args
145    }
146
147    /// Data type produced by this expression
148    pub fn return_type(&self) -> &DataType {
149        self.return_field.data_type()
150    }
151
152    pub fn with_nullable(mut self, nullable: bool) -> Self {
153        self.return_field = self
154            .return_field
155            .as_ref()
156            .clone()
157            .with_nullable(nullable)
158            .into();
159        self
160    }
161
162    pub fn nullable(&self) -> bool {
163        self.return_field.is_nullable()
164    }
165
166    pub fn config_options(&self) -> &ConfigOptions {
167        &self.config_options
168    }
169
170    /// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr
171    /// and verify that its inner function is of type T.
172    /// If the downcast fails, or the function is not of type T, returns `None`.
173    /// Otherwise returns `Some(ScalarFunctionExpr)`.
174    pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
175    where
176        T: 'static,
177    {
178        match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
179            Some(scalar_expr)
180                if scalar_expr
181                    .fun()
182                    .inner()
183                    .as_any()
184                    .downcast_ref::<T>()
185                    .is_some() =>
186            {
187                Some(scalar_expr)
188            }
189            _ => None,
190        }
191    }
192}
193
194impl fmt::Display for ScalarFunctionExpr {
195    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
196        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
197    }
198}
199
200impl PartialEq for ScalarFunctionExpr {
201    fn eq(&self, o: &Self) -> bool {
202        if std::ptr::eq(self, o) {
203            // The equality implementation is somewhat expensive, so let's short-circuit when possible.
204            return true;
205        }
206        let Self {
207            fun,
208            name,
209            args,
210            return_field,
211            config_options,
212        } = self;
213        fun.eq(&o.fun)
214            && name.eq(&o.name)
215            && args.eq(&o.args)
216            && return_field.eq(&o.return_field)
217            && (Arc::ptr_eq(config_options, &o.config_options)
218                || sorted_config_entries(config_options)
219                    == sorted_config_entries(&o.config_options))
220    }
221}
222impl Eq for ScalarFunctionExpr {}
223impl Hash for ScalarFunctionExpr {
224    fn hash<H: Hasher>(&self, state: &mut H) {
225        let Self {
226            fun,
227            name,
228            args,
229            return_field,
230            config_options: _, // expensive to hash, and often equal
231        } = self;
232        fun.hash(state);
233        name.hash(state);
234        args.hash(state);
235        return_field.hash(state);
236    }
237}
238
239fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
240    let mut entries = config_options.entries();
241    entries.sort_by(|l, r| l.key.cmp(&r.key));
242    entries
243}
244
245impl PhysicalExpr for ScalarFunctionExpr {
246    /// Return a reference to Any that can be used for downcasting
247    fn as_any(&self) -> &dyn Any {
248        self
249    }
250
251    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
252        Ok(self.return_field.data_type().clone())
253    }
254
255    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
256        Ok(self.return_field.is_nullable())
257    }
258
259    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
260        let args = self
261            .args
262            .iter()
263            .map(|e| e.evaluate(batch))
264            .collect::<Result<Vec<_>>>()?;
265
266        let arg_fields = self
267            .args
268            .iter()
269            .map(|e| e.return_field(batch.schema_ref()))
270            .collect::<Result<Vec<_>>>()?;
271
272        let input_empty = args.is_empty();
273        let input_all_scalar = args
274            .iter()
275            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
276
277        // evaluate the function
278        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
279            args,
280            arg_fields,
281            number_rows: batch.num_rows(),
282            return_field: Arc::clone(&self.return_field),
283            config_options: Arc::clone(&self.config_options),
284        })?;
285
286        if let ColumnarValue::Array(array) = &output {
287            if array.len() != batch.num_rows() {
288                // If the arguments are a non-empty slice of scalar values, we can assume that
289                // returning a one-element array is equivalent to returning a scalar.
290                let preserve_scalar =
291                    array.len() == 1 && !input_empty && input_all_scalar;
292                return if preserve_scalar {
293                    ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
294                } else {
295                    internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
296                            self.name, batch.num_rows(), array.len())
297                };
298            }
299        }
300        Ok(output)
301    }
302
303    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
304        Ok(Arc::clone(&self.return_field))
305    }
306
307    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
308        self.args.iter().collect()
309    }
310
311    fn with_new_children(
312        self: Arc<Self>,
313        children: Vec<Arc<dyn PhysicalExpr>>,
314    ) -> Result<Arc<dyn PhysicalExpr>> {
315        Ok(Arc::new(ScalarFunctionExpr::new(
316            &self.name,
317            Arc::clone(&self.fun),
318            children,
319            Arc::clone(&self.return_field),
320            Arc::clone(&self.config_options),
321        )))
322    }
323
324    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
325        self.fun.evaluate_bounds(children)
326    }
327
328    fn propagate_constraints(
329        &self,
330        interval: &Interval,
331        children: &[&Interval],
332    ) -> Result<Option<Vec<Interval>>> {
333        self.fun.propagate_constraints(interval, children)
334    }
335
336    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
337        let sort_properties = self.fun.output_ordering(children)?;
338        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
339        let children_range = children
340            .iter()
341            .map(|props| &props.range)
342            .collect::<Vec<_>>();
343        let range = self.fun().evaluate_bounds(&children_range)?;
344
345        Ok(ExprProperties {
346            sort_properties,
347            range,
348            preserves_lex_ordering,
349        })
350    }
351
352    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
353        write!(f, "{}(", self.name)?;
354        for (i, expr) in self.args.iter().enumerate() {
355            if i > 0 {
356                write!(f, ", ")?;
357            }
358            expr.fmt_sql(f)?;
359        }
360        write!(f, ")")
361    }
362
363    fn is_volatile_node(&self) -> bool {
364        self.fun.signature().volatility == Volatility::Volatile
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::expressions::Column;
372    use arrow::datatypes::{DataType, Field, Schema};
373    use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
374    use datafusion_physical_expr_common::physical_expr::is_volatile;
375    use std::any::Any;
376
377    /// Test helper to create a mock UDF with a specific volatility
378    #[derive(Debug, PartialEq, Eq, Hash)]
379    struct MockScalarUDF {
380        signature: Signature,
381    }
382
383    impl ScalarUDFImpl for MockScalarUDF {
384        fn as_any(&self) -> &dyn Any {
385            self
386        }
387
388        fn name(&self) -> &str {
389            "mock_function"
390        }
391
392        fn signature(&self) -> &Signature {
393            &self.signature
394        }
395
396        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
397            Ok(DataType::Int32)
398        }
399
400        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
401            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
402        }
403    }
404
405    #[test]
406    fn test_scalar_function_volatile_node() {
407        // Create a volatile UDF
408        let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
409            signature: Signature::uniform(
410                1,
411                vec![DataType::Float32],
412                Volatility::Volatile,
413            ),
414        }));
415
416        // Create a non-volatile UDF
417        let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
418            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
419        }));
420
421        let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
422        let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
423        let config_options = Arc::new(ConfigOptions::new());
424
425        // Test volatile function
426        let volatile_expr = ScalarFunctionExpr::try_new(
427            volatile_udf,
428            args.clone(),
429            &schema,
430            Arc::clone(&config_options),
431        )
432        .unwrap();
433
434        assert!(volatile_expr.is_volatile_node());
435        let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
436        assert!(is_volatile(&volatile_arc));
437
438        // Test non-volatile function
439        let stable_expr =
440            ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
441                .unwrap();
442
443        assert!(!stable_expr.is_volatile_node());
444        let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
445        assert!(!is_volatile(&stable_arc));
446    }
447}