datafusion_python/
udtf.rs1use std::sync::Arc;
19
20use datafusion::catalog::{TableFunctionImpl, TableProvider};
21use datafusion::error::Result as DataFusionResult;
22use datafusion::logical_expr::Expr;
23use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
24use pyo3::prelude::*;
25use pyo3::types::{PyCapsule, PyTuple};
26
27use crate::errors::{py_datafusion_err, to_datafusion_err};
28use crate::expr::PyExpr;
29use crate::table::PyTable;
30use crate::utils::validate_pycapsule;
31
32#[pyclass(frozen, name = "TableFunction", module = "datafusion")]
34#[derive(Debug, Clone)]
35pub struct PyTableFunction {
36 pub(crate) name: String,
37 pub(crate) inner: PyTableFunctionInner,
38}
39
40#[derive(Debug, Clone)]
42pub(crate) enum PyTableFunctionInner {
43 PythonFunction(Arc<Py<PyAny>>),
44 FFIFunction(Arc<dyn TableFunctionImpl>),
45}
46
47#[pymethods]
48impl PyTableFunction {
49 #[new]
50 #[pyo3(signature=(name, func))]
51 pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
52 let inner = if func.hasattr("__datafusion_table_function__")? {
53 let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
54 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
55 validate_pycapsule(capsule, "datafusion_table_function")?;
56
57 let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
58 let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
59
60 PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
61 } else {
62 let py_obj = Arc::new(func.unbind());
63 PyTableFunctionInner::PythonFunction(py_obj)
64 };
65
66 Ok(Self {
67 name: name.to_string(),
68 inner,
69 })
70 }
71
72 #[pyo3(signature = (*args))]
73 pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTable> {
74 let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
75 let table_provider = self.call(&args).map_err(py_datafusion_err)?;
76
77 Ok(PyTable::from(table_provider))
78 }
79
80 fn __repr__(&self) -> PyResult<String> {
81 Ok(format!("TableUDF({})", self.name))
82 }
83}
84
85#[allow(clippy::result_large_err)]
86fn call_python_table_function(
87 func: &Arc<Py<PyAny>>,
88 args: &[Expr],
89) -> DataFusionResult<Arc<dyn TableProvider>> {
90 let args = args
91 .iter()
92 .map(|arg| PyExpr::from(arg.clone()))
93 .collect::<Vec<_>>();
94
95 Python::attach(|py| {
97 let py_args = PyTuple::new(py, args)?;
98 let provider_obj = func.call1(py, py_args)?;
99 let provider = provider_obj.bind(py);
100
101 Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider)?.table)
102 })
103 .map_err(to_datafusion_err)
104}
105
106impl TableFunctionImpl for PyTableFunction {
107 fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
108 match &self.inner {
109 PyTableFunctionInner::FFIFunction(func) => func.call(args),
110 PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
111 }
112 }
113}