1use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use datafusion::catalog::{
24 CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
25};
26use datafusion::common::DataFusionError;
27use datafusion::datasource::TableProvider;
28use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
29use pyo3::exceptions::PyKeyError;
30use pyo3::prelude::*;
31use pyo3::types::PyCapsule;
32use pyo3::IntoPyObjectExt;
33
34use crate::dataset::Dataset;
35use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
36use crate::table::PyTable;
37use crate::utils::{validate_pycapsule, wait_for_future};
38
39#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
40#[derive(Clone)]
41pub struct PyCatalog {
42 pub catalog: Arc<dyn CatalogProvider>,
43}
44
45#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)]
46#[derive(Clone)]
47pub struct PySchema {
48 pub schema: Arc<dyn SchemaProvider>,
49}
50
51impl From<Arc<dyn CatalogProvider>> for PyCatalog {
52 fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
53 Self { catalog }
54 }
55}
56
57impl From<Arc<dyn SchemaProvider>> for PySchema {
58 fn from(schema: Arc<dyn SchemaProvider>) -> Self {
59 Self { schema }
60 }
61}
62
63#[pymethods]
64impl PyCatalog {
65 #[new]
66 fn new(catalog: Py<PyAny>) -> Self {
67 let catalog_provider =
68 Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
69 catalog_provider.into()
70 }
71
72 #[staticmethod]
73 fn memory_catalog() -> Self {
74 let catalog_provider =
75 Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
76 catalog_provider.into()
77 }
78
79 fn schema_names(&self) -> HashSet<String> {
80 self.catalog.schema_names().into_iter().collect()
81 }
82
83 #[pyo3(signature = (name="public"))]
84 fn schema(&self, name: &str) -> PyResult<Py<PyAny>> {
85 let schema = self
86 .catalog
87 .schema(name)
88 .ok_or(PyKeyError::new_err(format!(
89 "Schema with name {name} doesn't exist."
90 )))?;
91
92 Python::attach(|py| {
93 match schema
94 .as_any()
95 .downcast_ref::<RustWrappedPySchemaProvider>()
96 {
97 Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)),
98 None => PySchema::from(schema).into_py_any(py),
99 }
100 })
101 }
102
103 fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
104 let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
105 let capsule = schema_provider
106 .getattr("__datafusion_schema_provider__")?
107 .call0()?;
108 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
109 validate_pycapsule(capsule, "datafusion_schema_provider")?;
110
111 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
112 let provider: ForeignSchemaProvider = provider.into();
113 Arc::new(provider) as Arc<dyn SchemaProvider>
114 } else {
115 match schema_provider.extract::<PySchema>() {
116 Ok(py_schema) => py_schema.schema,
117 Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
118 as Arc<dyn SchemaProvider>,
119 }
120 };
121
122 let _ = self
123 .catalog
124 .register_schema(name, provider)
125 .map_err(py_datafusion_err)?;
126
127 Ok(())
128 }
129
130 fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> {
131 let _ = self
132 .catalog
133 .deregister_schema(name, cascade)
134 .map_err(py_datafusion_err)?;
135
136 Ok(())
137 }
138
139 fn __repr__(&self) -> PyResult<String> {
140 let mut names: Vec<String> = self.schema_names().into_iter().collect();
141 names.sort();
142 Ok(format!("Catalog(schema_names=[{}])", names.join(", ")))
143 }
144}
145
146#[pymethods]
147impl PySchema {
148 #[new]
149 fn new(schema_provider: Py<PyAny>) -> Self {
150 let schema_provider =
151 Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
152 schema_provider.into()
153 }
154
155 #[staticmethod]
156 fn memory_schema() -> Self {
157 let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
158 schema_provider.into()
159 }
160
161 #[getter]
162 fn table_names(&self) -> HashSet<String> {
163 self.schema.table_names().into_iter().collect()
164 }
165
166 fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
167 if let Some(table) = wait_for_future(py, self.schema.table(name))?? {
168 Ok(PyTable::from(table))
169 } else {
170 Err(PyDataFusionError::Common(format!(
171 "Table not found: {name}"
172 )))
173 }
174 }
175
176 fn __repr__(&self) -> PyResult<String> {
177 let mut names: Vec<String> = self.table_names().into_iter().collect();
178 names.sort();
179 Ok(format!("Schema(table_names=[{}])", names.join(";")))
180 }
181
182 fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
183 let table = PyTable::new(table_provider)?;
184
185 let _ = self
186 .schema
187 .register_table(name.to_string(), table.table)
188 .map_err(py_datafusion_err)?;
189
190 Ok(())
191 }
192
193 fn deregister_table(&self, name: &str) -> PyResult<()> {
194 let _ = self
195 .schema
196 .deregister_table(name)
197 .map_err(py_datafusion_err)?;
198
199 Ok(())
200 }
201}
202
203#[derive(Debug)]
204pub(crate) struct RustWrappedPySchemaProvider {
205 schema_provider: Py<PyAny>,
206 owner_name: Option<String>,
207}
208
209impl RustWrappedPySchemaProvider {
210 pub fn new(schema_provider: Py<PyAny>) -> Self {
211 let owner_name = Python::attach(|py| {
212 schema_provider
213 .bind(py)
214 .getattr("owner_name")
215 .ok()
216 .map(|name| name.to_string())
217 });
218
219 Self {
220 schema_provider,
221 owner_name,
222 }
223 }
224
225 fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
226 Python::attach(|py| {
227 let provider = self.schema_provider.bind(py);
228 let py_table_method = provider.getattr("table")?;
229
230 let py_table = py_table_method.call((name,), None)?;
231 if py_table.is_none() {
232 return Ok(None);
233 }
234
235 let table = PyTable::new(&py_table)?;
236
237 Ok(Some(table.table))
238 })
239 }
240}
241
242#[async_trait]
243impl SchemaProvider for RustWrappedPySchemaProvider {
244 fn owner_name(&self) -> Option<&str> {
245 self.owner_name.as_deref()
246 }
247
248 fn as_any(&self) -> &dyn Any {
249 self
250 }
251
252 fn table_names(&self) -> Vec<String> {
253 Python::attach(|py| {
254 let provider = self.schema_provider.bind(py);
255
256 provider
257 .getattr("table_names")
258 .and_then(|names| names.extract::<Vec<String>>())
259 .unwrap_or_else(|err| {
260 log::error!("Unable to get table_names: {err}");
261 Vec::default()
262 })
263 })
264 }
265
266 async fn table(
267 &self,
268 name: &str,
269 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
270 self.table_inner(name).map_err(to_datafusion_err)
271 }
272
273 fn register_table(
274 &self,
275 name: String,
276 table: Arc<dyn TableProvider>,
277 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
278 let py_table = PyTable::from(table);
279 Python::attach(|py| {
280 let provider = self.schema_provider.bind(py);
281 let _ = provider
282 .call_method1("register_table", (name, py_table))
283 .map_err(to_datafusion_err)?;
284 Ok(None)
288 })
289 }
290
291 fn deregister_table(
292 &self,
293 name: &str,
294 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
295 Python::attach(|py| {
296 let provider = self.schema_provider.bind(py);
297 let table = provider
298 .call_method1("deregister_table", (name,))
299 .map_err(to_datafusion_err)?;
300 if table.is_none() {
301 return Ok(None);
302 }
303
304 let dataset = match Dataset::new(&table, py) {
307 Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
308 Err(_) => None,
309 };
310
311 Ok(dataset)
312 })
313 }
314
315 fn table_exist(&self, name: &str) -> bool {
316 Python::attach(|py| {
317 let provider = self.schema_provider.bind(py);
318 provider
319 .call_method1("table_exist", (name,))
320 .and_then(|pyobj| pyobj.extract())
321 .unwrap_or(false)
322 })
323 }
324}
325
326#[derive(Debug)]
327pub(crate) struct RustWrappedPyCatalogProvider {
328 pub(crate) catalog_provider: Py<PyAny>,
329}
330
331impl RustWrappedPyCatalogProvider {
332 pub fn new(catalog_provider: Py<PyAny>) -> Self {
333 Self { catalog_provider }
334 }
335
336 fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
337 Python::attach(|py| {
338 let provider = self.catalog_provider.bind(py);
339
340 let py_schema = provider.call_method1("schema", (name,))?;
341 if py_schema.is_none() {
342 return Ok(None);
343 }
344
345 if py_schema.hasattr("__datafusion_schema_provider__")? {
346 let capsule = provider
347 .getattr("__datafusion_schema_provider__")?
348 .call0()?;
349 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
350 validate_pycapsule(capsule, "datafusion_schema_provider")?;
351
352 let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
353 let provider: ForeignSchemaProvider = provider.into();
354
355 Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
356 } else {
357 if let Ok(inner_schema) = py_schema.getattr("schema") {
358 if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
359 return Ok(Some(inner_schema.schema));
360 }
361 }
362 match py_schema.extract::<PySchema>() {
363 Ok(inner_schema) => Ok(Some(inner_schema.schema)),
364 Err(_) => {
365 let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
366
367 Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
368 }
369 }
370 }
371 })
372 }
373}
374
375#[async_trait]
376impl CatalogProvider for RustWrappedPyCatalogProvider {
377 fn as_any(&self) -> &dyn Any {
378 self
379 }
380
381 fn schema_names(&self) -> Vec<String> {
382 Python::attach(|py| {
383 let provider = self.catalog_provider.bind(py);
384 provider
385 .getattr("schema_names")
386 .and_then(|names| names.extract::<Vec<String>>())
387 .unwrap_or_else(|err| {
388 log::error!("Unable to get schema_names: {err}");
389 Vec::default()
390 })
391 })
392 }
393
394 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
395 self.schema_inner(name).unwrap_or_else(|err| {
396 log::error!("CatalogProvider schema returned error: {err}");
397 None
398 })
399 }
400
401 fn register_schema(
402 &self,
403 name: &str,
404 schema: Arc<dyn SchemaProvider>,
405 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
406 Python::attach(|py| {
409 let py_schema = match schema
410 .as_any()
411 .downcast_ref::<RustWrappedPySchemaProvider>()
412 {
413 Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(),
414 None => &PySchema::from(schema)
415 .into_py_any(py)
416 .map_err(to_datafusion_err)?,
417 };
418
419 let provider = self.catalog_provider.bind(py);
420 let schema = provider
421 .call_method1("register_schema", (name, py_schema))
422 .map_err(to_datafusion_err)?;
423 if schema.is_none() {
424 return Ok(None);
425 }
426
427 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
428 as Arc<dyn SchemaProvider>;
429
430 Ok(Some(schema))
431 })
432 }
433
434 fn deregister_schema(
435 &self,
436 name: &str,
437 cascade: bool,
438 ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
439 Python::attach(|py| {
440 let provider = self.catalog_provider.bind(py);
441 let schema = provider
442 .call_method1("deregister_schema", (name, cascade))
443 .map_err(to_datafusion_err)?;
444 if schema.is_none() {
445 return Ok(None);
446 }
447
448 let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
449 as Arc<dyn SchemaProvider>;
450
451 Ok(Some(schema))
452 })
453 }
454}
455
456pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
457 m.add_class::<PyCatalog>()?;
458 m.add_class::<PySchema>()?;
459 m.add_class::<PyTable>()?;
460
461 Ok(())
462}