use crate::async_udf::AsyncScalarUDF;
use crate::expr::schema_name_from_exprs_comma_separated_without_space;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::sort_properties::{ExprProperties, SortProperties};
use crate::udf_eq::UdfEq;
use crate::{ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
use datafusion_expr_common::interval_arithmetic::Interval;
use std::any::Any;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ScalarUDF {
inner: Arc<dyn ScalarUDFImpl>,
}
impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.dyn_eq(other.inner.as_any())
}
}
impl PartialOrd for ScalarUDF {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let mut cmp = self.name().cmp(other.name());
if cmp == Ordering::Equal {
cmp = self.signature().partial_cmp(other.signature())?;
}
if cmp == Ordering::Equal {
cmp = self.aliases().partial_cmp(other.aliases())?;
}
if cmp == Ordering::Equal && self != other {
return None;
}
debug_assert!(
cmp == Ordering::Equal || self != other,
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
The functions compare as equal, but they are not equal based on general properties that \
the PartialOrd implementation observes,",
self.name(), other.name()
);
Some(cmp)
}
}
impl Eq for ScalarUDF {}
impl Hash for ScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.dyn_hash(state)
}
}
impl ScalarUDF {
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
Self::new_from_shared_impl(Arc::new(fun))
}
pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
Self { inner: fun }
}
pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
Arc::new(self.clone()),
args,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
#[deprecated(
since = "50.0.0",
note = "This method is unused and will be removed in a future release"
)]
pub fn display_name(&self, args: &[Expr]) -> Result<String> {
#[expect(deprecated)]
self.inner.display_name(args)
}
pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
pub fn simplify(
&self,
args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
#[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
#[allow(deprecated)]
self.inner.is_nullable(args, schema)
}
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(inputs)
}
pub fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
pub fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
}
}
impl<F> From<F> for ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
#[derive(Debug, Clone)]
pub struct ScalarFunctionArgs {
pub args: Vec<ColumnarValue>,
pub arg_fields: Vec<FieldRef>,
pub number_rows: usize,
pub return_field: FieldRef,
pub config_options: Arc<ConfigOptions>,
}
impl ScalarFunctionArgs {
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
}
#[derive(Debug)]
pub struct ReturnFieldArgs<'a> {
pub arg_fields: &'a [FieldRef],
pub scalar_arguments: &'a [Option<&'a ScalarValue>],
}
pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn aliases(&self) -> &[String] {
&[]
}
#[deprecated(
since = "50.0.0",
note = "This method is unused and will be removed in a future release"
)]
fn display_name(&self, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(ToString::to_string).collect();
Ok(format!("{}({})", self.name(), names.join(",")))
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
Ok(format!(
"{}({})",
self.name(),
schema_name_from_exprs_comma_separated_without_space(args)?
))
}
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
Ok(Arc::new(Field::new(self.name(), return_type, true)))
}
#[deprecated(
since = "45.0.0",
note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
)]
fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
true
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
Ok(ExprSimplifyResult::Original(args))
}
fn short_circuits(&self) -> bool {
false
}
fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
Interval::make_unbounded(&DataType::Null)
}
fn propagate_constraints(
&self,
_interval: &Interval,
_inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
Ok(Some(vec![]))
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
if !self.preserves_lex_ordering(inputs)? {
return Ok(SortProperties::Unordered);
}
let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
return Ok(SortProperties::Singleton);
};
if inputs
.iter()
.skip(1)
.all(|input| &input.sort_properties == first_order)
{
Ok(*first_order)
} else {
Ok(SortProperties::Unordered)
}
}
fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
Ok(false)
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
fn documentation(&self) -> Option<&Documentation> {
None
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct AliasedScalarUDFImpl {
inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
aliases: Vec<String>,
}
impl AliasedScalarUDFImpl {
pub fn new(
inner: Arc<dyn ScalarUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self {
inner: inner.into(),
aliases,
}
}
}
#[warn(clippy::missing_trait_methods)] impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn display_name(&self, args: &[Expr]) -> Result<String> {
#[expect(deprecated)]
self.inner.display_name(args)
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
#[allow(deprecated)]
self.inner.is_nullable(args, schema)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn simplify(
&self,
args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(input)
}
fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
pub mod scalar_doc_sections {
use crate::DocSection;
pub fn doc_sections() -> Vec<DocSection> {
vec![
DOC_SECTION_MATH,
DOC_SECTION_CONDITIONAL,
DOC_SECTION_STRING,
DOC_SECTION_BINARY_STRING,
DOC_SECTION_REGEX,
DOC_SECTION_DATETIME,
DOC_SECTION_ARRAY,
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
pub const fn doc_sections_const() -> &'static [DocSection] {
&[
DOC_SECTION_MATH,
DOC_SECTION_CONDITIONAL,
DOC_SECTION_STRING,
DOC_SECTION_BINARY_STRING,
DOC_SECTION_REGEX,
DOC_SECTION_DATETIME,
DOC_SECTION_ARRAY,
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
pub const DOC_SECTION_MATH: DocSection = DocSection {
include: true,
label: "Math Functions",
description: None,
};
pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection {
include: true,
label: "Conditional Functions",
description: None,
};
pub const DOC_SECTION_STRING: DocSection = DocSection {
include: true,
label: "String Functions",
description: None,
};
pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection {
include: true,
label: "Binary String Functions",
description: None,
};
pub const DOC_SECTION_REGEX: DocSection = DocSection {
include: true,
label: "Regular Expression Functions",
description: Some(
r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions)
regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax)
(minus support for several features including look-around and backreferences).
The following regular expression functions are supported:"#,
),
};
pub const DOC_SECTION_DATETIME: DocSection = DocSection {
include: true,
label: "Time and Date Functions",
description: None,
};
pub const DOC_SECTION_ARRAY: DocSection = DocSection {
include: true,
label: "Array Functions",
description: None,
};
pub const DOC_SECTION_STRUCT: DocSection = DocSection {
include: true,
label: "Struct Functions",
description: None,
};
pub const DOC_SECTION_MAP: DocSection = DocSection {
include: true,
label: "Map Functions",
description: None,
};
pub const DOC_SECTION_HASHING: DocSection = DocSection {
include: true,
label: "Hashing Functions",
description: None,
};
pub const DOC_SECTION_OTHER: DocSection = DocSection {
include: true,
label: "Other Functions",
description: None,
};
pub const DOC_SECTION_UNION: DocSection = DocSection {
include: true,
label: "Union Functions",
description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"),
};
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr_common::signature::Volatility;
use std::hash::DefaultHasher;
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestScalarUDFImpl {
name: &'static str,
field: &'static str,
signature: Signature,
}
impl ScalarUDFImpl for TestScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
unimplemented!()
}
}
#[test]
fn test_partial_eq_hash_and_partial_ord() {
let f = test_func("foo", "a");
let f2 = test_func("foo", "a");
assert_eq!(f, f2);
assert_eq!(hash(&f), hash(&f2));
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
let b = test_func("foo", "b");
assert_ne!(f, b);
assert_ne!(hash(&f), hash(&b)); assert_eq!(f.partial_cmp(&b), None);
let o = test_func("other", "a");
assert_ne!(f, o);
assert_ne!(hash(&f), hash(&o)); assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
assert_ne!(b, o);
assert_ne!(hash(&b), hash(&o)); assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
}
fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
ScalarUDF::from(TestScalarUDFImpl {
name,
field: parameter,
signature: Signature::any(1, Volatility::Immutable),
})
}
fn hash<T: Hash>(value: &T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}
}