datafusion_python/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::dataset::Dataset;
19use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20use crate::table::PyTable;
21use crate::utils::{validate_pycapsule, wait_for_future};
22use async_trait::async_trait;
23use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
24use datafusion::common::DataFusionError;
25use datafusion::{
26    catalog::{CatalogProvider, SchemaProvider},
27    datasource::TableProvider,
28};
29use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
30use pyo3::exceptions::PyKeyError;
31use pyo3::prelude::*;
32use pyo3::types::PyCapsule;
33use pyo3::IntoPyObjectExt;
34use std::any::Any;
35use std::collections::HashSet;
36use std::sync::Arc;
37
38#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
39#[derive(Clone)]
40pub struct PyCatalog {
41    pub catalog: Arc<dyn CatalogProvider>,
42}
43
44#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
45#[derive(Clone)]
46pub struct PySchema {
47    pub schema: Arc<dyn SchemaProvider>,
48}
49
50impl From<Arc<dyn CatalogProvider>> for PyCatalog {
51    fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
52        Self { catalog }
53    }
54}
55
56impl From<Arc<dyn SchemaProvider>> for PySchema {
57    fn from(schema: Arc<dyn SchemaProvider>) -> Self {
58        Self { schema }
59    }
60}
61
62#[pymethods]
63impl PyCatalog {
64    #[new]
65    fn new(catalog: PyObject) -> Self {
66        let catalog_provider =
67            Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
68        catalog_provider.into()
69    }
70
71    #[staticmethod]
72    fn memory_catalog() -> Self {
73        let catalog_provider =
74            Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
75        catalog_provider.into()
76    }
77
78    fn schema_names(&self) -> HashSet<String> {
79        self.catalog.schema_names().into_iter().collect()
80    }
81
82    #[pyo3(signature = (name="public"))]
83    fn schema(&self, name: &str) -> PyResult<PyObject> {
84        let schema = self
85            .catalog
86            .schema(name)
87            .ok_or(PyKeyError::new_err(format!(
88                "Schema with name {name} doesn't exist."
89            )))?;
90
91        Python::with_gil(|py| {
92            match schema
93                .as_any()
94                .downcast_ref::<RustWrappedPySchemaProvider>()
95            {
96                Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
97                None => PySchema::from(schema).into_py_any(py),
98            }
99        })
100    }
101
102    fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
103        let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
104            let capsule = schema_provider
105                .getattr("__datafusion_schema_provider__")?
106                .call0()?;
107            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
108            validate_pycapsule(capsule, "datafusion_schema_provider")?;
109
110            let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
111            let provider: ForeignSchemaProvider = provider.into();
112            Arc::new(provider) as Arc<dyn SchemaProvider>
113        } else {
114            match schema_provider.extract::<PySchema>() {
115                Ok(py_schema) => py_schema.schema,
116                Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
117                    as Arc<dyn SchemaProvider>,
118            }
119        };
120
121        let _ = self
122            .catalog
123            .register_schema(name, provider)
124            .map_err(py_datafusion_err)?;
125
126        Ok(())
127    }
128
129    fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
130        let _ = self
131            .catalog
132            .deregister_schema(name, cascade)
133            .map_err(py_datafusion_err)?;
134
135        Ok(())
136    }
137
138    fn __repr__(&self) -> PyResult<String> {
139        let mut names: Vec<String> = self.schema_names().into_iter().collect();
140        names.sort();
141        Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
142    }
143}
144
145#[pymethods]
146impl PySchema {
147    #[new]
148    fn new(schema_provider: PyObject) -> Self {
149        let schema_provider =
150            Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
151        schema_provider.into()
152    }
153
154    #[staticmethod]
155    fn memory_schema() -> Self {
156        let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
157        schema_provider.into()
158    }
159
160    #[getter]
161    fn table_names(&self) -> HashSet<String> {
162        self.schema.table_names().into_iter().collect()
163    }
164
165    fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
166        if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
167            Ok(PyTable::from(table))
168        } else {
169            Err(PyDataFusionError::Common(format!(
170                "Table not found: {name}"
171            )))
172        }
173    }
174
175    fn __repr__(&self) -> PyResult<String> {
176        let mut names: Vec<String> = self.table_names().into_iter().collect();
177        names.sort();
178        Ok(format!("Schema(table_names=[{}])", names.join(";")))
179    }
180
181    fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
182        let table = PyTable::new(table_provider)?;
183
184        let _ = self
185            .schema
186            .register_table(name.to_string(), table.table)
187            .map_err(py_datafusion_err)?;
188
189        Ok(())
190    }
191
192    fn deregister_table(&self, name: &str) -> PyResult<()> {
193        let _ = self
194            .schema
195            .deregister_table(name)
196            .map_err(py_datafusion_err)?;
197
198        Ok(())
199    }
200}
201
202#[derive(Debug)]
203pub(crate) struct RustWrappedPySchemaProvider {
204    schema_provider: PyObject,
205    owner_name: Option<String>,
206}
207
208impl RustWrappedPySchemaProvider {
209    pub fn new(schema_provider: PyObject) -> Self {
210        let owner_name = Python::with_gil(|py| {
211            schema_provider
212                .bind(py)
213                .getattr("owner_name")
214                .ok()
215                .map(|name| name.to_string())
216        });
217
218        Self {
219            schema_provider,
220            owner_name,
221        }
222    }
223
224    fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
225        Python::with_gil(|py| {
226            let provider = self.schema_provider.bind(py);
227            let py_table_method = provider.getattr("table")?;
228
229            let py_table = py_table_method.call((name,), None)?;
230            if py_table.is_none() {
231                return Ok(None);
232            }
233
234            let table = PyTable::new(&py_table)?;
235
236            Ok(Some(table.table))
237        })
238    }
239}
240
241#[async_trait]
242impl SchemaProvider for RustWrappedPySchemaProvider {
243    fn owner_name(&self) -> Option<&str> {
244        self.owner_name.as_deref()
245    }
246
247    fn as_any(&self) -> &dyn Any {
248        self
249    }
250
251    fn table_names(&self) -> Vec<String> {
252        Python::with_gil(|py| {
253            let provider = self.schema_provider.bind(py);
254
255            provider
256                .getattr("table_names")
257                .and_then(|names| names.extract::<Vec<String>>())
258                .unwrap_or_else(|err| {
259                    log::error!("Unable to get table_names: {err}");
260                    Vec::default()
261                })
262        })
263    }
264
265    async fn table(
266        &self,
267        name: &str,
268    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
269        self.table_inner(name).map_err(to_datafusion_err)
270    }
271
272    fn register_table(
273        &self,
274        name: String,
275        table: Arc<dyn TableProvider>,
276    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
277        let py_table = PyTable::from(table);
278        Python::with_gil(|py| {
279            let provider = self.schema_provider.bind(py);
280            let _ = provider
281                .call_method1("register_table", (name, py_table))
282                .map_err(to_datafusion_err)?;
283            // Since the definition of `register_table` says that an error
284            // will be returned if the table already exists, there is no
285            // case where we want to return a table provider as output.
286            Ok(None)
287        })
288    }
289
290    fn deregister_table(
291        &self,
292        name: &str,
293    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
294        Python::with_gil(|py| {
295            let provider = self.schema_provider.bind(py);
296            let table = provider
297                .call_method1("deregister_table", (name,))
298                .map_err(to_datafusion_err)?;
299            if table.is_none() {
300                return Ok(None);
301            }
302
303            // If we can turn this table provider into a `Dataset`, return it.
304            // Otherwise, return None.
305            let dataset = match Dataset::new(&table, py) {
306                Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
307                Err(_) => None,
308            };
309
310            Ok(dataset)
311        })
312    }
313
314    fn table_exist(&self, name: &str) -> bool {
315        Python::with_gil(|py| {
316            let provider = self.schema_provider.bind(py);
317            provider
318                .call_method1("table_exist", (name,))
319                .and_then(|pyobj| pyobj.extract())
320                .unwrap_or(false)
321        })
322    }
323}
324
325#[derive(Debug)]
326pub(crate) struct RustWrappedPyCatalogProvider {
327    pub(crate) catalog_provider: PyObject,
328}
329
330impl RustWrappedPyCatalogProvider {
331    pub fn new(catalog_provider: PyObject) -> Self {
332        Self { catalog_provider }
333    }
334
335    fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
336        Python::with_gil(|py| {
337            let provider = self.catalog_provider.bind(py);
338
339            let py_schema = provider.call_method1("schema", (name,))?;
340            if py_schema.is_none() {
341                return Ok(None);
342            }
343
344            if py_schema.hasattr("__datafusion_schema_provider__")? {
345                let capsule = provider
346                    .getattr("__datafusion_schema_provider__")?
347                    .call0()?;
348                let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
349                validate_pycapsule(capsule, "datafusion_schema_provider")?;
350
351                let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
352                let provider: ForeignSchemaProvider = provider.into();
353
354                Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
355            } else {
356                if let Ok(inner_schema) = py_schema.getattr("schema") {
357                    if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
358                        return Ok(Some(inner_schema.schema));
359                    }
360                }
361                match py_schema.extract::<PySchema>() {
362                    Ok(inner_schema) => Ok(Some(inner_schema.schema)),
363                    Err(_) => {
364                        let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
365
366                        Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
367                    }
368                }
369            }
370        })
371    }
372}
373
374#[async_trait]
375impl CatalogProvider for RustWrappedPyCatalogProvider {
376    fn as_any(&self) -> &dyn Any {
377        self
378    }
379
380    fn schema_names(&self) -> Vec<String> {
381        Python::with_gil(|py| {
382            let provider = self.catalog_provider.bind(py);
383            provider
384                .getattr("schema_names")
385                .and_then(|names| names.extract::<Vec<String>>())
386                .unwrap_or_else(|err| {
387                    log::error!("Unable to get schema_names: {err}");
388                    Vec::default()
389                })
390        })
391    }
392
393    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
394        self.schema_inner(name).unwrap_or_else(|err| {
395            log::error!("CatalogProvider schema returned error: {err}");
396            None
397        })
398    }
399
400    fn register_schema(
401        &self,
402        name: &str,
403        schema: Arc<dyn SchemaProvider>,
404    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
405        // JRIGHT HERE
406        // let py_schema: PySchema = schema.into();
407        Python::with_gil(|py| {
408            let py_schema = match schema
409                .as_any()
410                .downcast_ref::<RustWrappedPySchemaProvider>()
411            {
412                Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
413                None => &PySchema::from(schema)
414                    .into_py_any(py)
415                    .map_err(to_datafusion_err)?,
416            };
417
418            let provider = self.catalog_provider.bind(py);
419            let schema = provider
420                .call_method1("register_schema", (name, py_schema))
421                .map_err(to_datafusion_err)?;
422            if schema.is_none() {
423                return Ok(None);
424            }
425
426            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
427                as Arc<dyn SchemaProvider>;
428
429            Ok(Some(schema))
430        })
431    }
432
433    fn deregister_schema(
434        &self,
435        name: &str,
436        cascade: bool,
437    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
438        Python::with_gil(|py| {
439            let provider = self.catalog_provider.bind(py);
440            let schema = provider
441                .call_method1("deregister_schema", (name, cascade))
442                .map_err(to_datafusion_err)?;
443            if schema.is_none() {
444                return Ok(None);
445            }
446
447            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
448                as Arc<dyn SchemaProvider>;
449
450            Ok(Some(schema))
451        })
452    }
453}
454
455pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
456    m.add_class::<PyCatalog>()?;
457    m.add_class::<PySchema>()?;
458    m.add_class::<PyTable>()?;
459
460    Ok(())
461}