datafusion_python/
udtf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use pyo3::prelude::*;
19use std::sync::Arc;
20
21use crate::errors::{py_datafusion_err, to_datafusion_err};
22use crate::expr::PyExpr;
23use crate::table::PyTable;
24use crate::utils::validate_pycapsule;
25use datafusion::catalog::{TableFunctionImpl, TableProvider};
26use datafusion::error::Result as DataFusionResult;
27use datafusion::logical_expr::Expr;
28use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
29use pyo3::types::{PyCapsule, PyTuple};
30
31/// Represents a user defined table function
32#[pyclass(frozen, name = "TableFunction", module = "datafusion")]
33#[derive(Debug, Clone)]
34pub struct PyTableFunction {
35    pub(crate) name: String,
36    pub(crate) inner: PyTableFunctionInner,
37}
38
39// TODO: Implement pure python based user defined table functions
40#[derive(Debug, Clone)]
41pub(crate) enum PyTableFunctionInner {
42    PythonFunction(Arc<PyObject>),
43    FFIFunction(Arc<dyn TableFunctionImpl>),
44}
45
46#[pymethods]
47impl PyTableFunction {
48    #[new]
49    #[pyo3(signature=(name, func))]
50    pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
51        let inner = if func.hasattr("__datafusion_table_function__")? {
52            let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
53            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
54            validate_pycapsule(capsule, "datafusion_table_function")?;
55
56            let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
57            let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
58
59            PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
60        } else {
61            let py_obj = Arc::new(func.unbind());
62            PyTableFunctionInner::PythonFunction(py_obj)
63        };
64
65        Ok(Self {
66            name: name.to_string(),
67            inner,
68        })
69    }
70
71    #[pyo3(signature = (*args))]
72    pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTable> {
73        let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
74        let table_provider = self.call(&args).map_err(py_datafusion_err)?;
75
76        Ok(PyTable::from(table_provider))
77    }
78
79    fn __repr__(&self) -> PyResult<String> {
80        Ok(format!("TableUDF({})", self.name))
81    }
82}
83
84#[allow(clippy::result_large_err)]
85fn call_python_table_function(
86    func: &Arc<PyObject>,
87    args: &[Expr],
88) -> DataFusionResult<Arc<dyn TableProvider>> {
89    let args = args
90        .iter()
91        .map(|arg| PyExpr::from(arg.clone()))
92        .collect::<Vec<_>>();
93
94    // move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
95    Python::with_gil(|py| {
96        let py_args = PyTuple::new(py, args)?;
97        let provider_obj = func.call1(py, py_args)?;
98        let provider = provider_obj.bind(py);
99
100        Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider)?.table)
101    })
102    .map_err(to_datafusion_err)
103}
104
105impl TableFunctionImpl for PyTableFunction {
106    fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
107        match &self.inner {
108            PyTableFunctionInner::FFIFunction(func) => func.call(args),
109            PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
110        }
111    }
112}