use crate::async_udf::AsyncScalarUDF;
use crate::expr::schema_name_from_exprs_comma_separated_without_space;
use crate::preimage::PreimageResult;
use crate::simplify::{ExprSimplifyResult, SimplifyContext};
use crate::sort_properties::{ExprProperties, SortProperties};
use crate::udf_eq::UdfEq;
use crate::{ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::{DataType, Field, FieldRef};
#[cfg(debug_assertions)]
use datafusion_common::assert_or_internal_err;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err};
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::placement::ExpressionPlacement;
use std::any::Any;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ScalarUDF {
inner: Arc<dyn ScalarUDFImpl>,
}
impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.dyn_eq(other.inner.as_any())
}
}
impl PartialOrd for ScalarUDF {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let mut cmp = self.name().cmp(other.name());
if cmp == Ordering::Equal {
cmp = self.signature().partial_cmp(other.signature())?;
}
if cmp == Ordering::Equal {
cmp = self.aliases().partial_cmp(other.aliases())?;
}
if cmp == Ordering::Equal && self != other {
return None;
}
debug_assert!(
cmp == Ordering::Equal || self != other,
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
The functions compare as equal, but they are not equal based on general properties that \
the PartialOrd implementation observes,",
self.name(),
other.name()
);
Some(cmp)
}
}
impl Eq for ScalarUDF {}
impl Hash for ScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.dyn_hash(state)
}
}
impl ScalarUDF {
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
Self::new_from_shared_impl(Arc::new(fun))
}
pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
Self { inner: fun }
}
pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
Arc::new(self.clone()),
args,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
#[deprecated(
since = "50.0.0",
note = "This method is unused and will be removed in a future release"
)]
pub fn display_name(&self, args: &[Expr]) -> Result<String> {
#[expect(deprecated)]
self.inner.display_name(args)
}
pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
pub fn simplify(
&self,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
#[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
#[expect(deprecated)]
self.inner.is_nullable(args, schema)
}
pub fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<PreimageResult> {
self.inner.preimage(args, lit_expr, info)
}
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[cfg(debug_assertions)]
let return_field = Arc::clone(&args.return_field);
let result = self.inner.invoke_with_args(args)?;
#[cfg(debug_assertions)]
{
let result_data_type = result.data_type();
let expected_type = return_field.data_type();
assert_or_internal_err!(
result_data_type == *expected_type,
"Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
self.name(),
result_data_type,
expected_type
);
}
Ok(result)
}
pub fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
self.inner.conditional_arguments(args)
}
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(inputs)
}
pub fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
pub fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
}
pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
self.inner.placement(args)
}
}
impl<F> From<F> for ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
#[derive(Debug, Clone)]
pub struct ScalarFunctionArgs {
pub args: Vec<ColumnarValue>,
pub arg_fields: Vec<FieldRef>,
pub number_rows: usize,
pub return_field: FieldRef,
pub config_options: Arc<ConfigOptions>,
}
impl ScalarFunctionArgs {
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
}
#[derive(Debug)]
pub struct ReturnFieldArgs<'a> {
pub arg_fields: &'a [FieldRef],
pub scalar_arguments: &'a [Option<&'a ScalarValue>],
}
pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn aliases(&self) -> &[String] {
&[]
}
#[deprecated(
since = "50.0.0",
note = "This method is unused and will be removed in a future release"
)]
fn display_name(&self, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(ToString::to_string).collect();
Ok(format!("{}({})", self.name(), names.join(",")))
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
Ok(format!(
"{}({})",
self.name(),
schema_name_from_exprs_comma_separated_without_space(args)?
))
}
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
None
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
Ok(Arc::new(Field::new(self.name(), return_type, true)))
}
#[deprecated(
since = "45.0.0",
note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
)]
fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
true
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
fn simplify(
&self,
args: Vec<Expr>,
_info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
Ok(ExprSimplifyResult::Original(args))
}
fn preimage(
&self,
_args: &[Expr],
_lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
Ok(PreimageResult::None)
}
fn short_circuits(&self) -> bool {
false
}
fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
if self.short_circuits() {
Some((vec![], args.iter().collect()))
} else {
None
}
}
fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
Interval::make_unbounded(&DataType::Null)
}
fn propagate_constraints(
&self,
_interval: &Interval,
_inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
Ok(Some(vec![]))
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
if !self.preserves_lex_ordering(inputs)? {
return Ok(SortProperties::Unordered);
}
let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
return Ok(SortProperties::Singleton);
};
if inputs
.iter()
.skip(1)
.all(|input| &input.sort_properties == first_order)
{
Ok(*first_order)
} else {
Ok(SortProperties::Unordered)
}
}
fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
Ok(false)
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
fn documentation(&self) -> Option<&Documentation> {
None
}
fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement {
ExpressionPlacement::KeepInPlace
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct AliasedScalarUDFImpl {
inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
aliases: Vec<String>,
}
impl AliasedScalarUDFImpl {
pub fn new(
inner: Arc<dyn ScalarUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self {
inner: inner.into(),
aliases,
}
}
}
#[warn(clippy::missing_trait_methods)] impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn display_name(&self, args: &[Expr]) -> Result<String> {
#[expect(deprecated)]
self.inner.display_name(args)
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
#[expect(deprecated)]
self.inner.is_nullable(args, schema)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
None
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn simplify(
&self,
args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
info: &SimplifyContext,
) -> Result<PreimageResult> {
self.inner.preimage(args, lit_expr, info)
}
fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
self.inner.conditional_arguments(args)
}
fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(input)
}
fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
self.inner.placement(args)
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr_common::signature::Volatility;
use std::hash::DefaultHasher;
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestScalarUDFImpl {
name: &'static str,
field: &'static str,
signature: Signature,
}
impl ScalarUDFImpl for TestScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
unimplemented!()
}
}
#[test]
fn test_partial_eq_hash_and_partial_ord() {
let f = test_func("foo", "a");
let f2 = test_func("foo", "a");
assert_eq!(f, f2);
assert_eq!(hash(&f), hash(&f2));
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
let b = test_func("foo", "b");
assert_ne!(f, b);
assert_ne!(hash(&f), hash(&b)); assert_eq!(f.partial_cmp(&b), None);
let o = test_func("other", "a");
assert_ne!(f, o);
assert_ne!(hash(&f), hash(&o)); assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
assert_ne!(b, o);
assert_ne!(hash(&b), hash(&o)); assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
}
fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
ScalarUDF::from(TestScalarUDFImpl {
name,
field: parameter,
signature: Signature::any(1, Volatility::Immutable),
})
}
fn hash<T: Hash>(value: &T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}
}