1use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use datafusion::catalog::{
24 CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList,
25 MemorySchemaProvider, SchemaProvider,
26};
27use datafusion::common::DataFusionError;
28use datafusion::datasource::TableProvider;
29use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
30use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
31use datafusion_ffi::schema_provider::FFI_SchemaProvider;
32use pyo3::IntoPyObjectExt;
33use pyo3::exceptions::PyKeyError;
34use pyo3::prelude::*;
35use pyo3::types::PyCapsule;
36
37use crate::dataset::Dataset;
38use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
39use crate::table::PyTable;
40use crate::utils::{
41 create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
42 wait_for_future,
43};
44
45#[pyclass(
46 frozen,
47 name = "RawCatalogList",
48 module = "datafusion.catalog",
49 subclass
50)]
51#[derive(Clone)]
52pub struct PyCatalogList {
53 pub catalog_list: Arc<dyn CatalogProviderList>,
54 codec: Arc<FFI_LogicalExtensionCodec>,
55}
56
57#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
58#[derive(Clone)]
59pub struct PyCatalog {
60 pub catalog: Arc<dyn CatalogProvider>,
61 codec: Arc<FFI_LogicalExtensionCodec>,
62}
63
64#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
65#[derive(Clone)]
66pub struct PySchema {
67 pub schema: Arc<dyn SchemaProvider>,
68 codec: Arc<FFI_LogicalExtensionCodec>,
69}
70
71impl PyCatalog {
72 pub(crate) fn new_from_parts(
73 catalog: Arc<dyn CatalogProvider>,
74 codec: Arc<FFI_LogicalExtensionCodec>,
75 ) -> Self {
76 Self { catalog, codec }
77 }
78}
79
80impl PySchema {
81 pub(crate) fn new_from_parts(
82 schema: Arc<dyn SchemaProvider>,
83 codec: Arc<FFI_LogicalExtensionCodec>,
84 ) -> Self {
85 Self { schema, codec }
86 }
87}
88
89#[pymethods]
90impl PyCatalogList {
91 #[new]
92 pub fn new(
93 py: Python,
94 catalog_list: Py<PyAny>,
95 session: Option<Bound<PyAny>>,
96 ) -> PyResult<Self> {
97 let codec = extract_logical_extension_codec(py, session)?;
98 let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
99 catalog_list,
100 codec.clone(),
101 )) as Arc<dyn CatalogProviderList>;
102 Ok(Self {
103 catalog_list,
104 codec,
105 })
106 }
107
108 #[staticmethod]
109 pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
110 let codec = extract_logical_extension_codec(py, session)?;
111 let catalog_list =
112 Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn CatalogProviderList>;
113 Ok(Self {
114 catalog_list,
115 codec,
116 })
117 }
118
119 pub fn catalog_names(&self) -> HashSet<String> {
120 self.catalog_list.catalog_names().into_iter().collect()
121 }
122
123 #[pyo3(signature = (name="public"))]
124 pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
125 let catalog = self
126 .catalog_list
127 .catalog(name)
128 .ok_or(PyKeyError::new_err(format!(
129 "Schema with name {name} doesn't exist."
130 )))?;
131
132 Python::attach(|py| {
133 match catalog
134 .as_any()
135 .downcast_ref::<RustWrappedPyCatalogProvider>()
136 {
137 Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
138 None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py),
139 }
140 })
141 }
142
143 pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> {
144 let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?;
145
146 let _ = self
147 .catalog_list
148 .register_catalog(name.to_owned(), provider);
149
150 Ok(())
151 }
152
153 pub fn __repr__(&self) -> PyResult<String> {
154 let mut names: Vec<String> = self.catalog_names().into_iter().collect();
155 names.sort();
156 Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
157 }
158}
159
160#[pymethods]
161impl PyCatalog {
162 #[new]
163 pub fn new(py: Python, catalog: Py<PyAny>, session: Option<Bound<PyAny>>) -> PyResult<Self> {
164 let codec = extract_logical_extension_codec(py, session)?;
165 let catalog = Arc::new(RustWrappedPyCatalogProvider::new(catalog, codec.clone()))
166 as Arc<dyn CatalogProvider>;
167 Ok(Self { catalog, codec })
168 }
169
170 #[staticmethod]
171 pub fn memory_catalog(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
172 let codec = extract_logical_extension_codec(py, session)?;
173 let catalog = Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
174 Ok(Self { catalog, codec })
175 }
176
177 pub fn schema_names(&self) -> HashSet<String> {
178 self.catalog.schema_names().into_iter().collect()
179 }
180
181 #[pyo3(signature = (name="public"))]
182 pub fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
183 let schema = self
184 .catalog
185 .schema(name)
186 .ok_or(PyKeyError::new_err(format!(
187 "Schema with name {name} doesn't exist."
188 )))?;
189
190 Python::attach(|py| {
191 match schema
192 .as_any()
193 .downcast_ref::<RustWrappedPySchemaProvider>()
194 {
195 Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
196 None => PySchema::new_from_parts(schema, self.codec.clone()).into_py_any(py),
197 }
198 })
199 }
200
201 pub fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
202 let provider = extract_schema_provider_from_pyobj(schema_provider, self.codec.as_ref())?;
203
204 let _ = self
205 .catalog
206 .register_schema(name, provider)
207 .map_err(py_datafusion_err)?;
208
209 Ok(())
210 }
211
212 pub fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
213 let _ = self
214 .catalog
215 .deregister_schema(name, cascade)
216 .map_err(py_datafusion_err)?;
217
218 Ok(())
219 }
220
221 pub fn __repr__(&self) -> PyResult<String> {
222 let mut names: Vec<String> = self.schema_names().into_iter().collect();
223 names.sort();
224 Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
225 }
226}
227
228#[pymethods]
229impl PySchema {
230 #[new]
231 pub fn new(
232 py: Python,
233 schema_provider: Py<PyAny>,
234 session: Option<Bound<PyAny>>,
235 ) -> PyResult<Self> {
236 let codec = extract_logical_extension_codec(py, session)?;
237 let schema =
238 Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
239 Ok(Self { schema, codec })
240 }
241
242 #[staticmethod]
243 fn memory_schema(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
244 let codec = extract_logical_extension_codec(py, session)?;
245 let schema = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
246 Ok(Self { schema, codec })
247 }
248
249 #[getter]
250 fn table_names(&self) -> HashSet<String> {
251 self.schema.table_names().into_iter().collect()
252 }
253
254 fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
255 if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
256 Ok(PyTable::from(table))
257 } else {
258 Err(PyDataFusionError::Common(format!(
259 "Table not found: {name}"
260 )))
261 }
262 }
263
264 fn __repr__(&self) -> PyResult<String> {
265 let mut names: Vec<String> = self.table_names().into_iter().collect();
266 names.sort();
267 Ok(format!("Schema(table_names=[{}])", names.join(";")))
268 }
269
270 fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
271 let py = table_provider.py();
272 let codec_capsule = create_logical_extension_capsule(py, self.codec.as_ref())?
273 .as_any()
274 .clone();
275
276 let table = PyTable::new(table_provider, Some(codec_capsule))?;
277
278 let _ = self
279 .schema
280 .register_table(name.to_string(), table.table)
281 .map_err(py_datafusion_err)?;
282
283 Ok(())
284 }
285
286 fn deregister_table(&self, name: &str) -> PyResult<()> {
287 let _ = self
288 .schema
289 .deregister_table(name)
290 .map_err(py_datafusion_err)?;
291
292 Ok(())
293 }
294
295 fn table_exist(&self, name: &str) -> bool {
296 self.schema.table_exist(name)
297 }
298}
299
300#[derive(Debug)]
301pub(crate) struct RustWrappedPySchemaProvider {
302 schema_provider: Py<PyAny>,
303 owner_name: Option<String>,
304}
305
306impl RustWrappedPySchemaProvider {
307 pub fn new(schema_provider: Py<PyAny>) -> Self {
308 let owner_name = Python::attach(|py| {
309 schema_provider
310 .bind(py)
311 .getattr("owner_name")
312 .ok()
313 .map(|name| name.to_string())
314 });
315
316 Self {
317 schema_provider,
318 owner_name,
319 }
320 }
321
322 fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
323 Python::attach(|py| {
324 let provider = self.schema_provider.bind(py);
325 let py_table_method = provider.getattr("table")?;
326
327 let py_table = py_table_method.call((name,), None)?;
328 if py_table.is_none() {
329 return Ok(None);
330 }
331
332 let table = PyTable::new(py_table, None)?;
333
334 Ok(Some(table.table))
335 })
336 }
337}
338
339#[async_trait]
340impl SchemaProvider for RustWrappedPySchemaProvider {
341 fn owner_name(&self) -> Option<&str> {
342 self.owner_name.as_deref()
343 }
344
345 fn as_any(&self) -> &dyn Any {
346 self
347 }
348
349 fn table_names(&self) -> Vec<String> {
350 Python::attach(|py| {
351 let provider = self.schema_provider.bind(py);
352
353 provider
354 .getattr("table_names")
355 .and_then(|names| names.extract::<Vec<String>>())
356 .unwrap_or_else(|err| {
357 log::error!("Unable to get table_names: {err}");
358 Vec::default()
359 })
360 })
361 }
362
363 async fn table(
364 &self,
365 name: &str,
366 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
367 self.table_inner(name)
368 .map_err(|e| DataFusionError::External(Box::new(e)))
369 }
370
371 fn register_table(
372 &self,
373 name: String,
374 table: Arc<dyn TableProvider>,
375 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
376 let py_table = PyTable::from(table);
377 Python::attach(|py| {
378 let provider = self.schema_provider.bind(py);
379 let _ = provider
380 .call_method1("register_table", (name, py_table))
381 .map_err(to_datafusion_err)?;
382 Ok(None)
386 })
387 }
388
389 fn deregister_table(
390 &self,
391 name: &str,
392 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
393 Python::attach(|py| {
394 let provider = self.schema_provider.bind(py);
395 let table = provider
396 .call_method1("deregister_table", (name,))
397 .map_err(to_datafusion_err)?;
398 if table.is_none() {
399 return Ok(None);
400 }
401
402 let dataset = match Dataset::new(&table, py) {
405 Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
406 Err(_) => None,
407 };
408
409 Ok(dataset)
410 })
411 }
412
413 fn table_exist(&self, name: &str) -> bool {
414 Python::attach(|py| {
415 let provider = self.schema_provider.bind(py);
416 provider
417 .call_method1("table_exist", (name,))
418 .and_then(|pyobj| pyobj.extract())
419 .unwrap_or(false)
420 })
421 }
422}
423
424#[derive(Debug)]
425pub(crate) struct RustWrappedPyCatalogProvider {
426 pub(crate) catalog_provider: Py<PyAny>,
427 codec: Arc<FFI_LogicalExtensionCodec>,
428}
429
430impl RustWrappedPyCatalogProvider {
431 pub fn new(catalog_provider: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
432 Self {
433 catalog_provider,
434 codec,
435 }
436 }
437
438 fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
439 Python::attach(|py| {
440 let provider = self.catalog_provider.bind(py);
441
442 let py_schema = provider.call_method1("schema", (name,))?;
443 if py_schema.is_none() {
444 return Ok(None);
445 }
446
447 extract_schema_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
448 })
449 }
450}
451
452#[async_trait]
453impl CatalogProvider for RustWrappedPyCatalogProvider {
454 fn as_any(&self) -> &dyn Any {
455 self
456 }
457
458 fn schema_names(&self) -> Vec<String> {
459 Python::attach(|py| {
460 let provider = self.catalog_provider.bind(py);
461 provider
462 .call_method0("schema_names")
463 .and_then(|names| names.extract::<HashSet<String>>())
464 .map(|names| names.into_iter().collect())
465 .unwrap_or_else(|err| {
466 log::error!("Unable to get schema_names: {err}");
467 Vec::default()
468 })
469 })
470 }
471
472 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
473 self.schema_inner(name).unwrap_or_else(|err| {
474 log::error!("CatalogProvider schema returned error: {err}");
475 None
476 })
477 }
478
479 fn register_schema(
480 &self,
481 name: &str,
482 schema: Arc<dyn SchemaProvider>,
483 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
484 Python::attach(|py| {
485 let py_schema = match schema
486 .as_any()
487 .downcast_ref::<RustWrappedPySchemaProvider>()
488 {
489 Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
490 None => &PySchema::new_from_parts(schema, self.codec.clone())
491 .into_py_any(py)
492 .map_err(to_datafusion_err)?,
493 };
494
495 let provider = self.catalog_provider.bind(py);
496 let schema = provider
497 .call_method1("register_schema", (name, py_schema))
498 .map_err(to_datafusion_err)?;
499 if schema.is_none() {
500 return Ok(None);
501 }
502
503 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
504 as Arc<dyn SchemaProvider>;
505
506 Ok(Some(schema))
507 })
508 }
509
510 fn deregister_schema(
511 &self,
512 name: &str,
513 cascade: bool,
514 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
515 Python::attach(|py| {
516 let provider = self.catalog_provider.bind(py);
517 let schema = provider
518 .call_method1("deregister_schema", (name, cascade))
519 .map_err(to_datafusion_err)?;
520 if schema.is_none() {
521 return Ok(None);
522 }
523
524 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
525 as Arc<dyn SchemaProvider>;
526
527 Ok(Some(schema))
528 })
529 }
530}
531
532#[derive(Debug)]
533pub(crate) struct RustWrappedPyCatalogProviderList {
534 pub(crate) catalog_provider_list: Py<PyAny>,
535 codec: Arc<FFI_LogicalExtensionCodec>,
536}
537
538impl RustWrappedPyCatalogProviderList {
539 pub fn new(catalog_provider_list: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
540 Self {
541 catalog_provider_list,
542 codec,
543 }
544 }
545
546 fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn CatalogProvider>>> {
547 Python::attach(|py| {
548 let provider = self.catalog_provider_list.bind(py);
549
550 let py_schema = provider.call_method1("catalog", (name,))?;
551 if py_schema.is_none() {
552 return Ok(None);
553 }
554
555 extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
556 })
557 }
558}
559
560#[async_trait]
561impl CatalogProviderList for RustWrappedPyCatalogProviderList {
562 fn as_any(&self) -> &dyn Any {
563 self
564 }
565
566 fn catalog_names(&self) -> Vec<String> {
567 Python::attach(|py| {
568 let provider = self.catalog_provider_list.bind(py);
569 provider
570 .call_method0("catalog_names")
571 .and_then(|names| names.extract::<HashSet<String>>())
572 .map(|names| names.into_iter().collect())
573 .unwrap_or_else(|err| {
574 log::error!("Unable to get catalog_names: {err}");
575 Vec::default()
576 })
577 })
578 }
579
580 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
581 self.catalog_inner(name).unwrap_or_else(|err| {
582 log::error!("CatalogProvider catalog returned error: {err}");
583 None
584 })
585 }
586
587 fn register_catalog(
588 &self,
589 name: String,
590 catalog: Arc<dyn CatalogProvider>,
591 ) -> Option<Arc<dyn CatalogProvider>> {
592 Python::attach(|py| {
593 let py_catalog = match catalog
594 .as_any()
595 .downcast_ref::<RustWrappedPyCatalogProvider>()
596 {
597 Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py),
598 None => {
599 match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) {
600 Ok(c) => c,
601 Err(err) => {
602 log::error!(
603 "register_catalog returned error during conversion to PyAny: {err}"
604 );
605 return None;
606 }
607 }
608 }
609 };
610
611 let provider = self.catalog_provider_list.bind(py);
612 let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) {
613 Ok(c) => c,
614 Err(err) => {
615 log::error!("register_catalog returned error: {err}");
616 return None;
617 }
618 };
619 if catalog.is_none() {
620 return None;
621 }
622
623 let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
624 catalog.into(),
625 self.codec.clone(),
626 )) as Arc<dyn CatalogProvider>;
627
628 Some(catalog)
629 })
630 }
631}
632
633fn extract_catalog_provider_from_pyobj(
634 mut catalog_provider: Bound<PyAny>,
635 codec: &FFI_LogicalExtensionCodec,
636) -> PyResult<Arc<dyn CatalogProvider>> {
637 if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
638 let py = catalog_provider.py();
639 let codec_capsule = create_logical_extension_capsule(py, codec)?;
640 catalog_provider = catalog_provider
641 .getattr("__datafusion_catalog_provider__")?
642 .call1((codec_capsule,))?;
643 }
644
645 let provider = if let Ok(capsule) = catalog_provider.downcast::<PyCapsule>() {
646 validate_pycapsule(capsule, "datafusion_catalog_provider")?;
647
648 let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
649 let provider: Arc<dyn CatalogProvider + Send> = provider.into();
650 provider as Arc<dyn CatalogProvider>
651 } else {
652 match catalog_provider.extract::<PyCatalog>() {
653 Ok(py_catalog) => py_catalog.catalog,
654 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
655 catalog_provider.into(),
656 Arc::new(codec.clone()),
657 )) as Arc<dyn CatalogProvider>,
658 }
659 };
660
661 Ok(provider)
662}
663
664fn extract_schema_provider_from_pyobj(
665 mut schema_provider: Bound<PyAny>,
666 codec: &FFI_LogicalExtensionCodec,
667) -> PyResult<Arc<dyn SchemaProvider>> {
668 if schema_provider.hasattr("__datafusion_schema_provider__")? {
669 let py = schema_provider.py();
670 let codec_capsule = create_logical_extension_capsule(py, codec)?;
671 schema_provider = schema_provider
672 .getattr("__datafusion_schema_provider__")?
673 .call1((codec_capsule,))?;
674 }
675
676 let provider = if let Ok(capsule) = schema_provider.downcast::<PyCapsule>() {
677 validate_pycapsule(capsule, "datafusion_schema_provider")?;
678
679 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
680 let provider: Arc<dyn SchemaProvider + Send> = provider.into();
681 provider as Arc<dyn SchemaProvider>
682 } else {
683 match schema_provider.extract::<PySchema>() {
684 Ok(py_schema) => py_schema.schema,
685 Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
686 as Arc<dyn SchemaProvider>,
687 }
688 };
689
690 Ok(provider)
691}
692
693pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
694 m.add_class::<PyCatalog>()?;
695 m.add_class::<PySchema>()?;
696 m.add_class::<PyTable>()?;
697
698 Ok(())
699}