1use std::sync::Arc;
19
20use pyo3::{prelude::*, types::PyTuple};
21
22use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
23use datafusion::arrow::datatypes::DataType;
24use datafusion::arrow::pyarrow::FromPyArrow;
25use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
26use datafusion::error::DataFusionError;
27use datafusion::logical_expr::function::ScalarFunctionImplementation;
28use datafusion::logical_expr::ScalarUDF;
29use datafusion::logical_expr::{create_udf, ColumnarValue};
30
31use crate::errors::to_datafusion_err;
32use crate::expr::PyExpr;
33use crate::utils::parse_volatility;
34
35fn pyarrow_function_to_rust(
37 func: PyObject,
38) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
39 move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
40 Python::with_gil(|py| {
41 let py_args = args
43 .iter()
44 .map(|arg| {
45 arg.into_data()
46 .to_pyarrow(py)
47 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))
48 })
49 .collect::<Result<Vec<_>, _>>()?;
50 let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
51
52 let value = func
54 .call(py, py_args, None)
55 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
56
57 let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
59 .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
60 Ok(make_array(array_data))
61 })
62 }
63}
64
65fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
69 let pyarrow_func = pyarrow_function_to_rust(func);
71
72 Arc::new(move |args: &[ColumnarValue]| {
74 let array_refs = ColumnarValue::values_to_arrays(args)?;
75 let array_result = pyarrow_func(&array_refs)?;
76 Ok(array_result.into())
77 })
78}
79
80#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)]
82#[derive(Debug, Clone)]
83pub struct PyScalarUDF {
84 pub(crate) function: ScalarUDF,
85}
86
87#[pymethods]
88impl PyScalarUDF {
89 #[new]
90 #[pyo3(signature=(name, func, input_types, return_type, volatility))]
91 fn new(
92 name: &str,
93 func: PyObject,
94 input_types: PyArrowType<Vec<DataType>>,
95 return_type: PyArrowType<DataType>,
96 volatility: &str,
97 ) -> PyResult<Self> {
98 let function = create_udf(
99 name,
100 input_types.0,
101 return_type.0,
102 parse_volatility(volatility)?,
103 to_scalar_function_impl(func),
104 );
105 Ok(Self { function })
106 }
107
108 #[pyo3(signature = (*args))]
110 fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
111 let args = args.iter().map(|e| e.expr.clone()).collect();
112 Ok(self.function.call(args).into())
113 }
114
115 fn __repr__(&self) -> PyResult<String> {
116 Ok(format!("ScalarUDF({})", self.function.name()))
117 }
118}