datafusion_physical_expr/
scalar_function.rs1use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::Hash;
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, FieldRef, Schema};
42use datafusion_common::{internal_err, Result, ScalarValue};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
46use datafusion_expr::{
47 expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
48};
49
50#[derive(Eq, PartialEq, Hash)]
52pub struct ScalarFunctionExpr {
53 fun: Arc<ScalarUDF>,
54 name: String,
55 args: Vec<Arc<dyn PhysicalExpr>>,
56 return_field: FieldRef,
57}
58
59impl Debug for ScalarFunctionExpr {
60 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
61 f.debug_struct("ScalarFunctionExpr")
62 .field("fun", &"<FUNC>")
63 .field("name", &self.name)
64 .field("args", &self.args)
65 .field("return_field", &self.return_field)
66 .finish()
67 }
68}
69
70impl ScalarFunctionExpr {
71 pub fn new(
73 name: &str,
74 fun: Arc<ScalarUDF>,
75 args: Vec<Arc<dyn PhysicalExpr>>,
76 return_field: FieldRef,
77 ) -> Self {
78 Self {
79 fun,
80 name: name.to_owned(),
81 args,
82 return_field,
83 }
84 }
85
86 pub fn try_new(
88 fun: Arc<ScalarUDF>,
89 args: Vec<Arc<dyn PhysicalExpr>>,
90 schema: &Schema,
91 ) -> Result<Self> {
92 let name = fun.name().to_string();
93 let arg_fields = args
94 .iter()
95 .map(|e| e.return_field(schema))
96 .collect::<Result<Vec<_>>>()?;
97
98 let arg_types = arg_fields
100 .iter()
101 .map(|f| f.data_type().clone())
102 .collect::<Vec<_>>();
103 data_types_with_scalar_udf(&arg_types, &fun)?;
104
105 let arguments = args
106 .iter()
107 .map(|e| {
108 e.as_any()
109 .downcast_ref::<Literal>()
110 .map(|literal| literal.value())
111 })
112 .collect::<Vec<_>>();
113 let ret_args = ReturnFieldArgs {
114 arg_fields: &arg_fields,
115 scalar_arguments: &arguments,
116 };
117 let return_field = fun.return_field_from_args(ret_args)?;
118 Ok(Self {
119 fun,
120 name,
121 args,
122 return_field,
123 })
124 }
125
126 pub fn fun(&self) -> &ScalarUDF {
128 &self.fun
129 }
130
131 pub fn name(&self) -> &str {
133 &self.name
134 }
135
136 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
138 &self.args
139 }
140
141 pub fn return_type(&self) -> &DataType {
143 self.return_field.data_type()
144 }
145
146 pub fn with_nullable(mut self, nullable: bool) -> Self {
147 self.return_field = self
148 .return_field
149 .as_ref()
150 .clone()
151 .with_nullable(nullable)
152 .into();
153 self
154 }
155
156 pub fn nullable(&self) -> bool {
157 self.return_field.is_nullable()
158 }
159}
160
161impl fmt::Display for ScalarFunctionExpr {
162 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
163 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
164 }
165}
166
167impl PhysicalExpr for ScalarFunctionExpr {
168 fn as_any(&self) -> &dyn Any {
170 self
171 }
172
173 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
174 Ok(self.return_field.data_type().clone())
175 }
176
177 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
178 Ok(self.return_field.is_nullable())
179 }
180
181 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
182 let args = self
183 .args
184 .iter()
185 .map(|e| e.evaluate(batch))
186 .collect::<Result<Vec<_>>>()?;
187
188 let arg_fields = self
189 .args
190 .iter()
191 .map(|e| e.return_field(batch.schema_ref()))
192 .collect::<Result<Vec<_>>>()?;
193
194 let input_empty = args.is_empty();
195 let input_all_scalar = args
196 .iter()
197 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
198
199 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
201 args,
202 arg_fields,
203 number_rows: batch.num_rows(),
204 return_field: Arc::clone(&self.return_field),
205 })?;
206
207 if let ColumnarValue::Array(array) = &output {
208 if array.len() != batch.num_rows() {
209 let preserve_scalar =
212 array.len() == 1 && !input_empty && input_all_scalar;
213 return if preserve_scalar {
214 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
215 } else {
216 internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
217 self.name, batch.num_rows(), array.len())
218 };
219 }
220 }
221 Ok(output)
222 }
223
224 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
225 Ok(Arc::clone(&self.return_field))
226 }
227
228 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
229 self.args.iter().collect()
230 }
231
232 fn with_new_children(
233 self: Arc<Self>,
234 children: Vec<Arc<dyn PhysicalExpr>>,
235 ) -> Result<Arc<dyn PhysicalExpr>> {
236 Ok(Arc::new(ScalarFunctionExpr::new(
237 &self.name,
238 Arc::clone(&self.fun),
239 children,
240 Arc::clone(&self.return_field),
241 )))
242 }
243
244 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
245 self.fun.evaluate_bounds(children)
246 }
247
248 fn propagate_constraints(
249 &self,
250 interval: &Interval,
251 children: &[&Interval],
252 ) -> Result<Option<Vec<Interval>>> {
253 self.fun.propagate_constraints(interval, children)
254 }
255
256 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
257 let sort_properties = self.fun.output_ordering(children)?;
258 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
259 let children_range = children
260 .iter()
261 .map(|props| &props.range)
262 .collect::<Vec<_>>();
263 let range = self.fun().evaluate_bounds(&children_range)?;
264
265 Ok(ExprProperties {
266 sort_properties,
267 range,
268 preserves_lex_ordering,
269 })
270 }
271
272 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
273 write!(f, "{}(", self.name)?;
274 for (i, expr) in self.args.iter().enumerate() {
275 if i > 0 {
276 write!(f, ", ")?;
277 }
278 expr.fmt_sql(f)?;
279 }
280 write!(f, ")")
281 }
282}