datafusion_python/
udtf.rs1use pyo3::prelude::*;
19use std::sync::Arc;
20
21use crate::dataframe::PyTableProvider;
22use crate::errors::{py_datafusion_err, to_datafusion_err};
23use crate::expr::PyExpr;
24use crate::utils::validate_pycapsule;
25use datafusion::catalog::{TableFunctionImpl, TableProvider};
26use datafusion::error::Result as DataFusionResult;
27use datafusion::logical_expr::Expr;
28use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
29use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
30use pyo3::exceptions::PyNotImplementedError;
31use pyo3::types::{PyCapsule, PyTuple};
32
33#[pyclass(name = "TableFunction", module = "datafusion")]
35#[derive(Debug, Clone)]
36pub struct PyTableFunction {
37 pub(crate) name: String,
38 pub(crate) inner: PyTableFunctionInner,
39}
40
41#[derive(Debug, Clone)]
43pub(crate) enum PyTableFunctionInner {
44 PythonFunction(Arc<PyObject>),
45 FFIFunction(Arc<dyn TableFunctionImpl>),
46}
47
48#[pymethods]
49impl PyTableFunction {
50 #[new]
51 #[pyo3(signature=(name, func))]
52 pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
53 let inner = if func.hasattr("__datafusion_table_function__")? {
54 let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
55 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
56 validate_pycapsule(capsule, "datafusion_table_function")?;
57
58 let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
59 let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
60
61 PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
62 } else {
63 let py_obj = Arc::new(func.unbind());
64 PyTableFunctionInner::PythonFunction(py_obj)
65 };
66
67 Ok(Self {
68 name: name.to_string(),
69 inner,
70 })
71 }
72
73 #[pyo3(signature = (*args))]
74 pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTableProvider> {
75 let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
76 let table_provider = self.call(&args).map_err(py_datafusion_err)?;
77
78 Ok(PyTableProvider::new(table_provider))
79 }
80
81 fn __repr__(&self) -> PyResult<String> {
82 Ok(format!("TableUDF({})", self.name))
83 }
84}
85
86#[allow(clippy::result_large_err)]
87fn call_python_table_function(
88 func: &Arc<PyObject>,
89 args: &[Expr],
90) -> DataFusionResult<Arc<dyn TableProvider>> {
91 let args = args
92 .iter()
93 .map(|arg| PyExpr::from(arg.clone()))
94 .collect::<Vec<_>>();
95
96 Python::with_gil(|py| {
98 let py_args = PyTuple::new(py, args)?;
99 let provider_obj = func.call1(py, py_args)?;
100 let provider = provider_obj.bind(py);
101
102 if provider.hasattr("__datafusion_table_provider__")? {
103 let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
104 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
105 validate_pycapsule(capsule, "datafusion_table_provider")?;
106
107 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
108 let provider: ForeignTableProvider = provider.into();
109
110 Ok(Arc::new(provider) as Arc<dyn TableProvider>)
111 } else {
112 Err(PyNotImplementedError::new_err(
113 "__datafusion_table_provider__ does not exist on Table Provider object.",
114 ))
115 }
116 })
117 .map_err(to_datafusion_err)
118}
119
120impl TableFunctionImpl for PyTableFunction {
121 fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
122 match &self.inner {
123 PyTableFunctionInner::FFIFunction(func) => func.call(args),
124 PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
125 }
126 }
127}