datafusion_table_providers/sql/sql_provider_datafusion/
federation.rs1use crate::sql::db_connection_pool::{dbconnection::get_schema, JoinPushDown};
2use async_trait::async_trait;
3use datafusion_federation::sql::{
4 RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource,
5};
6use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource};
7use futures::TryStreamExt;
8use snafu::prelude::*;
9use std::sync::Arc;
10
11use crate::sql::sql_provider_datafusion::{
12 get_stream, to_execution_error, SqlTable, UnableToGetSchemaSnafu,
13};
14use datafusion::{
15 arrow::datatypes::SchemaRef,
16 error::{DataFusionError, Result as DataFusionResult},
17 physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream},
18 sql::{
19 unparser::dialect::{DefaultDialect, Dialect},
20 TableReference,
21 },
22};
23
24impl<T, P> SqlTable<T, P> {
25 fn unique_id(&self) -> usize {
27 std::ptr::from_ref(self) as usize
28 }
29
30 fn arc_dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
31 match &self.dialect {
32 Some(dialect) => Arc::clone(dialect),
33 None => Arc::new(DefaultDialect {}),
34 }
35 }
36
37 fn create_federated_table_source(
38 self: Arc<Self>,
39 ) -> DataFusionResult<Arc<dyn FederatedTableSource>> {
40 let table_reference = self.table_reference.clone();
41 let schema = Arc::clone(&self.schema);
42 let fed_provider = Arc::new(SQLFederationProvider::new(self));
43 Ok(Arc::new(SQLTableSource::new_with_schema(
44 fed_provider,
45 RemoteTableRef::from(table_reference),
46 schema,
47 )))
48 }
49
50 pub fn create_federated_table_provider(
51 self: Arc<Self>,
52 ) -> DataFusionResult<FederatedTableProviderAdaptor> {
53 let table_source = Self::create_federated_table_source(Arc::clone(&self))?;
54 Ok(FederatedTableProviderAdaptor::new_with_provider(
55 table_source,
56 self,
57 ))
58 }
59}
60
61#[async_trait]
62impl<T, P> SQLExecutor for SqlTable<T, P> {
63 fn name(&self) -> &str {
64 &self.name
65 }
66
67 fn compute_context(&self) -> Option<String> {
68 match self.pool.join_push_down() {
69 JoinPushDown::AllowedFor(context) => Some(context),
70 JoinPushDown::Disallow => Some(format!("{}", self.unique_id())),
73 }
74 }
75
76 fn dialect(&self) -> Arc<dyn Dialect> {
77 self.arc_dialect()
78 }
79
80 fn execute(
81 &self,
82 query: &str,
83 schema: SchemaRef,
84 ) -> DataFusionResult<SendableRecordBatchStream> {
85 let fut = get_stream(
86 Arc::clone(&self.pool),
87 query.to_string(),
88 Arc::clone(&schema),
89 );
90
91 let stream = futures::stream::once(fut).try_flatten();
92 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
93 }
94
95 async fn table_names(&self) -> DataFusionResult<Vec<String>> {
96 Err(DataFusionError::NotImplemented(
97 "table inference not implemented".to_string(),
98 ))
99 }
100
101 async fn get_table_schema(&self, table_name: &str) -> DataFusionResult<SchemaRef> {
102 let conn = self.pool.connect().await.map_err(to_execution_error)?;
103 get_schema(conn, &TableReference::from(table_name))
104 .await
105 .context(UnableToGetSchemaSnafu)
106 .map_err(to_execution_error)
107 }
108}