Skip to main content

datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::fmt::{self, Debug, Formatter};
33use std::hash::{Hash, Hasher};
34use std::sync::Arc;
35
36use crate::PhysicalExpr;
37use crate::expressions::Literal;
38
39use arrow::array::{Array, RecordBatch};
40use arrow::datatypes::{DataType, FieldRef, Schema};
41use datafusion_common::config::{ConfigEntry, ConfigOptions};
42use datafusion_common::{Result, ScalarValue, internal_err};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::fields_with_udf;
46use datafusion_expr::{
47    ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
48    ScalarUDFImpl, Volatility, expr_vec_fmt,
49};
50
51/// Physical expression of a scalar function
52pub struct ScalarFunctionExpr {
53    fun: Arc<ScalarUDF>,
54    name: String,
55    args: Vec<Arc<dyn PhysicalExpr>>,
56    return_field: FieldRef,
57    config_options: Arc<ConfigOptions>,
58}
59
60impl Debug for ScalarFunctionExpr {
61    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
62        f.debug_struct("ScalarFunctionExpr")
63            .field("fun", &"<FUNC>")
64            .field("name", &self.name)
65            .field("args", &self.args)
66            .field("return_field", &self.return_field)
67            .finish()
68    }
69}
70
71impl ScalarFunctionExpr {
72    /// Create a new Scalar function
73    pub fn new(
74        name: &str,
75        fun: Arc<ScalarUDF>,
76        args: Vec<Arc<dyn PhysicalExpr>>,
77        return_field: FieldRef,
78        config_options: Arc<ConfigOptions>,
79    ) -> Self {
80        Self {
81            fun,
82            name: name.to_owned(),
83            args,
84            return_field,
85            config_options,
86        }
87    }
88
89    /// Create a new Scalar function
90    pub fn try_new(
91        fun: Arc<ScalarUDF>,
92        args: Vec<Arc<dyn PhysicalExpr>>,
93        schema: &Schema,
94        config_options: Arc<ConfigOptions>,
95    ) -> Result<Self> {
96        let name = fun.name().to_string();
97        let arg_fields = args
98            .iter()
99            .map(|e| e.return_field(schema))
100            .collect::<Result<Vec<_>>>()?;
101
102        // verify that input data types is consistent with function's `TypeSignature`
103        fields_with_udf(&arg_fields, fun.as_ref())?;
104
105        let arguments = args
106            .iter()
107            .map(|e| e.downcast_ref::<Literal>().map(|literal| literal.value()))
108            .collect::<Vec<_>>();
109        let ret_args = ReturnFieldArgs {
110            arg_fields: &arg_fields,
111            scalar_arguments: &arguments,
112        };
113        let return_field = fun.return_field_from_args(ret_args)?;
114        Ok(Self {
115            fun,
116            name,
117            args,
118            return_field,
119            config_options,
120        })
121    }
122
123    /// Get the scalar function implementation
124    pub fn fun(&self) -> &ScalarUDF {
125        &self.fun
126    }
127
128    /// The name for this expression
129    pub fn name(&self) -> &str {
130        &self.name
131    }
132
133    /// Input arguments
134    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
135        &self.args
136    }
137
138    /// Data type produced by this expression
139    pub fn return_type(&self) -> &DataType {
140        self.return_field.data_type()
141    }
142
143    pub fn with_nullable(mut self, nullable: bool) -> Self {
144        self.return_field = self
145            .return_field
146            .as_ref()
147            .clone()
148            .with_nullable(nullable)
149            .into();
150        self
151    }
152
153    pub fn nullable(&self) -> bool {
154        self.return_field.is_nullable()
155    }
156
157    pub fn config_options(&self) -> &ConfigOptions {
158        &self.config_options
159    }
160
161    /// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr
162    /// and verify that its inner function is of type T.
163    /// If the downcast fails, or the function is not of type T, returns `None`.
164    /// Otherwise returns `Some(ScalarFunctionExpr)`.
165    pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
166    where
167        T: ScalarUDFImpl,
168    {
169        match expr.downcast_ref::<ScalarFunctionExpr>() {
170            Some(scalar_expr) if scalar_expr.fun().inner().is::<T>() => Some(scalar_expr),
171            _ => None,
172        }
173    }
174}
175
176impl fmt::Display for ScalarFunctionExpr {
177    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
178        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
179    }
180}
181
182impl PartialEq for ScalarFunctionExpr {
183    fn eq(&self, o: &Self) -> bool {
184        if std::ptr::eq(self, o) {
185            // The equality implementation is somewhat expensive, so let's short-circuit when possible.
186            return true;
187        }
188        let Self {
189            fun,
190            name,
191            args,
192            return_field,
193            config_options,
194        } = self;
195        fun.eq(&o.fun)
196            && name.eq(&o.name)
197            && args.eq(&o.args)
198            && return_field.eq(&o.return_field)
199            && (Arc::ptr_eq(config_options, &o.config_options)
200                || sorted_config_entries(config_options)
201                    == sorted_config_entries(&o.config_options))
202    }
203}
204impl Eq for ScalarFunctionExpr {}
205impl Hash for ScalarFunctionExpr {
206    fn hash<H: Hasher>(&self, state: &mut H) {
207        let Self {
208            fun,
209            name,
210            args,
211            return_field,
212            config_options: _, // expensive to hash, and often equal
213        } = self;
214        fun.hash(state);
215        name.hash(state);
216        args.hash(state);
217        return_field.hash(state);
218    }
219}
220
221fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
222    let mut entries = config_options.entries();
223    entries.sort_by(|l, r| l.key.cmp(&r.key));
224    entries
225}
226
227impl PhysicalExpr for ScalarFunctionExpr {
228    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
229        Ok(self.return_field.data_type().clone())
230    }
231
232    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
233        Ok(self.return_field.is_nullable())
234    }
235
236    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
237        let args = self
238            .args
239            .iter()
240            .map(|e| e.evaluate(batch))
241            .collect::<Result<Vec<_>>>()?;
242
243        let arg_fields = self
244            .args
245            .iter()
246            .map(|e| e.return_field(batch.schema_ref()))
247            .collect::<Result<Vec<_>>>()?;
248
249        let input_empty = args.is_empty();
250        let input_all_scalar = args
251            .iter()
252            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
253
254        // evaluate the function
255        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
256            args,
257            arg_fields,
258            number_rows: batch.num_rows(),
259            return_field: Arc::clone(&self.return_field),
260            config_options: Arc::clone(&self.config_options),
261        })?;
262
263        if let ColumnarValue::Array(array) = &output
264            && array.len() != batch.num_rows()
265        {
266            // If the arguments are a non-empty slice of scalar values, we can assume that
267            // returning a one-element array is equivalent to returning a scalar.
268            let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
269            return if preserve_scalar {
270                ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
271            } else {
272                internal_err!(
273                    "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
274                    self.name,
275                    batch.num_rows(),
276                    array.len()
277                )
278            };
279        }
280        Ok(output)
281    }
282
283    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
284        Ok(Arc::clone(&self.return_field))
285    }
286
287    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
288        self.args.iter().collect()
289    }
290
291    fn with_new_children(
292        self: Arc<Self>,
293        children: Vec<Arc<dyn PhysicalExpr>>,
294    ) -> Result<Arc<dyn PhysicalExpr>> {
295        Ok(Arc::new(ScalarFunctionExpr::new(
296            &self.name,
297            Arc::clone(&self.fun),
298            children,
299            Arc::clone(&self.return_field),
300            Arc::clone(&self.config_options),
301        )))
302    }
303
304    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
305        self.fun.evaluate_bounds(children)
306    }
307
308    fn propagate_constraints(
309        &self,
310        interval: &Interval,
311        children: &[&Interval],
312    ) -> Result<Option<Vec<Interval>>> {
313        self.fun.propagate_constraints(interval, children)
314    }
315
316    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
317        let sort_properties = self.fun.output_ordering(children)?;
318        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
319        let children_range = children
320            .iter()
321            .map(|props| &props.range)
322            .collect::<Vec<_>>();
323        let range = self.fun().evaluate_bounds(&children_range)?;
324
325        Ok(ExprProperties {
326            sort_properties,
327            range,
328            preserves_lex_ordering,
329        })
330    }
331
332    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
333        write!(f, "{}(", self.name)?;
334        for (i, expr) in self.args.iter().enumerate() {
335            if i > 0 {
336                write!(f, ", ")?;
337            }
338            expr.fmt_sql(f)?;
339        }
340        write!(f, ")")
341    }
342
343    fn is_volatile_node(&self) -> bool {
344        self.fun.signature().volatility == Volatility::Volatile
345    }
346
347    fn placement(&self) -> ExpressionPlacement {
348        let arg_placements: Vec<_> =
349            self.args.iter().map(|arg| arg.placement()).collect();
350        self.fun.placement(&arg_placements)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::expressions::Column;
358    use arrow::datatypes::Field;
359    use datafusion_expr::{ScalarUDFImpl, Signature};
360    use datafusion_physical_expr_common::physical_expr::is_volatile;
361
362    /// Test helper to create a mock UDF with a specific volatility
363    #[derive(Debug, PartialEq, Eq, Hash)]
364    struct MockScalarUDF {
365        signature: Signature,
366    }
367
368    impl ScalarUDFImpl for MockScalarUDF {
369        fn name(&self) -> &str {
370            "mock_function"
371        }
372
373        fn signature(&self) -> &Signature {
374            &self.signature
375        }
376
377        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
378            Ok(DataType::Int32)
379        }
380
381        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
382            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
383        }
384    }
385
386    #[test]
387    fn test_scalar_function_volatile_node() {
388        // Create a volatile UDF
389        let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
390            signature: Signature::uniform(
391                1,
392                vec![DataType::Float32],
393                Volatility::Volatile,
394            ),
395        }));
396
397        // Create a non-volatile UDF
398        let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
399            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
400        }));
401
402        let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
403        let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
404        let config_options = Arc::new(ConfigOptions::new());
405
406        // Test volatile function
407        let volatile_expr = ScalarFunctionExpr::try_new(
408            volatile_udf,
409            args.clone(),
410            &schema,
411            Arc::clone(&config_options),
412        )
413        .unwrap();
414
415        assert!(volatile_expr.is_volatile_node());
416        let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
417        assert!(is_volatile(&volatile_arc));
418
419        // Test non-volatile function
420        let stable_expr =
421            ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
422                .unwrap();
423
424        assert!(!stable_expr.is_volatile_node());
425        let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
426        assert!(!is_volatile(&stable_arc));
427    }
428}