use crate::physical_expr::down_cast_any_ref;
use crate::utils::expr_list_eq_strict_order;
use crate::PhysicalExpr;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::ColumnarValue;
use datafusion_expr::ScalarFunctionImplementation;
use std::any::Any;
use std::fmt::Debug;
use std::fmt::{self, Formatter};
use std::sync::Arc;
pub struct ScalarFunctionExpr {
fun: ScalarFunctionImplementation,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
}
impl Debug for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarFunctionExpr")
.field("fun", &"<FUNC>")
.field("name", &self.name)
.field("args", &self.args)
.field("return_type", &self.return_type)
.finish()
}
}
impl ScalarFunctionExpr {
pub fn new(
name: &str,
fun: ScalarFunctionImplementation,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: &DataType,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type: return_type.clone(),
}
}
pub fn fun(&self) -> &ScalarFunctionImplementation {
&self.fun
}
pub fn name(&self) -> &str {
&self.name
}
pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.args
}
pub fn return_type(&self) -> &DataType {
&self.return_type
}
}
impl fmt::Display for ScalarFunctionExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}({})",
self.name,
self.args
.iter()
.map(|e| format!("{e}"))
.collect::<Vec<String>>()
.join(", ")
)
}
}
impl PhysicalExpr for ScalarFunctionExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let inputs = match (self.args.len(), self.name.parse::<BuiltinScalarFunction>()) {
(0, Ok(scalar_fun)) if scalar_fun.supports_zero_argument() => {
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
_ => self
.args
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?,
};
let fun = self.fun.as_ref();
(fun)(&inputs)
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(ScalarFunctionExpr::new(
&self.name,
self.fun.clone(),
children,
self.return_type(),
)))
}
}
impl PartialEq<dyn Any> for ScalarFunctionExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& expr_list_eq_strict_order(&self.args, &x.args)
&& self.return_type == x.return_type
})
.unwrap_or(false)
}
}