Skip to main content

datafusion_python/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use datafusion::catalog::{
24    CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList,
25    MemorySchemaProvider, SchemaProvider,
26};
27use datafusion::common::DataFusionError;
28use datafusion::datasource::TableProvider;
29use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
30use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
31use datafusion_ffi::schema_provider::FFI_SchemaProvider;
32use pyo3::IntoPyObjectExt;
33use pyo3::exceptions::PyKeyError;
34use pyo3::prelude::*;
35use pyo3::types::PyCapsule;
36
37use crate::dataset::Dataset;
38use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
39use crate::table::PyTable;
40use crate::utils::{
41    create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
42    wait_for_future,
43};
44
45#[pyclass(
46    frozen,
47    name = "RawCatalogList",
48    module = "datafusion.catalog",
49    subclass
50)]
51#[derive(Clone)]
52pub struct PyCatalogList {
53    pub catalog_list: Arc<dyn CatalogProviderList>,
54    codec: Arc<FFI_LogicalExtensionCodec>,
55}
56
57#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
58#[derive(Clone)]
59pub struct PyCatalog {
60    pub catalog: Arc<dyn CatalogProvider>,
61    codec: Arc<FFI_LogicalExtensionCodec>,
62}
63
64#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
65#[derive(Clone)]
66pub struct PySchema {
67    pub schema: Arc<dyn SchemaProvider>,
68    codec: Arc<FFI_LogicalExtensionCodec>,
69}
70
71impl PyCatalog {
72    pub(crate) fn new_from_parts(
73        catalog: Arc<dyn CatalogProvider>,
74        codec: Arc<FFI_LogicalExtensionCodec>,
75    ) -> Self {
76        Self { catalog, codec }
77    }
78}
79
80impl PySchema {
81    pub(crate) fn new_from_parts(
82        schema: Arc<dyn SchemaProvider>,
83        codec: Arc<FFI_LogicalExtensionCodec>,
84    ) -> Self {
85        Self { schema, codec }
86    }
87}
88
89#[pymethods]
90impl PyCatalogList {
91    #[new]
92    pub fn new(
93        py: Python,
94        catalog_list: Py<PyAny>,
95        session: Option<Bound<PyAny>>,
96    ) -> PyResult<Self> {
97        let codec = extract_logical_extension_codec(py, session)?;
98        let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
99            catalog_list,
100            codec.clone(),
101        )) as Arc<dyn CatalogProviderList>;
102        Ok(Self {
103            catalog_list,
104            codec,
105        })
106    }
107
108    #[staticmethod]
109    pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
110        let codec = extract_logical_extension_codec(py, session)?;
111        let catalog_list =
112            Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn CatalogProviderList>;
113        Ok(Self {
114            catalog_list,
115            codec,
116        })
117    }
118
119    pub fn catalog_names(&self) -> HashSet<String> {
120        self.catalog_list.catalog_names().into_iter().collect()
121    }
122
123    #[pyo3(signature = (name="public"))]
124    pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
125        let catalog = self
126            .catalog_list
127            .catalog(name)
128            .ok_or(PyKeyError::new_err(format!(
129                "Schema with name {name} doesn't exist."
130            )))?;
131
132        Python::attach(|py| {
133            match catalog
134                .as_any()
135                .downcast_ref::<RustWrappedPyCatalogProvider>()
136            {
137                Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
138                None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py),
139            }
140        })
141    }
142
143    pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> {
144        let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?;
145
146        let _ = self
147            .catalog_list
148            .register_catalog(name.to_owned(), provider);
149
150        Ok(())
151    }
152
153    pub fn __repr__(&self) -> PyResult<String> {
154        let mut names: Vec<String> = self.catalog_names().into_iter().collect();
155        names.sort();
156        Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
157    }
158}
159
160#[pymethods]
161impl PyCatalog {
162    #[new]
163    pub fn new(py: Python, catalog: Py<PyAny>, session: Option<Bound<PyAny>>) -> PyResult<Self> {
164        let codec = extract_logical_extension_codec(py, session)?;
165        let catalog = Arc::new(RustWrappedPyCatalogProvider::new(catalog, codec.clone()))
166            as Arc<dyn CatalogProvider>;
167        Ok(Self { catalog, codec })
168    }
169
170    #[staticmethod]
171    pub fn memory_catalog(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
172        let codec = extract_logical_extension_codec(py, session)?;
173        let catalog = Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
174        Ok(Self { catalog, codec })
175    }
176
177    pub fn schema_names(&self) -> HashSet<String> {
178        self.catalog.schema_names().into_iter().collect()
179    }
180
181    #[pyo3(signature = (name="public"))]
182    pub fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
183        let schema = self
184            .catalog
185            .schema(name)
186            .ok_or(PyKeyError::new_err(format!(
187                "Schema with name {name} doesn't exist."
188            )))?;
189
190        Python::attach(|py| {
191            match schema
192                .as_any()
193                .downcast_ref::<RustWrappedPySchemaProvider>()
194            {
195                Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
196                None => PySchema::new_from_parts(schema, self.codec.clone()).into_py_any(py),
197            }
198        })
199    }
200
201    pub fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
202        let provider = extract_schema_provider_from_pyobj(schema_provider, self.codec.as_ref())?;
203
204        let _ = self
205            .catalog
206            .register_schema(name, provider)
207            .map_err(py_datafusion_err)?;
208
209        Ok(())
210    }
211
212    pub fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
213        let _ = self
214            .catalog
215            .deregister_schema(name, cascade)
216            .map_err(py_datafusion_err)?;
217
218        Ok(())
219    }
220
221    pub fn __repr__(&self) -> PyResult<String> {
222        let mut names: Vec<String> = self.schema_names().into_iter().collect();
223        names.sort();
224        Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
225    }
226}
227
228#[pymethods]
229impl PySchema {
230    #[new]
231    pub fn new(
232        py: Python,
233        schema_provider: Py<PyAny>,
234        session: Option<Bound<PyAny>>,
235    ) -> PyResult<Self> {
236        let codec = extract_logical_extension_codec(py, session)?;
237        let schema =
238            Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
239        Ok(Self { schema, codec })
240    }
241
242    #[staticmethod]
243    fn memory_schema(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
244        let codec = extract_logical_extension_codec(py, session)?;
245        let schema = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
246        Ok(Self { schema, codec })
247    }
248
249    #[getter]
250    fn table_names(&self) -> HashSet<String> {
251        self.schema.table_names().into_iter().collect()
252    }
253
254    fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
255        if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
256            Ok(PyTable::from(table))
257        } else {
258            Err(PyDataFusionError::Common(format!(
259                "Table not found: {name}"
260            )))
261        }
262    }
263
264    fn __repr__(&self) -> PyResult<String> {
265        let mut names: Vec<String> = self.table_names().into_iter().collect();
266        names.sort();
267        Ok(format!("Schema(table_names=[{}])", names.join(";")))
268    }
269
270    fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
271        let py = table_provider.py();
272        let codec_capsule = create_logical_extension_capsule(py, self.codec.as_ref())?
273            .as_any()
274            .clone();
275
276        let table = PyTable::new(table_provider, Some(codec_capsule))?;
277
278        let _ = self
279            .schema
280            .register_table(name.to_string(), table.table)
281            .map_err(py_datafusion_err)?;
282
283        Ok(())
284    }
285
286    fn deregister_table(&self, name: &str) -> PyResult<()> {
287        let _ = self
288            .schema
289            .deregister_table(name)
290            .map_err(py_datafusion_err)?;
291
292        Ok(())
293    }
294
295    fn table_exist(&self, name: &str) -> bool {
296        self.schema.table_exist(name)
297    }
298}
299
300#[derive(Debug)]
301pub(crate) struct RustWrappedPySchemaProvider {
302    schema_provider: Py<PyAny>,
303    owner_name: Option<String>,
304}
305
306impl RustWrappedPySchemaProvider {
307    pub fn new(schema_provider: Py<PyAny>) -> Self {
308        let owner_name = Python::attach(|py| {
309            schema_provider
310                .bind(py)
311                .getattr("owner_name")
312                .ok()
313                .map(|name| name.to_string())
314        });
315
316        Self {
317            schema_provider,
318            owner_name,
319        }
320    }
321
322    fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
323        Python::attach(|py| {
324            let provider = self.schema_provider.bind(py);
325            let py_table_method = provider.getattr("table")?;
326
327            let py_table = py_table_method.call((name,), None)?;
328            if py_table.is_none() {
329                return Ok(None);
330            }
331
332            let table = PyTable::new(py_table, None)?;
333
334            Ok(Some(table.table))
335        })
336    }
337}
338
339#[async_trait]
340impl SchemaProvider for RustWrappedPySchemaProvider {
341    fn owner_name(&self) -> Option<&str> {
342        self.owner_name.as_deref()
343    }
344
345    fn as_any(&self) -> &dyn Any {
346        self
347    }
348
349    fn table_names(&self) -> Vec<String> {
350        Python::attach(|py| {
351            let provider = self.schema_provider.bind(py);
352
353            provider
354                .getattr("table_names")
355                .and_then(|names| names.extract::<Vec<String>>())
356                .unwrap_or_else(|err| {
357                    log::error!("Unable to get table_names: {err}");
358                    Vec::default()
359                })
360        })
361    }
362
363    async fn table(
364        &self,
365        name: &str,
366    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
367        self.table_inner(name)
368            .map_err(|e| DataFusionError::External(Box::new(e)))
369    }
370
371    fn register_table(
372        &self,
373        name: String,
374        table: Arc<dyn TableProvider>,
375    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
376        let py_table = PyTable::from(table);
377        Python::attach(|py| {
378            let provider = self.schema_provider.bind(py);
379            let _ = provider
380                .call_method1("register_table", (name, py_table))
381                .map_err(to_datafusion_err)?;
382            // Since the definition of `register_table` says that an error
383            // will be returned if the table already exists, there is no
384            // case where we want to return a table provider as output.
385            Ok(None)
386        })
387    }
388
389    fn deregister_table(
390        &self,
391        name: &str,
392    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
393        Python::attach(|py| {
394            let provider = self.schema_provider.bind(py);
395            let table = provider
396                .call_method1("deregister_table", (name,))
397                .map_err(to_datafusion_err)?;
398            if table.is_none() {
399                return Ok(None);
400            }
401
402            // If we can turn this table provider into a `Dataset`, return it.
403            // Otherwise, return None.
404            let dataset = match Dataset::new(&table, py) {
405                Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
406                Err(_) => None,
407            };
408
409            Ok(dataset)
410        })
411    }
412
413    fn table_exist(&self, name: &str) -> bool {
414        Python::attach(|py| {
415            let provider = self.schema_provider.bind(py);
416            provider
417                .call_method1("table_exist", (name,))
418                .and_then(|pyobj| pyobj.extract())
419                .unwrap_or(false)
420        })
421    }
422}
423
424#[derive(Debug)]
425pub(crate) struct RustWrappedPyCatalogProvider {
426    pub(crate) catalog_provider: Py<PyAny>,
427    codec: Arc<FFI_LogicalExtensionCodec>,
428}
429
430impl RustWrappedPyCatalogProvider {
431    pub fn new(catalog_provider: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
432        Self {
433            catalog_provider,
434            codec,
435        }
436    }
437
438    fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
439        Python::attach(|py| {
440            let provider = self.catalog_provider.bind(py);
441
442            let py_schema = provider.call_method1("schema", (name,))?;
443            if py_schema.is_none() {
444                return Ok(None);
445            }
446
447            extract_schema_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
448        })
449    }
450}
451
452#[async_trait]
453impl CatalogProvider for RustWrappedPyCatalogProvider {
454    fn as_any(&self) -> &dyn Any {
455        self
456    }
457
458    fn schema_names(&self) -> Vec<String> {
459        Python::attach(|py| {
460            let provider = self.catalog_provider.bind(py);
461            provider
462                .call_method0("schema_names")
463                .and_then(|names| names.extract::<HashSet<String>>())
464                .map(|names| names.into_iter().collect())
465                .unwrap_or_else(|err| {
466                    log::error!("Unable to get schema_names: {err}");
467                    Vec::default()
468                })
469        })
470    }
471
472    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
473        self.schema_inner(name).unwrap_or_else(|err| {
474            log::error!("CatalogProvider schema returned error: {err}");
475            None
476        })
477    }
478
479    fn register_schema(
480        &self,
481        name: &str,
482        schema: Arc<dyn SchemaProvider>,
483    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
484        Python::attach(|py| {
485            let py_schema = match schema
486                .as_any()
487                .downcast_ref::<RustWrappedPySchemaProvider>()
488            {
489                Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
490                None => &PySchema::new_from_parts(schema, self.codec.clone())
491                    .into_py_any(py)
492                    .map_err(to_datafusion_err)?,
493            };
494
495            let provider = self.catalog_provider.bind(py);
496            let schema = provider
497                .call_method1("register_schema", (name, py_schema))
498                .map_err(to_datafusion_err)?;
499            if schema.is_none() {
500                return Ok(None);
501            }
502
503            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
504                as Arc<dyn SchemaProvider>;
505
506            Ok(Some(schema))
507        })
508    }
509
510    fn deregister_schema(
511        &self,
512        name: &str,
513        cascade: bool,
514    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
515        Python::attach(|py| {
516            let provider = self.catalog_provider.bind(py);
517            let schema = provider
518                .call_method1("deregister_schema", (name, cascade))
519                .map_err(to_datafusion_err)?;
520            if schema.is_none() {
521                return Ok(None);
522            }
523
524            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
525                as Arc<dyn SchemaProvider>;
526
527            Ok(Some(schema))
528        })
529    }
530}
531
532#[derive(Debug)]
533pub(crate) struct RustWrappedPyCatalogProviderList {
534    pub(crate) catalog_provider_list: Py<PyAny>,
535    codec: Arc<FFI_LogicalExtensionCodec>,
536}
537
538impl RustWrappedPyCatalogProviderList {
539    pub fn new(catalog_provider_list: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
540        Self {
541            catalog_provider_list,
542            codec,
543        }
544    }
545
546    fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn CatalogProvider>>> {
547        Python::attach(|py| {
548            let provider = self.catalog_provider_list.bind(py);
549
550            let py_schema = provider.call_method1("catalog", (name,))?;
551            if py_schema.is_none() {
552                return Ok(None);
553            }
554
555            extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
556        })
557    }
558}
559
560#[async_trait]
561impl CatalogProviderList for RustWrappedPyCatalogProviderList {
562    fn as_any(&self) -> &dyn Any {
563        self
564    }
565
566    fn catalog_names(&self) -> Vec<String> {
567        Python::attach(|py| {
568            let provider = self.catalog_provider_list.bind(py);
569            provider
570                .call_method0("catalog_names")
571                .and_then(|names| names.extract::<HashSet<String>>())
572                .map(|names| names.into_iter().collect())
573                .unwrap_or_else(|err| {
574                    log::error!("Unable to get catalog_names: {err}");
575                    Vec::default()
576                })
577        })
578    }
579
580    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
581        self.catalog_inner(name).unwrap_or_else(|err| {
582            log::error!("CatalogProvider catalog returned error: {err}");
583            None
584        })
585    }
586
587    fn register_catalog(
588        &self,
589        name: String,
590        catalog: Arc<dyn CatalogProvider>,
591    ) -> Option<Arc<dyn CatalogProvider>> {
592        Python::attach(|py| {
593            let py_catalog = match catalog
594                .as_any()
595                .downcast_ref::<RustWrappedPyCatalogProvider>()
596            {
597                Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py),
598                None => {
599                    match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) {
600                        Ok(c) => c,
601                        Err(err) => {
602                            log::error!(
603                                "register_catalog returned error during conversion to PyAny: {err}"
604                            );
605                            return None;
606                        }
607                    }
608                }
609            };
610
611            let provider = self.catalog_provider_list.bind(py);
612            let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) {
613                Ok(c) => c,
614                Err(err) => {
615                    log::error!("register_catalog returned error: {err}");
616                    return None;
617                }
618            };
619            if catalog.is_none() {
620                return None;
621            }
622
623            let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
624                catalog.into(),
625                self.codec.clone(),
626            )) as Arc<dyn CatalogProvider>;
627
628            Some(catalog)
629        })
630    }
631}
632
633fn extract_catalog_provider_from_pyobj(
634    mut catalog_provider: Bound<PyAny>,
635    codec: &FFI_LogicalExtensionCodec,
636) -> PyResult<Arc<dyn CatalogProvider>> {
637    if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
638        let py = catalog_provider.py();
639        let codec_capsule = create_logical_extension_capsule(py, codec)?;
640        catalog_provider = catalog_provider
641            .getattr("__datafusion_catalog_provider__")?
642            .call1((codec_capsule,))?;
643    }
644
645    let provider = if let Ok(capsule) = catalog_provider.downcast::<PyCapsule>() {
646        validate_pycapsule(capsule, "datafusion_catalog_provider")?;
647
648        let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
649        let provider: Arc<dyn CatalogProvider + Send> = provider.into();
650        provider as Arc<dyn CatalogProvider>
651    } else {
652        match catalog_provider.extract::<PyCatalog>() {
653            Ok(py_catalog) => py_catalog.catalog,
654            Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
655                catalog_provider.into(),
656                Arc::new(codec.clone()),
657            )) as Arc<dyn CatalogProvider>,
658        }
659    };
660
661    Ok(provider)
662}
663
664fn extract_schema_provider_from_pyobj(
665    mut schema_provider: Bound<PyAny>,
666    codec: &FFI_LogicalExtensionCodec,
667) -> PyResult<Arc<dyn SchemaProvider>> {
668    if schema_provider.hasattr("__datafusion_schema_provider__")? {
669        let py = schema_provider.py();
670        let codec_capsule = create_logical_extension_capsule(py, codec)?;
671        schema_provider = schema_provider
672            .getattr("__datafusion_schema_provider__")?
673            .call1((codec_capsule,))?;
674    }
675
676    let provider = if let Ok(capsule) = schema_provider.downcast::<PyCapsule>() {
677        validate_pycapsule(capsule, "datafusion_schema_provider")?;
678
679        let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
680        let provider: Arc<dyn SchemaProvider + Send> = provider.into();
681        provider as Arc<dyn SchemaProvider>
682    } else {
683        match schema_provider.extract::<PySchema>() {
684            Ok(py_schema) => py_schema.schema,
685            Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
686                as Arc<dyn SchemaProvider>,
687        }
688    };
689
690    Ok(provider)
691}
692
693pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
694    m.add_class::<PyCatalog>()?;
695    m.add_class::<PySchema>()?;
696    m.add_class::<PyTable>()?;
697
698    Ok(())
699}