datafusion_python/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use pyo3::{prelude::*, types::PyTuple};
21
22use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
23use datafusion::arrow::datatypes::DataType;
24use datafusion::arrow::pyarrow::FromPyArrow;
25use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
26use datafusion::error::DataFusionError;
27use datafusion::logical_expr::function::ScalarFunctionImplementation;
28use datafusion::logical_expr::ScalarUDF;
29use datafusion::logical_expr::{create_udf, ColumnarValue};
30
31use crate::errors::to_datafusion_err;
32use crate::expr::PyExpr;
33use crate::utils::parse_volatility;
34
35/// Create a Rust callable function from a python function that expects pyarrow arrays
36fn pyarrow_function_to_rust(
37    func: PyObject,
38) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
39    move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
40        Python::with_gil(|py| {
41            // 1. cast args to Pyarrow arrays
42            let py_args = args
43                .iter()
44                .map(|arg| {
45                    arg.into_data()
46                        .to_pyarrow(py)
47                        .map_err(|e| DataFusionError::Execution(format!("{e:?}")))
48                })
49                .collect::<Result<Vec<_>, _>>()?;
50            let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
51
52            // 2. call function
53            let value = func
54                .call(py, py_args, None)
55                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
56
57            // 3. cast to arrow::array::Array
58            let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
59                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
60            Ok(make_array(array_data))
61        })
62    }
63}
64
65/// Create a DataFusion's UDF implementation from a python function
66/// that expects pyarrow arrays. This is more efficient as it performs
67/// a zero-copy of the contents.
68fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
69    // Make the python function callable from rust
70    let pyarrow_func = pyarrow_function_to_rust(func);
71
72    // Convert input/output from datafusion ColumnarValue to arrow arrays
73    Arc::new(move |args: &[ColumnarValue]| {
74        let array_refs = ColumnarValue::values_to_arrays(args)?;
75        let array_result = pyarrow_func(&array_refs)?;
76        Ok(array_result.into())
77    })
78}
79
80/// Represents a PyScalarUDF
81#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)]
82#[derive(Debug, Clone)]
83pub struct PyScalarUDF {
84    pub(crate) function: ScalarUDF,
85}
86
87#[pymethods]
88impl PyScalarUDF {
89    #[new]
90    #[pyo3(signature=(name, func, input_types, return_type, volatility))]
91    fn new(
92        name: &str,
93        func: PyObject,
94        input_types: PyArrowType<Vec<DataType>>,
95        return_type: PyArrowType<DataType>,
96        volatility: &str,
97    ) -> PyResult<Self> {
98        let function = create_udf(
99            name,
100            input_types.0,
101            return_type.0,
102            parse_volatility(volatility)?,
103            to_scalar_function_impl(func),
104        );
105        Ok(Self { function })
106    }
107
108    /// creates a new PyExpr with the call of the udf
109    #[pyo3(signature = (*args))]
110    fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
111        let args = args.iter().map(|e| e.expr.clone()).collect();
112        Ok(self.function.call(args).into())
113    }
114
115    fn __repr__(&self) -> PyResult<String> {
116        Ok(format!("ScalarUDF({})", self.function.name()))
117    }
118}