datafusion_federation/sql/
schema.rs

1use std::{any::Any, sync::Arc};
2
3use async_trait::async_trait;
4use datafusion::logical_expr::{TableSource, TableType};
5use datafusion::{
6    arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result,
7};
8use futures::future::join_all;
9
10use crate::{
11    sql::SQLFederationProvider, FederatedTableProviderAdaptor, FederatedTableSource,
12    FederationProvider,
13};
14
15#[derive(Debug)]
16pub struct SQLSchemaProvider {
17    // provider: Arc<SQLFederationProvider>,
18    tables: Vec<Arc<SQLTableSource>>,
19}
20
21impl SQLSchemaProvider {
22    pub async fn new(provider: Arc<SQLFederationProvider>) -> Result<Self> {
23        let tables = Arc::clone(&provider).executor.table_names().await?;
24
25        Self::new_with_tables(provider, tables).await
26    }
27
28    pub async fn new_with_tables(
29        provider: Arc<SQLFederationProvider>,
30        tables: Vec<String>,
31    ) -> Result<Self> {
32        let futures: Vec<_> = tables
33            .into_iter()
34            .map(|t| SQLTableSource::new(Arc::clone(&provider), t))
35            .collect();
36        let results: Result<Vec<_>> = join_all(futures).await.into_iter().collect();
37        let sources = results?.into_iter().map(Arc::new).collect();
38        Ok(Self::new_with_table_sources(sources))
39    }
40
41    pub fn new_with_table_sources(tables: Vec<Arc<SQLTableSource>>) -> Self {
42        Self { tables }
43    }
44}
45
46#[async_trait]
47impl SchemaProvider for SQLSchemaProvider {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn table_names(&self) -> Vec<String> {
53        self.tables.iter().map(|s| s.table_name.clone()).collect()
54    }
55
56    async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
57        if let Some(source) = self
58            .tables
59            .iter()
60            .find(|s| s.table_name.eq_ignore_ascii_case(name))
61        {
62            let adaptor = FederatedTableProviderAdaptor::new(
63                Arc::clone(source) as Arc<dyn FederatedTableSource>
64            );
65            return Ok(Some(Arc::new(adaptor)));
66        }
67        Ok(None)
68    }
69
70    fn table_exist(&self, name: &str) -> bool {
71        self.tables
72            .iter()
73            .any(|s| s.table_name.eq_ignore_ascii_case(name))
74    }
75}
76
77#[derive(Debug)]
78pub struct MultiSchemaProvider {
79    children: Vec<Arc<dyn SchemaProvider>>,
80}
81
82impl MultiSchemaProvider {
83    pub fn new(children: Vec<Arc<dyn SchemaProvider>>) -> Self {
84        Self { children }
85    }
86}
87
88#[async_trait]
89impl SchemaProvider for MultiSchemaProvider {
90    fn as_any(&self) -> &dyn Any {
91        self
92    }
93
94    fn table_names(&self) -> Vec<String> {
95        self.children.iter().flat_map(|p| p.table_names()).collect()
96    }
97
98    async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
99        for child in &self.children {
100            if let Ok(Some(table)) = child.table(name).await {
101                return Ok(Some(table));
102            }
103        }
104        Ok(None)
105    }
106
107    fn table_exist(&self, name: &str) -> bool {
108        self.children.iter().any(|p| p.table_exist(name))
109    }
110}
111
112#[derive(Debug)]
113pub struct SQLTableSource {
114    provider: Arc<SQLFederationProvider>,
115    table_name: String,
116    schema: SchemaRef,
117}
118
119impl SQLTableSource {
120    // creates a SQLTableSource and infers the table schema
121    pub async fn new(provider: Arc<SQLFederationProvider>, table_name: String) -> Result<Self> {
122        let schema = Arc::clone(&provider)
123            .executor
124            .get_table_schema(table_name.as_str())
125            .await?;
126        Self::new_with_schema(provider, table_name, schema)
127    }
128
129    pub fn new_with_schema(
130        provider: Arc<SQLFederationProvider>,
131        table_name: String,
132        schema: SchemaRef,
133    ) -> Result<Self> {
134        Ok(Self {
135            provider,
136            table_name,
137            schema,
138        })
139    }
140
141    pub fn table_name(&self) -> &str {
142        self.table_name.as_str()
143    }
144}
145
146impl FederatedTableSource for SQLTableSource {
147    fn federation_provider(&self) -> Arc<dyn FederationProvider> {
148        Arc::clone(&self.provider) as Arc<dyn FederationProvider>
149    }
150}
151
152impl TableSource for SQLTableSource {
153    fn as_any(&self) -> &dyn Any {
154        self
155    }
156    fn schema(&self) -> SchemaRef {
157        Arc::clone(&self.schema)
158    }
159    fn table_type(&self) -> TableType {
160        TableType::Temporary
161    }
162}