datafusion_python/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use datafusion::catalog::{
24    CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
25};
26use datafusion::common::DataFusionError;
27use datafusion::datasource::TableProvider;
28use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
29use pyo3::exceptions::PyKeyError;
30use pyo3::prelude::*;
31use pyo3::types::PyCapsule;
32use pyo3::IntoPyObjectExt;
33
34use crate::dataset::Dataset;
35use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
36use crate::table::PyTable;
37use crate::utils::{validate_pycapsule, wait_for_future};
38
39#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
40#[derive(Clone)]
41pub struct PyCatalog {
42    pub catalog: Arc<dyn CatalogProvider>,
43}
44
45#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
46#[derive(Clone)]
47pub struct PySchema {
48    pub schema: Arc<dyn SchemaProvider>,
49}
50
51impl From<Arc<dyn CatalogProvider>> for PyCatalog {
52    fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
53        Self { catalog }
54    }
55}
56
57impl From<Arc<dyn SchemaProvider>> for PySchema {
58    fn from(schema: Arc<dyn SchemaProvider>) -> Self {
59        Self { schema }
60    }
61}
62
63#[pymethods]
64impl PyCatalog {
65    #[new]
66    fn new(catalog: Py<PyAny>) -> Self {
67        let catalog_provider =
68            Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
69        catalog_provider.into()
70    }
71
72    #[staticmethod]
73    fn memory_catalog() -> Self {
74        let catalog_provider =
75            Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
76        catalog_provider.into()
77    }
78
79    fn schema_names(&self) -> HashSet<String> {
80        self.catalog.schema_names().into_iter().collect()
81    }
82
83    #[pyo3(signature = (name="public"))]
84    fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
85        let schema = self
86            .catalog
87            .schema(name)
88            .ok_or(PyKeyError::new_err(format!(
89                "Schema with name {name} doesn't exist."
90            )))?;
91
92        Python::attach(|py| {
93            match schema
94                .as_any()
95                .downcast_ref::<RustWrappedPySchemaProvider>()
96            {
97                Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
98                None => PySchema::from(schema).into_py_any(py),
99            }
100        })
101    }
102
103    fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
104        let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
105            let capsule = schema_provider
106                .getattr("__datafusion_schema_provider__")?
107                .call0()?;
108            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
109            validate_pycapsule(capsule, "datafusion_schema_provider")?;
110
111            let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
112            let provider: ForeignSchemaProvider = provider.into();
113            Arc::new(provider) as Arc<dyn SchemaProvider>
114        } else {
115            match schema_provider.extract::<PySchema>() {
116                Ok(py_schema) => py_schema.schema,
117                Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
118                    as Arc<dyn SchemaProvider>,
119            }
120        };
121
122        let _ = self
123            .catalog
124            .register_schema(name, provider)
125            .map_err(py_datafusion_err)?;
126
127        Ok(())
128    }
129
130    fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
131        let _ = self
132            .catalog
133            .deregister_schema(name, cascade)
134            .map_err(py_datafusion_err)?;
135
136        Ok(())
137    }
138
139    fn __repr__(&self) -> PyResult<String> {
140        let mut names: Vec<String> = self.schema_names().into_iter().collect();
141        names.sort();
142        Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
143    }
144}
145
146#[pymethods]
147impl PySchema {
148    #[new]
149    fn new(schema_provider: Py<PyAny>) -> Self {
150        let schema_provider =
151            Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
152        schema_provider.into()
153    }
154
155    #[staticmethod]
156    fn memory_schema() -> Self {
157        let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
158        schema_provider.into()
159    }
160
161    #[getter]
162    fn table_names(&self) -> HashSet<String> {
163        self.schema.table_names().into_iter().collect()
164    }
165
166    fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
167        if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
168            Ok(PyTable::from(table))
169        } else {
170            Err(PyDataFusionError::Common(format!(
171                "Table not found: {name}"
172            )))
173        }
174    }
175
176    fn __repr__(&self) -> PyResult<String> {
177        let mut names: Vec<String> = self.table_names().into_iter().collect();
178        names.sort();
179        Ok(format!("Schema(table_names=[{}])", names.join(";")))
180    }
181
182    fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
183        let table = PyTable::new(table_provider)?;
184
185        let _ = self
186            .schema
187            .register_table(name.to_string(), table.table)
188            .map_err(py_datafusion_err)?;
189
190        Ok(())
191    }
192
193    fn deregister_table(&self, name: &str) -> PyResult<()> {
194        let _ = self
195            .schema
196            .deregister_table(name)
197            .map_err(py_datafusion_err)?;
198
199        Ok(())
200    }
201}
202
203#[derive(Debug)]
204pub(crate) struct RustWrappedPySchemaProvider {
205    schema_provider: Py<PyAny>,
206    owner_name: Option<String>,
207}
208
209impl RustWrappedPySchemaProvider {
210    pub fn new(schema_provider: Py<PyAny>) -> Self {
211        let owner_name = Python::attach(|py| {
212            schema_provider
213                .bind(py)
214                .getattr("owner_name")
215                .ok()
216                .map(|name| name.to_string())
217        });
218
219        Self {
220            schema_provider,
221            owner_name,
222        }
223    }
224
225    fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
226        Python::attach(|py| {
227            let provider = self.schema_provider.bind(py);
228            let py_table_method = provider.getattr("table")?;
229
230            let py_table = py_table_method.call((name,), None)?;
231            if py_table.is_none() {
232                return Ok(None);
233            }
234
235            let table = PyTable::new(&py_table)?;
236
237            Ok(Some(table.table))
238        })
239    }
240}
241
242#[async_trait]
243impl SchemaProvider for RustWrappedPySchemaProvider {
244    fn owner_name(&self) -> Option<&str> {
245        self.owner_name.as_deref()
246    }
247
248    fn as_any(&self) -> &dyn Any {
249        self
250    }
251
252    fn table_names(&self) -> Vec<String> {
253        Python::attach(|py| {
254            let provider = self.schema_provider.bind(py);
255
256            provider
257                .getattr("table_names")
258                .and_then(|names| names.extract::<Vec<String>>())
259                .unwrap_or_else(|err| {
260                    log::error!("Unable to get table_names: {err}");
261                    Vec::default()
262                })
263        })
264    }
265
266    async fn table(
267        &self,
268        name: &str,
269    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
270        self.table_inner(name).map_err(to_datafusion_err)
271    }
272
273    fn register_table(
274        &self,
275        name: String,
276        table: Arc<dyn TableProvider>,
277    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
278        let py_table = PyTable::from(table);
279        Python::attach(|py| {
280            let provider = self.schema_provider.bind(py);
281            let _ = provider
282                .call_method1("register_table", (name, py_table))
283                .map_err(to_datafusion_err)?;
284            // Since the definition of `register_table` says that an error
285            // will be returned if the table already exists, there is no
286            // case where we want to return a table provider as output.
287            Ok(None)
288        })
289    }
290
291    fn deregister_table(
292        &self,
293        name: &str,
294    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
295        Python::attach(|py| {
296            let provider = self.schema_provider.bind(py);
297            let table = provider
298                .call_method1("deregister_table", (name,))
299                .map_err(to_datafusion_err)?;
300            if table.is_none() {
301                return Ok(None);
302            }
303
304            // If we can turn this table provider into a `Dataset`, return it.
305            // Otherwise, return None.
306            let dataset = match Dataset::new(&table, py) {
307                Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
308                Err(_) => None,
309            };
310
311            Ok(dataset)
312        })
313    }
314
315    fn table_exist(&self, name: &str) -> bool {
316        Python::attach(|py| {
317            let provider = self.schema_provider.bind(py);
318            provider
319                .call_method1("table_exist", (name,))
320                .and_then(|pyobj| pyobj.extract())
321                .unwrap_or(false)
322        })
323    }
324}
325
326#[derive(Debug)]
327pub(crate) struct RustWrappedPyCatalogProvider {
328    pub(crate) catalog_provider: Py<PyAny>,
329}
330
331impl RustWrappedPyCatalogProvider {
332    pub fn new(catalog_provider: Py<PyAny>) -> Self {
333        Self { catalog_provider }
334    }
335
336    fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
337        Python::attach(|py| {
338            let provider = self.catalog_provider.bind(py);
339
340            let py_schema = provider.call_method1("schema", (name,))?;
341            if py_schema.is_none() {
342                return Ok(None);
343            }
344
345            if py_schema.hasattr("__datafusion_schema_provider__")? {
346                let capsule = provider
347                    .getattr("__datafusion_schema_provider__")?
348                    .call0()?;
349                let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
350                validate_pycapsule(capsule, "datafusion_schema_provider")?;
351
352                let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
353                let provider: ForeignSchemaProvider = provider.into();
354
355                Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
356            } else {
357                if let Ok(inner_schema) = py_schema.getattr("schema") {
358                    if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
359                        return Ok(Some(inner_schema.schema));
360                    }
361                }
362                match py_schema.extract::<PySchema>() {
363                    Ok(inner_schema) => Ok(Some(inner_schema.schema)),
364                    Err(_) => {
365                        let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
366
367                        Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
368                    }
369                }
370            }
371        })
372    }
373}
374
375#[async_trait]
376impl CatalogProvider for RustWrappedPyCatalogProvider {
377    fn as_any(&self) -> &dyn Any {
378        self
379    }
380
381    fn schema_names(&self) -> Vec<String> {
382        Python::attach(|py| {
383            let provider = self.catalog_provider.bind(py);
384            provider
385                .getattr("schema_names")
386                .and_then(|names| names.extract::<Vec<String>>())
387                .unwrap_or_else(|err| {
388                    log::error!("Unable to get schema_names: {err}");
389                    Vec::default()
390                })
391        })
392    }
393
394    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
395        self.schema_inner(name).unwrap_or_else(|err| {
396            log::error!("CatalogProvider schema returned error: {err}");
397            None
398        })
399    }
400
401    fn register_schema(
402        &self,
403        name: &str,
404        schema: Arc<dyn SchemaProvider>,
405    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
406        // JRIGHT HERE
407        // let py_schema: PySchema = schema.into();
408        Python::attach(|py| {
409            let py_schema = match schema
410                .as_any()
411                .downcast_ref::<RustWrappedPySchemaProvider>()
412            {
413                Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
414                None => &PySchema::from(schema)
415                    .into_py_any(py)
416                    .map_err(to_datafusion_err)?,
417            };
418
419            let provider = self.catalog_provider.bind(py);
420            let schema = provider
421                .call_method1("register_schema", (name, py_schema))
422                .map_err(to_datafusion_err)?;
423            if schema.is_none() {
424                return Ok(None);
425            }
426
427            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
428                as Arc<dyn SchemaProvider>;
429
430            Ok(Some(schema))
431        })
432    }
433
434    fn deregister_schema(
435        &self,
436        name: &str,
437        cascade: bool,
438    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
439        Python::attach(|py| {
440            let provider = self.catalog_provider.bind(py);
441            let schema = provider
442                .call_method1("deregister_schema", (name, cascade))
443                .map_err(to_datafusion_err)?;
444            if schema.is_none() {
445                return Ok(None);
446            }
447
448            let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
449                as Arc<dyn SchemaProvider>;
450
451            Ok(Some(schema))
452        })
453    }
454}
455
456pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
457    m.add_class::<PyCatalog>()?;
458    m.add_class::<PySchema>()?;
459    m.add_class::<PyTable>()?;
460
461    Ok(())
462}