Skip to main content

datafusion_expr/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::schema_name_from_exprs_comma_separated_without_space;
22use crate::preimage::PreimageResult;
23use crate::simplify::{ExprSimplifyResult, SimplifyContext};
24use crate::sort_properties::{ExprProperties, SortProperties};
25use crate::udf_eq::UdfEq;
26use crate::{ColumnarValue, Documentation, Expr, Signature};
27use arrow::datatypes::{DataType, Field, FieldRef};
28#[cfg(debug_assertions)]
29use datafusion_common::assert_or_internal_err;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err};
32use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
33use datafusion_expr_common::interval_arithmetic::Interval;
34use datafusion_expr_common::placement::ExpressionPlacement;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::hash::{Hash, Hasher};
39use std::sync::Arc;
40
41/// Describes how a struct-producing UDF's output fields correspond to its
42/// input arguments. This enables the optimizer to propagate orderings
43/// through struct projections (e.g., so that sorting by a struct field
44/// can be recognized as equivalent to sorting by the source column).
45///
46/// See [`ScalarUDFImpl::struct_field_mapping`] for details.
47pub struct StructFieldMapping {
48    /// The UDF used to construct field access expressions on the output.
49    /// For example, the `get_field` UDF for accessing struct fields.
50    pub field_accessor: Arc<ScalarUDF>,
51    /// For each output field: the literal arguments to pass to the
52    /// `field_accessor` UDF (after the base expression), and the index
53    /// of the corresponding input argument that produces the field's value.
54    ///
55    /// For `named_struct('a', col1, 'b', col2)`, this would be:
56    /// `[(["a"], 1), (["b"], 3)]` — field `"a"` comes from arg index 1.
57    pub fields: Vec<(Vec<ScalarValue>, usize)>,
58}
59
60/// Logical representation of a Scalar User Defined Function.
61///
62/// A scalar function produces a single row output for each row of input. This
63/// struct contains the information DataFusion needs to plan and invoke
64/// functions you supply such as name, type signature, return type, and actual
65/// implementation.
66///
67/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
68///
69/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
70///    access (examples in  [`advanced_udf.rs`]).
71///
72/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
73///
74/// # API Note
75///
76/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
77/// compatibility with the older API.
78///
79/// [`create_udf`]: crate::expr_fn::create_udf
80/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udf.rs
81/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs
82#[derive(Debug, Clone)]
83pub struct ScalarUDF {
84    inner: Arc<dyn ScalarUDFImpl>,
85}
86
87impl PartialEq for ScalarUDF {
88    fn eq(&self, other: &Self) -> bool {
89        self.inner.as_ref().dyn_eq(other.inner.as_ref() as &dyn Any)
90    }
91}
92
93impl PartialOrd for ScalarUDF {
94    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
95        let mut cmp = self.name().cmp(other.name());
96        if cmp == Ordering::Equal {
97            cmp = self.signature().partial_cmp(other.signature())?;
98        }
99        if cmp == Ordering::Equal {
100            cmp = self.aliases().partial_cmp(other.aliases())?;
101        }
102        // Contract for PartialOrd and PartialEq consistency requires that
103        // a == b if and only if partial_cmp(a, b) == Some(Equal).
104        if cmp == Ordering::Equal && self != other {
105            // Functions may have other properties besides name and signature
106            // that differentiate two instances (e.g. type, or arbitrary parameters).
107            // We cannot return Some(Equal) in such case.
108            return None;
109        }
110        debug_assert!(
111            cmp == Ordering::Equal || self != other,
112            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
113            The functions compare as equal, but they are not equal based on general properties that \
114            the PartialOrd implementation observes,",
115            self.name(),
116            other.name()
117        );
118        Some(cmp)
119    }
120}
121
122impl Eq for ScalarUDF {}
123
124impl Hash for ScalarUDF {
125    fn hash<H: Hasher>(&self, state: &mut H) {
126        self.inner.dyn_hash(state)
127    }
128}
129
130impl ScalarUDF {
131    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
132    ///
133    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
134    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
135    where
136        F: ScalarUDFImpl + 'static,
137    {
138        Self::new_from_shared_impl(Arc::new(fun))
139    }
140
141    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
142    pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
143        Self { inner: fun }
144    }
145
146    /// Return the underlying [`ScalarUDFImpl`] trait object for this function
147    pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
148        &self.inner
149    }
150
151    /// Adds additional names that can be used to invoke this function, in
152    /// addition to `name`
153    ///
154    /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
155    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
156        Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
157    }
158
159    /// Returns a [`Expr`] logical expression to call this UDF with specified
160    /// arguments.
161    ///
162    /// This utility allows easily calling UDFs
163    ///
164    /// # Example
165    /// ```no_run
166    /// use datafusion_expr::{col, lit, ScalarUDF};
167    /// # fn my_udf() -> ScalarUDF { unimplemented!() }
168    /// let my_func: ScalarUDF = my_udf();
169    /// // Create an expr for `my_func(a, 12.3)`
170    /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
171    /// ```
172    pub fn call(&self, args: Vec<Expr>) -> Expr {
173        Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
174            Arc::new(self.clone()),
175            args,
176        ))
177    }
178
179    /// Returns this function's name.
180    ///
181    /// See [`ScalarUDFImpl::name`] for more details.
182    pub fn name(&self) -> &str {
183        self.inner.name()
184    }
185
186    /// Returns this function's display_name.
187    ///
188    /// See [`ScalarUDFImpl::display_name`] for more details
189    #[deprecated(
190        since = "50.0.0",
191        note = "This method is unused and will be removed in a future release"
192    )]
193    pub fn display_name(&self, args: &[Expr]) -> Result<String> {
194        #[expect(deprecated)]
195        self.inner.display_name(args)
196    }
197
198    /// Returns this function's schema_name.
199    ///
200    /// See [`ScalarUDFImpl::schema_name`] for more details
201    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
202        self.inner.schema_name(args)
203    }
204
205    /// Returns the aliases for this function.
206    ///
207    /// See [`ScalarUDF::with_aliases`] for more details
208    pub fn aliases(&self) -> &[String] {
209        self.inner.aliases()
210    }
211
212    /// Returns this function's [`Signature`] (what input types are accepted).
213    ///
214    /// See [`ScalarUDFImpl::signature`] for more details.
215    pub fn signature(&self) -> &Signature {
216        self.inner.signature()
217    }
218
219    /// The datatype this function returns given the input argument types.
220    /// This function is used when the input arguments are [`DataType`]s.
221    ///
222    ///  # Notes
223    ///
224    /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
225    /// its [`ScalarUDFImpl::return_type`] should raise an error.
226    ///
227    /// See [`ScalarUDFImpl::return_type`] for more details.
228    pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
229        self.inner.return_type(arg_types)
230    }
231
232    /// Return the datatype this function returns given the input argument types.
233    ///
234    /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
235    pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
236        self.inner.return_field_from_args(args)
237    }
238
239    /// Returns this scalar function's simplification result.
240    ///
241    /// See [`ScalarUDFImpl::simplify`] for more details.
242    pub fn simplify(
243        &self,
244        args: Vec<Expr>,
245        info: &SimplifyContext,
246    ) -> Result<ExprSimplifyResult> {
247        self.inner.simplify(args, info)
248    }
249
250    #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
251    pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
252        #[expect(deprecated)]
253        self.inner.is_nullable(args, schema)
254    }
255
256    /// Return a preimage
257    ///
258    /// See [`ScalarUDFImpl::preimage`] for more details.
259    pub fn preimage(
260        &self,
261        args: &[Expr],
262        lit_expr: &Expr,
263        info: &SimplifyContext,
264    ) -> Result<PreimageResult> {
265        self.inner.preimage(args, lit_expr, info)
266    }
267
268    /// Invoke the function on `args`, returning the appropriate result.
269    ///
270    /// See [`ScalarUDFImpl::invoke_with_args`] for details.
271    pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
272        #[cfg(debug_assertions)]
273        let return_field = Arc::clone(&args.return_field);
274        let result = self.inner.invoke_with_args(args)?;
275        // Maybe this could be enabled always?
276        // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions.
277        #[cfg(debug_assertions)]
278        {
279            let result_data_type = result.data_type();
280            let expected_type = return_field.data_type();
281            assert_or_internal_err!(
282                result_data_type == *expected_type,
283                "Function '{}' returned value of type '{}' while the following type was promised at planning time and expected: '{}'",
284                self.name(),
285                result_data_type,
286                expected_type
287            );
288            // TODO verify return data is non-null when it was promised to be?
289        }
290        Ok(result)
291    }
292
293    /// Determines which of the arguments passed to this function are evaluated eagerly
294    /// and which may be evaluated lazily.
295    ///
296    /// See [ScalarUDFImpl::conditional_arguments] for more information.
297    pub fn conditional_arguments<'a>(
298        &self,
299        args: &'a [Expr],
300    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
301        self.inner.conditional_arguments(args)
302    }
303
304    /// Returns true if some of this `exprs` subexpressions may not be evaluated
305    /// and thus any side effects (like divide by zero) may not be encountered.
306    ///
307    /// See [ScalarUDFImpl::short_circuits] for more information.
308    pub fn short_circuits(&self) -> bool {
309        self.inner.short_circuits()
310    }
311
312    /// Computes the output interval for a [`ScalarUDF`], given the input
313    /// intervals.
314    ///
315    /// # Parameters
316    ///
317    /// * `inputs` are the intervals for the inputs (children) of this function.
318    ///
319    /// # Example
320    ///
321    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
322    /// then the output interval would be `[0, 3]`.
323    pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
324        self.inner.evaluate_bounds(inputs)
325    }
326
327    /// See [`ScalarUDFImpl::struct_field_mapping`] for more details.
328    pub fn struct_field_mapping(
329        &self,
330        literal_args: &[Option<ScalarValue>],
331    ) -> Option<StructFieldMapping> {
332        self.inner.struct_field_mapping(literal_args)
333    }
334
335    /// Updates bounds for child expressions, given a known interval for this
336    /// function. This is used to propagate constraints down through an expression
337    /// tree.
338    ///
339    /// # Parameters
340    ///
341    /// * `interval` is the currently known interval for this function.
342    /// * `inputs` are the current intervals for the inputs (children) of this function.
343    ///
344    /// # Returns
345    ///
346    /// A `Vec` of new intervals for the children, in order.
347    ///
348    /// If constraint propagation reveals an infeasibility for any child, returns
349    /// [`None`]. If none of the children intervals change as a result of
350    /// propagation, may return an empty vector instead of cloning `children`.
351    /// This is the default (and conservative) return value.
352    ///
353    /// # Example
354    ///
355    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
356    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
357    pub fn propagate_constraints(
358        &self,
359        interval: &Interval,
360        inputs: &[&Interval],
361    ) -> Result<Option<Vec<Interval>>> {
362        self.inner.propagate_constraints(interval, inputs)
363    }
364
365    /// Calculates the [`SortProperties`] of this function based on its
366    /// children's properties.
367    pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
368        self.inner.output_ordering(inputs)
369    }
370
371    pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
372        self.inner.preserves_lex_ordering(inputs)
373    }
374
375    /// See [`ScalarUDFImpl::coerce_types`] for more details.
376    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
377        self.inner.coerce_types(arg_types)
378    }
379
380    /// Returns the documentation for this Scalar UDF.
381    ///
382    /// Documentation can be accessed programmatically as well as
383    /// generating publicly facing documentation.
384    pub fn documentation(&self) -> Option<&Documentation> {
385        self.inner.documentation()
386    }
387
388    /// Return true if this function is an async function
389    pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
390        self.inner().downcast_ref::<AsyncScalarUDF>()
391    }
392
393    /// Returns placement information for this function.
394    ///
395    /// See [`ScalarUDFImpl::placement`] for more details.
396    pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
397        self.inner.placement(args)
398    }
399}
400
401impl<F> From<F> for ScalarUDF
402where
403    F: ScalarUDFImpl + 'static,
404{
405    fn from(fun: F) -> Self {
406        Self::new_from_impl(fun)
407    }
408}
409
410/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
411/// scalar function.
412#[derive(Debug, Clone)]
413pub struct ScalarFunctionArgs {
414    /// The evaluated arguments to the function
415    pub args: Vec<ColumnarValue>,
416    /// Field associated with each arg, if it exists
417    pub arg_fields: Vec<FieldRef>,
418    /// The number of rows in record batch being evaluated
419    pub number_rows: usize,
420    /// The return field of the scalar function returned (from `return_type`
421    /// or `return_field_from_args`) when creating the physical expression
422    /// from the logical expression
423    pub return_field: FieldRef,
424    /// The config options at execution time
425    pub config_options: Arc<ConfigOptions>,
426}
427
428impl ScalarFunctionArgs {
429    /// The return type of the function. See [`Self::return_field`] for more
430    /// details.
431    pub fn return_type(&self) -> &DataType {
432        self.return_field.data_type()
433    }
434}
435
436/// Information about arguments passed to the function
437///
438/// This structure contains metadata about how the function was called
439/// such as the type of the arguments, any scalar arguments and if the
440/// arguments can (ever) be null
441///
442/// See [`ScalarUDFImpl::return_field_from_args`] for more information
443#[derive(Debug)]
444pub struct ReturnFieldArgs<'a> {
445    /// The data types of the arguments to the function
446    pub arg_fields: &'a [FieldRef],
447    /// Is argument `i` to the function a scalar (constant)?
448    ///
449    /// If the argument `i` is not a scalar, it will be None
450    ///
451    /// For example, if a function is called like `my_function(column_a, 5)`
452    /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
453    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
454}
455
456/// Trait for implementing user defined scalar functions.
457///
458/// This trait exposes the full API for implementing user defined functions and
459/// can be used to implement any function.
460///
461/// See [`advanced_udf.rs`] for a full example with complete implementation and
462/// [`ScalarUDF`] for other available options.
463///
464/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs
465///
466/// # Basic Example
467/// ```
468/// # use std::any::Any;
469/// # use std::sync::LazyLock;
470/// # use arrow::datatypes::DataType;
471/// # use datafusion_common::{DataFusionError, plan_err, Result};
472/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
473/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
474/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
475/// /// This struct for a simple UDF that adds one to an int32
476/// #[derive(Debug, PartialEq, Eq, Hash)]
477/// struct AddOne {
478///   signature: Signature,
479/// }
480///
481/// impl AddOne {
482///   fn new() -> Self {
483///     Self {
484///       signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
485///      }
486///   }
487/// }
488///
489/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
490///         Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
491///             .with_argument("arg1", "The int32 number to add one to")
492///             .build()
493///     });
494///
495/// fn get_doc() -> &'static Documentation {
496///     &DOCUMENTATION
497/// }
498///
499/// /// Implement the ScalarUDFImpl trait for AddOne
500/// impl ScalarUDFImpl for AddOne {
501///    fn name(&self) -> &str { "add_one" }
502///    fn signature(&self) -> &Signature { &self.signature }
503///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
504///      if !matches!(args.get(0), Some(&DataType::Int32)) {
505///        return plan_err!("add_one only accepts Int32 arguments");
506///      }
507///      Ok(DataType::Int32)
508///    }
509///    // The actual implementation would add one to the argument
510///    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
511///         unimplemented!()
512///    }
513///    fn documentation(&self) -> Option<&Documentation> {
514///         Some(get_doc())
515///     }
516/// }
517///
518/// // Create a new ScalarUDF from the implementation
519/// let add_one = ScalarUDF::from(AddOne::new());
520///
521/// // Call the function `add_one(col)`
522/// let expr = add_one.call(vec![col("a")]);
523/// ```
524pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
525    /// Returns this function's name
526    fn name(&self) -> &str;
527
528    /// Returns any aliases (alternate names) for this function.
529    ///
530    /// Aliases can be used to invoke the same function using different names.
531    /// For example in some databases `now()` and `current_timestamp()` are
532    /// aliases for the same function. This behavior can be obtained by
533    /// returning `current_timestamp` as an alias for the `now` function.
534    ///
535    /// Note: `aliases` should only include names other than [`Self::name`].
536    /// Defaults to `[]` (no aliases)
537    fn aliases(&self) -> &[String] {
538        &[]
539    }
540
541    /// Returns the user-defined display name of function, given the arguments
542    ///
543    /// This can be used to customize the output column name generated by this
544    /// function.
545    ///
546    /// Defaults to `name(args[0], args[1], ...)`
547    #[deprecated(
548        since = "50.0.0",
549        note = "This method is unused and will be removed in a future release"
550    )]
551    fn display_name(&self, args: &[Expr]) -> Result<String> {
552        let names: Vec<String> = args.iter().map(ToString::to_string).collect();
553        // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
554        Ok(format!("{}({})", self.name(), names.join(",")))
555    }
556
557    /// Returns the name of the column this expression would create
558    ///
559    /// See [`Expr::schema_name`] for details
560    fn schema_name(&self, args: &[Expr]) -> Result<String> {
561        Ok(format!(
562            "{}({})",
563            self.name(),
564            schema_name_from_exprs_comma_separated_without_space(args)?
565        ))
566    }
567
568    /// Returns a [`Signature`] describing the argument types for which this
569    /// function has an implementation, and the function's [`Volatility`].
570    ///
571    /// See [`Signature`] for more details on argument type handling
572    /// and [`Self::return_type`] for computing the return type.
573    ///
574    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
575    fn signature(&self) -> &Signature;
576
577    /// [`DataType`] returned by this function, given the types of the
578    /// arguments.
579    ///
580    /// # Arguments
581    ///
582    /// `arg_types` Data types of the arguments. The implementation of
583    /// `return_type` can assume that some other part of the code has coerced
584    /// the actual argument types to match [`Self::signature`].
585    ///
586    /// # Notes
587    ///
588    /// If you provide an implementation for [`Self::return_field_from_args`],
589    /// DataFusion will not call `return_type` (this function). While it is
590    /// valid to put [`unimplemented!()`] or [`unreachable!()`], it is
591    /// recommended to return [`DataFusionError::Internal`] instead, which
592    /// reduces the severity of symptoms if bugs occur (an error rather than a
593    /// panic).
594    ///
595    /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
596    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
597
598    /// Create a new instance of this function with updated configuration.
599    ///
600    /// This method is called when configuration options change at runtime
601    /// (e.g., via `SET` statements) to allow functions that depend on
602    /// configuration to update themselves accordingly.
603    ///
604    /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so
605    /// this API is not needed for functions where the values may
606    /// depend on the current options.
607    ///
608    /// This API is useful for functions where the return
609    /// **type** depends on the configuration options, such as the `now()` function
610    /// which depends on the current timezone.
611    ///
612    /// # Arguments
613    ///
614    /// * `config` - The updated configuration options
615    ///
616    /// # Returns
617    ///
618    /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings
619    /// * `None` - If this function does not change with new configuration settings (the default)
620    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
621        None
622    }
623
624    /// What type will be returned by this function, given the arguments?
625    ///
626    /// By default, this function calls [`Self::return_type`] with the
627    /// types of each argument.
628    ///
629    /// # Notes
630    ///
631    /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
632    /// as the result type is typically a deterministic function of the input types
633    /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
634    /// is generally unnecessary unless the return type depends on runtime values.
635    ///
636    /// This function can be used for more advanced cases such as:
637    ///
638    /// 1. specifying nullability
639    /// 2. return types based on the **values** of the arguments (rather than
640    ///    their **types**.
641    ///
642    /// # Example creating `Field`
643    ///
644    /// Note the name of the [`Field`] is ignored, except for structured types such as
645    /// `DataType::Struct`.
646    ///
647    /// ```rust
648    /// # use std::sync::Arc;
649    /// # use arrow::datatypes::{DataType, Field, FieldRef};
650    /// # use datafusion_common::Result;
651    /// # use datafusion_expr::ReturnFieldArgs;
652    /// # struct Example{}
653    /// # impl Example {
654    /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
655    ///     // report output is only nullable if any one of the arguments are nullable
656    ///     let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
657    ///     let field = Arc::new(Field::new("ignored_name", DataType::Int32, nullable));
658    ///     Ok(field)
659    /// }
660    /// # }
661    /// ```
662    ///
663    /// # Output Type based on Values
664    ///
665    /// For example, the following two function calls get the same argument
666    /// types (something and a `Utf8` string) but return different types based
667    /// on the value of the second argument:
668    ///
669    /// * `arrow_cast(x, 'Int16')` --> `Int16`
670    /// * `arrow_cast(x, 'Float32')` --> `Float32`
671    ///
672    /// # Requirements
673    ///
674    /// This function **must** consistently return the same type for the same
675    /// logical input even if the input is simplified (e.g. it must return the same
676    /// value for `('foo' | 'bar')` as it does for ('foobar').
677    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
678        let data_types = args
679            .arg_fields
680            .iter()
681            .map(|f| f.data_type())
682            .cloned()
683            .collect::<Vec<_>>();
684        let return_type = self.return_type(&data_types)?;
685        Ok(Arc::new(Field::new(self.name(), return_type, true)))
686    }
687
688    #[deprecated(
689        since = "45.0.0",
690        note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
691    )]
692    fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
693        true
694    }
695
696    /// Invoke the function returning the appropriate result.
697    ///
698    /// # Performance
699    ///
700    /// For the best performance, the implementations should handle the common case
701    /// when one or more of their arguments are constant values (aka
702    /// [`ColumnarValue::Scalar`]).
703    ///
704    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
705    /// to arrays, which will likely be simpler code, but be slower.
706    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
707
708    /// Optionally apply per-UDF simplification / rewrite rules.
709    ///
710    /// This can be used to apply function specific simplification rules during
711    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
712    /// implementation does nothing.
713    ///
714    /// Note that DataFusion handles simplifying arguments and  "constant
715    /// folding" (replacing a function call with constant arguments such as
716    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
717    /// optimizations manually for specific UDFs.
718    ///
719    /// # Arguments
720    /// * `args`: The arguments of the function
721    /// * `info`: The necessary information for simplification
722    ///
723    /// # Returns
724    /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
725    /// if the function cannot be simplified, the arguments *MUST* be returned
726    /// unmodified
727    ///
728    /// # Notes
729    ///
730    /// The returned expression must have the same schema as the original
731    /// expression, including both the data type and nullability. For example,
732    /// if the original expression is nullable, the returned expression must
733    /// also be nullable, otherwise it may lead to schema verification errors
734    /// later in query planning.
735    fn simplify(
736        &self,
737        args: Vec<Expr>,
738        _info: &SimplifyContext,
739    ) -> Result<ExprSimplifyResult> {
740        Ok(ExprSimplifyResult::Original(args))
741    }
742
743    /// Returns a single contiguous preimage for this function and the specified
744    /// scalar expression, if any.
745    ///
746    /// Currently only applies to `=, !=, >, >=, <, <=, is distinct from, is not distinct from` predicates
747    /// # Return Value
748    ///
749    /// Implementations should return a half-open interval: inclusive lower
750    /// bound and exclusive upper bound. This is slightly different from normal
751    /// [`Interval`] semantics where the upper bound is closed (inclusive).
752    /// Typically this means the upper endpoint must be adjusted to the next
753    /// value not included in the preimage. See the Half-Open Intervals section
754    /// below for more details.
755    ///
756    /// # Background
757    ///
758    /// Inspired by the [ClickHouse Paper], a "preimage rewrite" transforms a
759    /// predicate containing a function call into a predicate containing an
760    /// equivalent set of input literal (constant) values. The resulting
761    /// predicate can often be further optimized by other rewrites (see
762    /// Examples).
763    ///
764    /// From the paper:
765    ///
766    /// > some functions can compute the preimage of a given function result.
767    /// > This is used to replace comparisons of constants with function calls
768    /// > on the key columns by comparing the key column value with the preimage.
769    /// > For example, `toYear(k) = 2024` can be replaced by
770    /// > `k >= 2024-01-01 && k < 2025-01-01`
771    ///
772    /// For example, given an expression like
773    /// ```sql
774    /// date_part('YEAR', k) = 2024
775    /// ```
776    ///
777    /// The interval `[2024-01-01, 2025-12-31`]` contains all possible input
778    /// values (preimage values) for which the function `date_part(YEAR, k)`
779    /// produces the output value `2024` (image value). Returning the interval
780    /// (note upper bound adjusted up) `[2024-01-01, 2025-01-01]` the expression
781    /// can be rewritten to
782    ///
783    /// ```sql
784    /// k >= '2024-01-01' AND k < '2025-01-01'
785    /// ```
786    ///
787    /// which is a simpler and a more canonical form, making it easier for other
788    /// optimizer passes to recognize and apply further transformations.
789    ///
790    /// # Examples
791    ///
792    /// Case 1:
793    ///
794    /// Original:
795    /// ```sql
796    /// date_part('YEAR', k) = 2024 AND k >= '2024-06-01'
797    /// ```
798    ///
799    /// After preimage rewrite:
800    /// ```sql
801    /// k >= '2024-01-01' AND k < '2025-01-01' AND k >= '2024-06-01'
802    /// ```
803    ///
804    /// Since this form is much simpler, the optimizer can combine and simplify
805    /// sub-expressions further into:
806    /// ```sql
807    /// k >= '2024-06-01' AND k < '2025-01-01'
808    /// ```
809    ///
810    /// Case 2:
811    ///
812    /// For min/max pruning, simpler predicates such as:
813    /// ```sql
814    /// k >= '2024-01-01' AND k < '2025-01-01'
815    /// ```
816    /// are much easier for the pruner to reason about. See [PruningPredicate]
817    /// for the backgrounds of predicate pruning.
818    ///
819    /// The trade-off with the preimage rewrite is that evaluating the rewritten
820    /// form might be slightly more expensive than evaluating the original
821    /// expression. In practice, this cost is usually outweighed by the more
822    /// aggressive optimization opportunities it enables.
823    ///
824    /// # Half-Open Intervals
825    ///
826    /// The preimage API uses half-open intervals, which makes the rewrite
827    /// easier to implement by avoiding calculations to adjust the upper bound.
828    /// For example, if a function returns its input unchanged and the desired
829    /// output is the single value `5`, a closed interval could be represented
830    /// as `[5, 5]`, but then the rewrite would require adjusting the upper
831    /// bound to `6` to create a proper range predicate. With a half-open
832    /// interval, the same range is represented as `[5, 6)`, which already
833    /// forms a valid predicate.
834    ///
835    /// [PruningPredicate]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html
836    /// [ClickHouse Paper]:  https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
837    /// [image]: https://en.wikipedia.org/wiki/Image_(mathematics)#Image_of_an_element
838    /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
839    fn preimage(
840        &self,
841        _args: &[Expr],
842        _lit_expr: &Expr,
843        _info: &SimplifyContext,
844    ) -> Result<PreimageResult> {
845        Ok(PreimageResult::None)
846    }
847
848    /// Returns true if some of this `exprs` subexpressions may not be evaluated
849    /// and thus any side effects (like divide by zero) may not be encountered.
850    ///
851    /// Setting this to true prevents certain optimizations such as common
852    /// subexpression elimination
853    ///
854    /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be
855    /// overridden to report more accurately which arguments are eagerly evaluated and which ones
856    /// lazily.
857    fn short_circuits(&self) -> bool {
858        false
859    }
860
861    /// Determines which of the arguments passed to this function are evaluated eagerly
862    /// and which may be evaluated lazily.
863    ///
864    /// If this function returns `None`, all arguments are eagerly evaluated.
865    /// Returning `None` is a micro optimization that saves a needless `Vec`
866    /// allocation.
867    ///
868    /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
869    /// are the arguments that are always evaluated, and `lazy` are the
870    /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
871    /// in some cases).
872    ///
873    /// Implementations must ensure that the two returned `Vec`s are disjunct,
874    /// and that each argument from `args` is present in one the two `Vec`s.
875    ///
876    /// When overriding this function, [ScalarUDFImpl::short_circuits] must
877    /// be overridden to return `true`.
878    fn conditional_arguments<'a>(
879        &self,
880        args: &'a [Expr],
881    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
882        if self.short_circuits() {
883            Some((vec![], args.iter().collect()))
884        } else {
885            None
886        }
887    }
888
889    /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
890    /// intervals.
891    ///
892    /// # Parameters
893    ///
894    /// * `children` are the intervals for the children (inputs) of this function.
895    ///
896    /// # Example
897    ///
898    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
899    /// then the output interval would be `[0, 3]`.
900    fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
901        // We cannot assume the input datatype is the same of output type.
902        Interval::make_unbounded(&DataType::Null)
903    }
904
905    /// Updates bounds for child expressions, given a known [`Interval`]s for this
906    /// function.
907    ///
908    /// This function is used to propagate constraints down through an
909    /// expression tree.
910    ///
911    /// # Parameters
912    ///
913    /// * `interval` is the currently known interval for this function.
914    /// * `inputs` are the current intervals for the inputs (children) of this function.
915    ///
916    /// # Returns
917    ///
918    /// A `Vec` of new intervals for the children, in order.
919    ///
920    /// If constraint propagation reveals an infeasibility for any child, returns
921    /// [`None`]. If none of the children intervals change as a result of
922    /// propagation, may return an empty vector instead of cloning `children`.
923    /// This is the default (and conservative) return value.
924    ///
925    /// # Example
926    ///
927    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
928    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
929    fn propagate_constraints(
930        &self,
931        _interval: &Interval,
932        _inputs: &[&Interval],
933    ) -> Result<Option<Vec<Interval>>> {
934        Ok(Some(vec![]))
935    }
936
937    /// Calculates the [`SortProperties`] of this function based on its children's properties.
938    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
939        if !self.preserves_lex_ordering(inputs)? {
940            return Ok(SortProperties::Unordered);
941        }
942
943        let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
944            return Ok(SortProperties::Singleton);
945        };
946
947        if inputs
948            .iter()
949            .skip(1)
950            .all(|input| &input.sort_properties == first_order)
951        {
952            Ok(*first_order)
953        } else {
954            Ok(SortProperties::Unordered)
955        }
956    }
957
958    /// Returns true if the function preserves lexicographical ordering based on
959    /// the input ordering.
960    ///
961    /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
962    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
963        Ok(false)
964    }
965
966    /// Coerce arguments of a function call to types that the function can evaluate.
967    ///
968    /// This function is only called if [`ScalarUDFImpl::signature`] returns
969    /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
970    /// the other variants of [`TypeSignature`] which handle common cases.
971    ///
972    /// See the [type coercion module](crate::type_coercion)
973    /// documentation for more details on type coercion
974    ///
975    /// [`TypeSignature`]: crate::TypeSignature
976    ///
977    /// For example, if your function requires a floating point arguments, but the user calls
978    /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
979    /// to ensure the argument is converted to `1::double`
980    ///
981    /// # Parameters
982    /// * `arg_types`: The argument types of the arguments  this function with
983    ///
984    /// # Return value
985    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
986    /// arguments to these specific types.
987    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
988        not_impl_err!("Function {} does not implement coerce_types", self.name())
989    }
990
991    /// For struct-producing functions, return how output fields map to input
992    /// arguments. This enables the optimizer to propagate orderings through
993    /// struct projections.
994    ///
995    /// `literal_args[i]` is `Some(value)` if argument `i` is a known literal,
996    /// allowing extraction of field names from arguments like
997    /// `named_struct('field_name', value, ...)`.
998    ///
999    /// For example, `named_struct('a', col1, 'b', col2)` would return a
1000    /// mapping indicating that output field `'a'` (accessed via
1001    /// `get_field(output, 'a')`) corresponds to input argument `col1` at
1002    /// index 1, and field `'b'` corresponds to `col2` at index 3.
1003    fn struct_field_mapping(
1004        &self,
1005        _literal_args: &[Option<ScalarValue>],
1006    ) -> Option<StructFieldMapping> {
1007        None
1008    }
1009
1010    /// Returns the documentation for this Scalar UDF.
1011    ///
1012    /// Documentation can be accessed programmatically as well as generating
1013    /// publicly facing documentation.
1014    fn documentation(&self) -> Option<&Documentation> {
1015        None
1016    }
1017
1018    /// Returns placement information for this function.
1019    ///
1020    /// This is used by optimizers to make decisions about expression placement,
1021    /// such as whether to push expressions down through projections.
1022    ///
1023    /// The default implementation returns [`ExpressionPlacement::KeepInPlace`],
1024    /// meaning the expression should be kept where it is in the plan.
1025    ///
1026    /// Override this method to indicate that the function can be pushed down
1027    /// closer to the data source.
1028    fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement {
1029        ExpressionPlacement::KeepInPlace
1030    }
1031}
1032
1033impl dyn ScalarUDFImpl {
1034    /// Returns `true` if the implementation is of type `T`.
1035    ///
1036    /// Works correctly when called on `Arc<dyn ScalarUDFImpl>` via auto-deref.
1037    pub fn is<T: ScalarUDFImpl>(&self) -> bool {
1038        (self as &dyn Any).is::<T>()
1039    }
1040
1041    /// Attempts to downcast to a concrete type `T`, returning `None` if the
1042    /// implementation is not of that type.
1043    ///
1044    /// Works correctly when called on `Arc<dyn ScalarUDFImpl>` via auto-deref,
1045    /// unlike `(&arc as &dyn Any).downcast_ref::<T>()` which would attempt to
1046    /// downcast the `Arc` itself.
1047    pub fn downcast_ref<T: ScalarUDFImpl>(&self) -> Option<&T> {
1048        (self as &dyn Any).downcast_ref()
1049    }
1050}
1051
1052/// ScalarUDF that adds an alias to the underlying function. It is better to
1053/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
1054#[derive(Debug, PartialEq, Eq, Hash)]
1055struct AliasedScalarUDFImpl {
1056    inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
1057    aliases: Vec<String>,
1058}
1059
1060impl AliasedScalarUDFImpl {
1061    pub fn new(
1062        inner: Arc<dyn ScalarUDFImpl>,
1063        new_aliases: impl IntoIterator<Item = &'static str>,
1064    ) -> Self {
1065        let mut aliases = inner.aliases().to_vec();
1066        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1067        Self {
1068            inner: inner.into(),
1069            aliases,
1070        }
1071    }
1072}
1073
1074#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1075impl ScalarUDFImpl for AliasedScalarUDFImpl {
1076    fn name(&self) -> &str {
1077        self.inner.name()
1078    }
1079
1080    fn display_name(&self, args: &[Expr]) -> Result<String> {
1081        #[expect(deprecated)]
1082        self.inner.display_name(args)
1083    }
1084
1085    fn schema_name(&self, args: &[Expr]) -> Result<String> {
1086        self.inner.schema_name(args)
1087    }
1088
1089    fn signature(&self) -> &Signature {
1090        self.inner.signature()
1091    }
1092
1093    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1094        self.inner.return_type(arg_types)
1095    }
1096
1097    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
1098        self.inner.return_field_from_args(args)
1099    }
1100
1101    fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
1102        #[expect(deprecated)]
1103        self.inner.is_nullable(args, schema)
1104    }
1105
1106    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1107        self.inner.invoke_with_args(args)
1108    }
1109
1110    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
1111        None
1112    }
1113
1114    fn aliases(&self) -> &[String] {
1115        &self.aliases
1116    }
1117
1118    fn simplify(
1119        &self,
1120        args: Vec<Expr>,
1121        info: &SimplifyContext,
1122    ) -> Result<ExprSimplifyResult> {
1123        self.inner.simplify(args, info)
1124    }
1125
1126    fn preimage(
1127        &self,
1128        args: &[Expr],
1129        lit_expr: &Expr,
1130        info: &SimplifyContext,
1131    ) -> Result<PreimageResult> {
1132        self.inner.preimage(args, lit_expr, info)
1133    }
1134
1135    fn conditional_arguments<'a>(
1136        &self,
1137        args: &'a [Expr],
1138    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1139        self.inner.conditional_arguments(args)
1140    }
1141
1142    fn short_circuits(&self) -> bool {
1143        self.inner.short_circuits()
1144    }
1145
1146    fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
1147        self.inner.evaluate_bounds(input)
1148    }
1149
1150    fn propagate_constraints(
1151        &self,
1152        interval: &Interval,
1153        inputs: &[&Interval],
1154    ) -> Result<Option<Vec<Interval>>> {
1155        self.inner.propagate_constraints(interval, inputs)
1156    }
1157
1158    fn struct_field_mapping(
1159        &self,
1160        literal_args: &[Option<ScalarValue>],
1161    ) -> Option<StructFieldMapping> {
1162        self.inner.struct_field_mapping(literal_args)
1163    }
1164
1165    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1166        self.inner.output_ordering(inputs)
1167    }
1168
1169    fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
1170        self.inner.preserves_lex_ordering(inputs)
1171    }
1172
1173    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1174        self.inner.coerce_types(arg_types)
1175    }
1176
1177    fn documentation(&self) -> Option<&Documentation> {
1178        self.inner.documentation()
1179    }
1180
1181    fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
1182        self.inner.placement(args)
1183    }
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188    use super::*;
1189    use datafusion_expr_common::signature::Volatility;
1190    use std::hash::DefaultHasher;
1191
1192    #[derive(Debug, PartialEq, Eq, Hash)]
1193    struct TestScalarUDFImpl {
1194        name: &'static str,
1195        field: &'static str,
1196        signature: Signature,
1197    }
1198    impl ScalarUDFImpl for TestScalarUDFImpl {
1199        fn name(&self) -> &str {
1200            self.name
1201        }
1202
1203        fn signature(&self) -> &Signature {
1204            &self.signature
1205        }
1206
1207        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1208            unimplemented!()
1209        }
1210
1211        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1212            unimplemented!()
1213        }
1214    }
1215
1216    // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1217    // must be consistent, so they are tested together.
1218    #[test]
1219    fn test_partial_eq_hash_and_partial_ord() {
1220        // A parameterized function
1221        let f = test_func("foo", "a");
1222
1223        // Same like `f`, different instance
1224        let f2 = test_func("foo", "a");
1225        assert_eq!(f, f2);
1226        assert_eq!(hash(&f), hash(&f2));
1227        assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1228
1229        // Different parameter
1230        let b = test_func("foo", "b");
1231        assert_ne!(f, b);
1232        assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1233        assert_eq!(f.partial_cmp(&b), None);
1234
1235        // Different name
1236        let o = test_func("other", "a");
1237        assert_ne!(f, o);
1238        assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1239        assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1240
1241        // Different name and parameter
1242        assert_ne!(b, o);
1243        assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1244        assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1245    }
1246
1247    fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1248        ScalarUDF::from(TestScalarUDFImpl {
1249            name,
1250            field: parameter,
1251            signature: Signature::any(1, Volatility::Immutable),
1252        })
1253    }
1254
1255    fn hash<T: Hash>(value: &T) -> u64 {
1256        let hasher = &mut DefaultHasher::new();
1257        value.hash(hasher);
1258        hasher.finish()
1259    }
1260}