1use crate::dataset::Dataset;
19use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20use crate::table::PyTable;
21use crate::utils::{validate_pycapsule, wait_for_future};
22use async_trait::async_trait;
23use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
24use datafusion::common::DataFusionError;
25use datafusion::{
26 catalog::{CatalogProvider, SchemaProvider},
27 datasource::TableProvider,
28};
29use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
30use pyo3::exceptions::PyKeyError;
31use pyo3::prelude::*;
32use pyo3::types::PyCapsule;
33use pyo3::IntoPyObjectExt;
34use std::any::Any;
35use std::collections::HashSet;
36use std::sync::Arc;
37
38#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
39#[derive(Clone)]
40pub struct PyCatalog {
41 pub catalog: Arc<dyn CatalogProvider>,
42}
43
44#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
45#[derive(Clone)]
46pub struct PySchema {
47 pub schema: Arc<dyn SchemaProvider>,
48}
49
50impl From<Arc<dyn CatalogProvider>> for PyCatalog {
51 fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
52 Self { catalog }
53 }
54}
55
56impl From<Arc<dyn SchemaProvider>> for PySchema {
57 fn from(schema: Arc<dyn SchemaProvider>) -> Self {
58 Self { schema }
59 }
60}
61
62#[pymethods]
63impl PyCatalog {
64 #[new]
65 fn new(catalog: PyObject) -> Self {
66 let catalog_provider =
67 Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
68 catalog_provider.into()
69 }
70
71 #[staticmethod]
72 fn memory_catalog() -> Self {
73 let catalog_provider =
74 Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
75 catalog_provider.into()
76 }
77
78 fn schema_names(&self) -> HashSet<String> {
79 self.catalog.schema_names().into_iter().collect()
80 }
81
82 #[pyo3(signature = (name="public"))]
83 fn schema(&self, name: &str) -> PyResult<PyObject> {
84 let schema = self
85 .catalog
86 .schema(name)
87 .ok_or(PyKeyError::new_err(format!(
88 "Schema with name {name} doesn't exist."
89 )))?;
90
91 Python::with_gil(|py| {
92 match schema
93 .as_any()
94 .downcast_ref::<RustWrappedPySchemaProvider>()
95 {
96 Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
97 None => PySchema::from(schema).into_py_any(py),
98 }
99 })
100 }
101
102 fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
103 let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
104 let capsule = schema_provider
105 .getattr("__datafusion_schema_provider__")?
106 .call0()?;
107 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
108 validate_pycapsule(capsule, "datafusion_schema_provider")?;
109
110 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
111 let provider: ForeignSchemaProvider = provider.into();
112 Arc::new(provider) as Arc<dyn SchemaProvider>
113 } else {
114 match schema_provider.extract::<PySchema>() {
115 Ok(py_schema) => py_schema.schema,
116 Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
117 as Arc<dyn SchemaProvider>,
118 }
119 };
120
121 let _ = self
122 .catalog
123 .register_schema(name, provider)
124 .map_err(py_datafusion_err)?;
125
126 Ok(())
127 }
128
129 fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
130 let _ = self
131 .catalog
132 .deregister_schema(name, cascade)
133 .map_err(py_datafusion_err)?;
134
135 Ok(())
136 }
137
138 fn __repr__(&self) -> PyResult<String> {
139 let mut names: Vec<String> = self.schema_names().into_iter().collect();
140 names.sort();
141 Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
142 }
143}
144
145#[pymethods]
146impl PySchema {
147 #[new]
148 fn new(schema_provider: PyObject) -> Self {
149 let schema_provider =
150 Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
151 schema_provider.into()
152 }
153
154 #[staticmethod]
155 fn memory_schema() -> Self {
156 let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
157 schema_provider.into()
158 }
159
160 #[getter]
161 fn table_names(&self) -> HashSet<String> {
162 self.schema.table_names().into_iter().collect()
163 }
164
165 fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
166 if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
167 Ok(PyTable::from(table))
168 } else {
169 Err(PyDataFusionError::Common(format!(
170 "Table not found: {name}"
171 )))
172 }
173 }
174
175 fn __repr__(&self) -> PyResult<String> {
176 let mut names: Vec<String> = self.table_names().into_iter().collect();
177 names.sort();
178 Ok(format!("Schema(table_names=[{}])", names.join(";")))
179 }
180
181 fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
182 let table = PyTable::new(table_provider)?;
183
184 let _ = self
185 .schema
186 .register_table(name.to_string(), table.table)
187 .map_err(py_datafusion_err)?;
188
189 Ok(())
190 }
191
192 fn deregister_table(&self, name: &str) -> PyResult<()> {
193 let _ = self
194 .schema
195 .deregister_table(name)
196 .map_err(py_datafusion_err)?;
197
198 Ok(())
199 }
200}
201
202#[derive(Debug)]
203pub(crate) struct RustWrappedPySchemaProvider {
204 schema_provider: PyObject,
205 owner_name: Option<String>,
206}
207
208impl RustWrappedPySchemaProvider {
209 pub fn new(schema_provider: PyObject) -> Self {
210 let owner_name = Python::with_gil(|py| {
211 schema_provider
212 .bind(py)
213 .getattr("owner_name")
214 .ok()
215 .map(|name| name.to_string())
216 });
217
218 Self {
219 schema_provider,
220 owner_name,
221 }
222 }
223
224 fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
225 Python::with_gil(|py| {
226 let provider = self.schema_provider.bind(py);
227 let py_table_method = provider.getattr("table")?;
228
229 let py_table = py_table_method.call((name,), None)?;
230 if py_table.is_none() {
231 return Ok(None);
232 }
233
234 let table = PyTable::new(&py_table)?;
235
236 Ok(Some(table.table))
237 })
238 }
239}
240
241#[async_trait]
242impl SchemaProvider for RustWrappedPySchemaProvider {
243 fn owner_name(&self) -> Option<&str> {
244 self.owner_name.as_deref()
245 }
246
247 fn as_any(&self) -> &dyn Any {
248 self
249 }
250
251 fn table_names(&self) -> Vec<String> {
252 Python::with_gil(|py| {
253 let provider = self.schema_provider.bind(py);
254
255 provider
256 .getattr("table_names")
257 .and_then(|names| names.extract::<Vec<String>>())
258 .unwrap_or_else(|err| {
259 log::error!("Unable to get table_names: {err}");
260 Vec::default()
261 })
262 })
263 }
264
265 async fn table(
266 &self,
267 name: &str,
268 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
269 self.table_inner(name).map_err(to_datafusion_err)
270 }
271
272 fn register_table(
273 &self,
274 name: String,
275 table: Arc<dyn TableProvider>,
276 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
277 let py_table = PyTable::from(table);
278 Python::with_gil(|py| {
279 let provider = self.schema_provider.bind(py);
280 let _ = provider
281 .call_method1("register_table", (name, py_table))
282 .map_err(to_datafusion_err)?;
283 Ok(None)
287 })
288 }
289
290 fn deregister_table(
291 &self,
292 name: &str,
293 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
294 Python::with_gil(|py| {
295 let provider = self.schema_provider.bind(py);
296 let table = provider
297 .call_method1("deregister_table", (name,))
298 .map_err(to_datafusion_err)?;
299 if table.is_none() {
300 return Ok(None);
301 }
302
303 let dataset = match Dataset::new(&table, py) {
306 Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
307 Err(_) => None,
308 };
309
310 Ok(dataset)
311 })
312 }
313
314 fn table_exist(&self, name: &str) -> bool {
315 Python::with_gil(|py| {
316 let provider = self.schema_provider.bind(py);
317 provider
318 .call_method1("table_exist", (name,))
319 .and_then(|pyobj| pyobj.extract())
320 .unwrap_or(false)
321 })
322 }
323}
324
325#[derive(Debug)]
326pub(crate) struct RustWrappedPyCatalogProvider {
327 pub(crate) catalog_provider: PyObject,
328}
329
330impl RustWrappedPyCatalogProvider {
331 pub fn new(catalog_provider: PyObject) -> Self {
332 Self { catalog_provider }
333 }
334
335 fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
336 Python::with_gil(|py| {
337 let provider = self.catalog_provider.bind(py);
338
339 let py_schema = provider.call_method1("schema", (name,))?;
340 if py_schema.is_none() {
341 return Ok(None);
342 }
343
344 if py_schema.hasattr("__datafusion_schema_provider__")? {
345 let capsule = provider
346 .getattr("__datafusion_schema_provider__")?
347 .call0()?;
348 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
349 validate_pycapsule(capsule, "datafusion_schema_provider")?;
350
351 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
352 let provider: ForeignSchemaProvider = provider.into();
353
354 Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
355 } else {
356 if let Ok(inner_schema) = py_schema.getattr("schema") {
357 if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
358 return Ok(Some(inner_schema.schema));
359 }
360 }
361 match py_schema.extract::<PySchema>() {
362 Ok(inner_schema) => Ok(Some(inner_schema.schema)),
363 Err(_) => {
364 let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
365
366 Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
367 }
368 }
369 }
370 })
371 }
372}
373
374#[async_trait]
375impl CatalogProvider for RustWrappedPyCatalogProvider {
376 fn as_any(&self) -> &dyn Any {
377 self
378 }
379
380 fn schema_names(&self) -> Vec<String> {
381 Python::with_gil(|py| {
382 let provider = self.catalog_provider.bind(py);
383 provider
384 .getattr("schema_names")
385 .and_then(|names| names.extract::<Vec<String>>())
386 .unwrap_or_else(|err| {
387 log::error!("Unable to get schema_names: {err}");
388 Vec::default()
389 })
390 })
391 }
392
393 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
394 self.schema_inner(name).unwrap_or_else(|err| {
395 log::error!("CatalogProvider schema returned error: {err}");
396 None
397 })
398 }
399
400 fn register_schema(
401 &self,
402 name: &str,
403 schema: Arc<dyn SchemaProvider>,
404 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
405 Python::with_gil(|py| {
408 let py_schema = match schema
409 .as_any()
410 .downcast_ref::<RustWrappedPySchemaProvider>()
411 {
412 Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
413 None => &PySchema::from(schema)
414 .into_py_any(py)
415 .map_err(to_datafusion_err)?,
416 };
417
418 let provider = self.catalog_provider.bind(py);
419 let schema = provider
420 .call_method1("register_schema", (name, py_schema))
421 .map_err(to_datafusion_err)?;
422 if schema.is_none() {
423 return Ok(None);
424 }
425
426 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
427 as Arc<dyn SchemaProvider>;
428
429 Ok(Some(schema))
430 })
431 }
432
433 fn deregister_schema(
434 &self,
435 name: &str,
436 cascade: bool,
437 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
438 Python::with_gil(|py| {
439 let provider = self.catalog_provider.bind(py);
440 let schema = provider
441 .call_method1("deregister_schema", (name, cascade))
442 .map_err(to_datafusion_err)?;
443 if schema.is_none() {
444 return Ok(None);
445 }
446
447 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
448 as Arc<dyn SchemaProvider>;
449
450 Ok(Some(schema))
451 })
452 }
453}
454
455pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
456 m.add_class::<PyCatalog>()?;
457 m.add_class::<PySchema>()?;
458 m.add_class::<PyTable>()?;
459
460 Ok(())
461}