Skip to main content

datafusion_python/
udtf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use datafusion::catalog::{TableFunctionImpl, TableProvider};
21use datafusion::error::Result as DataFusionResult;
22use datafusion::logical_expr::Expr;
23use datafusion_ffi::udtf::FFI_TableFunction;
24use pyo3::IntoPyObjectExt;
25use pyo3::exceptions::{PyImportError, PyTypeError};
26use pyo3::prelude::*;
27use pyo3::types::{PyCapsule, PyTuple, PyType};
28
29use crate::context::PySessionContext;
30use crate::errors::{py_datafusion_err, to_datafusion_err};
31use crate::expr::PyExpr;
32use crate::table::PyTable;
33use crate::utils::validate_pycapsule;
34
35/// Represents a user defined table function
36#[pyclass(frozen, name = "TableFunction", module = "datafusion")]
37#[derive(Debug, Clone)]
38pub struct PyTableFunction {
39    pub(crate) name: String,
40    pub(crate) inner: PyTableFunctionInner,
41}
42
43// TODO: Implement pure python based user defined table functions
44#[derive(Debug, Clone)]
45pub(crate) enum PyTableFunctionInner {
46    PythonFunction(Arc<Py<PyAny>>),
47    FFIFunction(Arc<dyn TableFunctionImpl>),
48}
49
50#[pymethods]
51impl PyTableFunction {
52    #[new]
53    #[pyo3(signature=(name, func, session))]
54    pub fn new(
55        name: &str,
56        func: Bound<'_, PyAny>,
57        session: Option<Bound<PyAny>>,
58    ) -> PyResult<Self> {
59        let inner = if func.hasattr("__datafusion_table_function__")? {
60            let py = func.py();
61            let session = match session {
62                Some(session) => session,
63                None => PySessionContext::global_ctx()?.into_bound_py_any(py)?,
64            };
65            let capsule = func
66                .getattr("__datafusion_table_function__")?
67                .call1((session,)).map_err(|err| {
68                if err.get_type(py).is(PyType::new::<PyTypeError>(py)) {
69                    PyImportError::new_err("Incompatible libraries. DataFusion 52.0.0 introduced an incompatible signature change for table functions. Either downgrade DataFusion or upgrade your function library.")
70                } else {
71                    err
72                }
73            })?;
74            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
75            validate_pycapsule(capsule, "datafusion_table_function")?;
76
77            let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
78            let foreign_func: Arc<dyn TableFunctionImpl> = ffi_func.to_owned().into();
79
80            PyTableFunctionInner::FFIFunction(foreign_func)
81        } else {
82            let py_obj = Arc::new(func.unbind());
83            PyTableFunctionInner::PythonFunction(py_obj)
84        };
85
86        Ok(Self {
87            name: name.to_string(),
88            inner,
89        })
90    }
91
92    #[pyo3(signature = (*args))]
93    pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTable> {
94        let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
95        let table_provider = self.call(&args).map_err(py_datafusion_err)?;
96
97        Ok(PyTable::from(table_provider))
98    }
99
100    fn __repr__(&self) -> PyResult<String> {
101        Ok(format!("TableUDF({})", self.name))
102    }
103}
104
105#[allow(clippy::result_large_err)]
106fn call_python_table_function(
107    func: &Arc<Py<PyAny>>,
108    args: &[Expr],
109) -> DataFusionResult<Arc<dyn TableProvider>> {
110    let args = args
111        .iter()
112        .map(|arg| PyExpr::from(arg.clone()))
113        .collect::<Vec<_>>();
114
115    // move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
116    Python::attach(|py| {
117        let py_args = PyTuple::new(py, args)?;
118        let provider_obj = func.call1(py, py_args)?;
119        let provider = provider_obj.bind(py).clone();
120
121        Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider, None)?.table)
122    })
123    .map_err(to_datafusion_err)
124}
125
126impl TableFunctionImpl for PyTableFunction {
127    fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
128        match &self.inner {
129            PyTableFunctionInner::FFIFunction(func) => func.call(args),
130            PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
131        }
132    }
133}