datafusion_table_providers/mysql/
federation.rs

1use crate::sql::db_connection_pool::dbconnection::{get_schema, Error as DbError};
2use crate::sql::sql_provider_datafusion::{get_stream, to_execution_error};
3use arrow::datatypes::SchemaRef;
4use async_trait::async_trait;
5use datafusion::sql::sqlparser::ast::{self, VisitMut};
6use datafusion::sql::unparser::dialect::Dialect;
7use datafusion_federation::sql::{
8    AstAnalyzer, RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource,
9};
10use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource};
11use futures::TryStreamExt;
12use snafu::ResultExt;
13use std::sync::Arc;
14
15use super::mysql_window::MySQLWindowVisitor;
16use super::sql_table::MySQLTable;
17use datafusion::{
18    datasource::TableProvider,
19    error::{DataFusionError, Result as DataFusionResult},
20    execution::SendableRecordBatchStream,
21    physical_plan::stream::RecordBatchStreamAdapter,
22    sql::TableReference,
23};
24
25impl MySQLTable {
26    fn create_federated_table_source(
27        self: Arc<Self>,
28    ) -> DataFusionResult<Arc<dyn FederatedTableSource>> {
29        let table_reference = self.base_table.table_reference.clone();
30        let schema = Arc::clone(&Arc::clone(&self).base_table.schema());
31        let fed_provider = Arc::new(SQLFederationProvider::new(self));
32        Ok(Arc::new(SQLTableSource::new_with_schema(
33            fed_provider,
34            RemoteTableRef::from(table_reference),
35            schema,
36        )))
37    }
38
39    pub fn create_federated_table_provider(
40        self: Arc<Self>,
41    ) -> DataFusionResult<FederatedTableProviderAdaptor> {
42        let table_source = Self::create_federated_table_source(Arc::clone(&self))?;
43        Ok(FederatedTableProviderAdaptor::new_with_provider(
44            table_source,
45            self,
46        ))
47    }
48}
49
50#[allow(clippy::unnecessary_wraps)]
51fn mysql_ast_analyzer(ast: ast::Statement) -> Result<ast::Statement, DataFusionError> {
52    match ast {
53        ast::Statement::Query(query) => {
54            let mut new_query = query.clone();
55
56            let mut window_visitor = MySQLWindowVisitor::default();
57            new_query.visit(&mut window_visitor);
58
59            Ok(ast::Statement::Query(new_query))
60        }
61        _ => Ok(ast),
62    }
63}
64
65#[async_trait]
66impl SQLExecutor for MySQLTable {
67    fn name(&self) -> &str {
68        self.base_table.name()
69    }
70
71    fn compute_context(&self) -> Option<String> {
72        self.base_table.compute_context()
73    }
74
75    fn dialect(&self) -> Arc<dyn Dialect> {
76        self.base_table.dialect()
77    }
78
79    fn ast_analyzer(&self) -> Option<AstAnalyzer> {
80        Some(Box::new(mysql_ast_analyzer))
81    }
82
83    fn execute(
84        &self,
85        query: &str,
86        schema: SchemaRef,
87    ) -> DataFusionResult<SendableRecordBatchStream> {
88        let fut = get_stream(
89            self.base_table.clone_pool(),
90            query.to_string(),
91            Arc::clone(&schema),
92        );
93
94        let stream = futures::stream::once(fut).try_flatten();
95        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
96    }
97
98    async fn table_names(&self) -> DataFusionResult<Vec<String>> {
99        Err(DataFusionError::NotImplemented(
100            "table inference not implemented".to_string(),
101        ))
102    }
103
104    async fn get_table_schema(&self, table_name: &str) -> DataFusionResult<SchemaRef> {
105        let conn = self
106            .base_table
107            .clone_pool()
108            .connect()
109            .await
110            .map_err(to_execution_error)?;
111        get_schema(conn, &TableReference::from(table_name))
112            .await
113            .boxed()
114            .map_err(|e| DbError::UnableToGetSchema { source: e })
115            .map_err(to_execution_error)
116    }
117}