use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::expressions::Literal;
use crate::PhysicalExpr;
use arrow::array::{Array, RecordBatch};
use arrow::datatypes::{DataType, FieldRef, Schema};
use datafusion_common::config::{ConfigEntry, ConfigOptions};
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
use datafusion_expr::{
expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
Volatility,
};
pub struct ScalarFunctionExpr {
fun: Arc<ScalarUDF>,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_field: FieldRef,
config_options: Arc<ConfigOptions>,
}
impl Debug for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarFunctionExpr")
.field("fun", &"<FUNC>")
.field("name", &self.name)
.field("args", &self.args)
.field("return_field", &self.return_field)
.finish()
}
}
impl ScalarFunctionExpr {
pub fn new(
name: &str,
fun: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
return_field: FieldRef,
config_options: Arc<ConfigOptions>,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_field,
config_options,
}
}
pub fn try_new(
fun: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
schema: &Schema,
config_options: Arc<ConfigOptions>,
) -> Result<Self> {
let name = fun.name().to_string();
let arg_fields = args
.iter()
.map(|e| e.return_field(schema))
.collect::<Result<Vec<_>>>()?;
let arg_types = arg_fields
.iter()
.map(|f| f.data_type().clone())
.collect::<Vec<_>>();
data_types_with_scalar_udf(&arg_types, &fun)?;
let arguments = args
.iter()
.map(|e| {
e.as_any()
.downcast_ref::<Literal>()
.map(|literal| literal.value())
})
.collect::<Vec<_>>();
let ret_args = ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &arguments,
};
let return_field = fun.return_field_from_args(ret_args)?;
Ok(Self {
fun,
name,
args,
return_field,
config_options,
})
}
pub fn fun(&self) -> &ScalarUDF {
&self.fun
}
pub fn name(&self) -> &str {
&self.name
}
pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.args
}
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.return_field = self
.return_field
.as_ref()
.clone()
.with_nullable(nullable)
.into();
self
}
pub fn nullable(&self) -> bool {
self.return_field.is_nullable()
}
pub fn config_options(&self) -> &ConfigOptions {
&self.config_options
}
pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
where
T: 'static,
{
match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
Some(scalar_expr)
if scalar_expr
.fun()
.inner()
.as_any()
.downcast_ref::<T>()
.is_some() =>
{
Some(scalar_expr)
}
_ => None,
}
}
}
impl fmt::Display for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
}
}
impl PartialEq for ScalarFunctionExpr {
fn eq(&self, o: &Self) -> bool {
if std::ptr::eq(self, o) {
return true;
}
let Self {
fun,
name,
args,
return_field,
config_options,
} = self;
fun.eq(&o.fun)
&& name.eq(&o.name)
&& args.eq(&o.args)
&& return_field.eq(&o.return_field)
&& (Arc::ptr_eq(config_options, &o.config_options)
|| sorted_config_entries(config_options)
== sorted_config_entries(&o.config_options))
}
}
impl Eq for ScalarFunctionExpr {}
impl Hash for ScalarFunctionExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
fun,
name,
args,
return_field,
config_options: _, } = self;
fun.hash(state);
name.hash(state);
args.hash(state);
return_field.hash(state);
}
}
fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
let mut entries = config_options.entries();
entries.sort_by(|l, r| l.key.cmp(&r.key));
entries
}
impl PhysicalExpr for ScalarFunctionExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.return_field.data_type().clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(self.return_field.is_nullable())
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let args = self
.args
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;
let arg_fields = self
.args
.iter()
.map(|e| e.return_field(batch.schema_ref()))
.collect::<Result<Vec<_>>>()?;
let input_empty = args.is_empty();
let input_all_scalar = args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
let output = self.fun.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields,
number_rows: batch.num_rows(),
return_field: Arc::clone(&self.return_field),
config_options: Arc::clone(&self.config_options),
})?;
if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
let preserve_scalar =
array.len() == 1 && !input_empty && input_all_scalar;
return if preserve_scalar {
ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
} else {
internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
self.name, batch.num_rows(), array.len())
};
}
}
Ok(output)
}
fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
Ok(Arc::clone(&self.return_field))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.args.iter().collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(ScalarFunctionExpr::new(
&self.name,
Arc::clone(&self.fun),
children,
Arc::clone(&self.return_field),
Arc::clone(&self.config_options),
)))
}
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
self.fun.evaluate_bounds(children)
}
fn propagate_constraints(
&self,
interval: &Interval,
children: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.fun.propagate_constraints(interval, children)
}
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
let sort_properties = self.fun.output_ordering(children)?;
let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
let children_range = children
.iter()
.map(|props| &props.range)
.collect::<Vec<_>>();
let range = self.fun().evaluate_bounds(&children_range)?;
Ok(ExprProperties {
sort_properties,
range,
preserves_lex_ordering,
})
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}(", self.name)?;
for (i, expr) in self.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
expr.fmt_sql(f)?;
}
write!(f, ")")
}
fn is_volatile_node(&self) -> bool {
self.fun.signature().volatility == Volatility::Volatile
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::Column;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
use datafusion_physical_expr_common::physical_expr::is_volatile;
use std::any::Any;
#[derive(Debug, PartialEq, Eq, Hash)]
struct MockScalarUDF {
signature: Signature,
}
impl ScalarUDFImpl for MockScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"mock_function"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
}
}
#[test]
fn test_scalar_function_volatile_node() {
let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
signature: Signature::uniform(
1,
vec![DataType::Float32],
Volatility::Volatile,
),
}));
let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
}));
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
let config_options = Arc::new(ConfigOptions::new());
let volatile_expr = ScalarFunctionExpr::try_new(
volatile_udf,
args.clone(),
&schema,
Arc::clone(&config_options),
)
.unwrap();
assert!(volatile_expr.is_volatile_node());
let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
assert!(is_volatile(&volatile_arc));
let stable_expr =
ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
.unwrap();
assert!(!stable_expr.is_volatile_node());
let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
assert!(!is_volatile(&stable_arc));
}
}