use std::any::Any;
use std::hash::{Hash, Hasher};
use std::ptr::NonNull;
use std::sync::Arc;
use arrow::datatypes::{Field, FieldRef};
use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::array::{ArrayData, make_array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
use datafusion::common::internal_err;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_ffi::udf::FFI_ScalarUDF;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};
use crate::array::PyArrowArrayExportable;
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;
#[derive(Debug)]
struct PythonFunctionScalarUDF {
name: String,
func: Py<PyAny>,
signature: Signature,
return_field: FieldRef,
}
impl PythonFunctionScalarUDF {
fn new(
name: String,
func: Py<PyAny>,
input_fields: Vec<Field>,
return_field: Field,
volatility: Volatility,
) -> Self {
let input_types = input_fields.iter().map(|f| f.data_type().clone()).collect();
let signature = Signature::exact(input_types, volatility);
Self {
name,
func,
signature,
return_field: Arc::new(return_field),
}
}
}
impl Eq for PythonFunctionScalarUDF {}
impl PartialEq for PythonFunctionScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.signature == other.signature
&& self.return_field == other.return_field
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
}
}
impl Hash for PythonFunctionScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.return_field.hash(state);
Python::attach(|py| {
let py_hash = self.func.bind(py).hash().unwrap_or(0);
state.write_isize(py_hash);
});
}
}
impl ScalarUDFImpl for PythonFunctionScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
internal_err!(
"return_field should not be called when return_field_from_args is implemented."
)
}
fn return_field_from_args(
&self,
_args: ReturnFieldArgs,
) -> datafusion::common::Result<FieldRef> {
Ok(Arc::clone(&self.return_field))
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion::common::Result<ColumnarValue> {
let num_rows = args.number_rows;
Python::attach(|py| {
let py_args = args
.args
.into_iter()
.zip(args.arg_fields)
.map(|(arg, field)| {
let array = arg.to_array(num_rows)?;
PyArrowArrayExportable::new(array, field)
.to_pyarrow(py)
.map_err(to_datafusion_err)
})
.collect::<Result<Vec<_>, _>>()?;
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
let value = self
.func
.call(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
Ok(ColumnarValue::Array(make_array(array_data)))
})
}
}
#[pyclass(
from_py_object,
frozen,
name = "ScalarUDF",
module = "datafusion",
subclass
)]
#[derive(Debug, Clone)]
pub struct PyScalarUDF {
pub(crate) function: ScalarUDF,
}
#[pymethods]
impl PyScalarUDF {
#[new]
#[pyo3(signature=(name, func, input_types, return_type, volatility))]
fn new(
name: String,
func: Py<PyAny>,
input_types: PyArrowType<Vec<Field>>,
return_type: PyArrowType<Field>,
volatility: &str,
) -> PyResult<Self> {
let py_function = PythonFunctionScalarUDF::new(
name,
func,
input_types.0,
return_type.0,
parse_volatility(volatility)?,
);
let function = ScalarUDF::new_from_impl(py_function);
Ok(Self { function })
}
#[staticmethod]
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.hasattr("__datafusion_scalar_udf__")? {
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_ScalarUDF> = capsule
.pointer_checked(Some(c"datafusion_scalar_udf"))?
.cast();
let udf = unsafe { data.as_ref() };
let udf: Arc<dyn ScalarUDFImpl> = udf.into();
Ok(Self {
function: ScalarUDF::new_from_shared_impl(udf),
})
} else {
Err(crate::errors::PyDataFusionError::Common(
"__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(),
))
}
}
#[pyo3(signature = (*args))]
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
let args = args.iter().map(|e| e.expr.clone()).collect();
Ok(self.function.call(args).into())
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name()))
}
}