datafusion_expr/udf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::schema_name_from_exprs_comma_separated_without_space;
22use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
23use crate::sort_properties::{ExprProperties, SortProperties};
24use crate::udf_eq::UdfEq;
25use crate::{ColumnarValue, Documentation, Expr, Signature};
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::config::ConfigOptions;
28use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
29use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
30use datafusion_expr_common::interval_arithmetic::Interval;
31use std::any::Any;
32use std::cmp::Ordering;
33use std::fmt::Debug;
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37/// Logical representation of a Scalar User Defined Function.
38///
39/// A scalar function produces a single row output for each row of input. This
40/// struct contains the information DataFusion needs to plan and invoke
41/// functions you supply such as name, type signature, return type, and actual
42/// implementation.
43///
44/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
45///
46/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
47/// access (examples in [`advanced_udf.rs`]).
48///
49/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
50///
51/// # API Note
52///
53/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
54/// compatibility with the older API.
55///
56/// [`create_udf`]: crate::expr_fn::create_udf
57/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
58/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
59#[derive(Debug, Clone)]
60pub struct ScalarUDF {
61 inner: Arc<dyn ScalarUDFImpl>,
62}
63
64impl PartialEq for ScalarUDF {
65 fn eq(&self, other: &Self) -> bool {
66 self.inner.dyn_eq(other.inner.as_any())
67 }
68}
69
70impl PartialOrd for ScalarUDF {
71 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
72 let mut cmp = self.name().cmp(other.name());
73 if cmp == Ordering::Equal {
74 cmp = self.signature().partial_cmp(other.signature())?;
75 }
76 if cmp == Ordering::Equal {
77 cmp = self.aliases().partial_cmp(other.aliases())?;
78 }
79 // Contract for PartialOrd and PartialEq consistency requires that
80 // a == b if and only if partial_cmp(a, b) == Some(Equal).
81 if cmp == Ordering::Equal && self != other {
82 // Functions may have other properties besides name and signature
83 // that differentiate two instances (e.g. type, or arbitrary parameters).
84 // We cannot return Some(Equal) in such case.
85 return None;
86 }
87 debug_assert!(
88 cmp == Ordering::Equal || self != other,
89 "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
90 The functions compare as equal, but they are not equal based on general properties that \
91 the PartialOrd implementation observes,",
92 self.name(), other.name()
93 );
94 Some(cmp)
95 }
96}
97
98impl Eq for ScalarUDF {}
99
100impl Hash for ScalarUDF {
101 fn hash<H: Hasher>(&self, state: &mut H) {
102 self.inner.dyn_hash(state)
103 }
104}
105
106impl ScalarUDF {
107 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
108 ///
109 /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
110 pub fn new_from_impl<F>(fun: F) -> ScalarUDF
111 where
112 F: ScalarUDFImpl + 'static,
113 {
114 Self::new_from_shared_impl(Arc::new(fun))
115 }
116
117 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
118 pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
119 Self { inner: fun }
120 }
121
122 /// Return the underlying [`ScalarUDFImpl`] trait object for this function
123 pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
124 &self.inner
125 }
126
127 /// Adds additional names that can be used to invoke this function, in
128 /// addition to `name`
129 ///
130 /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
131 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
132 Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
133 }
134
135 /// Returns a [`Expr`] logical expression to call this UDF with specified
136 /// arguments.
137 ///
138 /// This utility allows easily calling UDFs
139 ///
140 /// # Example
141 /// ```no_run
142 /// use datafusion_expr::{col, lit, ScalarUDF};
143 /// # fn my_udf() -> ScalarUDF { unimplemented!() }
144 /// let my_func: ScalarUDF = my_udf();
145 /// // Create an expr for `my_func(a, 12.3)`
146 /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
147 /// ```
148 pub fn call(&self, args: Vec<Expr>) -> Expr {
149 Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
150 Arc::new(self.clone()),
151 args,
152 ))
153 }
154
155 /// Returns this function's name.
156 ///
157 /// See [`ScalarUDFImpl::name`] for more details.
158 pub fn name(&self) -> &str {
159 self.inner.name()
160 }
161
162 /// Returns this function's display_name.
163 ///
164 /// See [`ScalarUDFImpl::display_name`] for more details
165 #[deprecated(
166 since = "50.0.0",
167 note = "This method is unused and will be removed in a future release"
168 )]
169 pub fn display_name(&self, args: &[Expr]) -> Result<String> {
170 #[expect(deprecated)]
171 self.inner.display_name(args)
172 }
173
174 /// Returns this function's schema_name.
175 ///
176 /// See [`ScalarUDFImpl::schema_name`] for more details
177 pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
178 self.inner.schema_name(args)
179 }
180
181 /// Returns the aliases for this function.
182 ///
183 /// See [`ScalarUDF::with_aliases`] for more details
184 pub fn aliases(&self) -> &[String] {
185 self.inner.aliases()
186 }
187
188 /// Returns this function's [`Signature`] (what input types are accepted).
189 ///
190 /// See [`ScalarUDFImpl::signature`] for more details.
191 pub fn signature(&self) -> &Signature {
192 self.inner.signature()
193 }
194
195 /// The datatype this function returns given the input argument types.
196 /// This function is used when the input arguments are [`DataType`]s.
197 ///
198 /// # Notes
199 ///
200 /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
201 /// its [`ScalarUDFImpl::return_type`] should raise an error.
202 ///
203 /// See [`ScalarUDFImpl::return_type`] for more details.
204 pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
205 self.inner.return_type(arg_types)
206 }
207
208 /// Return the datatype this function returns given the input argument types.
209 ///
210 /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
211 pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
212 self.inner.return_field_from_args(args)
213 }
214
215 /// Do the function rewrite
216 ///
217 /// See [`ScalarUDFImpl::simplify`] for more details.
218 pub fn simplify(
219 &self,
220 args: Vec<Expr>,
221 info: &dyn SimplifyInfo,
222 ) -> Result<ExprSimplifyResult> {
223 self.inner.simplify(args, info)
224 }
225
226 #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
227 pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
228 #[allow(deprecated)]
229 self.inner.is_nullable(args, schema)
230 }
231
232 /// Invoke the function on `args`, returning the appropriate result.
233 ///
234 /// See [`ScalarUDFImpl::invoke_with_args`] for details.
235 pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
236 #[cfg(debug_assertions)]
237 let return_field = Arc::clone(&args.return_field);
238 let result = self.inner.invoke_with_args(args)?;
239 // Maybe this could be enabled always?
240 // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions.
241 #[cfg(debug_assertions)]
242 {
243 if &result.data_type() != return_field.data_type() {
244 return datafusion_common::internal_err!("Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
245 self.name(),
246 result.data_type(),
247 return_field.data_type()
248 );
249 }
250 // TODO verify return data is non-null when it was promised to be?
251 }
252 Ok(result)
253 }
254
255 /// Get the circuits of inner implementation
256 pub fn short_circuits(&self) -> bool {
257 self.inner.short_circuits()
258 }
259
260 /// Computes the output interval for a [`ScalarUDF`], given the input
261 /// intervals.
262 ///
263 /// # Parameters
264 ///
265 /// * `inputs` are the intervals for the inputs (children) of this function.
266 ///
267 /// # Example
268 ///
269 /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
270 /// then the output interval would be `[0, 3]`.
271 pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
272 self.inner.evaluate_bounds(inputs)
273 }
274
275 /// Updates bounds for child expressions, given a known interval for this
276 /// function. This is used to propagate constraints down through an expression
277 /// tree.
278 ///
279 /// # Parameters
280 ///
281 /// * `interval` is the currently known interval for this function.
282 /// * `inputs` are the current intervals for the inputs (children) of this function.
283 ///
284 /// # Returns
285 ///
286 /// A `Vec` of new intervals for the children, in order.
287 ///
288 /// If constraint propagation reveals an infeasibility for any child, returns
289 /// [`None`]. If none of the children intervals change as a result of
290 /// propagation, may return an empty vector instead of cloning `children`.
291 /// This is the default (and conservative) return value.
292 ///
293 /// # Example
294 ///
295 /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
296 /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
297 pub fn propagate_constraints(
298 &self,
299 interval: &Interval,
300 inputs: &[&Interval],
301 ) -> Result<Option<Vec<Interval>>> {
302 self.inner.propagate_constraints(interval, inputs)
303 }
304
305 /// Calculates the [`SortProperties`] of this function based on its
306 /// children's properties.
307 pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
308 self.inner.output_ordering(inputs)
309 }
310
311 pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
312 self.inner.preserves_lex_ordering(inputs)
313 }
314
315 /// See [`ScalarUDFImpl::coerce_types`] for more details.
316 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
317 self.inner.coerce_types(arg_types)
318 }
319
320 /// Returns the documentation for this Scalar UDF.
321 ///
322 /// Documentation can be accessed programmatically as well as
323 /// generating publicly facing documentation.
324 pub fn documentation(&self) -> Option<&Documentation> {
325 self.inner.documentation()
326 }
327
328 /// Return true if this function is an async function
329 pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
330 self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
331 }
332}
333
334impl<F> From<F> for ScalarUDF
335where
336 F: ScalarUDFImpl + 'static,
337{
338 fn from(fun: F) -> Self {
339 Self::new_from_impl(fun)
340 }
341}
342
343/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
344/// scalar function.
345#[derive(Debug, Clone)]
346pub struct ScalarFunctionArgs {
347 /// The evaluated arguments to the function
348 pub args: Vec<ColumnarValue>,
349 /// Field associated with each arg, if it exists
350 pub arg_fields: Vec<FieldRef>,
351 /// The number of rows in record batch being evaluated
352 pub number_rows: usize,
353 /// The return field of the scalar function returned (from `return_type`
354 /// or `return_field_from_args`) when creating the physical expression
355 /// from the logical expression
356 pub return_field: FieldRef,
357 /// The config options at execution time
358 pub config_options: Arc<ConfigOptions>,
359}
360
361impl ScalarFunctionArgs {
362 /// The return type of the function. See [`Self::return_field`] for more
363 /// details.
364 pub fn return_type(&self) -> &DataType {
365 self.return_field.data_type()
366 }
367}
368
369/// Information about arguments passed to the function
370///
371/// This structure contains metadata about how the function was called
372/// such as the type of the arguments, any scalar arguments and if the
373/// arguments can (ever) be null
374///
375/// See [`ScalarUDFImpl::return_field_from_args`] for more information
376#[derive(Debug)]
377pub struct ReturnFieldArgs<'a> {
378 /// The data types of the arguments to the function
379 pub arg_fields: &'a [FieldRef],
380 /// Is argument `i` to the function a scalar (constant)?
381 ///
382 /// If the argument `i` is not a scalar, it will be None
383 ///
384 /// For example, if a function is called like `my_function(column_a, 5)`
385 /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
386 pub scalar_arguments: &'a [Option<&'a ScalarValue>],
387}
388
389/// Trait for implementing user defined scalar functions.
390///
391/// This trait exposes the full API for implementing user defined functions and
392/// can be used to implement any function.
393///
394/// See [`advanced_udf.rs`] for a full example with complete implementation and
395/// [`ScalarUDF`] for other available options.
396///
397/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
398///
399/// # Basic Example
400/// ```
401/// # use std::any::Any;
402/// # use std::sync::LazyLock;
403/// # use arrow::datatypes::DataType;
404/// # use datafusion_common::{DataFusionError, plan_err, Result};
405/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
406/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
407/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
408/// /// This struct for a simple UDF that adds one to an int32
409/// #[derive(Debug, PartialEq, Eq, Hash)]
410/// struct AddOne {
411/// signature: Signature,
412/// }
413///
414/// impl AddOne {
415/// fn new() -> Self {
416/// Self {
417/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
418/// }
419/// }
420/// }
421///
422/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
423/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
424/// .with_argument("arg1", "The int32 number to add one to")
425/// .build()
426/// });
427///
428/// fn get_doc() -> &'static Documentation {
429/// &DOCUMENTATION
430/// }
431///
432/// /// Implement the ScalarUDFImpl trait for AddOne
433/// impl ScalarUDFImpl for AddOne {
434/// fn as_any(&self) -> &dyn Any { self }
435/// fn name(&self) -> &str { "add_one" }
436/// fn signature(&self) -> &Signature { &self.signature }
437/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
438/// if !matches!(args.get(0), Some(&DataType::Int32)) {
439/// return plan_err!("add_one only accepts Int32 arguments");
440/// }
441/// Ok(DataType::Int32)
442/// }
443/// // The actual implementation would add one to the argument
444/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
445/// unimplemented!()
446/// }
447/// fn documentation(&self) -> Option<&Documentation> {
448/// Some(get_doc())
449/// }
450/// }
451///
452/// // Create a new ScalarUDF from the implementation
453/// let add_one = ScalarUDF::from(AddOne::new());
454///
455/// // Call the function `add_one(col)`
456/// let expr = add_one.call(vec![col("a")]);
457/// ```
458pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
459 /// Returns this object as an [`Any`] trait object
460 fn as_any(&self) -> &dyn Any;
461
462 /// Returns this function's name
463 fn name(&self) -> &str;
464
465 /// Returns any aliases (alternate names) for this function.
466 ///
467 /// Aliases can be used to invoke the same function using different names.
468 /// For example in some databases `now()` and `current_timestamp()` are
469 /// aliases for the same function. This behavior can be obtained by
470 /// returning `current_timestamp` as an alias for the `now` function.
471 ///
472 /// Note: `aliases` should only include names other than [`Self::name`].
473 /// Defaults to `[]` (no aliases)
474 fn aliases(&self) -> &[String] {
475 &[]
476 }
477
478 /// Returns the user-defined display name of function, given the arguments
479 ///
480 /// This can be used to customize the output column name generated by this
481 /// function.
482 ///
483 /// Defaults to `name(args[0], args[1], ...)`
484 #[deprecated(
485 since = "50.0.0",
486 note = "This method is unused and will be removed in a future release"
487 )]
488 fn display_name(&self, args: &[Expr]) -> Result<String> {
489 let names: Vec<String> = args.iter().map(ToString::to_string).collect();
490 // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
491 Ok(format!("{}({})", self.name(), names.join(",")))
492 }
493
494 /// Returns the name of the column this expression would create
495 ///
496 /// See [`Expr::schema_name`] for details
497 fn schema_name(&self, args: &[Expr]) -> Result<String> {
498 Ok(format!(
499 "{}({})",
500 self.name(),
501 schema_name_from_exprs_comma_separated_without_space(args)?
502 ))
503 }
504
505 /// Returns a [`Signature`] describing the argument types for which this
506 /// function has an implementation, and the function's [`Volatility`].
507 ///
508 /// See [`Signature`] for more details on argument type handling
509 /// and [`Self::return_type`] for computing the return type.
510 ///
511 /// [`Volatility`]: datafusion_expr_common::signature::Volatility
512 fn signature(&self) -> &Signature;
513
514 /// [`DataType`] returned by this function, given the types of the
515 /// arguments.
516 ///
517 /// # Arguments
518 ///
519 /// `arg_types` Data types of the arguments. The implementation of
520 /// `return_type` can assume that some other part of the code has coerced
521 /// the actual argument types to match [`Self::signature`].
522 ///
523 /// # Notes
524 ///
525 /// If you provide an implementation for [`Self::return_field_from_args`],
526 /// DataFusion will not call `return_type` (this function). While it is
527 /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is
528 /// recommended to return [`DataFusionError::Internal`] instead, which
529 /// reduces the severity of symptoms if bugs occur (an error rather than a
530 /// panic).
531 ///
532 /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
533 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
534
535 /// Create a new instance of this function with updated configuration.
536 ///
537 /// This method is called when configuration options change at runtime
538 /// (e.g., via `SET` statements) to allow functions that depend on
539 /// configuration to update themselves accordingly.
540 ///
541 /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so
542 /// this API is not needed for functions where the values may
543 /// depend on the current options.
544 ///
545 /// This API is useful for functions where the return
546 /// **type** depends on the configuration options, such as the `now()` function
547 /// which depends on the current timezone.
548 ///
549 /// # Arguments
550 ///
551 /// * `config` - The updated configuration options
552 ///
553 /// # Returns
554 ///
555 /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings
556 /// * `None` - If this function does not change with new configuration settings (the default)
557 fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
558 None
559 }
560
561 /// What type will be returned by this function, given the arguments?
562 ///
563 /// By default, this function calls [`Self::return_type`] with the
564 /// types of each argument.
565 ///
566 /// # Notes
567 ///
568 /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
569 /// as the result type is typically a deterministic function of the input types
570 /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
571 /// is generally unnecessary unless the return type depends on runtime values.
572 ///
573 /// This function can be used for more advanced cases such as:
574 ///
575 /// 1. specifying nullability
576 /// 2. return types based on the **values** of the arguments (rather than
577 /// their **types**.
578 ///
579 /// # Example creating `Field`
580 ///
581 /// Note the name of the [`Field`] is ignored, except for structured types such as
582 /// `DataType::Struct`.
583 ///
584 /// ```rust
585 /// # use std::sync::Arc;
586 /// # use arrow::datatypes::{DataType, Field, FieldRef};
587 /// # use datafusion_common::Result;
588 /// # use datafusion_expr::ReturnFieldArgs;
589 /// # struct Example{}
590 /// # impl Example {
591 /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
592 /// // report output is only nullable if any one of the arguments are nullable
593 /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
594 /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
595 /// Ok(field)
596 /// }
597 /// # }
598 /// ```
599 ///
600 /// # Output Type based on Values
601 ///
602 /// For example, the following two function calls get the same argument
603 /// types (something and a `Utf8` string) but return different types based
604 /// on the value of the second argument:
605 ///
606 /// * `arrow_cast(x, 'Int16')` --> `Int16`
607 /// * `arrow_cast(x, 'Float32')` --> `Float32`
608 ///
609 /// # Requirements
610 ///
611 /// This function **must** consistently return the same type for the same
612 /// logical input even if the input is simplified (e.g. it must return the same
613 /// value for `('foo' | 'bar')` as it does for ('foobar').
614 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
615 let data_types = args
616 .arg_fields
617 .iter()
618 .map(|f| f.data_type())
619 .cloned()
620 .collect::<Vec<_>>();
621 let return_type = self.return_type(&data_types)?;
622 Ok(Arc::new(Field::new(self.name(), return_type, true)))
623 }
624
625 #[deprecated(
626 since = "45.0.0",
627 note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
628 )]
629 fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
630 true
631 }
632
633 /// Invoke the function returning the appropriate result.
634 ///
635 /// # Performance
636 ///
637 /// For the best performance, the implementations should handle the common case
638 /// when one or more of their arguments are constant values (aka
639 /// [`ColumnarValue::Scalar`]).
640 ///
641 /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
642 /// to arrays, which will likely be simpler code, but be slower.
643 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
644
645 /// Optionally apply per-UDF simplification / rewrite rules.
646 ///
647 /// This can be used to apply function specific simplification rules during
648 /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
649 /// implementation does nothing.
650 ///
651 /// Note that DataFusion handles simplifying arguments and "constant
652 /// folding" (replacing a function call with constant arguments such as
653 /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
654 /// optimizations manually for specific UDFs.
655 ///
656 /// # Arguments
657 /// * `args`: The arguments of the function
658 /// * `info`: The necessary information for simplification
659 ///
660 /// # Returns
661 /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
662 /// if the function cannot be simplified, the arguments *MUST* be returned
663 /// unmodified
664 ///
665 /// # Notes
666 ///
667 /// The returned expression must have the same schema as the original
668 /// expression, including both the data type and nullability. For example,
669 /// if the original expression is nullable, the returned expression must
670 /// also be nullable, otherwise it may lead to schema verification errors
671 /// later in query planning.
672 fn simplify(
673 &self,
674 args: Vec<Expr>,
675 _info: &dyn SimplifyInfo,
676 ) -> Result<ExprSimplifyResult> {
677 Ok(ExprSimplifyResult::Original(args))
678 }
679
680 /// Returns true if some of this `exprs` subexpressions may not be evaluated
681 /// and thus any side effects (like divide by zero) may not be encountered.
682 ///
683 /// Setting this to true prevents certain optimizations such as common
684 /// subexpression elimination
685 fn short_circuits(&self) -> bool {
686 false
687 }
688
689 /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
690 /// intervals.
691 ///
692 /// # Parameters
693 ///
694 /// * `children` are the intervals for the children (inputs) of this function.
695 ///
696 /// # Example
697 ///
698 /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
699 /// then the output interval would be `[0, 3]`.
700 fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
701 // We cannot assume the input datatype is the same of output type.
702 Interval::make_unbounded(&DataType::Null)
703 }
704
705 /// Updates bounds for child expressions, given a known [`Interval`]s for this
706 /// function.
707 ///
708 /// This function is used to propagate constraints down through an
709 /// expression tree.
710 ///
711 /// # Parameters
712 ///
713 /// * `interval` is the currently known interval for this function.
714 /// * `inputs` are the current intervals for the inputs (children) of this function.
715 ///
716 /// # Returns
717 ///
718 /// A `Vec` of new intervals for the children, in order.
719 ///
720 /// If constraint propagation reveals an infeasibility for any child, returns
721 /// [`None`]. If none of the children intervals change as a result of
722 /// propagation, may return an empty vector instead of cloning `children`.
723 /// This is the default (and conservative) return value.
724 ///
725 /// # Example
726 ///
727 /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
728 /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
729 fn propagate_constraints(
730 &self,
731 _interval: &Interval,
732 _inputs: &[&Interval],
733 ) -> Result<Option<Vec<Interval>>> {
734 Ok(Some(vec![]))
735 }
736
737 /// Calculates the [`SortProperties`] of this function based on its children's properties.
738 fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
739 if !self.preserves_lex_ordering(inputs)? {
740 return Ok(SortProperties::Unordered);
741 }
742
743 let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
744 return Ok(SortProperties::Singleton);
745 };
746
747 if inputs
748 .iter()
749 .skip(1)
750 .all(|input| &input.sort_properties == first_order)
751 {
752 Ok(*first_order)
753 } else {
754 Ok(SortProperties::Unordered)
755 }
756 }
757
758 /// Returns true if the function preserves lexicographical ordering based on
759 /// the input ordering.
760 ///
761 /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
762 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
763 Ok(false)
764 }
765
766 /// Coerce arguments of a function call to types that the function can evaluate.
767 ///
768 /// This function is only called if [`ScalarUDFImpl::signature`] returns
769 /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
770 /// the other variants of [`TypeSignature`] which handle common cases.
771 ///
772 /// See the [type coercion module](crate::type_coercion)
773 /// documentation for more details on type coercion
774 ///
775 /// [`TypeSignature`]: crate::TypeSignature
776 ///
777 /// For example, if your function requires a floating point arguments, but the user calls
778 /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
779 /// to ensure the argument is converted to `1::double`
780 ///
781 /// # Parameters
782 /// * `arg_types`: The argument types of the arguments this function with
783 ///
784 /// # Return value
785 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
786 /// arguments to these specific types.
787 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
788 not_impl_err!("Function {} does not implement coerce_types", self.name())
789 }
790
791 /// Returns the documentation for this Scalar UDF.
792 ///
793 /// Documentation can be accessed programmatically as well as generating
794 /// publicly facing documentation.
795 fn documentation(&self) -> Option<&Documentation> {
796 None
797 }
798}
799
800/// ScalarUDF that adds an alias to the underlying function. It is better to
801/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
802#[derive(Debug, PartialEq, Eq, Hash)]
803struct AliasedScalarUDFImpl {
804 inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
805 aliases: Vec<String>,
806}
807
808impl AliasedScalarUDFImpl {
809 pub fn new(
810 inner: Arc<dyn ScalarUDFImpl>,
811 new_aliases: impl IntoIterator<Item = &'static str>,
812 ) -> Self {
813 let mut aliases = inner.aliases().to_vec();
814 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
815 Self {
816 inner: inner.into(),
817 aliases,
818 }
819 }
820}
821
822#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
823impl ScalarUDFImpl for AliasedScalarUDFImpl {
824 fn as_any(&self) -> &dyn Any {
825 self
826 }
827
828 fn name(&self) -> &str {
829 self.inner.name()
830 }
831
832 fn display_name(&self, args: &[Expr]) -> Result<String> {
833 #[expect(deprecated)]
834 self.inner.display_name(args)
835 }
836
837 fn schema_name(&self, args: &[Expr]) -> Result<String> {
838 self.inner.schema_name(args)
839 }
840
841 fn signature(&self) -> &Signature {
842 self.inner.signature()
843 }
844
845 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
846 self.inner.return_type(arg_types)
847 }
848
849 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
850 self.inner.return_field_from_args(args)
851 }
852
853 fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
854 #[allow(deprecated)]
855 self.inner.is_nullable(args, schema)
856 }
857
858 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
859 self.inner.invoke_with_args(args)
860 }
861
862 fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
863 None
864 }
865
866 fn aliases(&self) -> &[String] {
867 &self.aliases
868 }
869
870 fn simplify(
871 &self,
872 args: Vec<Expr>,
873 info: &dyn SimplifyInfo,
874 ) -> Result<ExprSimplifyResult> {
875 self.inner.simplify(args, info)
876 }
877
878 fn short_circuits(&self) -> bool {
879 self.inner.short_circuits()
880 }
881
882 fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
883 self.inner.evaluate_bounds(input)
884 }
885
886 fn propagate_constraints(
887 &self,
888 interval: &Interval,
889 inputs: &[&Interval],
890 ) -> Result<Option<Vec<Interval>>> {
891 self.inner.propagate_constraints(interval, inputs)
892 }
893
894 fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
895 self.inner.output_ordering(inputs)
896 }
897
898 fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
899 self.inner.preserves_lex_ordering(inputs)
900 }
901
902 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
903 self.inner.coerce_types(arg_types)
904 }
905
906 fn documentation(&self) -> Option<&Documentation> {
907 self.inner.documentation()
908 }
909}
910
911#[cfg(test)]
912mod tests {
913 use super::*;
914 use datafusion_expr_common::signature::Volatility;
915 use std::hash::DefaultHasher;
916
917 #[derive(Debug, PartialEq, Eq, Hash)]
918 struct TestScalarUDFImpl {
919 name: &'static str,
920 field: &'static str,
921 signature: Signature,
922 }
923 impl ScalarUDFImpl for TestScalarUDFImpl {
924 fn as_any(&self) -> &dyn Any {
925 self
926 }
927
928 fn name(&self) -> &str {
929 self.name
930 }
931
932 fn signature(&self) -> &Signature {
933 &self.signature
934 }
935
936 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
937 unimplemented!()
938 }
939
940 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
941 unimplemented!()
942 }
943 }
944
945 // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
946 // must be consistent, so they are tested together.
947 #[test]
948 fn test_partial_eq_hash_and_partial_ord() {
949 // A parameterized function
950 let f = test_func("foo", "a");
951
952 // Same like `f`, different instance
953 let f2 = test_func("foo", "a");
954 assert_eq!(f, f2);
955 assert_eq!(hash(&f), hash(&f2));
956 assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
957
958 // Different parameter
959 let b = test_func("foo", "b");
960 assert_ne!(f, b);
961 assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
962 assert_eq!(f.partial_cmp(&b), None);
963
964 // Different name
965 let o = test_func("other", "a");
966 assert_ne!(f, o);
967 assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
968 assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
969
970 // Different name and parameter
971 assert_ne!(b, o);
972 assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
973 assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
974 }
975
976 fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
977 ScalarUDF::from(TestScalarUDFImpl {
978 name,
979 field: parameter,
980 signature: Signature::any(1, Volatility::Immutable),
981 })
982 }
983
984 fn hash<T: Hash>(value: &T) -> u64 {
985 let hasher = &mut DefaultHasher::new();
986 value.hash(hasher);
987 hasher.finish()
988 }
989}