1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::sync::Arc;
use pyo3::{prelude::*, types::PyTuple};
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowType};
use datafusion::error::DataFusionError;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion_expr::create_udf;
use datafusion_expr::function::ScalarFunctionImplementation;
use crate::expr::PyExpr;
use crate::utils::parse_volatility;
fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation {
make_scalar_function(
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
let py_args = args
.iter()
.map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py, py_args);
let value = func
.as_ref(py)
.call(py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let array_data = ArrayData::from_pyarrow(value).unwrap();
Ok(make_array(array_data))
})
},
)
}
#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
pub struct PyScalarUDF {
pub(crate) function: ScalarUDF,
}
#[pymethods]
impl PyScalarUDF {
#[new(name, func, input_types, return_type, volatility)]
fn new(
name: &str,
func: PyObject,
input_types: PyArrowType<Vec<DataType>>,
return_type: PyArrowType<DataType>,
volatility: &str,
) -> PyResult<Self> {
let function = create_udf(
name,
input_types.0,
Arc::new(return_type.0),
parse_volatility(volatility)?,
to_rust_function(func),
);
Ok(Self { function })
}
#[pyo3(signature = (*args))]
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
let args = args.iter().map(|e| e.expr.clone()).collect();
Ok(self.function.call(args).into())
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name))
}
}