datafusion_remote_table/
table.rs

1use crate::connection::RemoteDbType;
2use crate::{
3    ConnectionOptions, DFResult, Pool, RemoteSchemaRef, RemoteTableExec, Transform, connect,
4    transform_schema,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::catalog::{Session, TableProvider};
8use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion::datasource::TableType;
10use datafusion::error::DataFusionError;
11use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
12use datafusion::physical_plan::ExecutionPlan;
13use datafusion::sql::unparser::Unparser;
14use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
15use std::any::Any;
16use std::sync::Arc;
17
18#[derive(Debug)]
19pub struct RemoteTable {
20    pub(crate) conn_options: ConnectionOptions,
21    pub(crate) sql: String,
22    pub(crate) table_schema: SchemaRef,
23    pub(crate) transformed_table_schema: SchemaRef,
24    pub(crate) remote_schema: Option<RemoteSchemaRef>,
25    pub(crate) transform: Option<Arc<dyn Transform>>,
26    pub(crate) pool: Arc<dyn Pool>,
27}
28
29impl RemoteTable {
30    pub async fn try_new(
31        conn_options: ConnectionOptions,
32        sql: impl Into<String>,
33    ) -> DFResult<Self> {
34        Self::try_new_with_schema_transform(conn_options, sql, None, None).await
35    }
36
37    pub async fn try_new_with_schema(
38        conn_options: ConnectionOptions,
39        sql: impl Into<String>,
40        table_schema: SchemaRef,
41    ) -> DFResult<Self> {
42        Self::try_new_with_schema_transform(conn_options, sql, Some(table_schema), None).await
43    }
44
45    pub async fn try_new_with_transform(
46        conn_options: ConnectionOptions,
47        sql: impl Into<String>,
48        transform: Arc<dyn Transform>,
49    ) -> DFResult<Self> {
50        Self::try_new_with_schema_transform(conn_options, sql, None, Some(transform)).await
51    }
52
53    pub async fn try_new_with_schema_transform(
54        conn_options: ConnectionOptions,
55        sql: impl Into<String>,
56        table_schema: Option<SchemaRef>,
57        transform: Option<Arc<dyn Transform>>,
58    ) -> DFResult<Self> {
59        let sql = sql.into();
60        let pool = connect(&conn_options).await?;
61        let conn = pool.get().await?;
62        let (table_schema, remote_schema) = match conn.infer_schema(&sql).await {
63            Ok((remote_schema, inferred_table_schema)) => (
64                table_schema.unwrap_or(inferred_table_schema),
65                Some(remote_schema),
66            ),
67            Err(e) => {
68                if let Some(table_schema) = table_schema {
69                    (table_schema, None)
70                } else {
71                    return Err(DataFusionError::Execution(format!(
72                        "Failed to infer schema: {e}"
73                    )));
74                }
75            }
76        };
77        let transformed_table_schema = transform_schema(
78            table_schema.clone(),
79            transform.as_ref(),
80            remote_schema.as_ref(),
81        )?;
82        Ok(RemoteTable {
83            conn_options,
84            sql,
85            table_schema,
86            transformed_table_schema,
87            remote_schema,
88            transform,
89            pool,
90        })
91    }
92
93    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
94        self.remote_schema.clone()
95    }
96}
97
98#[async_trait::async_trait]
99impl TableProvider for RemoteTable {
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn schema(&self) -> SchemaRef {
105        self.transformed_table_schema.clone()
106    }
107
108    fn table_type(&self) -> TableType {
109        TableType::View
110    }
111
112    async fn scan(
113        &self,
114        _state: &dyn Session,
115        projection: Option<&Vec<usize>>,
116        filters: &[Expr],
117        limit: Option<usize>,
118    ) -> DFResult<Arc<dyn ExecutionPlan>> {
119        // TODO support filter pushdown
120        let supported_filters = filters
121            .iter()
122            .filter(|f| {
123                let pushdown = support_filter_pushdown(self.conn_options.db_type(), &self.sql, f);
124                matches!(
125                    pushdown,
126                    TableProviderFilterPushDown::Exact | TableProviderFilterPushDown::Inexact
127                )
128            })
129            .cloned()
130            .collect::<Vec<_>>();
131
132        Ok(Arc::new(RemoteTableExec::try_new(
133            self.conn_options.clone(),
134            self.sql.clone(),
135            self.table_schema.clone(),
136            self.remote_schema.clone(),
137            projection.cloned(),
138            supported_filters,
139            limit,
140            self.transform.clone(),
141            self.pool.get().await?,
142        )?))
143    }
144
145    fn supports_filters_pushdown(
146        &self,
147        filters: &[&Expr],
148    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
149        Ok(filters
150            .iter()
151            .map(|f| support_filter_pushdown(self.conn_options.db_type(), &self.sql, f))
152            .collect())
153    }
154}
155
156pub(crate) fn support_filter_pushdown(
157    db_type: RemoteDbType,
158    sql: &str,
159    filter: &Expr,
160) -> TableProviderFilterPushDown {
161    if !db_type.support_rewrite_with_filters_limit(sql) {
162        return TableProviderFilterPushDown::Unsupported;
163    }
164    let unparser = match db_type {
165        RemoteDbType::Mysql => Unparser::new(&MySqlDialect {}),
166        RemoteDbType::Postgres => Unparser::new(&PostgreSqlDialect {}),
167        RemoteDbType::Sqlite => Unparser::new(&SqliteDialect {}),
168        RemoteDbType::Oracle => return TableProviderFilterPushDown::Unsupported,
169    };
170    if unparser.expr_to_sql(filter).is_err() {
171        return TableProviderFilterPushDown::Unsupported;
172    }
173
174    let mut pushdown = TableProviderFilterPushDown::Exact;
175    filter
176        .apply(|e| {
177            if matches!(e, Expr::ScalarFunction(_)) {
178                pushdown = TableProviderFilterPushDown::Unsupported;
179            }
180            Ok(TreeNodeRecursion::Continue)
181        })
182        .expect("won't fail");
183
184    pushdown
185}