datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, DefaultTransform, DefaultUnparser, Pool, RemoteSchemaRef,
3    RemoteTableExec, Transform, Unparse, connect, transform_schema,
4};
5use datafusion::arrow::datatypes::SchemaRef;
6use datafusion::catalog::{Session, TableProvider};
7use datafusion::common::Column;
8use datafusion::common::tree_node::{Transformed, TreeNode};
9use datafusion::datasource::TableType;
10use datafusion::error::DataFusionError;
11use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
12use datafusion::physical_plan::ExecutionPlan;
13use std::any::Any;
14use std::sync::Arc;
15
16#[derive(Debug)]
17pub struct RemoteTable {
18    pub(crate) conn_options: ConnectionOptions,
19    pub(crate) sql: String,
20    pub(crate) table_schema: SchemaRef,
21    pub(crate) transformed_table_schema: SchemaRef,
22    pub(crate) remote_schema: Option<RemoteSchemaRef>,
23    pub(crate) transform: Arc<dyn Transform>,
24    pub(crate) unparser: Arc<dyn Unparse>,
25    pub(crate) pool: Arc<dyn Pool>,
26}
27
28impl RemoteTable {
29    pub async fn try_new(
30        conn_options: ConnectionOptions,
31        sql: impl Into<String>,
32    ) -> DFResult<Self> {
33        Self::try_new_with_schema_transform_unparser(
34            conn_options,
35            sql,
36            None,
37            Arc::new(DefaultTransform {}),
38            Arc::new(DefaultUnparser {}),
39        )
40        .await
41    }
42
43    pub async fn try_new_with_schema(
44        conn_options: ConnectionOptions,
45        sql: impl Into<String>,
46        table_schema: SchemaRef,
47    ) -> DFResult<Self> {
48        Self::try_new_with_schema_transform_unparser(
49            conn_options,
50            sql,
51            Some(table_schema),
52            Arc::new(DefaultTransform {}),
53            Arc::new(DefaultUnparser {}),
54        )
55        .await
56    }
57
58    pub async fn try_new_with_transform(
59        conn_options: ConnectionOptions,
60        sql: impl Into<String>,
61        transform: Arc<dyn Transform>,
62    ) -> DFResult<Self> {
63        Self::try_new_with_schema_transform_unparser(
64            conn_options,
65            sql,
66            None,
67            transform,
68            Arc::new(DefaultUnparser {}),
69        )
70        .await
71    }
72
73    pub async fn try_new_with_schema_transform_unparser(
74        conn_options: ConnectionOptions,
75        sql: impl Into<String>,
76        table_schema: Option<SchemaRef>,
77        transform: Arc<dyn Transform>,
78        unparser: Arc<dyn Unparse>,
79    ) -> DFResult<Self> {
80        let sql = sql.into();
81        let pool = connect(&conn_options).await?;
82
83        let (table_schema, remote_schema) = if let Some(table_schema) = table_schema {
84            let remote_schema = if transform.as_any().is::<DefaultTransform>() {
85                None
86            } else {
87                // Infer remote schema
88                let conn = pool.get().await?;
89                conn.infer_schema(&sql).await.ok()
90            };
91            (table_schema, remote_schema)
92        } else {
93            // Infer table schema
94            let conn = pool.get().await?;
95            match conn.infer_schema(&sql).await {
96                Ok(remote_schema) => {
97                    let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
98                    (inferred_table_schema, Some(remote_schema))
99                }
100                Err(e) => {
101                    return Err(DataFusionError::Execution(format!(
102                        "Failed to infer schema: {e}"
103                    )));
104                }
105            }
106        };
107
108        let transformed_table_schema = transform_schema(
109            table_schema.clone(),
110            transform.as_ref(),
111            remote_schema.as_ref(),
112        )?;
113
114        Ok(RemoteTable {
115            conn_options,
116            sql,
117            table_schema,
118            transformed_table_schema,
119            remote_schema,
120            transform,
121            unparser,
122            pool,
123        })
124    }
125
126    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
127        self.remote_schema.clone()
128    }
129}
130
131#[async_trait::async_trait]
132impl TableProvider for RemoteTable {
133    fn as_any(&self) -> &dyn Any {
134        self
135    }
136
137    fn schema(&self) -> SchemaRef {
138        self.transformed_table_schema.clone()
139    }
140
141    fn table_type(&self) -> TableType {
142        TableType::Base
143    }
144
145    async fn scan(
146        &self,
147        _state: &dyn Session,
148        projection: Option<&Vec<usize>>,
149        filters: &[Expr],
150        limit: Option<usize>,
151    ) -> DFResult<Arc<dyn ExecutionPlan>> {
152        let transformed_table_schema = transform_schema(
153            self.table_schema.clone(),
154            self.transform.as_ref(),
155            self.remote_schema.as_ref(),
156        )?;
157        let rewritten_filters = rewrite_filters_column(
158            filters.to_vec(),
159            &self.table_schema,
160            &transformed_table_schema,
161        )?;
162        let mut unparsed_filters = vec![];
163        for filter in rewritten_filters {
164            unparsed_filters.push(
165                self.unparser
166                    .unparse_filter(&filter, self.conn_options.db_type())?,
167            );
168        }
169
170        Ok(Arc::new(RemoteTableExec::try_new(
171            self.conn_options.clone(),
172            self.sql.clone(),
173            self.table_schema.clone(),
174            self.remote_schema.clone(),
175            projection.cloned(),
176            unparsed_filters,
177            limit,
178            self.transform.clone(),
179            self.pool.get().await?,
180        )?))
181    }
182
183    fn supports_filters_pushdown(
184        &self,
185        filters: &[&Expr],
186    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
187        if !self
188            .conn_options
189            .db_type()
190            .support_rewrite_with_filters_limit(&self.sql)
191        {
192            return Ok(vec![
193                TableProviderFilterPushDown::Unsupported;
194                filters.len()
195            ]);
196        }
197        let mut pushdown = vec![];
198        for filter in filters {
199            pushdown.push(
200                self.unparser
201                    .support_filter_pushdown(filter, self.conn_options.db_type())?,
202            );
203        }
204        Ok(pushdown)
205    }
206}
207
208pub(crate) fn rewrite_filters_column(
209    filters: Vec<Expr>,
210    table_schema: &SchemaRef,
211    transformed_table_schema: &SchemaRef,
212) -> DFResult<Vec<Expr>> {
213    filters
214        .into_iter()
215        .map(|f| {
216            f.transform_down(|e| {
217                if let Expr::Column(col) = e {
218                    let col_idx = transformed_table_schema.index_of(col.name())?;
219                    let row_name = table_schema.field(col_idx).name().to_string();
220                    Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
221                        row_name,
222                    ))))
223                } else {
224                    Ok(Transformed::no(e))
225                }
226            })
227            .map(|trans| trans.data)
228        })
229        .collect::<DFResult<Vec<_>>>()
230}