datafusion_table_providers/
common.rs

1use std::{any::Any, sync::Arc};
2
3use crate::sql::db_connection_pool::dbconnection::{get_schemas, get_tables};
4use crate::sql::db_connection_pool::DbConnectionPool;
5use crate::sql::sql_provider_datafusion::SqlTable;
6use async_trait::async_trait;
7use dashmap::DashMap;
8use datafusion::error::{DataFusionError, Result as DataFusionResult};
9use datafusion::{
10    catalog::{CatalogProvider, SchemaProvider, TableProvider},
11    sql::TableReference,
12};
13
14type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
15type Pool<T, P> = Arc<dyn DbConnectionPool<T, P> + Send + Sync>;
16
17#[derive(Debug)]
18pub struct DatabaseCatalogProvider {
19    schemas: DashMap<String, Arc<dyn SchemaProvider>>,
20}
21
22impl DatabaseCatalogProvider {
23    pub async fn try_new<T: 'static, P: 'static>(pool: Pool<T, P>) -> Result<Self> {
24        let conn = pool.connect().await?;
25
26        let schemas = get_schemas(conn).await?;
27        let schema_map = DashMap::new();
28
29        for schema in schemas {
30            let provider = DatabaseSchemaProvider::try_new(schema.clone(), pool.clone()).await?;
31            schema_map.insert(schema, Arc::new(provider) as Arc<dyn SchemaProvider>);
32        }
33
34        Ok(Self {
35            schemas: schema_map,
36        })
37    }
38}
39
40impl CatalogProvider for DatabaseCatalogProvider {
41    fn as_any(&self) -> &dyn Any {
42        self
43    }
44
45    fn schema_names(&self) -> Vec<String> {
46        self.schemas.iter().map(|s| s.key().clone()).collect()
47    }
48
49    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
50        self.schemas.get(name).map(|s| s.clone())
51    }
52}
53
54pub struct DatabaseSchemaProvider<T, P> {
55    name: String,
56    tables: Vec<String>,
57    pool: Pool<T, P>,
58}
59
60impl<T, P> std::fmt::Debug for DatabaseSchemaProvider<T, P> {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "DatabaseSchemaProvider {{ name: {:?} }}", self.name)
63    }
64}
65
66impl<T, P: 'static> DatabaseSchemaProvider<T, P> {
67    pub async fn try_new(name: String, pool: Pool<T, P>) -> Result<Self> {
68        let conn = pool.connect().await?;
69        let tables = get_tables(conn, &name).await?;
70
71        Ok(Self { name, tables, pool })
72    }
73}
74
75#[async_trait]
76impl<T: 'static, P: 'static> SchemaProvider for DatabaseSchemaProvider<T, P> {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    fn table_names(&self) -> Vec<String> {
82        self.tables.clone()
83    }
84
85    async fn table(&self, table: &str) -> DataFusionResult<Option<Arc<dyn TableProvider>>> {
86        if self.table_exist(table) {
87            SqlTable::new(
88                &self.name,
89                &self.pool,
90                TableReference::partial(self.name.clone(), table.to_string()),
91                None,
92            )
93            .await
94            .map(|v| Some(Arc::new(v) as Arc<dyn TableProvider>))
95            .map_err(|e| DataFusionError::External(Box::new(e)))
96        } else {
97            Ok(None)
98        }
99    }
100
101    fn table_exist(&self, name: &str) -> bool {
102        self.tables.contains(&name.to_string())
103    }
104}