datafusion_python/
udtf.rs1use pyo3::prelude::*;
19use std::sync::Arc;
20
21use crate::errors::{py_datafusion_err, to_datafusion_err};
22use crate::expr::PyExpr;
23use crate::table::PyTable;
24use crate::utils::validate_pycapsule;
25use datafusion::catalog::{TableFunctionImpl, TableProvider};
26use datafusion::error::Result as DataFusionResult;
27use datafusion::logical_expr::Expr;
28use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
29use pyo3::types::{PyCapsule, PyTuple};
30
31#[pyclass(frozen, name = "TableFunction", module = "datafusion")]
33#[derive(Debug, Clone)]
34pub struct PyTableFunction {
35 pub(crate) name: String,
36 pub(crate) inner: PyTableFunctionInner,
37}
38
39#[derive(Debug, Clone)]
41pub(crate) enum PyTableFunctionInner {
42 PythonFunction(Arc<PyObject>),
43 FFIFunction(Arc<dyn TableFunctionImpl>),
44}
45
46#[pymethods]
47impl PyTableFunction {
48 #[new]
49 #[pyo3(signature=(name, func))]
50 pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
51 let inner = if func.hasattr("__datafusion_table_function__")? {
52 let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
53 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
54 validate_pycapsule(capsule, "datafusion_table_function")?;
55
56 let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
57 let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
58
59 PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
60 } else {
61 let py_obj = Arc::new(func.unbind());
62 PyTableFunctionInner::PythonFunction(py_obj)
63 };
64
65 Ok(Self {
66 name: name.to_string(),
67 inner,
68 })
69 }
70
71 #[pyo3(signature = (*args))]
72 pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTable> {
73 let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
74 let table_provider = self.call(&args).map_err(py_datafusion_err)?;
75
76 Ok(PyTable::from(table_provider))
77 }
78
79 fn __repr__(&self) -> PyResult<String> {
80 Ok(format!("TableUDF({})", self.name))
81 }
82}
83
84#[allow(clippy::result_large_err)]
85fn call_python_table_function(
86 func: &Arc<PyObject>,
87 args: &[Expr],
88) -> DataFusionResult<Arc<dyn TableProvider>> {
89 let args = args
90 .iter()
91 .map(|arg| PyExpr::from(arg.clone()))
92 .collect::<Vec<_>>();
93
94 Python::with_gil(|py| {
96 let py_args = PyTuple::new(py, args)?;
97 let provider_obj = func.call1(py, py_args)?;
98 let provider = provider_obj.bind(py);
99
100 Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider)?.table)
101 })
102 .map_err(to_datafusion_err)
103}
104
105impl TableFunctionImpl for PyTableFunction {
106 fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
107 match &self.inner {
108 PyTableFunctionInner::FFIFunction(func) => func.call(args),
109 PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
110 }
111 }
112}