datafusion_python/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::dataset::Dataset;
19use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20use crate::utils::{validate_pycapsule, wait_for_future};
21use async_trait::async_trait;
22use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
23use datafusion::common::DataFusionError;
24use datafusion::{
25    arrow::pyarrow::ToPyArrow,
26    catalog::{CatalogProvider, SchemaProvider},
27    datasource::{TableProvider, TableType},
28};
29use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
30use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
31use pyo3::exceptions::PyKeyError;
32use pyo3::prelude::*;
33use pyo3::types::PyCapsule;
34use pyo3::IntoPyObjectExt;
35use std::any::Any;
36use std::collections::HashSet;
37use std::sync::Arc;
38
39#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)]
40#[derive(Clone)]
41pub struct PyCatalog {
42    pub catalog: Arc<dyn CatalogProvider>,
43}
44
45#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)]
46#[derive(Clone)]
47pub struct PySchema {
48    pub schema: Arc<dyn SchemaProvider>,
49}
50
51#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
52#[derive(Clone)]
53pub struct PyTable {
54    pub table: Arc<dyn TableProvider>,
55}
56
57impl From<Arc<dyn CatalogProvider>> for PyCatalog {
58    fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
59        Self { catalog }
60    }
61}
62
63impl From<Arc<dyn SchemaProvider>> for PySchema {
64    fn from(schema: Arc<dyn SchemaProvider>) -> Self {
65        Self { schema }
66    }
67}
68
69impl PyTable {
70    pub fn new(table: Arc<dyn TableProvider>) -> Self {
71        Self { table }
72    }
73
74    pub fn table(&self) -> Arc<dyn TableProvider> {
75        self.table.clone()
76    }
77}
78
79#[pymethods]
80impl PyCatalog {
81    #[new]
82    fn new(catalog: PyObject) -> Self {
83        let catalog_provider =
84            Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
85        catalog_provider.into()
86    }
87
88    #[staticmethod]
89    fn memory_catalog() -> Self {
90        let catalog_provider =
91            Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
92        catalog_provider.into()
93    }
94
95    fn schema_names(&self) -> HashSet<String> {
96        self.catalog.schema_names().into_iter().collect()
97    }
98
99    #[pyo3(signature = (name="public"))]
100    fn schema(&self, name: &str) -> PyResult<PyObject> {
101        let schema = self
102            .catalog
103            .schema(name)
104            .ok_or(PyKeyError::new_err(format!(
105                "Schema with name {name} doesn't exist."
106            )))?;
107
108        Python::with_gil(|py| {
109            match schema
110                .as_any()
111                .downcast_ref::<RustWrappedPySchemaProvider>()
112            {
113                Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
114                None => PySchema::from(schema).into_py_any(py),
115            }
116        })
117    }
118
119    fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
120        let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
121            let capsule = schema_provider
122                .getattr("__datafusion_schema_provider__")?
123                .call0()?;
124            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
125            validate_pycapsule(capsule, "datafusion_schema_provider")?;
126
127            let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
128            let provider: ForeignSchemaProvider = provider.into();
129            Arc::new(provider) as Arc<dyn SchemaProvider>
130        } else {
131            match schema_provider.extract::<PySchema>() {
132                Ok(py_schema) => py_schema.schema,
133                Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
134                    as Arc<dyn SchemaProvider>,
135            }
136        };
137
138        let _ = self
139            .catalog
140            .register_schema(name, provider)
141            .map_err(py_datafusion_err)?;
142
143        Ok(())
144    }
145
146    fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
147        let _ = self
148            .catalog
149            .deregister_schema(name, cascade)
150            .map_err(py_datafusion_err)?;
151
152        Ok(())
153    }
154
155    fn __repr__(&self) -> PyResult<String> {
156        let mut names: Vec<String> = self.schema_names().into_iter().collect();
157        names.sort();
158        Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
159    }
160}
161
162#[pymethods]
163impl PySchema {
164    #[new]
165    fn new(schema_provider: PyObject) -> Self {
166        let schema_provider =
167            Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
168        schema_provider.into()
169    }
170
171    #[staticmethod]
172    fn memory_schema() -> Self {
173        let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
174        schema_provider.into()
175    }
176
177    #[getter]
178    fn table_names(&self) -> HashSet<String> {
179        self.schema.table_names().into_iter().collect()
180    }
181
182    fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
183        if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
184            Ok(PyTable::new(table))
185        } else {
186            Err(PyDataFusionError::Common(format!(
187                "Table not found: {name}"
188            )))
189        }
190    }
191
192    fn __repr__(&self) -> PyResult<String> {
193        let mut names: Vec<String> = self.table_names().into_iter().collect();
194        names.sort();
195        Ok(format!("Schema(table_names=[{}])", names.join(";")))
196    }
197
198    fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
199        let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
200            let capsule = table_provider
201                .getattr("__datafusion_table_provider__")?
202                .call0()?;
203            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
204            validate_pycapsule(capsule, "datafusion_table_provider")?;
205
206            let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
207            let provider: ForeignTableProvider = provider.into();
208            Arc::new(provider) as Arc<dyn TableProvider>
209        } else {
210            match table_provider.extract::<PyTable>() {
211                Ok(py_table) => py_table.table,
212                Err(_) => {
213                    let py = table_provider.py();
214                    let provider = Dataset::new(&table_provider, py)?;
215                    Arc::new(provider) as Arc<dyn TableProvider>
216                }
217            }
218        };
219
220        let _ = self
221            .schema
222            .register_table(name.to_string(), provider)
223            .map_err(py_datafusion_err)?;
224
225        Ok(())
226    }
227
228    fn deregister_table(&self, name: &str) -> PyResult<()> {
229        let _ = self
230            .schema
231            .deregister_table(name)
232            .map_err(py_datafusion_err)?;
233
234        Ok(())
235    }
236}
237
238#[pymethods]
239impl PyTable {
240    /// Get a reference to the schema for this table
241    #[getter]
242    fn schema(&self, py: Python) -> PyResult<PyObject> {
243        self.table.schema().to_pyarrow(py)
244    }
245
246    #[staticmethod]
247    fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult<Self> {
248        let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?)
249            as Arc<dyn TableProvider>;
250
251        Ok(Self::new(ds))
252    }
253
254    /// Get the type of this table for metadata/catalog purposes.
255    #[getter]
256    fn kind(&self) -> &str {
257        match self.table.table_type() {
258            TableType::Base => "physical",
259            TableType::View => "view",
260            TableType::Temporary => "temporary",
261        }
262    }
263
264    fn __repr__(&self) -> PyResult<String> {
265        let kind = self.kind();
266        Ok(format!("Table(kind={kind})"))
267    }
268
269    // fn scan
270    // fn statistics
271    // fn has_exact_statistics
272    // fn supports_filter_pushdown
273}
274
275#[derive(Debug)]
276pub(crate) struct RustWrappedPySchemaProvider {
277    schema_provider: PyObject,
278    owner_name: Option<String>,
279}
280
281impl RustWrappedPySchemaProvider {
282    pub fn new(schema_provider: PyObject) -> Self {
283        let owner_name = Python::with_gil(|py| {
284            schema_provider
285                .bind(py)
286                .getattr("owner_name")
287                .ok()
288                .map(|name| name.to_string())
289        });
290
291        Self {
292            schema_provider,
293            owner_name,
294        }
295    }
296
297    fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
298        Python::with_gil(|py| {
299            let provider = self.schema_provider.bind(py);
300            let py_table_method = provider.getattr("table")?;
301
302            let py_table = py_table_method.call((name,), None)?;
303            if py_table.is_none() {
304                return Ok(None);
305            }
306
307            if py_table.hasattr("__datafusion_table_provider__")? {
308                let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
309                let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
310                validate_pycapsule(capsule, "datafusion_table_provider")?;
311
312                let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
313                let provider: ForeignTableProvider = provider.into();
314
315                Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
316            } else {
317                if let Ok(inner_table) = py_table.getattr("table") {
318                    if let Ok(inner_table) = inner_table.extract::<PyTable>() {
319                        return Ok(Some(inner_table.table));
320                    }
321                }
322
323                match py_table.extract::<PyTable>() {
324                    Ok(py_table) => Ok(Some(py_table.table)),
325                    Err(_) => {
326                        let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
327                        Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
328                    }
329                }
330            }
331        })
332    }
333}
334
335#[async_trait]
336impl SchemaProvider for RustWrappedPySchemaProvider {
337    fn owner_name(&self) -> Option<&str> {
338        self.owner_name.as_deref()
339    }
340
341    fn as_any(&self) -> &dyn Any {
342        self
343    }
344
345    fn table_names(&self) -> Vec<String> {
346        Python::with_gil(|py| {
347            let provider = self.schema_provider.bind(py);
348
349            provider
350                .getattr("table_names")
351                .and_then(|names| names.extract::<Vec<String>>())
352                .unwrap_or_else(|err| {
353                    log::error!("Unable to get table_names: {err}");
354                    Vec::default()
355                })
356        })
357    }
358
359    async fn table(
360        &self,
361        name: &str,
362    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
363        self.table_inner(name).map_err(to_datafusion_err)
364    }
365
366    fn register_table(
367        &self,
368        name: String,
369        table: Arc<dyn TableProvider>,
370    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
371        let py_table = PyTable::new(table);
372        Python::with_gil(|py| {
373            let provider = self.schema_provider.bind(py);
374            let _ = provider
375                .call_method1("register_table", (name, py_table))
376                .map_err(to_datafusion_err)?;
377            // Since the definition of `register_table` says that an error
378            // will be returned if the table already exists, there is no
379            // case where we want to return a table provider as output.
380            Ok(None)
381        })
382    }
383
384    fn deregister_table(
385        &self,
386        name: &str,
387    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
388        Python::with_gil(|py| {
389            let provider = self.schema_provider.bind(py);
390            let table = provider
391                .call_method1("deregister_table", (name,))
392                .map_err(to_datafusion_err)?;
393            if table.is_none() {
394                return Ok(None);
395            }
396
397            // If we can turn this table provider into a `Dataset`, return it.
398            // Otherwise, return None.
399            let dataset = match Dataset::new(&table, py) {
400                Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
401                Err(_) => None,
402            };
403
404            Ok(dataset)
405        })
406    }
407
408    fn table_exist(&self, name: &str) -> bool {
409        Python::with_gil(|py| {
410            let provider = self.schema_provider.bind(py);
411            provider
412                .call_method1("table_exist", (name,))
413                .and_then(|pyobj| pyobj.extract())
414                .unwrap_or(false)
415        })
416    }
417}
418
419#[derive(Debug)]
420pub(crate) struct RustWrappedPyCatalogProvider {
421    pub(crate) catalog_provider: PyObject,
422}
423
424impl RustWrappedPyCatalogProvider {
425    pub fn new(catalog_provider: PyObject) -> Self {
426        Self { catalog_provider }
427    }
428
429    fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
430        Python::with_gil(|py| {
431            let provider = self.catalog_provider.bind(py);
432
433            let py_schema = provider.call_method1("schema", (name,))?;
434            if py_schema.is_none() {
435                return Ok(None);
436            }
437
438            if py_schema.hasattr("__datafusion_schema_provider__")? {
439                let capsule = provider
440                    .getattr("__datafusion_schema_provider__")?
441                    .call0()?;
442                let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
443                validate_pycapsule(capsule, "datafusion_schema_provider")?;
444
445                let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
446                let provider: ForeignSchemaProvider = provider.into();
447
448                Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
449            } else {
450                if let Ok(inner_schema) = py_schema.getattr("schema") {
451                    if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
452                        return Ok(Some(inner_schema.schema));
453                    }
454                }
455                match py_schema.extract::<PySchema>() {
456                    Ok(inner_schema) => Ok(Some(inner_schema.schema)),
457                    Err(_) => {
458                        let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
459
460                        Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
461                    }
462                }
463            }
464        })
465    }
466}
467
468#[async_trait]
469impl CatalogProvider for RustWrappedPyCatalogProvider {
470    fn as_any(&self) -> &dyn Any {
471        self
472    }
473
474    fn schema_names(&self) -> Vec<String> {
475        Python::with_gil(|py| {
476            let provider = self.catalog_provider.bind(py);
477            provider
478                .getattr("schema_names")
479                .and_then(|names| names.extract::<Vec<String>>())
480                .unwrap_or_else(|err| {
481                    log::error!("Unable to get schema_names: {err}");
482                    Vec::default()
483                })
484        })
485    }
486
487    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
488        self.schema_inner(name).unwrap_or_else(|err| {
489            log::error!("CatalogProvider schema returned error: {err}");
490            None
491        })
492    }
493
494    fn register_schema(
495        &self,
496        name: &str,
497        schema: Arc<dyn SchemaProvider>,
498    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
499        // JRIGHT HERE
500        // let py_schema: PySchema = schema.into();
501        Python::with_gil(|py| {
502            let py_schema = match schema
503                .as_any()
504                .downcast_ref::<RustWrappedPySchemaProvider>()
505            {
506                Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
507                None => &PySchema::from(schema)
508                    .into_py_any(py)
509                    .map_err(to_datafusion_err)?,
510            };
511
512            let provider = self.catalog_provider.bind(py);
513            let schema = provider
514                .call_method1("register_schema", (name, py_schema))
515                .map_err(to_datafusion_err)?;
516            if schema.is_none() {
517                return Ok(None);
518            }
519
520            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
521                as Arc<dyn SchemaProvider>;
522
523            Ok(Some(schema))
524        })
525    }
526
527    fn deregister_schema(
528        &self,
529        name: &str,
530        cascade: bool,
531    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
532        Python::with_gil(|py| {
533            let provider = self.catalog_provider.bind(py);
534            let schema = provider
535                .call_method1("deregister_schema", (name, cascade))
536                .map_err(to_datafusion_err)?;
537            if schema.is_none() {
538                return Ok(None);
539            }
540
541            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
542                as Arc<dyn SchemaProvider>;
543
544            Ok(Some(schema))
545        })
546    }
547}
548
549pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
550    m.add_class::<PyCatalog>()?;
551    m.add_class::<PySchema>()?;
552    m.add_class::<PyTable>()?;
553
554    Ok(())
555}