1use std::any::Any;
19use std::collections::HashSet;
20use std::ptr::NonNull;
21use std::sync::Arc;
22
23use async_trait::async_trait;
24use datafusion::catalog::{
25 CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList,
26 MemorySchemaProvider, SchemaProvider,
27};
28use datafusion::common::DataFusionError;
29use datafusion::datasource::TableProvider;
30use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
31use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
32use datafusion_ffi::schema_provider::FFI_SchemaProvider;
33use datafusion_python_util::{
34 create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, wait_for_future,
35};
36use pyo3::IntoPyObjectExt;
37use pyo3::exceptions::PyKeyError;
38use pyo3::prelude::*;
39use pyo3::types::PyCapsule;
40
41use crate::context::PySessionContext;
42use crate::dataset::Dataset;
43use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
44use crate::table::PyTable;
45
46#[pyclass(
47 from_py_object,
48 frozen,
49 name = "RawCatalogList",
50 module = "datafusion.catalog",
51 subclass
52)]
53#[derive(Clone)]
54pub struct PyCatalogList {
55 pub catalog_list: Arc<dyn CatalogProviderList>,
56 codec: Arc<FFI_LogicalExtensionCodec>,
57}
58
59#[pyclass(
60 from_py_object,
61 frozen,
62 name = "RawCatalog",
63 module = "datafusion.catalog",
64 subclass
65)]
66#[derive(Clone)]
67pub struct PyCatalog {
68 pub catalog: Arc<dyn CatalogProvider>,
69 codec: Arc<FFI_LogicalExtensionCodec>,
70}
71
72#[pyclass(
73 from_py_object,
74 frozen,
75 name = "RawSchema",
76 module = "datafusion.catalog",
77 subclass
78)]
79#[derive(Clone)]
80pub struct PySchema {
81 pub schema: Arc<dyn SchemaProvider>,
82 codec: Arc<FFI_LogicalExtensionCodec>,
83}
84
85impl PyCatalog {
86 pub(crate) fn new_from_parts(
87 catalog: Arc<dyn CatalogProvider>,
88 codec: Arc<FFI_LogicalExtensionCodec>,
89 ) -> Self {
90 Self { catalog, codec }
91 }
92}
93
94impl PySchema {
95 pub(crate) fn new_from_parts(
96 schema: Arc<dyn SchemaProvider>,
97 codec: Arc<FFI_LogicalExtensionCodec>,
98 ) -> Self {
99 Self { schema, codec }
100 }
101}
102
103#[pymethods]
104impl PyCatalogList {
105 #[new]
106 pub fn new(
107 py: Python,
108 catalog_list: Py<PyAny>,
109 session: Option<Bound<PyAny>>,
110 ) -> PyResult<Self> {
111 let codec = extract_logical_extension_codec(py, session)?;
112 let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
113 catalog_list,
114 codec.clone(),
115 )) as Arc<dyn CatalogProviderList>;
116 Ok(Self {
117 catalog_list,
118 codec,
119 })
120 }
121
122 #[staticmethod]
123 pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
124 let codec = extract_logical_extension_codec(py, session)?;
125 let catalog_list =
126 Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn CatalogProviderList>;
127 Ok(Self {
128 catalog_list,
129 codec,
130 })
131 }
132
133 pub fn catalog_names(&self) -> HashSet<String> {
134 self.catalog_list.catalog_names().into_iter().collect()
135 }
136
137 #[pyo3(signature = (name="public"))]
138 pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
139 let catalog = self
140 .catalog_list
141 .catalog(name)
142 .ok_or(PyKeyError::new_err(format!(
143 "Schema with name {name} doesn't exist."
144 )))?;
145
146 Python::attach(|py| {
147 match catalog
148 .as_any()
149 .downcast_ref::<RustWrappedPyCatalogProvider>()
150 {
151 Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
152 None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py),
153 }
154 })
155 }
156
157 pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> {
158 let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?;
159
160 let _ = self
161 .catalog_list
162 .register_catalog(name.to_owned(), provider);
163
164 Ok(())
165 }
166
167 pub fn __repr__(&self) -> PyResult<String> {
168 let mut names: Vec<String> = self.catalog_names().into_iter().collect();
169 names.sort();
170 Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
171 }
172}
173
174#[pymethods]
175impl PyCatalog {
176 #[new]
177 pub fn new(py: Python, catalog: Py<PyAny>, session: Option<Bound<PyAny>>) -> PyResult<Self> {
178 let codec = extract_logical_extension_codec(py, session)?;
179 let catalog = Arc::new(RustWrappedPyCatalogProvider::new(catalog, codec.clone()))
180 as Arc<dyn CatalogProvider>;
181 Ok(Self { catalog, codec })
182 }
183
184 #[staticmethod]
185 pub fn memory_catalog(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
186 let codec = extract_logical_extension_codec(py, session)?;
187 let catalog = Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
188 Ok(Self { catalog, codec })
189 }
190
191 pub fn schema_names(&self) -> HashSet<String> {
192 self.catalog.schema_names().into_iter().collect()
193 }
194
195 #[pyo3(signature = (name="public"))]
196 pub fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
197 let schema = self
198 .catalog
199 .schema(name)
200 .ok_or(PyKeyError::new_err(format!(
201 "Schema with name {name} doesn't exist."
202 )))?;
203
204 Python::attach(|py| {
205 match schema
206 .as_any()
207 .downcast_ref::<RustWrappedPySchemaProvider>()
208 {
209 Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
210 None => PySchema::new_from_parts(schema, self.codec.clone()).into_py_any(py),
211 }
212 })
213 }
214
215 pub fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
216 let provider = extract_schema_provider_from_pyobj(schema_provider, self.codec.as_ref())?;
217
218 let _ = self
219 .catalog
220 .register_schema(name, provider)
221 .map_err(py_datafusion_err)?;
222
223 Ok(())
224 }
225
226 pub fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
227 let _ = self
228 .catalog
229 .deregister_schema(name, cascade)
230 .map_err(py_datafusion_err)?;
231
232 Ok(())
233 }
234
235 pub fn __repr__(&self) -> PyResult<String> {
236 let mut names: Vec<String> = self.schema_names().into_iter().collect();
237 names.sort();
238 Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
239 }
240}
241
242#[pymethods]
243impl PySchema {
244 #[new]
245 pub fn new(
246 py: Python,
247 schema_provider: Py<PyAny>,
248 session: Option<Bound<PyAny>>,
249 ) -> PyResult<Self> {
250 let codec = extract_logical_extension_codec(py, session)?;
251 let schema =
252 Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
253 Ok(Self { schema, codec })
254 }
255
256 #[staticmethod]
257 fn memory_schema(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
258 let codec = extract_logical_extension_codec(py, session)?;
259 let schema = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
260 Ok(Self { schema, codec })
261 }
262
263 #[getter]
264 fn table_names(&self) -> HashSet<String> {
265 self.schema.table_names().into_iter().collect()
266 }
267
268 fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
269 if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
270 Ok(PyTable::from(table))
271 } else {
272 Err(PyDataFusionError::Common(format!(
273 "Table not found: {name}"
274 )))
275 }
276 }
277
278 fn __repr__(&self) -> PyResult<String> {
279 let mut names: Vec<String> = self.table_names().into_iter().collect();
280 names.sort();
281 Ok(format!("Schema(table_names=[{}])", names.join(";")))
282 }
283
284 fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
285 let py = table_provider.py();
286 let codec_capsule = create_logical_extension_capsule(py, self.codec.as_ref())?
287 .as_any()
288 .clone();
289
290 let table = PyTable::new(table_provider, Some(codec_capsule))?;
291
292 let _ = self
293 .schema
294 .register_table(name.to_string(), table.table)
295 .map_err(py_datafusion_err)?;
296
297 Ok(())
298 }
299
300 fn deregister_table(&self, name: &str) -> PyResult<()> {
301 let _ = self
302 .schema
303 .deregister_table(name)
304 .map_err(py_datafusion_err)?;
305
306 Ok(())
307 }
308
309 fn table_exist(&self, name: &str) -> bool {
310 self.schema.table_exist(name)
311 }
312}
313
314#[derive(Debug)]
315pub(crate) struct RustWrappedPySchemaProvider {
316 schema_provider: Py<PyAny>,
317 owner_name: Option<String>,
318}
319
320impl RustWrappedPySchemaProvider {
321 pub fn new(schema_provider: Py<PyAny>) -> Self {
322 let owner_name = Python::attach(|py| {
323 schema_provider
324 .bind(py)
325 .getattr("owner_name")
326 .ok()
327 .map(|name| name.to_string())
328 });
329
330 Self {
331 schema_provider,
332 owner_name,
333 }
334 }
335
336 fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
337 Python::attach(|py| {
338 let provider = self.schema_provider.bind(py);
339 let py_table_method = provider.getattr("table")?;
340
341 let py_table = py_table_method.call((name,), None)?;
342 if py_table.is_none() {
343 return Ok(None);
344 }
345
346 let table = PyTable::new(py_table, None)?;
347
348 Ok(Some(table.table))
349 })
350 }
351}
352
353#[async_trait]
354impl SchemaProvider for RustWrappedPySchemaProvider {
355 fn owner_name(&self) -> Option<&str> {
356 self.owner_name.as_deref()
357 }
358
359 fn as_any(&self) -> &dyn Any {
360 self
361 }
362
363 fn table_names(&self) -> Vec<String> {
364 Python::attach(|py| {
365 let provider = self.schema_provider.bind(py);
366
367 provider
368 .getattr("table_names")
369 .and_then(|names| names.extract::<Vec<String>>())
370 .unwrap_or_else(|err| {
371 log::error!("Unable to get table_names: {err}");
372 Vec::default()
373 })
374 })
375 }
376
377 async fn table(
378 &self,
379 name: &str,
380 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
381 self.table_inner(name)
382 .map_err(|e| DataFusionError::External(Box::new(e)))
383 }
384
385 fn register_table(
386 &self,
387 name: String,
388 table: Arc<dyn TableProvider>,
389 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
390 let py_table = PyTable::from(table);
391 Python::attach(|py| {
392 let provider = self.schema_provider.bind(py);
393 let _ = provider
394 .call_method1("register_table", (name, py_table))
395 .map_err(to_datafusion_err)?;
396 Ok(None)
400 })
401 }
402
403 fn deregister_table(
404 &self,
405 name: &str,
406 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
407 Python::attach(|py| {
408 let provider = self.schema_provider.bind(py);
409 let table = provider
410 .call_method1("deregister_table", (name,))
411 .map_err(to_datafusion_err)?;
412 if table.is_none() {
413 return Ok(None);
414 }
415
416 let dataset = match Dataset::new(&table, py) {
419 Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
420 Err(_) => None,
421 };
422
423 Ok(dataset)
424 })
425 }
426
427 fn table_exist(&self, name: &str) -> bool {
428 Python::attach(|py| {
429 let provider = self.schema_provider.bind(py);
430 provider
431 .call_method1("table_exist", (name,))
432 .and_then(|pyobj| pyobj.extract())
433 .unwrap_or(false)
434 })
435 }
436}
437
438#[derive(Debug)]
439pub(crate) struct RustWrappedPyCatalogProvider {
440 pub(crate) catalog_provider: Py<PyAny>,
441 codec: Arc<FFI_LogicalExtensionCodec>,
442}
443
444impl RustWrappedPyCatalogProvider {
445 pub fn new(catalog_provider: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
446 Self {
447 catalog_provider,
448 codec,
449 }
450 }
451
452 fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
453 Python::attach(|py| {
454 let provider = self.catalog_provider.bind(py);
455
456 let py_schema = provider.call_method1("schema", (name,))?;
457 if py_schema.is_none() {
458 return Ok(None);
459 }
460
461 extract_schema_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
462 })
463 }
464}
465
466#[async_trait]
467impl CatalogProvider for RustWrappedPyCatalogProvider {
468 fn as_any(&self) -> &dyn Any {
469 self
470 }
471
472 fn schema_names(&self) -> Vec<String> {
473 Python::attach(|py| {
474 let provider = self.catalog_provider.bind(py);
475 provider
476 .call_method0("schema_names")
477 .and_then(|names| names.extract::<HashSet<String>>())
478 .map(|names| names.into_iter().collect())
479 .unwrap_or_else(|err| {
480 log::error!("Unable to get schema_names: {err}");
481 Vec::default()
482 })
483 })
484 }
485
486 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
487 self.schema_inner(name).unwrap_or_else(|err| {
488 log::error!("CatalogProvider schema returned error: {err}");
489 None
490 })
491 }
492
493 fn register_schema(
494 &self,
495 name: &str,
496 schema: Arc<dyn SchemaProvider>,
497 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
498 Python::attach(|py| {
499 let py_schema = match schema
500 .as_any()
501 .downcast_ref::<RustWrappedPySchemaProvider>()
502 {
503 Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
504 None => &PySchema::new_from_parts(schema, self.codec.clone())
505 .into_py_any(py)
506 .map_err(to_datafusion_err)?,
507 };
508
509 let provider = self.catalog_provider.bind(py);
510 let schema = provider
511 .call_method1("register_schema", (name, py_schema))
512 .map_err(to_datafusion_err)?;
513 if schema.is_none() {
514 return Ok(None);
515 }
516
517 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
518 as Arc<dyn SchemaProvider>;
519
520 Ok(Some(schema))
521 })
522 }
523
524 fn deregister_schema(
525 &self,
526 name: &str,
527 cascade: bool,
528 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
529 Python::attach(|py| {
530 let provider = self.catalog_provider.bind(py);
531 let schema = provider
532 .call_method1("deregister_schema", (name, cascade))
533 .map_err(to_datafusion_err)?;
534 if schema.is_none() {
535 return Ok(None);
536 }
537
538 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
539 as Arc<dyn SchemaProvider>;
540
541 Ok(Some(schema))
542 })
543 }
544}
545
546#[derive(Debug)]
547pub(crate) struct RustWrappedPyCatalogProviderList {
548 pub(crate) catalog_provider_list: Py<PyAny>,
549 codec: Arc<FFI_LogicalExtensionCodec>,
550}
551
552impl RustWrappedPyCatalogProviderList {
553 pub fn new(catalog_provider_list: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
554 Self {
555 catalog_provider_list,
556 codec,
557 }
558 }
559
560 fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn CatalogProvider>>> {
561 Python::attach(|py| {
562 let provider = self.catalog_provider_list.bind(py);
563
564 let py_schema = provider.call_method1("catalog", (name,))?;
565 if py_schema.is_none() {
566 return Ok(None);
567 }
568
569 extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
570 })
571 }
572}
573
574#[async_trait]
575impl CatalogProviderList for RustWrappedPyCatalogProviderList {
576 fn as_any(&self) -> &dyn Any {
577 self
578 }
579
580 fn catalog_names(&self) -> Vec<String> {
581 Python::attach(|py| {
582 let provider = self.catalog_provider_list.bind(py);
583 provider
584 .call_method0("catalog_names")
585 .and_then(|names| names.extract::<HashSet<String>>())
586 .map(|names| names.into_iter().collect())
587 .unwrap_or_else(|err| {
588 log::error!("Unable to get catalog_names: {err}");
589 Vec::default()
590 })
591 })
592 }
593
594 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
595 self.catalog_inner(name).unwrap_or_else(|err| {
596 log::error!("CatalogProvider catalog returned error: {err}");
597 None
598 })
599 }
600
601 fn register_catalog(
602 &self,
603 name: String,
604 catalog: Arc<dyn CatalogProvider>,
605 ) -> Option<Arc<dyn CatalogProvider>> {
606 Python::attach(|py| {
607 let py_catalog = match catalog
608 .as_any()
609 .downcast_ref::<RustWrappedPyCatalogProvider>()
610 {
611 Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py),
612 None => {
613 match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) {
614 Ok(c) => c,
615 Err(err) => {
616 log::error!(
617 "register_catalog returned error during conversion to PyAny: {err}"
618 );
619 return None;
620 }
621 }
622 }
623 };
624
625 let provider = self.catalog_provider_list.bind(py);
626 let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) {
627 Ok(c) => c,
628 Err(err) => {
629 log::error!("register_catalog returned error: {err}");
630 return None;
631 }
632 };
633 if catalog.is_none() {
634 return None;
635 }
636
637 let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
638 catalog.into(),
639 self.codec.clone(),
640 )) as Arc<dyn CatalogProvider>;
641
642 Some(catalog)
643 })
644 }
645}
646
647fn extract_catalog_provider_from_pyobj(
648 mut catalog_provider: Bound<PyAny>,
649 codec: &FFI_LogicalExtensionCodec,
650) -> PyResult<Arc<dyn CatalogProvider>> {
651 if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
652 let py = catalog_provider.py();
653 let codec_capsule = create_logical_extension_capsule(py, codec)?;
654 catalog_provider = catalog_provider
655 .getattr("__datafusion_catalog_provider__")?
656 .call1((codec_capsule,))?;
657 }
658
659 let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
660 let data: NonNull<FFI_CatalogProvider> = capsule
661 .pointer_checked(Some(c"datafusion_catalog_provider"))?
662 .cast();
663 let provider = unsafe { data.as_ref() };
664 let provider: Arc<dyn CatalogProvider + Send> = provider.into();
665 provider as Arc<dyn CatalogProvider>
666 } else {
667 match catalog_provider.extract::<PyCatalog>() {
668 Ok(py_catalog) => py_catalog.catalog,
669 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
670 catalog_provider.into(),
671 Arc::new(codec.clone()),
672 )) as Arc<dyn CatalogProvider>,
673 }
674 };
675
676 Ok(provider)
677}
678
679fn extract_schema_provider_from_pyobj(
680 mut schema_provider: Bound<PyAny>,
681 codec: &FFI_LogicalExtensionCodec,
682) -> PyResult<Arc<dyn SchemaProvider>> {
683 if schema_provider.hasattr("__datafusion_schema_provider__")? {
684 let py = schema_provider.py();
685 let codec_capsule = create_logical_extension_capsule(py, codec)?;
686 schema_provider = schema_provider
687 .getattr("__datafusion_schema_provider__")?
688 .call1((codec_capsule,))?;
689 }
690
691 let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
692 let data: NonNull<FFI_SchemaProvider> = capsule
693 .pointer_checked(Some(c"datafusion_schema_provider"))?
694 .cast();
695 let provider = unsafe { data.as_ref() };
696 let provider: Arc<dyn SchemaProvider + Send> = provider.into();
697 provider as Arc<dyn SchemaProvider>
698 } else {
699 match schema_provider.extract::<PySchema>() {
700 Ok(py_schema) => py_schema.schema,
701 Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
702 as Arc<dyn SchemaProvider>,
703 }
704 };
705
706 Ok(provider)
707}
708
709fn extract_logical_extension_codec(
710 py: Python,
711 obj: Option<Bound<PyAny>>,
712) -> PyResult<Arc<FFI_LogicalExtensionCodec>> {
713 let obj = match obj {
714 Some(obj) => obj,
715 None => PySessionContext::global_ctx()?.into_bound_py_any(py)?,
716 };
717 ffi_logical_codec_from_pycapsule(obj).map(Arc::new)
718}
719
720pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
721 m.add_class::<PyCatalog>()?;
722 m.add_class::<PySchema>()?;
723 m.add_class::<PyTable>()?;
724
725 Ok(())
726}