datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::Hash;
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::{internal_err, Result, ScalarValue};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
46use datafusion_expr::{
47    expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
48};
49
50/// Physical expression of a scalar function
51#[derive(Eq, PartialEq, Hash)]
52pub struct ScalarFunctionExpr {
53    fun: Arc<ScalarUDF>,
54    name: String,
55    args: Vec<Arc<dyn PhysicalExpr>>,
56    return_field: FieldRef,
57}
58
59impl Debug for ScalarFunctionExpr {
60    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
61        f.debug_struct("ScalarFunctionExpr")
62            .field("fun", &"<FUNC>")
63            .field("name", &self.name)
64            .field("args", &self.args)
65            .field("return_field", &self.return_field)
66            .finish()
67    }
68}
69
70impl ScalarFunctionExpr {
71    /// Create a new Scalar function
72    pub fn new(
73        name: &str,
74        fun: Arc<ScalarUDF>,
75        args: Vec<Arc<dyn PhysicalExpr>>,
76        return_field: FieldRef,
77    ) -> Self {
78        Self {
79            fun,
80            name: name.to_owned(),
81            args,
82            return_field,
83        }
84    }
85
86    /// Create a new Scalar function
87    pub fn try_new(
88        fun: Arc<ScalarUDF>,
89        args: Vec<Arc<dyn PhysicalExpr>>,
90        schema: &Schema,
91    ) -> Result<Self> {
92        let name = fun.name().to_string();
93        let arg_fields = args
94            .iter()
95            .map(|e| e.return_field(schema))
96            .collect::<Result<Vec<_>>>()?;
97
98        // verify that input data types is consistent with function's `TypeSignature`
99        let arg_types = arg_fields
100            .iter()
101            .map(|f| f.data_type().clone())
102            .collect::<Vec<_>>();
103        data_types_with_scalar_udf(&arg_types, &fun)?;
104
105        let arguments = args
106            .iter()
107            .map(|e| {
108                e.as_any()
109                    .downcast_ref::<Literal>()
110                    .map(|literal| literal.value())
111            })
112            .collect::<Vec<_>>();
113        let ret_args = ReturnFieldArgs {
114            arg_fields: &arg_fields,
115            scalar_arguments: &arguments,
116        };
117        let return_field = fun.return_field_from_args(ret_args)?;
118        Ok(Self {
119            fun,
120            name,
121            args,
122            return_field,
123        })
124    }
125
126    /// Get the scalar function implementation
127    pub fn fun(&self) -> &ScalarUDF {
128        &self.fun
129    }
130
131    /// The name for this expression
132    pub fn name(&self) -> &str {
133        &self.name
134    }
135
136    /// Input arguments
137    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
138        &self.args
139    }
140
141    /// Data type produced by this expression
142    pub fn return_type(&self) -> &DataType {
143        self.return_field.data_type()
144    }
145
146    pub fn with_nullable(mut self, nullable: bool) -> Self {
147        self.return_field = self
148            .return_field
149            .as_ref()
150            .clone()
151            .with_nullable(nullable)
152            .into();
153        self
154    }
155
156    pub fn nullable(&self) -> bool {
157        self.return_field.is_nullable()
158    }
159}
160
161impl fmt::Display for ScalarFunctionExpr {
162    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
163        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
164    }
165}
166
167impl PhysicalExpr for ScalarFunctionExpr {
168    /// Return a reference to Any that can be used for downcasting
169    fn as_any(&self) -> &dyn Any {
170        self
171    }
172
173    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
174        Ok(self.return_field.data_type().clone())
175    }
176
177    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
178        Ok(self.return_field.is_nullable())
179    }
180
181    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
182        let args = self
183            .args
184            .iter()
185            .map(|e| e.evaluate(batch))
186            .collect::<Result<Vec<_>>>()?;
187
188        let arg_fields = self
189            .args
190            .iter()
191            .map(|e| e.return_field(batch.schema_ref()))
192            .collect::<Result<Vec<_>>>()?;
193
194        let input_empty = args.is_empty();
195        let input_all_scalar = args
196            .iter()
197            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
198
199        // evaluate the function
200        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
201            args,
202            arg_fields,
203            number_rows: batch.num_rows(),
204            return_field: Arc::clone(&self.return_field),
205        })?;
206
207        if let ColumnarValue::Array(array) = &output {
208            if array.len() != batch.num_rows() {
209                // If the arguments are a non-empty slice of scalar values, we can assume that
210                // returning a one-element array is equivalent to returning a scalar.
211                let preserve_scalar =
212                    array.len() == 1 && !input_empty && input_all_scalar;
213                return if preserve_scalar {
214                    ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
215                } else {
216                    internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
217                            self.name, batch.num_rows(), array.len())
218                };
219            }
220        }
221        Ok(output)
222    }
223
224    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
225        Ok(Arc::clone(&self.return_field))
226    }
227
228    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
229        self.args.iter().collect()
230    }
231
232    fn with_new_children(
233        self: Arc<Self>,
234        children: Vec<Arc<dyn PhysicalExpr>>,
235    ) -> Result<Arc<dyn PhysicalExpr>> {
236        Ok(Arc::new(ScalarFunctionExpr::new(
237            &self.name,
238            Arc::clone(&self.fun),
239            children,
240            Arc::clone(&self.return_field),
241        )))
242    }
243
244    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
245        self.fun.evaluate_bounds(children)
246    }
247
248    fn propagate_constraints(
249        &self,
250        interval: &Interval,
251        children: &[&Interval],
252    ) -> Result<Option<Vec<Interval>>> {
253        self.fun.propagate_constraints(interval, children)
254    }
255
256    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
257        let sort_properties = self.fun.output_ordering(children)?;
258        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
259        let children_range = children
260            .iter()
261            .map(|props| &props.range)
262            .collect::<Vec<_>>();
263        let range = self.fun().evaluate_bounds(&children_range)?;
264
265        Ok(ExprProperties {
266            sort_properties,
267            range,
268            preserves_lex_ordering,
269        })
270    }
271
272    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
273        write!(f, "{}(", self.name)?;
274        for (i, expr) in self.args.iter().enumerate() {
275            if i > 0 {
276                write!(f, ", ")?;
277            }
278            expr.fmt_sql(f)?;
279        }
280        write!(f, ")")
281    }
282}