datafusion_table_providers/sql/sql_provider_datafusion/
federation.rs

1use crate::sql::db_connection_pool::{dbconnection::get_schema, JoinPushDown};
2use async_trait::async_trait;
3use datafusion_federation::sql::{
4    RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource,
5};
6use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource};
7use futures::TryStreamExt;
8use snafu::prelude::*;
9use std::sync::Arc;
10
11use crate::sql::sql_provider_datafusion::{
12    get_stream, to_execution_error, SqlTable, UnableToGetSchemaSnafu,
13};
14use datafusion::{
15    arrow::datatypes::SchemaRef,
16    error::{DataFusionError, Result as DataFusionResult},
17    physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream},
18    sql::{
19        unparser::dialect::{DefaultDialect, Dialect},
20        TableReference,
21    },
22};
23
24impl<T, P> SqlTable<T, P> {
25    // Return the current memory location of the object as a unique identifier
26    fn unique_id(&self) -> usize {
27        std::ptr::from_ref(self) as usize
28    }
29
30    fn arc_dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
31        match &self.dialect {
32            Some(dialect) => Arc::clone(dialect),
33            None => Arc::new(DefaultDialect {}),
34        }
35    }
36
37    fn create_federated_table_source(
38        self: Arc<Self>,
39    ) -> DataFusionResult<Arc<dyn FederatedTableSource>> {
40        let table_reference = self.table_reference.clone();
41        let schema = Arc::clone(&self.schema);
42        let fed_provider = Arc::new(SQLFederationProvider::new(self));
43        Ok(Arc::new(SQLTableSource::new_with_schema(
44            fed_provider,
45            RemoteTableRef::from(table_reference),
46            schema,
47        )))
48    }
49
50    pub fn create_federated_table_provider(
51        self: Arc<Self>,
52    ) -> DataFusionResult<FederatedTableProviderAdaptor> {
53        let table_source = Self::create_federated_table_source(Arc::clone(&self))?;
54        Ok(FederatedTableProviderAdaptor::new_with_provider(
55            table_source,
56            self,
57        ))
58    }
59}
60
61#[async_trait]
62impl<T, P> SQLExecutor for SqlTable<T, P> {
63    fn name(&self) -> &str {
64        &self.name
65    }
66
67    fn compute_context(&self) -> Option<String> {
68        match self.pool.join_push_down() {
69            JoinPushDown::AllowedFor(context) => Some(context),
70            // Don't return None here - it will cause incorrect federation with other providers of the same name that also have a compute_context of None.
71            // Instead return a random string that will never match any other provider's context.
72            JoinPushDown::Disallow => Some(format!("{}", self.unique_id())),
73        }
74    }
75
76    fn dialect(&self) -> Arc<dyn Dialect> {
77        self.arc_dialect()
78    }
79
80    fn execute(
81        &self,
82        query: &str,
83        schema: SchemaRef,
84    ) -> DataFusionResult<SendableRecordBatchStream> {
85        let fut = get_stream(
86            Arc::clone(&self.pool),
87            query.to_string(),
88            Arc::clone(&schema),
89        );
90
91        let stream = futures::stream::once(fut).try_flatten();
92        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
93    }
94
95    async fn table_names(&self) -> DataFusionResult<Vec<String>> {
96        Err(DataFusionError::NotImplemented(
97            "table inference not implemented".to_string(),
98        ))
99    }
100
101    async fn get_table_schema(&self, table_name: &str) -> DataFusionResult<SchemaRef> {
102        let conn = self.pool.connect().await.map_err(to_execution_error)?;
103        get_schema(conn, &TableReference::from(table_name))
104            .await
105            .context(UnableToGetSchemaSnafu)
106            .map_err(to_execution_error)
107    }
108}