datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, Pool, RemoteSchemaRef, RemoteTableExec, Transform, connect,
3    transform_schema,
4};
5use datafusion::arrow::datatypes::SchemaRef;
6use datafusion::catalog::{Session, TableProvider};
7use datafusion::datasource::TableType;
8use datafusion::error::DataFusionError;
9use datafusion::logical_expr::Expr;
10use datafusion::physical_plan::ExecutionPlan;
11use std::any::Any;
12use std::sync::Arc;
13
14#[derive(Debug)]
15pub struct RemoteTable {
16    pub(crate) conn_options: ConnectionOptions,
17    pub(crate) sql: String,
18    pub(crate) table_schema: SchemaRef,
19    pub(crate) transformed_table_schema: SchemaRef,
20    pub(crate) remote_schema: Option<RemoteSchemaRef>,
21    pub(crate) transform: Option<Arc<dyn Transform>>,
22    pub(crate) pool: Arc<dyn Pool>,
23}
24
25impl RemoteTable {
26    pub async fn try_new(
27        conn_options: ConnectionOptions,
28        sql: impl Into<String>,
29    ) -> DFResult<Self> {
30        Self::try_new_with_schema_transform(conn_options, sql, None, None).await
31    }
32
33    pub async fn try_new_with_schema(
34        conn_options: ConnectionOptions,
35        sql: impl Into<String>,
36        table_schema: SchemaRef,
37    ) -> DFResult<Self> {
38        Self::try_new_with_schema_transform(conn_options, sql, Some(table_schema), None).await
39    }
40
41    pub async fn try_new_with_transform(
42        conn_options: ConnectionOptions,
43        sql: impl Into<String>,
44        transform: Arc<dyn Transform>,
45    ) -> DFResult<Self> {
46        Self::try_new_with_schema_transform(conn_options, sql, None, Some(transform)).await
47    }
48
49    pub async fn try_new_with_schema_transform(
50        conn_options: ConnectionOptions,
51        sql: impl Into<String>,
52        table_schema: Option<SchemaRef>,
53        transform: Option<Arc<dyn Transform>>,
54    ) -> DFResult<Self> {
55        let sql = sql.into();
56        let pool = connect(&conn_options).await?;
57        let conn = pool.get().await?;
58        let (table_schema, remote_schema) = match conn.infer_schema(&sql).await {
59            Ok((remote_schema, inferred_table_schema)) => (
60                table_schema.unwrap_or(inferred_table_schema),
61                Some(remote_schema),
62            ),
63            Err(e) => {
64                if let Some(table_schema) = table_schema {
65                    (table_schema, None)
66                } else {
67                    return Err(DataFusionError::Execution(format!(
68                        "Failed to infer schema: {e}"
69                    )));
70                }
71            }
72        };
73        let transformed_table_schema = transform_schema(
74            table_schema.clone(),
75            transform.as_ref(),
76            remote_schema.as_ref(),
77        )?;
78        Ok(RemoteTable {
79            conn_options,
80            sql,
81            table_schema,
82            transformed_table_schema,
83            remote_schema,
84            transform,
85            pool,
86        })
87    }
88
89    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
90        self.remote_schema.clone()
91    }
92}
93
94#[async_trait::async_trait]
95impl TableProvider for RemoteTable {
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn schema(&self) -> SchemaRef {
101        self.transformed_table_schema.clone()
102    }
103
104    fn table_type(&self) -> TableType {
105        TableType::View
106    }
107
108    async fn scan(
109        &self,
110        _state: &dyn Session,
111        projection: Option<&Vec<usize>>,
112        _filters: &[Expr],
113        limit: Option<usize>,
114    ) -> DFResult<Arc<dyn ExecutionPlan>> {
115        // TODO support filter pushdown
116        Ok(Arc::new(RemoteTableExec::try_new(
117            self.conn_options.clone(),
118            self.sql.clone(),
119            self.table_schema.clone(),
120            self.remote_schema.clone(),
121            projection.cloned(),
122            limit,
123            self.transform.clone(),
124            self.pool.get().await?,
125        )?))
126    }
127}