use crate::ScalarFunctionExpr;
use arrow::array::{make_array, MutableArrayData, RecordBatch};
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::Result;
use datafusion_common::{internal_err, not_impl_err};
use datafusion_expr::async_udf::AsyncScalarUDF;
use datafusion_expr::ScalarFunctionArgs;
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::any::Any;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone, Eq)]
pub struct AsyncFuncExpr {
pub name: String,
pub func: Arc<dyn PhysicalExpr>,
return_field: FieldRef,
}
impl Display for AsyncFuncExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "async_expr(name={}, expr={})", self.name, self.func)
}
}
impl PartialEq for AsyncFuncExpr {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.func == Arc::clone(&other.func)
}
}
impl Hash for AsyncFuncExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.func.as_ref().hash(state);
}
}
impl AsyncFuncExpr {
pub fn try_new(
name: impl Into<String>,
func: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Self> {
let Some(_) = func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
return internal_err!(
"unexpected function type, expected ScalarFunctionExpr, got: {:?}",
func
);
};
let return_field = func.return_field(schema)?;
Ok(Self {
name: name.into(),
func,
return_field,
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn field(&self, input_schema: &Schema) -> Result<Field> {
Ok(Field::new(
&self.name,
self.func.data_type(input_schema)?,
self.func.nullable(input_schema)?,
))
}
pub fn ideal_batch_size(&self) -> Result<Option<usize>> {
if let Some(expr) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() {
if let Some(udf) =
expr.fun().inner().as_any().downcast_ref::<AsyncScalarUDF>()
{
return Ok(udf.ideal_batch_size());
}
}
not_impl_err!("Can't get ideal_batch_size from {:?}", self.func)
}
pub async fn invoke_with_args(
&self,
batch: &RecordBatch,
config_options: Arc<ConfigOptions>,
) -> Result<ColumnarValue> {
let Some(scalar_function_expr) =
self.func.as_any().downcast_ref::<ScalarFunctionExpr>()
else {
return internal_err!(
"unexpected function type, expected ScalarFunctionExpr, got: {:?}",
self.func
);
};
let Some(async_udf) = scalar_function_expr
.fun()
.inner()
.as_any()
.downcast_ref::<AsyncScalarUDF>()
else {
return not_impl_err!(
"Don't know how to evaluate async function: {:?}",
scalar_function_expr
);
};
let arg_fields = scalar_function_expr
.args()
.iter()
.map(|e| e.return_field(batch.schema_ref()))
.collect::<Result<Vec<_>>>()?;
let mut result_batches = vec![];
if let Some(ideal_batch_size) = self.ideal_batch_size()? {
let mut remainder = batch.clone();
while remainder.num_rows() > 0 {
let size = if ideal_batch_size > remainder.num_rows() {
remainder.num_rows()
} else {
ideal_batch_size
};
let current_batch = remainder.slice(0, size); remainder = remainder.slice(size, remainder.num_rows() - size);
let args = scalar_function_expr
.args()
.iter()
.map(|e| e.evaluate(¤t_batch))
.collect::<Result<Vec<_>>>()?;
result_batches.push(
async_udf
.invoke_async_with_args(ScalarFunctionArgs {
args,
arg_fields: arg_fields.clone(),
number_rows: current_batch.num_rows(),
return_field: Arc::clone(&self.return_field),
config_options: Arc::clone(&config_options),
})
.await?,
);
}
} else {
let args = scalar_function_expr
.args()
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;
result_batches.push(
async_udf
.invoke_async_with_args(ScalarFunctionArgs {
args: args.to_vec(),
arg_fields,
number_rows: batch.num_rows(),
return_field: Arc::clone(&self.return_field),
config_options: Arc::clone(&config_options),
})
.await?,
);
}
let datas = ColumnarValue::values_to_arrays(&result_batches)?
.iter()
.map(|b| b.to_data())
.collect::<Vec<_>>();
let total_len = datas.iter().map(|d| d.len()).sum();
let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len);
datas.iter().enumerate().for_each(|(i, data)| {
mutable.extend(i, 0, data.len());
});
let array_ref = make_array(mutable.freeze());
Ok(ColumnarValue::Array(array_ref))
}
}
impl PhysicalExpr for AsyncFuncExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.func.data_type(input_schema)
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
self.func.nullable(input_schema)
}
fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
not_impl_err!("AsyncFuncExpr.evaluate")
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.func.children()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let new_func = Arc::clone(&self.func).with_new_children(children)?;
Ok(Arc::new(AsyncFuncExpr {
name: self.name.clone(),
func: new_func,
return_field: Arc::clone(&self.return_field),
}))
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.func)
}
}