datafusion_python/
udtf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use pyo3::prelude::*;
19use std::sync::Arc;
20
21use crate::dataframe::PyTableProvider;
22use crate::errors::{py_datafusion_err, to_datafusion_err};
23use crate::expr::PyExpr;
24use crate::utils::validate_pycapsule;
25use datafusion::catalog::{TableFunctionImpl, TableProvider};
26use datafusion::error::Result as DataFusionResult;
27use datafusion::logical_expr::Expr;
28use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
29use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
30use pyo3::exceptions::PyNotImplementedError;
31use pyo3::types::{PyCapsule, PyTuple};
32
33/// Represents a user defined table function
34#[pyclass(name = "TableFunction", module = "datafusion")]
35#[derive(Debug, Clone)]
36pub struct PyTableFunction {
37    pub(crate) name: String,
38    pub(crate) inner: PyTableFunctionInner,
39}
40
41// TODO: Implement pure python based user defined table functions
42#[derive(Debug, Clone)]
43pub(crate) enum PyTableFunctionInner {
44    PythonFunction(Arc<PyObject>),
45    FFIFunction(Arc<dyn TableFunctionImpl>),
46}
47
48#[pymethods]
49impl PyTableFunction {
50    #[new]
51    #[pyo3(signature=(name, func))]
52    pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
53        let inner = if func.hasattr("__datafusion_table_function__")? {
54            let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
55            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
56            validate_pycapsule(capsule, "datafusion_table_function")?;
57
58            let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
59            let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
60
61            PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
62        } else {
63            let py_obj = Arc::new(func.unbind());
64            PyTableFunctionInner::PythonFunction(py_obj)
65        };
66
67        Ok(Self {
68            name: name.to_string(),
69            inner,
70        })
71    }
72
73    #[pyo3(signature = (*args))]
74    pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTableProvider> {
75        let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
76        let table_provider = self.call(&args).map_err(py_datafusion_err)?;
77
78        Ok(PyTableProvider::new(table_provider))
79    }
80
81    fn __repr__(&self) -> PyResult<String> {
82        Ok(format!("TableUDF({})", self.name))
83    }
84}
85
86#[allow(clippy::result_large_err)]
87fn call_python_table_function(
88    func: &Arc<PyObject>,
89    args: &[Expr],
90) -> DataFusionResult<Arc<dyn TableProvider>> {
91    let args = args
92        .iter()
93        .map(|arg| PyExpr::from(arg.clone()))
94        .collect::<Vec<_>>();
95
96    // move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
97    Python::with_gil(|py| {
98        let py_args = PyTuple::new(py, args)?;
99        let provider_obj = func.call1(py, py_args)?;
100        let provider = provider_obj.bind(py);
101
102        if provider.hasattr("__datafusion_table_provider__")? {
103            let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
104            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
105            validate_pycapsule(capsule, "datafusion_table_provider")?;
106
107            let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
108            let provider: ForeignTableProvider = provider.into();
109
110            Ok(Arc::new(provider) as Arc<dyn TableProvider>)
111        } else {
112            Err(PyNotImplementedError::new_err(
113                "__datafusion_table_provider__ does not exist on Table Provider object.",
114            ))
115        }
116    })
117    .map_err(to_datafusion_err)
118}
119
120impl TableFunctionImpl for PyTableFunction {
121    fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
122        match &self.inner {
123            PyTableFunctionInner::FFIFunction(func) => func.call(args),
124            PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
125        }
126    }
127}