Skip to main content

datafusion_expr/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::schema_name_from_exprs_comma_separated_without_space;
22use crate::preimage::PreimageResult;
23use crate::simplify::{ExprSimplifyResult, SimplifyContext};
24use crate::sort_properties::{ExprProperties, SortProperties};
25use crate::udf_eq::UdfEq;
26use crate::{ColumnarValue, Documentation, Expr, Signature};
27use arrow::datatypes::{DataType, Field, FieldRef};
28#[cfg(debug_assertions)]
29use datafusion_common::assert_or_internal_err;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err};
32use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
33use datafusion_expr_common::interval_arithmetic::Interval;
34use datafusion_expr_common::placement::ExpressionPlacement;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::hash::{Hash, Hasher};
39use std::sync::Arc;
40
41/// Logical representation of a Scalar User Defined Function.
42///
43/// A scalar function produces a single row output for each row of input. This
44/// struct contains the information DataFusion needs to plan and invoke
45/// functions you supply such as name, type signature, return type, and actual
46/// implementation.
47///
48/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
49///
50/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
51///    access (examples in  [`advanced_udf.rs`]).
52///
53/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
54///
55/// # API Note
56///
57/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
58/// compatibility with the older API.
59///
60/// [`create_udf`]: crate::expr_fn::create_udf
61/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udf.rs
62/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs
63#[derive(Debug, Clone)]
64pub struct ScalarUDF {
65    inner: Arc<dyn ScalarUDFImpl>,
66}
67
68impl PartialEq for ScalarUDF {
69    fn eq(&self, other: &Self) -> bool {
70        self.inner.dyn_eq(other.inner.as_any())
71    }
72}
73
74impl PartialOrd for ScalarUDF {
75    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76        let mut cmp = self.name().cmp(other.name());
77        if cmp == Ordering::Equal {
78            cmp = self.signature().partial_cmp(other.signature())?;
79        }
80        if cmp == Ordering::Equal {
81            cmp = self.aliases().partial_cmp(other.aliases())?;
82        }
83        // Contract for PartialOrd and PartialEq consistency requires that
84        // a == b if and only if partial_cmp(a, b) == Some(Equal).
85        if cmp == Ordering::Equal && self != other {
86            // Functions may have other properties besides name and signature
87            // that differentiate two instances (e.g. type, or arbitrary parameters).
88            // We cannot return Some(Equal) in such case.
89            return None;
90        }
91        debug_assert!(
92            cmp == Ordering::Equal || self != other,
93            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
94            The functions compare as equal, but they are not equal based on general properties that \
95            the PartialOrd implementation observes,",
96            self.name(),
97            other.name()
98        );
99        Some(cmp)
100    }
101}
102
103impl Eq for ScalarUDF {}
104
105impl Hash for ScalarUDF {
106    fn hash<H: Hasher>(&self, state: &mut H) {
107        self.inner.dyn_hash(state)
108    }
109}
110
111impl ScalarUDF {
112    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
113    ///
114    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
115    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
116    where
117        F: ScalarUDFImpl + 'static,
118    {
119        Self::new_from_shared_impl(Arc::new(fun))
120    }
121
122    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
123    pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
124        Self { inner: fun }
125    }
126
127    /// Return the underlying [`ScalarUDFImpl`] trait object for this function
128    pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
129        &self.inner
130    }
131
132    /// Adds additional names that can be used to invoke this function, in
133    /// addition to `name`
134    ///
135    /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
136    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
137        Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
138    }
139
140    /// Returns a [`Expr`] logical expression to call this UDF with specified
141    /// arguments.
142    ///
143    /// This utility allows easily calling UDFs
144    ///
145    /// # Example
146    /// ```no_run
147    /// use datafusion_expr::{col, lit, ScalarUDF};
148    /// # fn my_udf() -> ScalarUDF { unimplemented!() }
149    /// let my_func: ScalarUDF = my_udf();
150    /// // Create an expr for `my_func(a, 12.3)`
151    /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
152    /// ```
153    pub fn call(&self, args: Vec<Expr>) -> Expr {
154        Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
155            Arc::new(self.clone()),
156            args,
157        ))
158    }
159
160    /// Returns this function's name.
161    ///
162    /// See [`ScalarUDFImpl::name`] for more details.
163    pub fn name(&self) -> &str {
164        self.inner.name()
165    }
166
167    /// Returns this function's display_name.
168    ///
169    /// See [`ScalarUDFImpl::display_name`] for more details
170    #[deprecated(
171        since = "50.0.0",
172        note = "This method is unused and will be removed in a future release"
173    )]
174    pub fn display_name(&self, args: &[Expr]) -> Result<String> {
175        #[expect(deprecated)]
176        self.inner.display_name(args)
177    }
178
179    /// Returns this function's schema_name.
180    ///
181    /// See [`ScalarUDFImpl::schema_name`] for more details
182    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
183        self.inner.schema_name(args)
184    }
185
186    /// Returns the aliases for this function.
187    ///
188    /// See [`ScalarUDF::with_aliases`] for more details
189    pub fn aliases(&self) -> &[String] {
190        self.inner.aliases()
191    }
192
193    /// Returns this function's [`Signature`] (what input types are accepted).
194    ///
195    /// See [`ScalarUDFImpl::signature`] for more details.
196    pub fn signature(&self) -> &Signature {
197        self.inner.signature()
198    }
199
200    /// The datatype this function returns given the input argument types.
201    /// This function is used when the input arguments are [`DataType`]s.
202    ///
203    ///  # Notes
204    ///
205    /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
206    /// its [`ScalarUDFImpl::return_type`] should raise an error.
207    ///
208    /// See [`ScalarUDFImpl::return_type`] for more details.
209    pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
210        self.inner.return_type(arg_types)
211    }
212
213    /// Return the datatype this function returns given the input argument types.
214    ///
215    /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
216    pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
217        self.inner.return_field_from_args(args)
218    }
219
220    /// Do the function rewrite
221    ///
222    /// See [`ScalarUDFImpl::simplify`] for more details.
223    pub fn simplify(
224        &self,
225        args: Vec<Expr>,
226        info: &SimplifyContext,
227    ) -> Result<ExprSimplifyResult> {
228        self.inner.simplify(args, info)
229    }
230
231    #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
232    pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
233        #[expect(deprecated)]
234        self.inner.is_nullable(args, schema)
235    }
236
237    /// Return a preimage
238    ///
239    /// See [`ScalarUDFImpl::preimage`] for more details.
240    pub fn preimage(
241        &self,
242        args: &[Expr],
243        lit_expr: &Expr,
244        info: &SimplifyContext,
245    ) -> Result<PreimageResult> {
246        self.inner.preimage(args, lit_expr, info)
247    }
248
249    /// Invoke the function on `args`, returning the appropriate result.
250    ///
251    /// See [`ScalarUDFImpl::invoke_with_args`] for details.
252    pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
253        #[cfg(debug_assertions)]
254        let return_field = Arc::clone(&args.return_field);
255        let result = self.inner.invoke_with_args(args)?;
256        // Maybe this could be enabled always?
257        // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions.
258        #[cfg(debug_assertions)]
259        {
260            let result_data_type = result.data_type();
261            let expected_type = return_field.data_type();
262            assert_or_internal_err!(
263                result_data_type == *expected_type,
264                "Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
265                self.name(),
266                result_data_type,
267                expected_type
268            );
269            // TODO verify return data is non-null when it was promised to be?
270        }
271        Ok(result)
272    }
273
274    /// Determines which of the arguments passed to this function are evaluated eagerly
275    /// and which may be evaluated lazily.
276    ///
277    /// See [ScalarUDFImpl::conditional_arguments] for more information.
278    pub fn conditional_arguments<'a>(
279        &self,
280        args: &'a [Expr],
281    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
282        self.inner.conditional_arguments(args)
283    }
284
285    /// Returns true if some of this `exprs` subexpressions may not be evaluated
286    /// and thus any side effects (like divide by zero) may not be encountered.
287    ///
288    /// See [ScalarUDFImpl::short_circuits] for more information.
289    pub fn short_circuits(&self) -> bool {
290        self.inner.short_circuits()
291    }
292
293    /// Computes the output interval for a [`ScalarUDF`], given the input
294    /// intervals.
295    ///
296    /// # Parameters
297    ///
298    /// * `inputs` are the intervals for the inputs (children) of this function.
299    ///
300    /// # Example
301    ///
302    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
303    /// then the output interval would be `[0, 3]`.
304    pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
305        self.inner.evaluate_bounds(inputs)
306    }
307
308    /// Updates bounds for child expressions, given a known interval for this
309    /// function. This is used to propagate constraints down through an expression
310    /// tree.
311    ///
312    /// # Parameters
313    ///
314    /// * `interval` is the currently known interval for this function.
315    /// * `inputs` are the current intervals for the inputs (children) of this function.
316    ///
317    /// # Returns
318    ///
319    /// A `Vec` of new intervals for the children, in order.
320    ///
321    /// If constraint propagation reveals an infeasibility for any child, returns
322    /// [`None`]. If none of the children intervals change as a result of
323    /// propagation, may return an empty vector instead of cloning `children`.
324    /// This is the default (and conservative) return value.
325    ///
326    /// # Example
327    ///
328    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
329    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
330    pub fn propagate_constraints(
331        &self,
332        interval: &Interval,
333        inputs: &[&Interval],
334    ) -> Result<Option<Vec<Interval>>> {
335        self.inner.propagate_constraints(interval, inputs)
336    }
337
338    /// Calculates the [`SortProperties`] of this function based on its
339    /// children's properties.
340    pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
341        self.inner.output_ordering(inputs)
342    }
343
344    pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
345        self.inner.preserves_lex_ordering(inputs)
346    }
347
348    /// See [`ScalarUDFImpl::coerce_types`] for more details.
349    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
350        self.inner.coerce_types(arg_types)
351    }
352
353    /// Returns the documentation for this Scalar UDF.
354    ///
355    /// Documentation can be accessed programmatically as well as
356    /// generating publicly facing documentation.
357    pub fn documentation(&self) -> Option<&Documentation> {
358        self.inner.documentation()
359    }
360
361    /// Return true if this function is an async function
362    pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
363        self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
364    }
365
366    /// Returns placement information for this function.
367    ///
368    /// See [`ScalarUDFImpl::placement`] for more details.
369    pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
370        self.inner.placement(args)
371    }
372}
373
374impl<F> From<F> for ScalarUDF
375where
376    F: ScalarUDFImpl + 'static,
377{
378    fn from(fun: F) -> Self {
379        Self::new_from_impl(fun)
380    }
381}
382
383/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
384/// scalar function.
385#[derive(Debug, Clone)]
386pub struct ScalarFunctionArgs {
387    /// The evaluated arguments to the function
388    pub args: Vec<ColumnarValue>,
389    /// Field associated with each arg, if it exists
390    pub arg_fields: Vec<FieldRef>,
391    /// The number of rows in record batch being evaluated
392    pub number_rows: usize,
393    /// The return field of the scalar function returned (from `return_type`
394    /// or `return_field_from_args`) when creating the physical expression
395    /// from the logical expression
396    pub return_field: FieldRef,
397    /// The config options at execution time
398    pub config_options: Arc<ConfigOptions>,
399}
400
401impl ScalarFunctionArgs {
402    /// The return type of the function. See [`Self::return_field`] for more
403    /// details.
404    pub fn return_type(&self) -> &DataType {
405        self.return_field.data_type()
406    }
407}
408
409/// Information about arguments passed to the function
410///
411/// This structure contains metadata about how the function was called
412/// such as the type of the arguments, any scalar arguments and if the
413/// arguments can (ever) be null
414///
415/// See [`ScalarUDFImpl::return_field_from_args`] for more information
416#[derive(Debug)]
417pub struct ReturnFieldArgs<'a> {
418    /// The data types of the arguments to the function
419    pub arg_fields: &'a [FieldRef],
420    /// Is argument `i` to the function a scalar (constant)?
421    ///
422    /// If the argument `i` is not a scalar, it will be None
423    ///
424    /// For example, if a function is called like `my_function(column_a, 5)`
425    /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
426    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
427}
428
429/// Trait for implementing user defined scalar functions.
430///
431/// This trait exposes the full API for implementing user defined functions and
432/// can be used to implement any function.
433///
434/// See [`advanced_udf.rs`] for a full example with complete implementation and
435/// [`ScalarUDF`] for other available options.
436///
437/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs
438///
439/// # Basic Example
440/// ```
441/// # use std::any::Any;
442/// # use std::sync::LazyLock;
443/// # use arrow::datatypes::DataType;
444/// # use datafusion_common::{DataFusionError, plan_err, Result};
445/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
446/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
447/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
448/// /// This struct for a simple UDF that adds one to an int32
449/// #[derive(Debug, PartialEq, Eq, Hash)]
450/// struct AddOne {
451///   signature: Signature,
452/// }
453///
454/// impl AddOne {
455///   fn new() -> Self {
456///     Self {
457///       signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
458///      }
459///   }
460/// }
461///
462/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
463///         Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
464///             .with_argument("arg1", "The int32 number to add one to")
465///             .build()
466///     });
467///
468/// fn get_doc() -> &'static Documentation {
469///     &DOCUMENTATION
470/// }
471///
472/// /// Implement the ScalarUDFImpl trait for AddOne
473/// impl ScalarUDFImpl for AddOne {
474///    fn as_any(&self) -> &dyn Any { self }
475///    fn name(&self) -> &str { "add_one" }
476///    fn signature(&self) -> &Signature { &self.signature }
477///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
478///      if !matches!(args.get(0), Some(&DataType::Int32)) {
479///        return plan_err!("add_one only accepts Int32 arguments");
480///      }
481///      Ok(DataType::Int32)
482///    }
483///    // The actual implementation would add one to the argument
484///    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
485///         unimplemented!()
486///    }
487///    fn documentation(&self) -> Option<&Documentation> {
488///         Some(get_doc())
489///     }
490/// }
491///
492/// // Create a new ScalarUDF from the implementation
493/// let add_one = ScalarUDF::from(AddOne::new());
494///
495/// // Call the function `add_one(col)`
496/// let expr = add_one.call(vec![col("a")]);
497/// ```
498pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
499    /// Returns this object as an [`Any`] trait object
500    fn as_any(&self) -> &dyn Any;
501
502    /// Returns this function's name
503    fn name(&self) -> &str;
504
505    /// Returns any aliases (alternate names) for this function.
506    ///
507    /// Aliases can be used to invoke the same function using different names.
508    /// For example in some databases `now()` and `current_timestamp()` are
509    /// aliases for the same function. This behavior can be obtained by
510    /// returning `current_timestamp` as an alias for the `now` function.
511    ///
512    /// Note: `aliases` should only include names other than [`Self::name`].
513    /// Defaults to `[]` (no aliases)
514    fn aliases(&self) -> &[String] {
515        &[]
516    }
517
518    /// Returns the user-defined display name of function, given the arguments
519    ///
520    /// This can be used to customize the output column name generated by this
521    /// function.
522    ///
523    /// Defaults to `name(args[0], args[1], ...)`
524    #[deprecated(
525        since = "50.0.0",
526        note = "This method is unused and will be removed in a future release"
527    )]
528    fn display_name(&self, args: &[Expr]) -> Result<String> {
529        let names: Vec<String> = args.iter().map(ToString::to_string).collect();
530        // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
531        Ok(format!("{}({})", self.name(), names.join(",")))
532    }
533
534    /// Returns the name of the column this expression would create
535    ///
536    /// See [`Expr::schema_name`] for details
537    fn schema_name(&self, args: &[Expr]) -> Result<String> {
538        Ok(format!(
539            "{}({})",
540            self.name(),
541            schema_name_from_exprs_comma_separated_without_space(args)?
542        ))
543    }
544
545    /// Returns a [`Signature`] describing the argument types for which this
546    /// function has an implementation, and the function's [`Volatility`].
547    ///
548    /// See [`Signature`] for more details on argument type handling
549    /// and [`Self::return_type`] for computing the return type.
550    ///
551    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
552    fn signature(&self) -> &Signature;
553
554    /// [`DataType`] returned by this function, given the types of the
555    /// arguments.
556    ///
557    /// # Arguments
558    ///
559    /// `arg_types` Data types of the arguments. The implementation of
560    /// `return_type` can assume that some other part of the code has coerced
561    /// the actual argument types to match [`Self::signature`].
562    ///
563    /// # Notes
564    ///
565    /// If you provide an implementation for [`Self::return_field_from_args`],
566    /// DataFusion will not call `return_type` (this function). While it is
567    /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is
568    /// recommended to return [`DataFusionError::Internal`] instead, which
569    /// reduces the severity of symptoms if bugs occur (an error rather than a
570    /// panic).
571    ///
572    /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
573    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
574
575    /// Create a new instance of this function with updated configuration.
576    ///
577    /// This method is called when configuration options change at runtime
578    /// (e.g., via `SET` statements) to allow functions that depend on
579    /// configuration to update themselves accordingly.
580    ///
581    /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so
582    /// this API is not needed for functions where the values may
583    /// depend on the current options.
584    ///
585    /// This API is useful for functions where the return
586    /// **type** depends on the configuration options, such as the `now()` function
587    /// which depends on the current timezone.
588    ///
589    /// # Arguments
590    ///
591    /// * `config` - The updated configuration options
592    ///
593    /// # Returns
594    ///
595    /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings
596    /// * `None` - If this function does not change with new configuration settings (the default)
597    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
598        None
599    }
600
601    /// What type will be returned by this function, given the arguments?
602    ///
603    /// By default, this function calls [`Self::return_type`] with the
604    /// types of each argument.
605    ///
606    /// # Notes
607    ///
608    /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
609    /// as the result type is typically a deterministic function of the input types
610    /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
611    /// is generally unnecessary unless the return type depends on runtime values.
612    ///
613    /// This function can be used for more advanced cases such as:
614    ///
615    /// 1. specifying nullability
616    /// 2. return types based on the **values** of the arguments (rather than
617    ///    their **types**.
618    ///
619    /// # Example creating `Field`
620    ///
621    /// Note the name of the [`Field`] is ignored, except for structured types such as
622    /// `DataType::Struct`.
623    ///
624    /// ```rust
625    /// # use std::sync::Arc;
626    /// # use arrow::datatypes::{DataType, Field, FieldRef};
627    /// # use datafusion_common::Result;
628    /// # use datafusion_expr::ReturnFieldArgs;
629    /// # struct Example{}
630    /// # impl Example {
631    /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
632    ///     // report output is only nullable if any one of the arguments are nullable
633    ///     let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
634    ///     let field = Arc::new(Field::new("ignored_name", DataType::Int32, nullable));
635    ///     Ok(field)
636    /// }
637    /// # }
638    /// ```
639    ///
640    /// # Output Type based on Values
641    ///
642    /// For example, the following two function calls get the same argument
643    /// types (something and a `Utf8` string) but return different types based
644    /// on the value of the second argument:
645    ///
646    /// * `arrow_cast(x, 'Int16')` --> `Int16`
647    /// * `arrow_cast(x, 'Float32')` --> `Float32`
648    ///
649    /// # Requirements
650    ///
651    /// This function **must** consistently return the same type for the same
652    /// logical input even if the input is simplified (e.g. it must return the same
653    /// value for `('foo' | 'bar')` as it does for ('foobar').
654    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
655        let data_types = args
656            .arg_fields
657            .iter()
658            .map(|f| f.data_type())
659            .cloned()
660            .collect::<Vec<_>>();
661        let return_type = self.return_type(&data_types)?;
662        Ok(Arc::new(Field::new(self.name(), return_type, true)))
663    }
664
665    #[deprecated(
666        since = "45.0.0",
667        note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
668    )]
669    fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
670        true
671    }
672
673    /// Invoke the function returning the appropriate result.
674    ///
675    /// # Performance
676    ///
677    /// For the best performance, the implementations should handle the common case
678    /// when one or more of their arguments are constant values (aka
679    /// [`ColumnarValue::Scalar`]).
680    ///
681    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
682    /// to arrays, which will likely be simpler code, but be slower.
683    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
684
685    /// Optionally apply per-UDF simplification / rewrite rules.
686    ///
687    /// This can be used to apply function specific simplification rules during
688    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
689    /// implementation does nothing.
690    ///
691    /// Note that DataFusion handles simplifying arguments and  "constant
692    /// folding" (replacing a function call with constant arguments such as
693    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
694    /// optimizations manually for specific UDFs.
695    ///
696    /// # Arguments
697    /// * `args`: The arguments of the function
698    /// * `info`: The necessary information for simplification
699    ///
700    /// # Returns
701    /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
702    /// if the function cannot be simplified, the arguments *MUST* be returned
703    /// unmodified
704    ///
705    /// # Notes
706    ///
707    /// The returned expression must have the same schema as the original
708    /// expression, including both the data type and nullability. For example,
709    /// if the original expression is nullable, the returned expression must
710    /// also be nullable, otherwise it may lead to schema verification errors
711    /// later in query planning.
712    fn simplify(
713        &self,
714        args: Vec<Expr>,
715        _info: &SimplifyContext,
716    ) -> Result<ExprSimplifyResult> {
717        Ok(ExprSimplifyResult::Original(args))
718    }
719
720    /// Returns a single contiguous preimage for this function and the specified
721    /// scalar expression, if any.
722    ///
723    /// Currently only applies to `=, !=, >, >=, <, <=, is distinct from, is not distinct from` predicates
724    /// # Return Value
725    ///
726    /// Implementations should return a half-open interval: inclusive lower
727    /// bound and exclusive upper bound. This is slightly different from normal
728    /// [`Interval`] semantics where the upper bound is closed (inclusive).
729    /// Typically this means the upper endpoint must be adjusted to the next
730    /// value not included in the preimage. See the Half-Open Intervals section
731    /// below for more details.
732    ///
733    /// # Background
734    ///
735    /// Inspired by the [ClickHouse Paper], a "preimage rewrite" transforms a
736    /// predicate containing a function call into a predicate containing an
737    /// equivalent set of input literal (constant) values. The resulting
738    /// predicate can often be further optimized by other rewrites (see
739    /// Examples).
740    ///
741    /// From the paper:
742    ///
743    /// > some functions can compute the preimage of a given function result.
744    /// > This is used to replace comparisons of constants with function calls
745    /// > on the key columns by comparing the key column value with the preimage.
746    /// > For example, `toYear(k) = 2024` can be replaced by
747    /// > `k >= 2024-01-01 && k < 2025-01-01`
748    ///
749    /// For example, given an expression like
750    /// ```sql
751    /// date_part('YEAR', k) = 2024
752    /// ```
753    ///
754    /// The interval `[2024-01-01, 2025-12-31`]` contains all possible input
755    /// values (preimage values) for which the function `date_part(YEAR, k)`
756    /// produces the output value `2024` (image value). Returning the interval
757    /// (note upper bound adjusted up) `[2024-01-01, 2025-01-01]` the expression
758    /// can be rewritten to
759    ///
760    /// ```sql
761    /// k >= '2024-01-01' AND k < '2025-01-01'
762    /// ```
763    ///
764    /// which is a simpler and a more canonical form, making it easier for other
765    /// optimizer passes to recognize and apply further transformations.
766    ///
767    /// # Examples
768    ///
769    /// Case 1:
770    ///
771    /// Original:
772    /// ```sql
773    /// date_part('YEAR', k) = 2024 AND k >= '2024-06-01'
774    /// ```
775    ///
776    /// After preimage rewrite:
777    /// ```sql
778    /// k >= '2024-01-01' AND k < '2025-01-01' AND k >= '2024-06-01'
779    /// ```
780    ///
781    /// Since this form is much simpler, the optimizer can combine and simplify
782    /// sub-expressions further into:
783    /// ```sql
784    /// k >= '2024-06-01' AND k < '2025-01-01'
785    /// ```
786    ///
787    /// Case 2:
788    ///
789    /// For min/max pruning, simpler predicates such as:
790    /// ```sql
791    /// k >= '2024-01-01' AND k < '2025-01-01'
792    /// ```
793    /// are much easier for the pruner to reason about. See [PruningPredicate]
794    /// for the backgrounds of predicate pruning.
795    ///
796    /// The trade-off with the preimage rewrite is that evaluating the rewritten
797    /// form might be slightly more expensive than evaluating the original
798    /// expression. In practice, this cost is usually outweighed by the more
799    /// aggressive optimization opportunities it enables.
800    ///
801    /// # Half-Open Intervals
802    ///
803    /// The preimage API uses half-open intervals, which makes the rewrite
804    /// easier to implement by avoiding calculations to adjust the upper bound.
805    /// For example, if a function returns its input unchanged and the desired
806    /// output is the single value `5`, a closed interval could be represented
807    /// as `[5, 5]`, but then the rewrite would require adjusting the upper
808    /// bound to `6` to create a proper range predicate. With a half-open
809    /// interval, the same range is represented as `[5, 6)`, which already
810    /// forms a valid predicate.
811    ///
812    /// [PruningPredicate]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html
813    /// [ClickHouse Paper]:  https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
814    /// [image]: https://en.wikipedia.org/wiki/Image_(mathematics)#Image_of_an_element
815    /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
816    fn preimage(
817        &self,
818        _args: &[Expr],
819        _lit_expr: &Expr,
820        _info: &SimplifyContext,
821    ) -> Result<PreimageResult> {
822        Ok(PreimageResult::None)
823    }
824
825    /// Returns true if some of this `exprs` subexpressions may not be evaluated
826    /// and thus any side effects (like divide by zero) may not be encountered.
827    ///
828    /// Setting this to true prevents certain optimizations such as common
829    /// subexpression elimination
830    ///
831    /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be
832    /// overridden to report more accurately which arguments are eagerly evaluated and which ones
833    /// lazily.
834    fn short_circuits(&self) -> bool {
835        false
836    }
837
838    /// Determines which of the arguments passed to this function are evaluated eagerly
839    /// and which may be evaluated lazily.
840    ///
841    /// If this function returns `None`, all arguments are eagerly evaluated.
842    /// Returning `None` is a micro optimization that saves a needless `Vec`
843    /// allocation.
844    ///
845    /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
846    /// are the arguments that are always evaluated, and `lazy` are the
847    /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
848    /// in some cases).
849    ///
850    /// Implementations must ensure that the two returned `Vec`s are disjunct,
851    /// and that each argument from `args` is present in one the two `Vec`s.
852    ///
853    /// When overriding this function, [ScalarUDFImpl::short_circuits] must
854    /// be overridden to return `true`.
855    fn conditional_arguments<'a>(
856        &self,
857        args: &'a [Expr],
858    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
859        if self.short_circuits() {
860            Some((vec![], args.iter().collect()))
861        } else {
862            None
863        }
864    }
865
866    /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
867    /// intervals.
868    ///
869    /// # Parameters
870    ///
871    /// * `children` are the intervals for the children (inputs) of this function.
872    ///
873    /// # Example
874    ///
875    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
876    /// then the output interval would be `[0, 3]`.
877    fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
878        // We cannot assume the input datatype is the same of output type.
879        Interval::make_unbounded(&DataType::Null)
880    }
881
882    /// Updates bounds for child expressions, given a known [`Interval`]s for this
883    /// function.
884    ///
885    /// This function is used to propagate constraints down through an
886    /// expression tree.
887    ///
888    /// # Parameters
889    ///
890    /// * `interval` is the currently known interval for this function.
891    /// * `inputs` are the current intervals for the inputs (children) of this function.
892    ///
893    /// # Returns
894    ///
895    /// A `Vec` of new intervals for the children, in order.
896    ///
897    /// If constraint propagation reveals an infeasibility for any child, returns
898    /// [`None`]. If none of the children intervals change as a result of
899    /// propagation, may return an empty vector instead of cloning `children`.
900    /// This is the default (and conservative) return value.
901    ///
902    /// # Example
903    ///
904    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
905    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
906    fn propagate_constraints(
907        &self,
908        _interval: &Interval,
909        _inputs: &[&Interval],
910    ) -> Result<Option<Vec<Interval>>> {
911        Ok(Some(vec![]))
912    }
913
914    /// Calculates the [`SortProperties`] of this function based on its children's properties.
915    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
916        if !self.preserves_lex_ordering(inputs)? {
917            return Ok(SortProperties::Unordered);
918        }
919
920        let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
921            return Ok(SortProperties::Singleton);
922        };
923
924        if inputs
925            .iter()
926            .skip(1)
927            .all(|input| &input.sort_properties == first_order)
928        {
929            Ok(*first_order)
930        } else {
931            Ok(SortProperties::Unordered)
932        }
933    }
934
935    /// Returns true if the function preserves lexicographical ordering based on
936    /// the input ordering.
937    ///
938    /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
939    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
940        Ok(false)
941    }
942
943    /// Coerce arguments of a function call to types that the function can evaluate.
944    ///
945    /// This function is only called if [`ScalarUDFImpl::signature`] returns
946    /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
947    /// the other variants of [`TypeSignature`] which handle common cases.
948    ///
949    /// See the [type coercion module](crate::type_coercion)
950    /// documentation for more details on type coercion
951    ///
952    /// [`TypeSignature`]: crate::TypeSignature
953    ///
954    /// For example, if your function requires a floating point arguments, but the user calls
955    /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
956    /// to ensure the argument is converted to `1::double`
957    ///
958    /// # Parameters
959    /// * `arg_types`: The argument types of the arguments  this function with
960    ///
961    /// # Return value
962    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
963    /// arguments to these specific types.
964    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
965        not_impl_err!("Function {} does not implement coerce_types", self.name())
966    }
967
968    /// Returns the documentation for this Scalar UDF.
969    ///
970    /// Documentation can be accessed programmatically as well as generating
971    /// publicly facing documentation.
972    fn documentation(&self) -> Option<&Documentation> {
973        None
974    }
975
976    /// Returns placement information for this function.
977    ///
978    /// This is used by optimizers to make decisions about expression placement,
979    /// such as whether to push expressions down through projections.
980    ///
981    /// The default implementation returns [`ExpressionPlacement::KeepInPlace`],
982    /// meaning the expression should be kept where it is in the plan.
983    ///
984    /// Override this method to indicate that the function can be pushed down
985    /// closer to the data source.
986    fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement {
987        ExpressionPlacement::KeepInPlace
988    }
989}
990
991/// ScalarUDF that adds an alias to the underlying function. It is better to
992/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
993#[derive(Debug, PartialEq, Eq, Hash)]
994struct AliasedScalarUDFImpl {
995    inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
996    aliases: Vec<String>,
997}
998
999impl AliasedScalarUDFImpl {
1000    pub fn new(
1001        inner: Arc<dyn ScalarUDFImpl>,
1002        new_aliases: impl IntoIterator<Item = &'static str>,
1003    ) -> Self {
1004        let mut aliases = inner.aliases().to_vec();
1005        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1006        Self {
1007            inner: inner.into(),
1008            aliases,
1009        }
1010    }
1011}
1012
1013#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1014impl ScalarUDFImpl for AliasedScalarUDFImpl {
1015    fn as_any(&self) -> &dyn Any {
1016        self
1017    }
1018
1019    fn name(&self) -> &str {
1020        self.inner.name()
1021    }
1022
1023    fn display_name(&self, args: &[Expr]) -> Result<String> {
1024        #[expect(deprecated)]
1025        self.inner.display_name(args)
1026    }
1027
1028    fn schema_name(&self, args: &[Expr]) -> Result<String> {
1029        self.inner.schema_name(args)
1030    }
1031
1032    fn signature(&self) -> &Signature {
1033        self.inner.signature()
1034    }
1035
1036    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1037        self.inner.return_type(arg_types)
1038    }
1039
1040    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
1041        self.inner.return_field_from_args(args)
1042    }
1043
1044    fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
1045        #[expect(deprecated)]
1046        self.inner.is_nullable(args, schema)
1047    }
1048
1049    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1050        self.inner.invoke_with_args(args)
1051    }
1052
1053    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
1054        None
1055    }
1056
1057    fn aliases(&self) -> &[String] {
1058        &self.aliases
1059    }
1060
1061    fn simplify(
1062        &self,
1063        args: Vec<Expr>,
1064        info: &SimplifyContext,
1065    ) -> Result<ExprSimplifyResult> {
1066        self.inner.simplify(args, info)
1067    }
1068
1069    fn preimage(
1070        &self,
1071        args: &[Expr],
1072        lit_expr: &Expr,
1073        info: &SimplifyContext,
1074    ) -> Result<PreimageResult> {
1075        self.inner.preimage(args, lit_expr, info)
1076    }
1077
1078    fn conditional_arguments<'a>(
1079        &self,
1080        args: &'a [Expr],
1081    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1082        self.inner.conditional_arguments(args)
1083    }
1084
1085    fn short_circuits(&self) -> bool {
1086        self.inner.short_circuits()
1087    }
1088
1089    fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
1090        self.inner.evaluate_bounds(input)
1091    }
1092
1093    fn propagate_constraints(
1094        &self,
1095        interval: &Interval,
1096        inputs: &[&Interval],
1097    ) -> Result<Option<Vec<Interval>>> {
1098        self.inner.propagate_constraints(interval, inputs)
1099    }
1100
1101    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1102        self.inner.output_ordering(inputs)
1103    }
1104
1105    fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
1106        self.inner.preserves_lex_ordering(inputs)
1107    }
1108
1109    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1110        self.inner.coerce_types(arg_types)
1111    }
1112
1113    fn documentation(&self) -> Option<&Documentation> {
1114        self.inner.documentation()
1115    }
1116
1117    fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
1118        self.inner.placement(args)
1119    }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124    use super::*;
1125    use datafusion_expr_common::signature::Volatility;
1126    use std::hash::DefaultHasher;
1127
1128    #[derive(Debug, PartialEq, Eq, Hash)]
1129    struct TestScalarUDFImpl {
1130        name: &'static str,
1131        field: &'static str,
1132        signature: Signature,
1133    }
1134    impl ScalarUDFImpl for TestScalarUDFImpl {
1135        fn as_any(&self) -> &dyn Any {
1136            self
1137        }
1138
1139        fn name(&self) -> &str {
1140            self.name
1141        }
1142
1143        fn signature(&self) -> &Signature {
1144            &self.signature
1145        }
1146
1147        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1148            unimplemented!()
1149        }
1150
1151        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1152            unimplemented!()
1153        }
1154    }
1155
1156    // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1157    // must be consistent, so they are tested together.
1158    #[test]
1159    fn test_partial_eq_hash_and_partial_ord() {
1160        // A parameterized function
1161        let f = test_func("foo", "a");
1162
1163        // Same like `f`, different instance
1164        let f2 = test_func("foo", "a");
1165        assert_eq!(f, f2);
1166        assert_eq!(hash(&f), hash(&f2));
1167        assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1168
1169        // Different parameter
1170        let b = test_func("foo", "b");
1171        assert_ne!(f, b);
1172        assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1173        assert_eq!(f.partial_cmp(&b), None);
1174
1175        // Different name
1176        let o = test_func("other", "a");
1177        assert_ne!(f, o);
1178        assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1179        assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1180
1181        // Different name and parameter
1182        assert_ne!(b, o);
1183        assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1184        assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1185    }
1186
1187    fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1188        ScalarUDF::from(TestScalarUDFImpl {
1189            name,
1190            field: parameter,
1191            signature: Signature::any(1, Volatility::Immutable),
1192        })
1193    }
1194
1195    fn hash<T: Hash>(value: &T) -> u64 {
1196        let hasher = &mut DefaultHasher::new();
1197        value.hash(hasher);
1198        hasher.finish()
1199    }
1200}