1use crate::dataset::Dataset;
19use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20use crate::utils::{validate_pycapsule, wait_for_future};
21use async_trait::async_trait;
22use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
23use datafusion::common::DataFusionError;
24use datafusion::{
25 arrow::pyarrow::ToPyArrow,
26 catalog::{CatalogProvider, SchemaProvider},
27 datasource::{TableProvider, TableType},
28};
29use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
30use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
31use pyo3::exceptions::PyKeyError;
32use pyo3::prelude::*;
33use pyo3::types::PyCapsule;
34use pyo3::IntoPyObjectExt;
35use std::any::Any;
36use std::collections::HashSet;
37use std::sync::Arc;
38
39#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)]
40#[derive(Clone)]
41pub struct PyCatalog {
42 pub catalog: Arc<dyn CatalogProvider>,
43}
44
45#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)]
46#[derive(Clone)]
47pub struct PySchema {
48 pub schema: Arc<dyn SchemaProvider>,
49}
50
51#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
52#[derive(Clone)]
53pub struct PyTable {
54 pub table: Arc<dyn TableProvider>,
55}
56
57impl From<Arc<dyn CatalogProvider>> for PyCatalog {
58 fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
59 Self { catalog }
60 }
61}
62
63impl From<Arc<dyn SchemaProvider>> for PySchema {
64 fn from(schema: Arc<dyn SchemaProvider>) -> Self {
65 Self { schema }
66 }
67}
68
69impl PyTable {
70 pub fn new(table: Arc<dyn TableProvider>) -> Self {
71 Self { table }
72 }
73
74 pub fn table(&self) -> Arc<dyn TableProvider> {
75 self.table.clone()
76 }
77}
78
79#[pymethods]
80impl PyCatalog {
81 #[new]
82 fn new(catalog: PyObject) -> Self {
83 let catalog_provider =
84 Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
85 catalog_provider.into()
86 }
87
88 #[staticmethod]
89 fn memory_catalog() -> Self {
90 let catalog_provider =
91 Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
92 catalog_provider.into()
93 }
94
95 fn schema_names(&self) -> HashSet<String> {
96 self.catalog.schema_names().into_iter().collect()
97 }
98
99 #[pyo3(signature = (name="public"))]
100 fn schema(&self, name: &str) -> PyResult<PyObject> {
101 let schema = self
102 .catalog
103 .schema(name)
104 .ok_or(PyKeyError::new_err(format!(
105 "Schema with name {name} doesn't exist."
106 )))?;
107
108 Python::with_gil(|py| {
109 match schema
110 .as_any()
111 .downcast_ref::<RustWrappedPySchemaProvider>()
112 {
113 Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
114 None => PySchema::from(schema).into_py_any(py),
115 }
116 })
117 }
118
119 fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
120 let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
121 let capsule = schema_provider
122 .getattr("__datafusion_schema_provider__")?
123 .call0()?;
124 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
125 validate_pycapsule(capsule, "datafusion_schema_provider")?;
126
127 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
128 let provider: ForeignSchemaProvider = provider.into();
129 Arc::new(provider) as Arc<dyn SchemaProvider>
130 } else {
131 match schema_provider.extract::<PySchema>() {
132 Ok(py_schema) => py_schema.schema,
133 Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
134 as Arc<dyn SchemaProvider>,
135 }
136 };
137
138 let _ = self
139 .catalog
140 .register_schema(name, provider)
141 .map_err(py_datafusion_err)?;
142
143 Ok(())
144 }
145
146 fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
147 let _ = self
148 .catalog
149 .deregister_schema(name, cascade)
150 .map_err(py_datafusion_err)?;
151
152 Ok(())
153 }
154
155 fn __repr__(&self) -> PyResult<String> {
156 let mut names: Vec<String> = self.schema_names().into_iter().collect();
157 names.sort();
158 Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
159 }
160}
161
162#[pymethods]
163impl PySchema {
164 #[new]
165 fn new(schema_provider: PyObject) -> Self {
166 let schema_provider =
167 Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
168 schema_provider.into()
169 }
170
171 #[staticmethod]
172 fn memory_schema() -> Self {
173 let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
174 schema_provider.into()
175 }
176
177 #[getter]
178 fn table_names(&self) -> HashSet<String> {
179 self.schema.table_names().into_iter().collect()
180 }
181
182 fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
183 if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
184 Ok(PyTable::new(table))
185 } else {
186 Err(PyDataFusionError::Common(format!(
187 "Table not found: {name}"
188 )))
189 }
190 }
191
192 fn __repr__(&self) -> PyResult<String> {
193 let mut names: Vec<String> = self.table_names().into_iter().collect();
194 names.sort();
195 Ok(format!("Schema(table_names=[{}])", names.join(";")))
196 }
197
198 fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
199 let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
200 let capsule = table_provider
201 .getattr("__datafusion_table_provider__")?
202 .call0()?;
203 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
204 validate_pycapsule(capsule, "datafusion_table_provider")?;
205
206 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
207 let provider: ForeignTableProvider = provider.into();
208 Arc::new(provider) as Arc<dyn TableProvider>
209 } else {
210 match table_provider.extract::<PyTable>() {
211 Ok(py_table) => py_table.table,
212 Err(_) => {
213 let py = table_provider.py();
214 let provider = Dataset::new(&table_provider, py)?;
215 Arc::new(provider) as Arc<dyn TableProvider>
216 }
217 }
218 };
219
220 let _ = self
221 .schema
222 .register_table(name.to_string(), provider)
223 .map_err(py_datafusion_err)?;
224
225 Ok(())
226 }
227
228 fn deregister_table(&self, name: &str) -> PyResult<()> {
229 let _ = self
230 .schema
231 .deregister_table(name)
232 .map_err(py_datafusion_err)?;
233
234 Ok(())
235 }
236}
237
238#[pymethods]
239impl PyTable {
240 #[getter]
242 fn schema(&self, py: Python) -> PyResult<PyObject> {
243 self.table.schema().to_pyarrow(py)
244 }
245
246 #[staticmethod]
247 fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult<Self> {
248 let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?)
249 as Arc<dyn TableProvider>;
250
251 Ok(Self::new(ds))
252 }
253
254 #[getter]
256 fn kind(&self) -> &str {
257 match self.table.table_type() {
258 TableType::Base => "physical",
259 TableType::View => "view",
260 TableType::Temporary => "temporary",
261 }
262 }
263
264 fn __repr__(&self) -> PyResult<String> {
265 let kind = self.kind();
266 Ok(format!("Table(kind={kind})"))
267 }
268
269 }
274
275#[derive(Debug)]
276pub(crate) struct RustWrappedPySchemaProvider {
277 schema_provider: PyObject,
278 owner_name: Option<String>,
279}
280
281impl RustWrappedPySchemaProvider {
282 pub fn new(schema_provider: PyObject) -> Self {
283 let owner_name = Python::with_gil(|py| {
284 schema_provider
285 .bind(py)
286 .getattr("owner_name")
287 .ok()
288 .map(|name| name.to_string())
289 });
290
291 Self {
292 schema_provider,
293 owner_name,
294 }
295 }
296
297 fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
298 Python::with_gil(|py| {
299 let provider = self.schema_provider.bind(py);
300 let py_table_method = provider.getattr("table")?;
301
302 let py_table = py_table_method.call((name,), None)?;
303 if py_table.is_none() {
304 return Ok(None);
305 }
306
307 if py_table.hasattr("__datafusion_table_provider__")? {
308 let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
309 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
310 validate_pycapsule(capsule, "datafusion_table_provider")?;
311
312 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
313 let provider: ForeignTableProvider = provider.into();
314
315 Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
316 } else {
317 if let Ok(inner_table) = py_table.getattr("table") {
318 if let Ok(inner_table) = inner_table.extract::<PyTable>() {
319 return Ok(Some(inner_table.table));
320 }
321 }
322
323 match py_table.extract::<PyTable>() {
324 Ok(py_table) => Ok(Some(py_table.table)),
325 Err(_) => {
326 let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
327 Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
328 }
329 }
330 }
331 })
332 }
333}
334
335#[async_trait]
336impl SchemaProvider for RustWrappedPySchemaProvider {
337 fn owner_name(&self) -> Option<&str> {
338 self.owner_name.as_deref()
339 }
340
341 fn as_any(&self) -> &dyn Any {
342 self
343 }
344
345 fn table_names(&self) -> Vec<String> {
346 Python::with_gil(|py| {
347 let provider = self.schema_provider.bind(py);
348
349 provider
350 .getattr("table_names")
351 .and_then(|names| names.extract::<Vec<String>>())
352 .unwrap_or_else(|err| {
353 log::error!("Unable to get table_names: {err}");
354 Vec::default()
355 })
356 })
357 }
358
359 async fn table(
360 &self,
361 name: &str,
362 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
363 self.table_inner(name).map_err(to_datafusion_err)
364 }
365
366 fn register_table(
367 &self,
368 name: String,
369 table: Arc<dyn TableProvider>,
370 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
371 let py_table = PyTable::new(table);
372 Python::with_gil(|py| {
373 let provider = self.schema_provider.bind(py);
374 let _ = provider
375 .call_method1("register_table", (name, py_table))
376 .map_err(to_datafusion_err)?;
377 Ok(None)
381 })
382 }
383
384 fn deregister_table(
385 &self,
386 name: &str,
387 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
388 Python::with_gil(|py| {
389 let provider = self.schema_provider.bind(py);
390 let table = provider
391 .call_method1("deregister_table", (name,))
392 .map_err(to_datafusion_err)?;
393 if table.is_none() {
394 return Ok(None);
395 }
396
397 let dataset = match Dataset::new(&table, py) {
400 Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
401 Err(_) => None,
402 };
403
404 Ok(dataset)
405 })
406 }
407
408 fn table_exist(&self, name: &str) -> bool {
409 Python::with_gil(|py| {
410 let provider = self.schema_provider.bind(py);
411 provider
412 .call_method1("table_exist", (name,))
413 .and_then(|pyobj| pyobj.extract())
414 .unwrap_or(false)
415 })
416 }
417}
418
419#[derive(Debug)]
420pub(crate) struct RustWrappedPyCatalogProvider {
421 pub(crate) catalog_provider: PyObject,
422}
423
424impl RustWrappedPyCatalogProvider {
425 pub fn new(catalog_provider: PyObject) -> Self {
426 Self { catalog_provider }
427 }
428
429 fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
430 Python::with_gil(|py| {
431 let provider = self.catalog_provider.bind(py);
432
433 let py_schema = provider.call_method1("schema", (name,))?;
434 if py_schema.is_none() {
435 return Ok(None);
436 }
437
438 if py_schema.hasattr("__datafusion_schema_provider__")? {
439 let capsule = provider
440 .getattr("__datafusion_schema_provider__")?
441 .call0()?;
442 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
443 validate_pycapsule(capsule, "datafusion_schema_provider")?;
444
445 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
446 let provider: ForeignSchemaProvider = provider.into();
447
448 Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
449 } else {
450 if let Ok(inner_schema) = py_schema.getattr("schema") {
451 if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
452 return Ok(Some(inner_schema.schema));
453 }
454 }
455 match py_schema.extract::<PySchema>() {
456 Ok(inner_schema) => Ok(Some(inner_schema.schema)),
457 Err(_) => {
458 let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
459
460 Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
461 }
462 }
463 }
464 })
465 }
466}
467
468#[async_trait]
469impl CatalogProvider for RustWrappedPyCatalogProvider {
470 fn as_any(&self) -> &dyn Any {
471 self
472 }
473
474 fn schema_names(&self) -> Vec<String> {
475 Python::with_gil(|py| {
476 let provider = self.catalog_provider.bind(py);
477 provider
478 .getattr("schema_names")
479 .and_then(|names| names.extract::<Vec<String>>())
480 .unwrap_or_else(|err| {
481 log::error!("Unable to get schema_names: {err}");
482 Vec::default()
483 })
484 })
485 }
486
487 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
488 self.schema_inner(name).unwrap_or_else(|err| {
489 log::error!("CatalogProvider schema returned error: {err}");
490 None
491 })
492 }
493
494 fn register_schema(
495 &self,
496 name: &str,
497 schema: Arc<dyn SchemaProvider>,
498 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
499 Python::with_gil(|py| {
502 let py_schema = match schema
503 .as_any()
504 .downcast_ref::<RustWrappedPySchemaProvider>()
505 {
506 Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
507 None => &PySchema::from(schema)
508 .into_py_any(py)
509 .map_err(to_datafusion_err)?,
510 };
511
512 let provider = self.catalog_provider.bind(py);
513 let schema = provider
514 .call_method1("register_schema", (name, py_schema))
515 .map_err(to_datafusion_err)?;
516 if schema.is_none() {
517 return Ok(None);
518 }
519
520 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
521 as Arc<dyn SchemaProvider>;
522
523 Ok(Some(schema))
524 })
525 }
526
527 fn deregister_schema(
528 &self,
529 name: &str,
530 cascade: bool,
531 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
532 Python::with_gil(|py| {
533 let provider = self.catalog_provider.bind(py);
534 let schema = provider
535 .call_method1("deregister_schema", (name, cascade))
536 .map_err(to_datafusion_err)?;
537 if schema.is_none() {
538 return Ok(None);
539 }
540
541 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
542 as Arc<dyn SchemaProvider>;
543
544 Ok(Some(schema))
545 })
546 }
547}
548
549pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
550 m.add_class::<PyCatalog>()?;
551 m.add_class::<PySchema>()?;
552 m.add_class::<PyTable>()?;
553
554 Ok(())
555}