Skip to main content

datafusion_python/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::collections::HashSet;
20use std::ptr::NonNull;
21use std::sync::Arc;
22
23use async_trait::async_trait;
24use datafusion::catalog::{
25    CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList,
26    MemorySchemaProvider, SchemaProvider,
27};
28use datafusion::common::DataFusionError;
29use datafusion::datasource::TableProvider;
30use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
31use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
32use datafusion_ffi::schema_provider::FFI_SchemaProvider;
33use datafusion_python_util::{
34    create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, wait_for_future,
35};
36use pyo3::IntoPyObjectExt;
37use pyo3::exceptions::PyKeyError;
38use pyo3::prelude::*;
39use pyo3::types::PyCapsule;
40
41use crate::context::PySessionContext;
42use crate::dataset::Dataset;
43use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
44use crate::table::PyTable;
45
46#[pyclass(
47    from_py_object,
48    frozen,
49    name = "RawCatalogList",
50    module = "datafusion.catalog",
51    subclass
52)]
53#[derive(Clone)]
54pub struct PyCatalogList {
55    pub catalog_list: Arc<dyn CatalogProviderList>,
56    codec: Arc<FFI_LogicalExtensionCodec>,
57}
58
59#[pyclass(
60    from_py_object,
61    frozen,
62    name = "RawCatalog",
63    module = "datafusion.catalog",
64    subclass
65)]
66#[derive(Clone)]
67pub struct PyCatalog {
68    pub catalog: Arc<dyn CatalogProvider>,
69    codec: Arc<FFI_LogicalExtensionCodec>,
70}
71
72#[pyclass(
73    from_py_object,
74    frozen,
75    name = "RawSchema",
76    module = "datafusion.catalog",
77    subclass
78)]
79#[derive(Clone)]
80pub struct PySchema {
81    pub schema: Arc<dyn SchemaProvider>,
82    codec: Arc<FFI_LogicalExtensionCodec>,
83}
84
85impl PyCatalog {
86    pub(crate) fn new_from_parts(
87        catalog: Arc<dyn CatalogProvider>,
88        codec: Arc<FFI_LogicalExtensionCodec>,
89    ) -> Self {
90        Self { catalog, codec }
91    }
92}
93
94impl PySchema {
95    pub(crate) fn new_from_parts(
96        schema: Arc<dyn SchemaProvider>,
97        codec: Arc<FFI_LogicalExtensionCodec>,
98    ) -> Self {
99        Self { schema, codec }
100    }
101}
102
103#[pymethods]
104impl PyCatalogList {
105    #[new]
106    pub fn new(
107        py: Python,
108        catalog_list: Py<PyAny>,
109        session: Option<Bound<PyAny>>,
110    ) -> PyResult<Self> {
111        let codec = extract_logical_extension_codec(py, session)?;
112        let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
113            catalog_list,
114            codec.clone(),
115        )) as Arc<dyn CatalogProviderList>;
116        Ok(Self {
117            catalog_list,
118            codec,
119        })
120    }
121
122    #[staticmethod]
123    pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
124        let codec = extract_logical_extension_codec(py, session)?;
125        let catalog_list =
126            Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn CatalogProviderList>;
127        Ok(Self {
128            catalog_list,
129            codec,
130        })
131    }
132
133    pub fn catalog_names(&self) -> HashSet<String> {
134        self.catalog_list.catalog_names().into_iter().collect()
135    }
136
137    #[pyo3(signature = (name="public"))]
138    pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
139        let catalog = self
140            .catalog_list
141            .catalog(name)
142            .ok_or(PyKeyError::new_err(format!(
143                "Schema with name {name} doesn't exist."
144            )))?;
145
146        Python::attach(|py| {
147            match catalog
148                .as_any()
149                .downcast_ref::<RustWrappedPyCatalogProvider>()
150            {
151                Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
152                None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py),
153            }
154        })
155    }
156
157    pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> {
158        let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?;
159
160        let _ = self
161            .catalog_list
162            .register_catalog(name.to_owned(), provider);
163
164        Ok(())
165    }
166
167    pub fn __repr__(&self) -> PyResult<String> {
168        let mut names: Vec<String> = self.catalog_names().into_iter().collect();
169        names.sort();
170        Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
171    }
172}
173
174#[pymethods]
175impl PyCatalog {
176    #[new]
177    pub fn new(py: Python, catalog: Py<PyAny>, session: Option<Bound<PyAny>>) -> PyResult<Self> {
178        let codec = extract_logical_extension_codec(py, session)?;
179        let catalog = Arc::new(RustWrappedPyCatalogProvider::new(catalog, codec.clone()))
180            as Arc<dyn CatalogProvider>;
181        Ok(Self { catalog, codec })
182    }
183
184    #[staticmethod]
185    pub fn memory_catalog(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
186        let codec = extract_logical_extension_codec(py, session)?;
187        let catalog = Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
188        Ok(Self { catalog, codec })
189    }
190
191    pub fn schema_names(&self) -> HashSet<String> {
192        self.catalog.schema_names().into_iter().collect()
193    }
194
195    #[pyo3(signature = (name="public"))]
196    pub fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
197        let schema = self
198            .catalog
199            .schema(name)
200            .ok_or(PyKeyError::new_err(format!(
201                "Schema with name {name} doesn't exist."
202            )))?;
203
204        Python::attach(|py| {
205            match schema
206                .as_any()
207                .downcast_ref::<RustWrappedPySchemaProvider>()
208            {
209                Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
210                None => PySchema::new_from_parts(schema, self.codec.clone()).into_py_any(py),
211            }
212        })
213    }
214
215    pub fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
216        let provider = extract_schema_provider_from_pyobj(schema_provider, self.codec.as_ref())?;
217
218        let _ = self
219            .catalog
220            .register_schema(name, provider)
221            .map_err(py_datafusion_err)?;
222
223        Ok(())
224    }
225
226    pub fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
227        let _ = self
228            .catalog
229            .deregister_schema(name, cascade)
230            .map_err(py_datafusion_err)?;
231
232        Ok(())
233    }
234
235    pub fn __repr__(&self) -> PyResult<String> {
236        let mut names: Vec<String> = self.schema_names().into_iter().collect();
237        names.sort();
238        Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
239    }
240}
241
242#[pymethods]
243impl PySchema {
244    #[new]
245    pub fn new(
246        py: Python,
247        schema_provider: Py<PyAny>,
248        session: Option<Bound<PyAny>>,
249    ) -> PyResult<Self> {
250        let codec = extract_logical_extension_codec(py, session)?;
251        let schema =
252            Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
253        Ok(Self { schema, codec })
254    }
255
256    #[staticmethod]
257    fn memory_schema(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
258        let codec = extract_logical_extension_codec(py, session)?;
259        let schema = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
260        Ok(Self { schema, codec })
261    }
262
263    #[getter]
264    fn table_names(&self) -> HashSet<String> {
265        self.schema.table_names().into_iter().collect()
266    }
267
268    fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
269        if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
270            Ok(PyTable::from(table))
271        } else {
272            Err(PyDataFusionError::Common(format!(
273                "Table not found: {name}"
274            )))
275        }
276    }
277
278    fn __repr__(&self) -> PyResult<String> {
279        let mut names: Vec<String> = self.table_names().into_iter().collect();
280        names.sort();
281        Ok(format!("Schema(table_names=[{}])", names.join(";")))
282    }
283
284    fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
285        let py = table_provider.py();
286        let codec_capsule = create_logical_extension_capsule(py, self.codec.as_ref())?
287            .as_any()
288            .clone();
289
290        let table = PyTable::new(table_provider, Some(codec_capsule))?;
291
292        let _ = self
293            .schema
294            .register_table(name.to_string(), table.table)
295            .map_err(py_datafusion_err)?;
296
297        Ok(())
298    }
299
300    fn deregister_table(&self, name: &str) -> PyResult<()> {
301        let _ = self
302            .schema
303            .deregister_table(name)
304            .map_err(py_datafusion_err)?;
305
306        Ok(())
307    }
308
309    fn table_exist(&self, name: &str) -> bool {
310        self.schema.table_exist(name)
311    }
312}
313
314#[derive(Debug)]
315pub(crate) struct RustWrappedPySchemaProvider {
316    schema_provider: Py<PyAny>,
317    owner_name: Option<String>,
318}
319
320impl RustWrappedPySchemaProvider {
321    pub fn new(schema_provider: Py<PyAny>) -> Self {
322        let owner_name = Python::attach(|py| {
323            schema_provider
324                .bind(py)
325                .getattr("owner_name")
326                .ok()
327                .map(|name| name.to_string())
328        });
329
330        Self {
331            schema_provider,
332            owner_name,
333        }
334    }
335
336    fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
337        Python::attach(|py| {
338            let provider = self.schema_provider.bind(py);
339            let py_table_method = provider.getattr("table")?;
340
341            let py_table = py_table_method.call((name,), None)?;
342            if py_table.is_none() {
343                return Ok(None);
344            }
345
346            let table = PyTable::new(py_table, None)?;
347
348            Ok(Some(table.table))
349        })
350    }
351}
352
353#[async_trait]
354impl SchemaProvider for RustWrappedPySchemaProvider {
355    fn owner_name(&self) -> Option<&str> {
356        self.owner_name.as_deref()
357    }
358
359    fn as_any(&self) -> &dyn Any {
360        self
361    }
362
363    fn table_names(&self) -> Vec<String> {
364        Python::attach(|py| {
365            let provider = self.schema_provider.bind(py);
366
367            provider
368                .getattr("table_names")
369                .and_then(|names| names.extract::<Vec<String>>())
370                .unwrap_or_else(|err| {
371                    log::error!("Unable to get table_names: {err}");
372                    Vec::default()
373                })
374        })
375    }
376
377    async fn table(
378        &self,
379        name: &str,
380    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
381        self.table_inner(name)
382            .map_err(|e| DataFusionError::External(Box::new(e)))
383    }
384
385    fn register_table(
386        &self,
387        name: String,
388        table: Arc<dyn TableProvider>,
389    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
390        let py_table = PyTable::from(table);
391        Python::attach(|py| {
392            let provider = self.schema_provider.bind(py);
393            let _ = provider
394                .call_method1("register_table", (name, py_table))
395                .map_err(to_datafusion_err)?;
396            // Since the definition of `register_table` says that an error
397            // will be returned if the table already exists, there is no
398            // case where we want to return a table provider as output.
399            Ok(None)
400        })
401    }
402
403    fn deregister_table(
404        &self,
405        name: &str,
406    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
407        Python::attach(|py| {
408            let provider = self.schema_provider.bind(py);
409            let table = provider
410                .call_method1("deregister_table", (name,))
411                .map_err(to_datafusion_err)?;
412            if table.is_none() {
413                return Ok(None);
414            }
415
416            // If we can turn this table provider into a `Dataset`, return it.
417            // Otherwise, return None.
418            let dataset = match Dataset::new(&table, py) {
419                Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
420                Err(_) => None,
421            };
422
423            Ok(dataset)
424        })
425    }
426
427    fn table_exist(&self, name: &str) -> bool {
428        Python::attach(|py| {
429            let provider = self.schema_provider.bind(py);
430            provider
431                .call_method1("table_exist", (name,))
432                .and_then(|pyobj| pyobj.extract())
433                .unwrap_or(false)
434        })
435    }
436}
437
438#[derive(Debug)]
439pub(crate) struct RustWrappedPyCatalogProvider {
440    pub(crate) catalog_provider: Py<PyAny>,
441    codec: Arc<FFI_LogicalExtensionCodec>,
442}
443
444impl RustWrappedPyCatalogProvider {
445    pub fn new(catalog_provider: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
446        Self {
447            catalog_provider,
448            codec,
449        }
450    }
451
452    fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
453        Python::attach(|py| {
454            let provider = self.catalog_provider.bind(py);
455
456            let py_schema = provider.call_method1("schema", (name,))?;
457            if py_schema.is_none() {
458                return Ok(None);
459            }
460
461            extract_schema_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
462        })
463    }
464}
465
466#[async_trait]
467impl CatalogProvider for RustWrappedPyCatalogProvider {
468    fn as_any(&self) -> &dyn Any {
469        self
470    }
471
472    fn schema_names(&self) -> Vec<String> {
473        Python::attach(|py| {
474            let provider = self.catalog_provider.bind(py);
475            provider
476                .call_method0("schema_names")
477                .and_then(|names| names.extract::<HashSet<String>>())
478                .map(|names| names.into_iter().collect())
479                .unwrap_or_else(|err| {
480                    log::error!("Unable to get schema_names: {err}");
481                    Vec::default()
482                })
483        })
484    }
485
486    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
487        self.schema_inner(name).unwrap_or_else(|err| {
488            log::error!("CatalogProvider schema returned error: {err}");
489            None
490        })
491    }
492
493    fn register_schema(
494        &self,
495        name: &str,
496        schema: Arc<dyn SchemaProvider>,
497    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
498        Python::attach(|py| {
499            let py_schema = match schema
500                .as_any()
501                .downcast_ref::<RustWrappedPySchemaProvider>()
502            {
503                Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
504                None => &PySchema::new_from_parts(schema, self.codec.clone())
505                    .into_py_any(py)
506                    .map_err(to_datafusion_err)?,
507            };
508
509            let provider = self.catalog_provider.bind(py);
510            let schema = provider
511                .call_method1("register_schema", (name, py_schema))
512                .map_err(to_datafusion_err)?;
513            if schema.is_none() {
514                return Ok(None);
515            }
516
517            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
518                as Arc<dyn SchemaProvider>;
519
520            Ok(Some(schema))
521        })
522    }
523
524    fn deregister_schema(
525        &self,
526        name: &str,
527        cascade: bool,
528    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
529        Python::attach(|py| {
530            let provider = self.catalog_provider.bind(py);
531            let schema = provider
532                .call_method1("deregister_schema", (name, cascade))
533                .map_err(to_datafusion_err)?;
534            if schema.is_none() {
535                return Ok(None);
536            }
537
538            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
539                as Arc<dyn SchemaProvider>;
540
541            Ok(Some(schema))
542        })
543    }
544}
545
546#[derive(Debug)]
547pub(crate) struct RustWrappedPyCatalogProviderList {
548    pub(crate) catalog_provider_list: Py<PyAny>,
549    codec: Arc<FFI_LogicalExtensionCodec>,
550}
551
552impl RustWrappedPyCatalogProviderList {
553    pub fn new(catalog_provider_list: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
554        Self {
555            catalog_provider_list,
556            codec,
557        }
558    }
559
560    fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn CatalogProvider>>> {
561        Python::attach(|py| {
562            let provider = self.catalog_provider_list.bind(py);
563
564            let py_schema = provider.call_method1("catalog", (name,))?;
565            if py_schema.is_none() {
566                return Ok(None);
567            }
568
569            extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
570        })
571    }
572}
573
574#[async_trait]
575impl CatalogProviderList for RustWrappedPyCatalogProviderList {
576    fn as_any(&self) -> &dyn Any {
577        self
578    }
579
580    fn catalog_names(&self) -> Vec<String> {
581        Python::attach(|py| {
582            let provider = self.catalog_provider_list.bind(py);
583            provider
584                .call_method0("catalog_names")
585                .and_then(|names| names.extract::<HashSet<String>>())
586                .map(|names| names.into_iter().collect())
587                .unwrap_or_else(|err| {
588                    log::error!("Unable to get catalog_names: {err}");
589                    Vec::default()
590                })
591        })
592    }
593
594    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
595        self.catalog_inner(name).unwrap_or_else(|err| {
596            log::error!("CatalogProvider catalog returned error: {err}");
597            None
598        })
599    }
600
601    fn register_catalog(
602        &self,
603        name: String,
604        catalog: Arc<dyn CatalogProvider>,
605    ) -> Option<Arc<dyn CatalogProvider>> {
606        Python::attach(|py| {
607            let py_catalog = match catalog
608                .as_any()
609                .downcast_ref::<RustWrappedPyCatalogProvider>()
610            {
611                Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py),
612                None => {
613                    match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) {
614                        Ok(c) => c,
615                        Err(err) => {
616                            log::error!(
617                                "register_catalog returned error during conversion to PyAny: {err}"
618                            );
619                            return None;
620                        }
621                    }
622                }
623            };
624
625            let provider = self.catalog_provider_list.bind(py);
626            let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) {
627                Ok(c) => c,
628                Err(err) => {
629                    log::error!("register_catalog returned error: {err}");
630                    return None;
631                }
632            };
633            if catalog.is_none() {
634                return None;
635            }
636
637            let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
638                catalog.into(),
639                self.codec.clone(),
640            )) as Arc<dyn CatalogProvider>;
641
642            Some(catalog)
643        })
644    }
645}
646
647fn extract_catalog_provider_from_pyobj(
648    mut catalog_provider: Bound<PyAny>,
649    codec: &FFI_LogicalExtensionCodec,
650) -> PyResult<Arc<dyn CatalogProvider>> {
651    if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
652        let py = catalog_provider.py();
653        let codec_capsule = create_logical_extension_capsule(py, codec)?;
654        catalog_provider = catalog_provider
655            .getattr("__datafusion_catalog_provider__")?
656            .call1((codec_capsule,))?;
657    }
658
659    let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
660        let data: NonNull<FFI_CatalogProvider> = capsule
661            .pointer_checked(Some(c"datafusion_catalog_provider"))?
662            .cast();
663        let provider = unsafe { data.as_ref() };
664        let provider: Arc<dyn CatalogProvider + Send> = provider.into();
665        provider as Arc<dyn CatalogProvider>
666    } else {
667        match catalog_provider.extract::<PyCatalog>() {
668            Ok(py_catalog) => py_catalog.catalog,
669            Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
670                catalog_provider.into(),
671                Arc::new(codec.clone()),
672            )) as Arc<dyn CatalogProvider>,
673        }
674    };
675
676    Ok(provider)
677}
678
679fn extract_schema_provider_from_pyobj(
680    mut schema_provider: Bound<PyAny>,
681    codec: &FFI_LogicalExtensionCodec,
682) -> PyResult<Arc<dyn SchemaProvider>> {
683    if schema_provider.hasattr("__datafusion_schema_provider__")? {
684        let py = schema_provider.py();
685        let codec_capsule = create_logical_extension_capsule(py, codec)?;
686        schema_provider = schema_provider
687            .getattr("__datafusion_schema_provider__")?
688            .call1((codec_capsule,))?;
689    }
690
691    let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
692        let data: NonNull<FFI_SchemaProvider> = capsule
693            .pointer_checked(Some(c"datafusion_schema_provider"))?
694            .cast();
695        let provider = unsafe { data.as_ref() };
696        let provider: Arc<dyn SchemaProvider + Send> = provider.into();
697        provider as Arc<dyn SchemaProvider>
698    } else {
699        match schema_provider.extract::<PySchema>() {
700            Ok(py_schema) => py_schema.schema,
701            Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
702                as Arc<dyn SchemaProvider>,
703        }
704    };
705
706    Ok(provider)
707}
708
709fn extract_logical_extension_codec(
710    py: Python,
711    obj: Option<Bound<PyAny>>,
712) -> PyResult<Arc<FFI_LogicalExtensionCodec>> {
713    let obj = match obj {
714        Some(obj) => obj,
715        None => PySessionContext::global_ctx()?.into_bound_py_any(py)?,
716    };
717    ffi_logical_codec_from_pycapsule(obj).map(Arc::new)
718}
719
720pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
721    m.add_class::<PyCatalog>()?;
722    m.add_class::<PySchema>()?;
723    m.add_class::<PyTable>()?;
724
725    Ok(())
726}