datafusion_expr/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::schema_name_from_exprs_comma_separated_without_space;
22use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
23use crate::sort_properties::{ExprProperties, SortProperties};
24use crate::udf_eq::UdfEq;
25use crate::{ColumnarValue, Documentation, Expr, Signature};
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::config::ConfigOptions;
28use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
29use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
30use datafusion_expr_common::interval_arithmetic::Interval;
31use std::any::Any;
32use std::cmp::Ordering;
33use std::fmt::Debug;
34use std::hash::{Hash, Hasher};
35use std::sync::Arc;
36
37/// Logical representation of a Scalar User Defined Function.
38///
39/// A scalar function produces a single row output for each row of input. This
40/// struct contains the information DataFusion needs to plan and invoke
41/// functions you supply such as name, type signature, return type, and actual
42/// implementation.
43///
44/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
45///
46/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
47///    access (examples in  [`advanced_udf.rs`]).
48///
49/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
50///
51/// # API Note
52///
53/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
54/// compatibility with the older API.
55///
56/// [`create_udf`]: crate::expr_fn::create_udf
57/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
58/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
59#[derive(Debug, Clone)]
60pub struct ScalarUDF {
61    inner: Arc<dyn ScalarUDFImpl>,
62}
63
64impl PartialEq for ScalarUDF {
65    fn eq(&self, other: &Self) -> bool {
66        self.inner.dyn_eq(other.inner.as_any())
67    }
68}
69
70impl PartialOrd for ScalarUDF {
71    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
72        let mut cmp = self.name().cmp(other.name());
73        if cmp == Ordering::Equal {
74            cmp = self.signature().partial_cmp(other.signature())?;
75        }
76        if cmp == Ordering::Equal {
77            cmp = self.aliases().partial_cmp(other.aliases())?;
78        }
79        // Contract for PartialOrd and PartialEq consistency requires that
80        // a == b if and only if partial_cmp(a, b) == Some(Equal).
81        if cmp == Ordering::Equal && self != other {
82            // Functions may have other properties besides name and signature
83            // that differentiate two instances (e.g. type, or arbitrary parameters).
84            // We cannot return Some(Equal) in such case.
85            return None;
86        }
87        debug_assert!(
88            cmp == Ordering::Equal || self != other,
89            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
90            The functions compare as equal, but they are not equal based on general properties that \
91            the PartialOrd implementation observes,",
92            self.name(), other.name()
93        );
94        Some(cmp)
95    }
96}
97
98impl Eq for ScalarUDF {}
99
100impl Hash for ScalarUDF {
101    fn hash<H: Hasher>(&self, state: &mut H) {
102        self.inner.dyn_hash(state)
103    }
104}
105
106impl ScalarUDF {
107    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
108    ///
109    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
110    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
111    where
112        F: ScalarUDFImpl + 'static,
113    {
114        Self::new_from_shared_impl(Arc::new(fun))
115    }
116
117    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
118    pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
119        Self { inner: fun }
120    }
121
122    /// Return the underlying [`ScalarUDFImpl`] trait object for this function
123    pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
124        &self.inner
125    }
126
127    /// Adds additional names that can be used to invoke this function, in
128    /// addition to `name`
129    ///
130    /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
131    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
132        Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
133    }
134
135    /// Returns a [`Expr`] logical expression to call this UDF with specified
136    /// arguments.
137    ///
138    /// This utility allows easily calling UDFs
139    ///
140    /// # Example
141    /// ```no_run
142    /// use datafusion_expr::{col, lit, ScalarUDF};
143    /// # fn my_udf() -> ScalarUDF { unimplemented!() }
144    /// let my_func: ScalarUDF = my_udf();
145    /// // Create an expr for `my_func(a, 12.3)`
146    /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
147    /// ```
148    pub fn call(&self, args: Vec<Expr>) -> Expr {
149        Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
150            Arc::new(self.clone()),
151            args,
152        ))
153    }
154
155    /// Returns this function's name.
156    ///
157    /// See [`ScalarUDFImpl::name`] for more details.
158    pub fn name(&self) -> &str {
159        self.inner.name()
160    }
161
162    /// Returns this function's display_name.
163    ///
164    /// See [`ScalarUDFImpl::display_name`] for more details
165    #[deprecated(
166        since = "50.0.0",
167        note = "This method is unused and will be removed in a future release"
168    )]
169    pub fn display_name(&self, args: &[Expr]) -> Result<String> {
170        #[expect(deprecated)]
171        self.inner.display_name(args)
172    }
173
174    /// Returns this function's schema_name.
175    ///
176    /// See [`ScalarUDFImpl::schema_name`] for more details
177    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
178        self.inner.schema_name(args)
179    }
180
181    /// Returns the aliases for this function.
182    ///
183    /// See [`ScalarUDF::with_aliases`] for more details
184    pub fn aliases(&self) -> &[String] {
185        self.inner.aliases()
186    }
187
188    /// Returns this function's [`Signature`] (what input types are accepted).
189    ///
190    /// See [`ScalarUDFImpl::signature`] for more details.
191    pub fn signature(&self) -> &Signature {
192        self.inner.signature()
193    }
194
195    /// The datatype this function returns given the input argument types.
196    /// This function is used when the input arguments are [`DataType`]s.
197    ///
198    ///  # Notes
199    ///
200    /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
201    /// its [`ScalarUDFImpl::return_type`] should raise an error.
202    ///
203    /// See [`ScalarUDFImpl::return_type`] for more details.
204    pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
205        self.inner.return_type(arg_types)
206    }
207
208    /// Return the datatype this function returns given the input argument types.
209    ///
210    /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
211    pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
212        self.inner.return_field_from_args(args)
213    }
214
215    /// Do the function rewrite
216    ///
217    /// See [`ScalarUDFImpl::simplify`] for more details.
218    pub fn simplify(
219        &self,
220        args: Vec<Expr>,
221        info: &dyn SimplifyInfo,
222    ) -> Result<ExprSimplifyResult> {
223        self.inner.simplify(args, info)
224    }
225
226    #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
227    pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
228        #[allow(deprecated)]
229        self.inner.is_nullable(args, schema)
230    }
231
232    /// Invoke the function on `args`, returning the appropriate result.
233    ///
234    /// See [`ScalarUDFImpl::invoke_with_args`] for details.
235    pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
236        self.inner.invoke_with_args(args)
237    }
238
239    /// Get the circuits of inner implementation
240    pub fn short_circuits(&self) -> bool {
241        self.inner.short_circuits()
242    }
243
244    /// Computes the output interval for a [`ScalarUDF`], given the input
245    /// intervals.
246    ///
247    /// # Parameters
248    ///
249    /// * `inputs` are the intervals for the inputs (children) of this function.
250    ///
251    /// # Example
252    ///
253    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
254    /// then the output interval would be `[0, 3]`.
255    pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
256        self.inner.evaluate_bounds(inputs)
257    }
258
259    /// Updates bounds for child expressions, given a known interval for this
260    /// function. This is used to propagate constraints down through an expression
261    /// tree.
262    ///
263    /// # Parameters
264    ///
265    /// * `interval` is the currently known interval for this function.
266    /// * `inputs` are the current intervals for the inputs (children) of this function.
267    ///
268    /// # Returns
269    ///
270    /// A `Vec` of new intervals for the children, in order.
271    ///
272    /// If constraint propagation reveals an infeasibility for any child, returns
273    /// [`None`]. If none of the children intervals change as a result of
274    /// propagation, may return an empty vector instead of cloning `children`.
275    /// This is the default (and conservative) return value.
276    ///
277    /// # Example
278    ///
279    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
280    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
281    pub fn propagate_constraints(
282        &self,
283        interval: &Interval,
284        inputs: &[&Interval],
285    ) -> Result<Option<Vec<Interval>>> {
286        self.inner.propagate_constraints(interval, inputs)
287    }
288
289    /// Calculates the [`SortProperties`] of this function based on its
290    /// children's properties.
291    pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
292        self.inner.output_ordering(inputs)
293    }
294
295    pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
296        self.inner.preserves_lex_ordering(inputs)
297    }
298
299    /// See [`ScalarUDFImpl::coerce_types`] for more details.
300    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
301        self.inner.coerce_types(arg_types)
302    }
303
304    /// Returns the documentation for this Scalar UDF.
305    ///
306    /// Documentation can be accessed programmatically as well as
307    /// generating publicly facing documentation.
308    pub fn documentation(&self) -> Option<&Documentation> {
309        self.inner.documentation()
310    }
311
312    /// Return true if this function is an async function
313    pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
314        self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
315    }
316}
317
318impl<F> From<F> for ScalarUDF
319where
320    F: ScalarUDFImpl + 'static,
321{
322    fn from(fun: F) -> Self {
323        Self::new_from_impl(fun)
324    }
325}
326
327/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
328/// scalar function.
329#[derive(Debug, Clone)]
330pub struct ScalarFunctionArgs {
331    /// The evaluated arguments to the function
332    pub args: Vec<ColumnarValue>,
333    /// Field associated with each arg, if it exists
334    pub arg_fields: Vec<FieldRef>,
335    /// The number of rows in record batch being evaluated
336    pub number_rows: usize,
337    /// The return field of the scalar function returned (from `return_type`
338    /// or `return_field_from_args`) when creating the physical expression
339    /// from the logical expression
340    pub return_field: FieldRef,
341    /// The config options at execution time
342    pub config_options: Arc<ConfigOptions>,
343}
344
345impl ScalarFunctionArgs {
346    /// The return type of the function. See [`Self::return_field`] for more
347    /// details.
348    pub fn return_type(&self) -> &DataType {
349        self.return_field.data_type()
350    }
351}
352
353/// Information about arguments passed to the function
354///
355/// This structure contains metadata about how the function was called
356/// such as the type of the arguments, any scalar arguments and if the
357/// arguments can (ever) be null
358///
359/// See [`ScalarUDFImpl::return_field_from_args`] for more information
360#[derive(Debug)]
361pub struct ReturnFieldArgs<'a> {
362    /// The data types of the arguments to the function
363    pub arg_fields: &'a [FieldRef],
364    /// Is argument `i` to the function a scalar (constant)?
365    ///
366    /// If the argument `i` is not a scalar, it will be None
367    ///
368    /// For example, if a function is called like `my_function(column_a, 5)`
369    /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
370    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
371}
372
373/// Trait for implementing user defined scalar functions.
374///
375/// This trait exposes the full API for implementing user defined functions and
376/// can be used to implement any function.
377///
378/// See [`advanced_udf.rs`] for a full example with complete implementation and
379/// [`ScalarUDF`] for other available options.
380///
381/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
382///
383/// # Basic Example
384/// ```
385/// # use std::any::Any;
386/// # use std::sync::LazyLock;
387/// # use arrow::datatypes::DataType;
388/// # use datafusion_common::{DataFusionError, plan_err, Result};
389/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
390/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
391/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
392/// /// This struct for a simple UDF that adds one to an int32
393/// #[derive(Debug, PartialEq, Eq, Hash)]
394/// struct AddOne {
395///   signature: Signature,
396/// }
397///
398/// impl AddOne {
399///   fn new() -> Self {
400///     Self {
401///       signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
402///      }
403///   }
404/// }
405///
406/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
407///         Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
408///             .with_argument("arg1", "The int32 number to add one to")
409///             .build()
410///     });
411///
412/// fn get_doc() -> &'static Documentation {
413///     &DOCUMENTATION
414/// }
415///
416/// /// Implement the ScalarUDFImpl trait for AddOne
417/// impl ScalarUDFImpl for AddOne {
418///    fn as_any(&self) -> &dyn Any { self }
419///    fn name(&self) -> &str { "add_one" }
420///    fn signature(&self) -> &Signature { &self.signature }
421///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
422///      if !matches!(args.get(0), Some(&DataType::Int32)) {
423///        return plan_err!("add_one only accepts Int32 arguments");
424///      }
425///      Ok(DataType::Int32)
426///    }
427///    // The actual implementation would add one to the argument
428///    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
429///         unimplemented!()
430///    }
431///    fn documentation(&self) -> Option<&Documentation> {
432///         Some(get_doc())
433///     }
434/// }
435///
436/// // Create a new ScalarUDF from the implementation
437/// let add_one = ScalarUDF::from(AddOne::new());
438///
439/// // Call the function `add_one(col)`
440/// let expr = add_one.call(vec![col("a")]);
441/// ```
442pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
443    /// Returns this object as an [`Any`] trait object
444    fn as_any(&self) -> &dyn Any;
445
446    /// Returns this function's name
447    fn name(&self) -> &str;
448
449    /// Returns any aliases (alternate names) for this function.
450    ///
451    /// Aliases can be used to invoke the same function using different names.
452    /// For example in some databases `now()` and `current_timestamp()` are
453    /// aliases for the same function. This behavior can be obtained by
454    /// returning `current_timestamp` as an alias for the `now` function.
455    ///
456    /// Note: `aliases` should only include names other than [`Self::name`].
457    /// Defaults to `[]` (no aliases)
458    fn aliases(&self) -> &[String] {
459        &[]
460    }
461
462    /// Returns the user-defined display name of function, given the arguments
463    ///
464    /// This can be used to customize the output column name generated by this
465    /// function.
466    ///
467    /// Defaults to `name(args[0], args[1], ...)`
468    #[deprecated(
469        since = "50.0.0",
470        note = "This method is unused and will be removed in a future release"
471    )]
472    fn display_name(&self, args: &[Expr]) -> Result<String> {
473        let names: Vec<String> = args.iter().map(ToString::to_string).collect();
474        // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
475        Ok(format!("{}({})", self.name(), names.join(",")))
476    }
477
478    /// Returns the name of the column this expression would create
479    ///
480    /// See [`Expr::schema_name`] for details
481    fn schema_name(&self, args: &[Expr]) -> Result<String> {
482        Ok(format!(
483            "{}({})",
484            self.name(),
485            schema_name_from_exprs_comma_separated_without_space(args)?
486        ))
487    }
488
489    /// Returns a [`Signature`] describing the argument types for which this
490    /// function has an implementation, and the function's [`Volatility`].
491    ///
492    /// See [`Signature`] for more details on argument type handling
493    /// and [`Self::return_type`] for computing the return type.
494    ///
495    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
496    fn signature(&self) -> &Signature;
497
498    /// [`DataType`] returned by this function, given the types of the
499    /// arguments.
500    ///
501    /// # Arguments
502    ///
503    /// `arg_types` Data types of the arguments. The implementation of
504    /// `return_type` can assume that some other part of the code has coerced
505    /// the actual argument types to match [`Self::signature`].
506    ///
507    /// # Notes
508    ///
509    /// If you provide an implementation for [`Self::return_field_from_args`],
510    /// DataFusion will not call `return_type` (this function). While it is
511    /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is
512    /// recommended to return [`DataFusionError::Internal`] instead, which
513    /// reduces the severity of symptoms if bugs occur (an error rather than a
514    /// panic).
515    ///
516    /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
517    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
518
519    /// What type will be returned by this function, given the arguments?
520    ///
521    /// By default, this function calls [`Self::return_type`] with the
522    /// types of each argument.
523    ///
524    /// # Notes
525    ///
526    /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
527    /// as the result type is typically a deterministic function of the input types
528    /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
529    /// is generally unnecessary unless the return type depends on runtime values.
530    ///
531    /// This function can be used for more advanced cases such as:
532    ///
533    /// 1. specifying nullability
534    /// 2. return types based on the **values** of the arguments (rather than
535    ///    their **types**.
536    ///
537    /// # Example creating `Field`
538    ///
539    /// Note the name of the [`Field`] is ignored, except for structured types such as
540    /// `DataType::Struct`.
541    ///
542    /// ```rust
543    /// # use std::sync::Arc;
544    /// # use arrow::datatypes::{DataType, Field, FieldRef};
545    /// # use datafusion_common::Result;
546    /// # use datafusion_expr::ReturnFieldArgs;
547    /// # struct Example{}
548    /// # impl Example {
549    /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
550    ///   // report output is only nullable if any one of the arguments are nullable
551    ///   let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
552    ///   let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
553    ///   Ok(field)
554    /// }
555    /// # }
556    /// ```
557    ///
558    /// # Output Type based on Values
559    ///
560    /// For example, the following two function calls get the same argument
561    /// types (something and a `Utf8` string) but return different types based
562    /// on the value of the second argument:
563    ///
564    /// * `arrow_cast(x, 'Int16')` --> `Int16`
565    /// * `arrow_cast(x, 'Float32')` --> `Float32`
566    ///
567    /// # Requirements
568    ///
569    /// This function **must** consistently return the same type for the same
570    /// logical input even if the input is simplified (e.g. it must return the same
571    /// value for `('foo' | 'bar')` as it does for ('foobar').
572    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
573        let data_types = args
574            .arg_fields
575            .iter()
576            .map(|f| f.data_type())
577            .cloned()
578            .collect::<Vec<_>>();
579        let return_type = self.return_type(&data_types)?;
580        Ok(Arc::new(Field::new(self.name(), return_type, true)))
581    }
582
583    #[deprecated(
584        since = "45.0.0",
585        note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
586    )]
587    fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
588        true
589    }
590
591    /// Invoke the function returning the appropriate result.
592    ///
593    /// # Performance
594    ///
595    /// For the best performance, the implementations should handle the common case
596    /// when one or more of their arguments are constant values (aka
597    /// [`ColumnarValue::Scalar`]).
598    ///
599    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
600    /// to arrays, which will likely be simpler code, but be slower.
601    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
602
603    /// Optionally apply per-UDF simplification / rewrite rules.
604    ///
605    /// This can be used to apply function specific simplification rules during
606    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
607    /// implementation does nothing.
608    ///
609    /// Note that DataFusion handles simplifying arguments and  "constant
610    /// folding" (replacing a function call with constant arguments such as
611    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
612    /// optimizations manually for specific UDFs.
613    ///
614    /// # Arguments
615    /// * `args`: The arguments of the function
616    /// * `info`: The necessary information for simplification
617    ///
618    /// # Returns
619    /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
620    /// if the function cannot be simplified, the arguments *MUST* be returned
621    /// unmodified
622    fn simplify(
623        &self,
624        args: Vec<Expr>,
625        _info: &dyn SimplifyInfo,
626    ) -> Result<ExprSimplifyResult> {
627        Ok(ExprSimplifyResult::Original(args))
628    }
629
630    /// Returns true if some of this `exprs` subexpressions may not be evaluated
631    /// and thus any side effects (like divide by zero) may not be encountered.
632    ///
633    /// Setting this to true prevents certain optimizations such as common
634    /// subexpression elimination
635    fn short_circuits(&self) -> bool {
636        false
637    }
638
639    /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
640    /// intervals.
641    ///
642    /// # Parameters
643    ///
644    /// * `children` are the intervals for the children (inputs) of this function.
645    ///
646    /// # Example
647    ///
648    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
649    /// then the output interval would be `[0, 3]`.
650    fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
651        // We cannot assume the input datatype is the same of output type.
652        Interval::make_unbounded(&DataType::Null)
653    }
654
655    /// Updates bounds for child expressions, given a known [`Interval`]s for this
656    /// function.
657    ///
658    /// This function is used to propagate constraints down through an
659    /// expression tree.
660    ///
661    /// # Parameters
662    ///
663    /// * `interval` is the currently known interval for this function.
664    /// * `inputs` are the current intervals for the inputs (children) of this function.
665    ///
666    /// # Returns
667    ///
668    /// A `Vec` of new intervals for the children, in order.
669    ///
670    /// If constraint propagation reveals an infeasibility for any child, returns
671    /// [`None`]. If none of the children intervals change as a result of
672    /// propagation, may return an empty vector instead of cloning `children`.
673    /// This is the default (and conservative) return value.
674    ///
675    /// # Example
676    ///
677    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
678    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
679    fn propagate_constraints(
680        &self,
681        _interval: &Interval,
682        _inputs: &[&Interval],
683    ) -> Result<Option<Vec<Interval>>> {
684        Ok(Some(vec![]))
685    }
686
687    /// Calculates the [`SortProperties`] of this function based on its children's properties.
688    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
689        if !self.preserves_lex_ordering(inputs)? {
690            return Ok(SortProperties::Unordered);
691        }
692
693        let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
694            return Ok(SortProperties::Singleton);
695        };
696
697        if inputs
698            .iter()
699            .skip(1)
700            .all(|input| &input.sort_properties == first_order)
701        {
702            Ok(*first_order)
703        } else {
704            Ok(SortProperties::Unordered)
705        }
706    }
707
708    /// Returns true if the function preserves lexicographical ordering based on
709    /// the input ordering.
710    ///
711    /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
712    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
713        Ok(false)
714    }
715
716    /// Coerce arguments of a function call to types that the function can evaluate.
717    ///
718    /// This function is only called if [`ScalarUDFImpl::signature`] returns
719    /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
720    /// the other variants of [`TypeSignature`] which handle common cases.
721    ///
722    /// See the [type coercion module](crate::type_coercion)
723    /// documentation for more details on type coercion
724    ///
725    /// [`TypeSignature`]: crate::TypeSignature
726    ///
727    /// For example, if your function requires a floating point arguments, but the user calls
728    /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
729    /// to ensure the argument is converted to `1::double`
730    ///
731    /// # Parameters
732    /// * `arg_types`: The argument types of the arguments  this function with
733    ///
734    /// # Return value
735    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
736    /// arguments to these specific types.
737    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
738        not_impl_err!("Function {} does not implement coerce_types", self.name())
739    }
740
741    /// Returns the documentation for this Scalar UDF.
742    ///
743    /// Documentation can be accessed programmatically as well as generating
744    /// publicly facing documentation.
745    fn documentation(&self) -> Option<&Documentation> {
746        None
747    }
748}
749
750/// ScalarUDF that adds an alias to the underlying function. It is better to
751/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
752#[derive(Debug, PartialEq, Eq, Hash)]
753struct AliasedScalarUDFImpl {
754    inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
755    aliases: Vec<String>,
756}
757
758impl AliasedScalarUDFImpl {
759    pub fn new(
760        inner: Arc<dyn ScalarUDFImpl>,
761        new_aliases: impl IntoIterator<Item = &'static str>,
762    ) -> Self {
763        let mut aliases = inner.aliases().to_vec();
764        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
765        Self {
766            inner: inner.into(),
767            aliases,
768        }
769    }
770}
771
772#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
773impl ScalarUDFImpl for AliasedScalarUDFImpl {
774    fn as_any(&self) -> &dyn Any {
775        self
776    }
777
778    fn name(&self) -> &str {
779        self.inner.name()
780    }
781
782    fn display_name(&self, args: &[Expr]) -> Result<String> {
783        #[expect(deprecated)]
784        self.inner.display_name(args)
785    }
786
787    fn schema_name(&self, args: &[Expr]) -> Result<String> {
788        self.inner.schema_name(args)
789    }
790
791    fn signature(&self) -> &Signature {
792        self.inner.signature()
793    }
794
795    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
796        self.inner.return_type(arg_types)
797    }
798
799    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
800        self.inner.return_field_from_args(args)
801    }
802
803    fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
804        #[allow(deprecated)]
805        self.inner.is_nullable(args, schema)
806    }
807
808    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
809        self.inner.invoke_with_args(args)
810    }
811
812    fn aliases(&self) -> &[String] {
813        &self.aliases
814    }
815
816    fn simplify(
817        &self,
818        args: Vec<Expr>,
819        info: &dyn SimplifyInfo,
820    ) -> Result<ExprSimplifyResult> {
821        self.inner.simplify(args, info)
822    }
823
824    fn short_circuits(&self) -> bool {
825        self.inner.short_circuits()
826    }
827
828    fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
829        self.inner.evaluate_bounds(input)
830    }
831
832    fn propagate_constraints(
833        &self,
834        interval: &Interval,
835        inputs: &[&Interval],
836    ) -> Result<Option<Vec<Interval>>> {
837        self.inner.propagate_constraints(interval, inputs)
838    }
839
840    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
841        self.inner.output_ordering(inputs)
842    }
843
844    fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
845        self.inner.preserves_lex_ordering(inputs)
846    }
847
848    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
849        self.inner.coerce_types(arg_types)
850    }
851
852    fn documentation(&self) -> Option<&Documentation> {
853        self.inner.documentation()
854    }
855}
856
857// Scalar UDF doc sections for use in public documentation
858pub mod scalar_doc_sections {
859    use crate::DocSection;
860
861    pub fn doc_sections() -> Vec<DocSection> {
862        vec![
863            DOC_SECTION_MATH,
864            DOC_SECTION_CONDITIONAL,
865            DOC_SECTION_STRING,
866            DOC_SECTION_BINARY_STRING,
867            DOC_SECTION_REGEX,
868            DOC_SECTION_DATETIME,
869            DOC_SECTION_ARRAY,
870            DOC_SECTION_STRUCT,
871            DOC_SECTION_MAP,
872            DOC_SECTION_HASHING,
873            DOC_SECTION_UNION,
874            DOC_SECTION_OTHER,
875        ]
876    }
877
878    pub const fn doc_sections_const() -> &'static [DocSection] {
879        &[
880            DOC_SECTION_MATH,
881            DOC_SECTION_CONDITIONAL,
882            DOC_SECTION_STRING,
883            DOC_SECTION_BINARY_STRING,
884            DOC_SECTION_REGEX,
885            DOC_SECTION_DATETIME,
886            DOC_SECTION_ARRAY,
887            DOC_SECTION_STRUCT,
888            DOC_SECTION_MAP,
889            DOC_SECTION_HASHING,
890            DOC_SECTION_UNION,
891            DOC_SECTION_OTHER,
892        ]
893    }
894
895    pub const DOC_SECTION_MATH: DocSection = DocSection {
896        include: true,
897        label: "Math Functions",
898        description: None,
899    };
900
901    pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection {
902        include: true,
903        label: "Conditional Functions",
904        description: None,
905    };
906
907    pub const DOC_SECTION_STRING: DocSection = DocSection {
908        include: true,
909        label: "String Functions",
910        description: None,
911    };
912
913    pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection {
914        include: true,
915        label: "Binary String Functions",
916        description: None,
917    };
918
919    pub const DOC_SECTION_REGEX: DocSection = DocSection {
920        include: true,
921        label: "Regular Expression Functions",
922        description: Some(
923            r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions)
924regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax)
925(minus support for several features including look-around and backreferences).
926The following regular expression functions are supported:"#,
927        ),
928    };
929
930    pub const DOC_SECTION_DATETIME: DocSection = DocSection {
931        include: true,
932        label: "Time and Date Functions",
933        description: None,
934    };
935
936    pub const DOC_SECTION_ARRAY: DocSection = DocSection {
937        include: true,
938        label: "Array Functions",
939        description: None,
940    };
941
942    pub const DOC_SECTION_STRUCT: DocSection = DocSection {
943        include: true,
944        label: "Struct Functions",
945        description: None,
946    };
947
948    pub const DOC_SECTION_MAP: DocSection = DocSection {
949        include: true,
950        label: "Map Functions",
951        description: None,
952    };
953
954    pub const DOC_SECTION_HASHING: DocSection = DocSection {
955        include: true,
956        label: "Hashing Functions",
957        description: None,
958    };
959
960    pub const DOC_SECTION_OTHER: DocSection = DocSection {
961        include: true,
962        label: "Other Functions",
963        description: None,
964    };
965
966    pub const DOC_SECTION_UNION: DocSection = DocSection {
967        include: true,
968        label: "Union Functions",
969        description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"),
970    };
971}
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976    use datafusion_expr_common::signature::Volatility;
977    use std::hash::DefaultHasher;
978
979    #[derive(Debug, PartialEq, Eq, Hash)]
980    struct TestScalarUDFImpl {
981        name: &'static str,
982        field: &'static str,
983        signature: Signature,
984    }
985    impl ScalarUDFImpl for TestScalarUDFImpl {
986        fn as_any(&self) -> &dyn Any {
987            self
988        }
989
990        fn name(&self) -> &str {
991            self.name
992        }
993
994        fn signature(&self) -> &Signature {
995            &self.signature
996        }
997
998        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
999            unimplemented!()
1000        }
1001
1002        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1003            unimplemented!()
1004        }
1005    }
1006
1007    // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1008    // must be consistent, so they are tested together.
1009    #[test]
1010    fn test_partial_eq_hash_and_partial_ord() {
1011        // A parameterized function
1012        let f = test_func("foo", "a");
1013
1014        // Same like `f`, different instance
1015        let f2 = test_func("foo", "a");
1016        assert_eq!(f, f2);
1017        assert_eq!(hash(&f), hash(&f2));
1018        assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1019
1020        // Different parameter
1021        let b = test_func("foo", "b");
1022        assert_ne!(f, b);
1023        assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1024        assert_eq!(f.partial_cmp(&b), None);
1025
1026        // Different name
1027        let o = test_func("other", "a");
1028        assert_ne!(f, o);
1029        assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1030        assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1031
1032        // Different name and parameter
1033        assert_ne!(b, o);
1034        assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1035        assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1036    }
1037
1038    fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1039        ScalarUDF::from(TestScalarUDFImpl {
1040            name,
1041            field: parameter,
1042            signature: Signature::any(1, Volatility::Immutable),
1043        })
1044    }
1045
1046    fn hash<T: Hash>(value: &T) -> u64 {
1047        let hasher = &mut DefaultHasher::new();
1048        value.hash(hasher);
1049        hasher.finish()
1050    }
1051}