use crate::expr::schema_name_from_exprs_comma_separated_without_space;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
use crate::sort_properties::{ExprProperties, SortProperties};
use crate::{ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
use datafusion_expr_common::interval_arithmetic::Interval;
use std::any::Any;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ScalarUDF {
inner: Arc<dyn ScalarUDFImpl>,
}
impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.equals(other.inner.as_ref())
}
}
impl PartialOrd for ScalarUDF {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.name().partial_cmp(other.name()) {
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
cmp => cmp,
}
}
}
impl Eq for ScalarUDF {}
impl Hash for ScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}
impl ScalarUDF {
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
Self::new_from_shared_impl(Arc::new(fun))
}
pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
Self { inner: fun }
}
pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
Arc::new(self.clone()),
args,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn display_name(&self, args: &[Expr]) -> Result<String> {
self.inner.display_name(args)
}
pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
pub fn simplify(
&self,
args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
#[allow(deprecated)]
pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
self.inner.is_nullable(args, schema)
}
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(inputs)
}
pub fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
pub fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
impl<F> From<F> for ScalarUDF
where
F: ScalarUDFImpl + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
pub struct ScalarFunctionArgs {
pub args: Vec<ColumnarValue>,
pub arg_fields: Vec<FieldRef>,
pub number_rows: usize,
pub return_field: FieldRef,
}
impl ScalarFunctionArgs {
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
}
#[derive(Debug)]
pub struct ReturnFieldArgs<'a> {
pub arg_fields: &'a [FieldRef],
pub scalar_arguments: &'a [Option<&'a ScalarValue>],
}
pub trait ScalarUDFImpl: Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn display_name(&self, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(ToString::to_string).collect();
Ok(format!("{}({})", self.name(), names.join(",")))
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
Ok(format!(
"{}({})",
self.name(),
schema_name_from_exprs_comma_separated_without_space(args)?
))
}
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
Ok(Arc::new(Field::new(self.name(), return_type, true)))
}
#[deprecated(
since = "45.0.0",
note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
)]
fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
true
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
fn aliases(&self) -> &[String] {
&[]
}
fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
Ok(ExprSimplifyResult::Original(args))
}
fn short_circuits(&self) -> bool {
false
}
fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
Interval::make_unbounded(&DataType::Null)
}
fn propagate_constraints(
&self,
_interval: &Interval,
_inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
Ok(Some(vec![]))
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
if !self.preserves_lex_ordering(inputs)? {
return Ok(SortProperties::Unordered);
}
let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
return Ok(SortProperties::Singleton);
};
if inputs
.iter()
.skip(1)
.all(|input| &input.sort_properties == first_order)
{
Ok(*first_order)
} else {
Ok(SortProperties::Unordered)
}
}
fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
Ok(false)
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
fn documentation(&self) -> Option<&Documentation> {
None
}
}
#[derive(Debug)]
struct AliasedScalarUDFImpl {
inner: Arc<dyn ScalarUDFImpl>,
aliases: Vec<String>,
}
impl AliasedScalarUDFImpl {
pub fn new(
inner: Arc<dyn ScalarUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self { inner, aliases }
}
}
impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn display_name(&self, args: &[Expr]) -> Result<String> {
self.inner.display_name(args)
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn simplify(
&self,
args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}
fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(input)
}
fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}
fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
self.inner.preserves_lex_ordering(inputs)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedScalarUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
pub mod scalar_doc_sections {
use crate::DocSection;
pub fn doc_sections() -> Vec<DocSection> {
vec![
DOC_SECTION_MATH,
DOC_SECTION_CONDITIONAL,
DOC_SECTION_STRING,
DOC_SECTION_BINARY_STRING,
DOC_SECTION_REGEX,
DOC_SECTION_DATETIME,
DOC_SECTION_ARRAY,
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
pub const fn doc_sections_const() -> &'static [DocSection] {
&[
DOC_SECTION_MATH,
DOC_SECTION_CONDITIONAL,
DOC_SECTION_STRING,
DOC_SECTION_BINARY_STRING,
DOC_SECTION_REGEX,
DOC_SECTION_DATETIME,
DOC_SECTION_ARRAY,
DOC_SECTION_STRUCT,
DOC_SECTION_MAP,
DOC_SECTION_HASHING,
DOC_SECTION_UNION,
DOC_SECTION_OTHER,
]
}
pub const DOC_SECTION_MATH: DocSection = DocSection {
include: true,
label: "Math Functions",
description: None,
};
pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection {
include: true,
label: "Conditional Functions",
description: None,
};
pub const DOC_SECTION_STRING: DocSection = DocSection {
include: true,
label: "String Functions",
description: None,
};
pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection {
include: true,
label: "Binary String Functions",
description: None,
};
pub const DOC_SECTION_REGEX: DocSection = DocSection {
include: true,
label: "Regular Expression Functions",
description: Some(
r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions)
regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax)
(minus support for several features including look-around and backreferences).
The following regular expression functions are supported:"#,
),
};
pub const DOC_SECTION_DATETIME: DocSection = DocSection {
include: true,
label: "Time and Date Functions",
description: None,
};
pub const DOC_SECTION_ARRAY: DocSection = DocSection {
include: true,
label: "Array Functions",
description: None,
};
pub const DOC_SECTION_STRUCT: DocSection = DocSection {
include: true,
label: "Struct Functions",
description: None,
};
pub const DOC_SECTION_MAP: DocSection = DocSection {
include: true,
label: "Map Functions",
description: None,
};
pub const DOC_SECTION_HASHING: DocSection = DocSection {
include: true,
label: "Hashing Functions",
description: None,
};
pub const DOC_SECTION_OTHER: DocSection = DocSection {
include: true,
label: "Other Functions",
description: None,
};
pub const DOC_SECTION_UNION: DocSection = DocSection {
include: true,
label: "Union Functions",
description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"),
};
}