datafusion_table_providers/sqlite/
federation.rs

1use crate::sql::db_connection_pool::dbconnection::{get_schema, Error as DbError};
2use crate::sql::sql_provider_datafusion::{get_stream, to_execution_error};
3use async_trait::async_trait;
4use datafusion::arrow::datatypes::SchemaRef;
5use datafusion::sql::sqlparser::ast::{self, VisitMut};
6use datafusion::sql::unparser::dialect::Dialect;
7use datafusion_federation::sql::{
8    AstAnalyzer, RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource,
9};
10use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource};
11use futures::TryStreamExt;
12use snafu::ResultExt;
13use std::sync::Arc;
14
15use super::sql_table::SQLiteTable;
16use super::sqlite_interval::SQLiteIntervalVisitor;
17use datafusion::{
18    datasource::TableProvider,
19    error::{DataFusionError, Result as DataFusionResult},
20    execution::SendableRecordBatchStream,
21    physical_plan::stream::RecordBatchStreamAdapter,
22    sql::TableReference,
23};
24
25impl<T, P> SQLiteTable<T, P> {
26    fn create_federated_table_source(
27        self: Arc<Self>,
28    ) -> DataFusionResult<Arc<dyn FederatedTableSource>> {
29        let table_reference = self.base_table.table_reference.clone();
30        let schema = Arc::clone(&Arc::clone(&self).base_table.schema());
31        let fed_provider = Arc::new(SQLFederationProvider::new(self));
32        Ok(Arc::new(SQLTableSource::new_with_schema(
33            fed_provider,
34            RemoteTableRef::from(table_reference),
35            schema,
36        )))
37    }
38
39    pub fn create_federated_table_provider(
40        self: Arc<Self>,
41    ) -> DataFusionResult<FederatedTableProviderAdaptor> {
42        let table_source = Self::create_federated_table_source(Arc::clone(&self))?;
43        Ok(FederatedTableProviderAdaptor::new_with_provider(
44            table_source,
45            self,
46        ))
47    }
48}
49
50#[allow(clippy::unnecessary_wraps)]
51fn sqlite_ast_analyzer(ast: ast::Statement) -> Result<ast::Statement, DataFusionError> {
52    match ast {
53        ast::Statement::Query(query) => {
54            let mut new_query = query.clone();
55
56            // iterate over the query and find any INTERVAL statements
57            // find the column they target, and replace the INTERVAL and column with e.g. datetime(column, '+1 day')
58            let mut interval_visitor = SQLiteIntervalVisitor::default();
59            new_query.visit(&mut interval_visitor);
60
61            Ok(ast::Statement::Query(new_query))
62        }
63        _ => Ok(ast),
64    }
65}
66
67#[async_trait]
68impl<T, P> SQLExecutor for SQLiteTable<T, P> {
69    fn name(&self) -> &str {
70        self.base_table.name()
71    }
72
73    fn compute_context(&self) -> Option<String> {
74        self.base_table.compute_context()
75    }
76
77    fn dialect(&self) -> Arc<dyn Dialect> {
78        self.base_table.dialect()
79    }
80
81    fn ast_analyzer(&self) -> Option<AstAnalyzer> {
82        Some(Box::new(sqlite_ast_analyzer))
83    }
84
85    fn execute(
86        &self,
87        query: &str,
88        schema: SchemaRef,
89    ) -> DataFusionResult<SendableRecordBatchStream> {
90        let fut = get_stream(
91            self.base_table.clone_pool(),
92            query.to_string(),
93            Arc::clone(&schema),
94        );
95
96        let stream = futures::stream::once(fut).try_flatten();
97        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
98    }
99
100    async fn table_names(&self) -> DataFusionResult<Vec<String>> {
101        Err(DataFusionError::NotImplemented(
102            "table inference not implemented".to_string(),
103        ))
104    }
105
106    async fn get_table_schema(&self, table_name: &str) -> DataFusionResult<SchemaRef> {
107        let conn = self
108            .base_table
109            .clone_pool()
110            .connect()
111            .await
112            .map_err(to_execution_error)?;
113        get_schema(conn, &TableReference::from(table_name))
114            .await
115            .boxed()
116            .map_err(|e| DbError::UnableToGetSchema { source: e })
117            .map_err(to_execution_error)
118    }
119}