datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37use crate::PhysicalExpr;
38use crate::expressions::Literal;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::config::{ConfigEntry, ConfigOptions};
43use datafusion_common::{Result, ScalarValue, internal_err};
44use datafusion_expr::interval_arithmetic::Interval;
45use datafusion_expr::sort_properties::ExprProperties;
46use datafusion_expr::type_coercion::functions::fields_with_udf;
47use datafusion_expr::{
48    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility,
49    expr_vec_fmt,
50};
51
52/// Physical expression of a scalar function
53pub struct ScalarFunctionExpr {
54    fun: Arc<ScalarUDF>,
55    name: String,
56    args: Vec<Arc<dyn PhysicalExpr>>,
57    return_field: FieldRef,
58    config_options: Arc<ConfigOptions>,
59}
60
61impl Debug for ScalarFunctionExpr {
62    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63        f.debug_struct("ScalarFunctionExpr")
64            .field("fun", &"<FUNC>")
65            .field("name", &self.name)
66            .field("args", &self.args)
67            .field("return_field", &self.return_field)
68            .finish()
69    }
70}
71
72impl ScalarFunctionExpr {
73    /// Create a new Scalar function
74    pub fn new(
75        name: &str,
76        fun: Arc<ScalarUDF>,
77        args: Vec<Arc<dyn PhysicalExpr>>,
78        return_field: FieldRef,
79        config_options: Arc<ConfigOptions>,
80    ) -> Self {
81        Self {
82            fun,
83            name: name.to_owned(),
84            args,
85            return_field,
86            config_options,
87        }
88    }
89
90    /// Create a new Scalar function
91    pub fn try_new(
92        fun: Arc<ScalarUDF>,
93        args: Vec<Arc<dyn PhysicalExpr>>,
94        schema: &Schema,
95        config_options: Arc<ConfigOptions>,
96    ) -> Result<Self> {
97        let name = fun.name().to_string();
98        let arg_fields = args
99            .iter()
100            .map(|e| e.return_field(schema))
101            .collect::<Result<Vec<_>>>()?;
102
103        // verify that input data types is consistent with function's `TypeSignature`
104        fields_with_udf(&arg_fields, fun.as_ref())?;
105
106        let arguments = args
107            .iter()
108            .map(|e| {
109                e.as_any()
110                    .downcast_ref::<Literal>()
111                    .map(|literal| literal.value())
112            })
113            .collect::<Vec<_>>();
114        let ret_args = ReturnFieldArgs {
115            arg_fields: &arg_fields,
116            scalar_arguments: &arguments,
117        };
118        let return_field = fun.return_field_from_args(ret_args)?;
119        Ok(Self {
120            fun,
121            name,
122            args,
123            return_field,
124            config_options,
125        })
126    }
127
128    /// Get the scalar function implementation
129    pub fn fun(&self) -> &ScalarUDF {
130        &self.fun
131    }
132
133    /// The name for this expression
134    pub fn name(&self) -> &str {
135        &self.name
136    }
137
138    /// Input arguments
139    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
140        &self.args
141    }
142
143    /// Data type produced by this expression
144    pub fn return_type(&self) -> &DataType {
145        self.return_field.data_type()
146    }
147
148    pub fn with_nullable(mut self, nullable: bool) -> Self {
149        self.return_field = self
150            .return_field
151            .as_ref()
152            .clone()
153            .with_nullable(nullable)
154            .into();
155        self
156    }
157
158    pub fn nullable(&self) -> bool {
159        self.return_field.is_nullable()
160    }
161
162    pub fn config_options(&self) -> &ConfigOptions {
163        &self.config_options
164    }
165
166    /// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr
167    /// and verify that its inner function is of type T.
168    /// If the downcast fails, or the function is not of type T, returns `None`.
169    /// Otherwise returns `Some(ScalarFunctionExpr)`.
170    pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
171    where
172        T: 'static,
173    {
174        match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
175            Some(scalar_expr)
176                if scalar_expr
177                    .fun()
178                    .inner()
179                    .as_any()
180                    .downcast_ref::<T>()
181                    .is_some() =>
182            {
183                Some(scalar_expr)
184            }
185            _ => None,
186        }
187    }
188}
189
190impl fmt::Display for ScalarFunctionExpr {
191    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
192        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
193    }
194}
195
196impl PartialEq for ScalarFunctionExpr {
197    fn eq(&self, o: &Self) -> bool {
198        if std::ptr::eq(self, o) {
199            // The equality implementation is somewhat expensive, so let's short-circuit when possible.
200            return true;
201        }
202        let Self {
203            fun,
204            name,
205            args,
206            return_field,
207            config_options,
208        } = self;
209        fun.eq(&o.fun)
210            && name.eq(&o.name)
211            && args.eq(&o.args)
212            && return_field.eq(&o.return_field)
213            && (Arc::ptr_eq(config_options, &o.config_options)
214                || sorted_config_entries(config_options)
215                    == sorted_config_entries(&o.config_options))
216    }
217}
218impl Eq for ScalarFunctionExpr {}
219impl Hash for ScalarFunctionExpr {
220    fn hash<H: Hasher>(&self, state: &mut H) {
221        let Self {
222            fun,
223            name,
224            args,
225            return_field,
226            config_options: _, // expensive to hash, and often equal
227        } = self;
228        fun.hash(state);
229        name.hash(state);
230        args.hash(state);
231        return_field.hash(state);
232    }
233}
234
235fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
236    let mut entries = config_options.entries();
237    entries.sort_by(|l, r| l.key.cmp(&r.key));
238    entries
239}
240
241impl PhysicalExpr for ScalarFunctionExpr {
242    /// Return a reference to Any that can be used for downcasting
243    fn as_any(&self) -> &dyn Any {
244        self
245    }
246
247    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
248        Ok(self.return_field.data_type().clone())
249    }
250
251    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
252        Ok(self.return_field.is_nullable())
253    }
254
255    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
256        let args = self
257            .args
258            .iter()
259            .map(|e| e.evaluate(batch))
260            .collect::<Result<Vec<_>>>()?;
261
262        let arg_fields = self
263            .args
264            .iter()
265            .map(|e| e.return_field(batch.schema_ref()))
266            .collect::<Result<Vec<_>>>()?;
267
268        let input_empty = args.is_empty();
269        let input_all_scalar = args
270            .iter()
271            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
272
273        // evaluate the function
274        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
275            args,
276            arg_fields,
277            number_rows: batch.num_rows(),
278            return_field: Arc::clone(&self.return_field),
279            config_options: Arc::clone(&self.config_options),
280        })?;
281
282        if let ColumnarValue::Array(array) = &output
283            && array.len() != batch.num_rows()
284        {
285            // If the arguments are a non-empty slice of scalar values, we can assume that
286            // returning a one-element array is equivalent to returning a scalar.
287            let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
288            return if preserve_scalar {
289                ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
290            } else {
291                internal_err!(
292                    "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
293                    self.name,
294                    batch.num_rows(),
295                    array.len()
296                )
297            };
298        }
299        Ok(output)
300    }
301
302    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
303        Ok(Arc::clone(&self.return_field))
304    }
305
306    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
307        self.args.iter().collect()
308    }
309
310    fn with_new_children(
311        self: Arc<Self>,
312        children: Vec<Arc<dyn PhysicalExpr>>,
313    ) -> Result<Arc<dyn PhysicalExpr>> {
314        Ok(Arc::new(ScalarFunctionExpr::new(
315            &self.name,
316            Arc::clone(&self.fun),
317            children,
318            Arc::clone(&self.return_field),
319            Arc::clone(&self.config_options),
320        )))
321    }
322
323    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
324        self.fun.evaluate_bounds(children)
325    }
326
327    fn propagate_constraints(
328        &self,
329        interval: &Interval,
330        children: &[&Interval],
331    ) -> Result<Option<Vec<Interval>>> {
332        self.fun.propagate_constraints(interval, children)
333    }
334
335    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
336        let sort_properties = self.fun.output_ordering(children)?;
337        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
338        let children_range = children
339            .iter()
340            .map(|props| &props.range)
341            .collect::<Vec<_>>();
342        let range = self.fun().evaluate_bounds(&children_range)?;
343
344        Ok(ExprProperties {
345            sort_properties,
346            range,
347            preserves_lex_ordering,
348        })
349    }
350
351    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
352        write!(f, "{}(", self.name)?;
353        for (i, expr) in self.args.iter().enumerate() {
354            if i > 0 {
355                write!(f, ", ")?;
356            }
357            expr.fmt_sql(f)?;
358        }
359        write!(f, ")")
360    }
361
362    fn is_volatile_node(&self) -> bool {
363        self.fun.signature().volatility == Volatility::Volatile
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::expressions::Column;
371    use arrow::datatypes::{DataType, Field, Schema};
372    use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
373    use datafusion_physical_expr_common::physical_expr::is_volatile;
374    use std::any::Any;
375
376    /// Test helper to create a mock UDF with a specific volatility
377    #[derive(Debug, PartialEq, Eq, Hash)]
378    struct MockScalarUDF {
379        signature: Signature,
380    }
381
382    impl ScalarUDFImpl for MockScalarUDF {
383        fn as_any(&self) -> &dyn Any {
384            self
385        }
386
387        fn name(&self) -> &str {
388            "mock_function"
389        }
390
391        fn signature(&self) -> &Signature {
392            &self.signature
393        }
394
395        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
396            Ok(DataType::Int32)
397        }
398
399        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
400            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
401        }
402    }
403
404    #[test]
405    fn test_scalar_function_volatile_node() {
406        // Create a volatile UDF
407        let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
408            signature: Signature::uniform(
409                1,
410                vec![DataType::Float32],
411                Volatility::Volatile,
412            ),
413        }));
414
415        // Create a non-volatile UDF
416        let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
417            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
418        }));
419
420        let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
421        let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
422        let config_options = Arc::new(ConfigOptions::new());
423
424        // Test volatile function
425        let volatile_expr = ScalarFunctionExpr::try_new(
426            volatile_udf,
427            args.clone(),
428            &schema,
429            Arc::clone(&config_options),
430        )
431        .unwrap();
432
433        assert!(volatile_expr.is_volatile_node());
434        let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
435        assert!(is_volatile(&volatile_arc));
436
437        // Test non-volatile function
438        let stable_expr =
439            ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
440                .unwrap();
441
442        assert!(!stable_expr.is_volatile_node());
443        let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
444        assert!(!is_volatile(&stable_arc));
445    }
446}