datafusion_remote_table/
table.rs1use crate::connection::RemoteDbType;
2use crate::{
3 ConnectionOptions, DFResult, Pool, RemoteSchemaRef, RemoteTableExec, Transform, connect,
4 transform_schema,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::catalog::{Session, TableProvider};
8use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion::datasource::TableType;
10use datafusion::error::DataFusionError;
11use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
12use datafusion::physical_plan::ExecutionPlan;
13use datafusion::sql::unparser::Unparser;
14use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
15use std::any::Any;
16use std::sync::Arc;
17
18#[derive(Debug)]
19pub struct RemoteTable {
20 pub(crate) conn_options: ConnectionOptions,
21 pub(crate) sql: String,
22 pub(crate) table_schema: SchemaRef,
23 pub(crate) transformed_table_schema: SchemaRef,
24 pub(crate) remote_schema: Option<RemoteSchemaRef>,
25 pub(crate) transform: Option<Arc<dyn Transform>>,
26 pub(crate) pool: Arc<dyn Pool>,
27}
28
29impl RemoteTable {
30 pub async fn try_new(
31 conn_options: ConnectionOptions,
32 sql: impl Into<String>,
33 ) -> DFResult<Self> {
34 Self::try_new_with_schema_transform(conn_options, sql, None, None).await
35 }
36
37 pub async fn try_new_with_schema(
38 conn_options: ConnectionOptions,
39 sql: impl Into<String>,
40 table_schema: SchemaRef,
41 ) -> DFResult<Self> {
42 Self::try_new_with_schema_transform(conn_options, sql, Some(table_schema), None).await
43 }
44
45 pub async fn try_new_with_transform(
46 conn_options: ConnectionOptions,
47 sql: impl Into<String>,
48 transform: Arc<dyn Transform>,
49 ) -> DFResult<Self> {
50 Self::try_new_with_schema_transform(conn_options, sql, None, Some(transform)).await
51 }
52
53 pub async fn try_new_with_schema_transform(
54 conn_options: ConnectionOptions,
55 sql: impl Into<String>,
56 table_schema: Option<SchemaRef>,
57 transform: Option<Arc<dyn Transform>>,
58 ) -> DFResult<Self> {
59 let sql = sql.into();
60 let pool = connect(&conn_options).await?;
61
62 let (table_schema, remote_schema) = if let Some(table_schema) = table_schema {
63 let remote_schema = if transform.is_some() {
64 let conn = pool.get().await?;
66 conn.infer_schema(&sql).await.ok()
67 } else {
68 None
69 };
70 (table_schema, remote_schema)
71 } else {
72 let conn = pool.get().await?;
74 match conn.infer_schema(&sql).await {
75 Ok(remote_schema) => {
76 let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
77 (inferred_table_schema, Some(remote_schema))
78 }
79 Err(e) => {
80 return Err(DataFusionError::Execution(format!(
81 "Failed to infer schema: {e}"
82 )));
83 }
84 }
85 };
86
87 let transformed_table_schema = transform_schema(
88 table_schema.clone(),
89 transform.as_ref(),
90 remote_schema.as_ref(),
91 )?;
92
93 Ok(RemoteTable {
94 conn_options,
95 sql,
96 table_schema,
97 transformed_table_schema,
98 remote_schema,
99 transform,
100 pool,
101 })
102 }
103
104 pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
105 self.remote_schema.clone()
106 }
107}
108
109#[async_trait::async_trait]
110impl TableProvider for RemoteTable {
111 fn as_any(&self) -> &dyn Any {
112 self
113 }
114
115 fn schema(&self) -> SchemaRef {
116 self.transformed_table_schema.clone()
117 }
118
119 fn table_type(&self) -> TableType {
120 TableType::View
121 }
122
123 async fn scan(
124 &self,
125 _state: &dyn Session,
126 projection: Option<&Vec<usize>>,
127 filters: &[Expr],
128 limit: Option<usize>,
129 ) -> DFResult<Arc<dyn ExecutionPlan>> {
130 Ok(Arc::new(RemoteTableExec::try_new(
131 self.conn_options.clone(),
132 self.sql.clone(),
133 self.table_schema.clone(),
134 self.remote_schema.clone(),
135 projection.cloned(),
136 filters.to_vec(),
137 limit,
138 self.transform.clone(),
139 self.pool.get().await?,
140 )?))
141 }
142
143 fn supports_filters_pushdown(
144 &self,
145 filters: &[&Expr],
146 ) -> DFResult<Vec<TableProviderFilterPushDown>> {
147 Ok(filters
148 .iter()
149 .map(|f| support_filter_pushdown(self.conn_options.db_type(), &self.sql, f))
150 .collect())
151 }
152}
153
154pub(crate) fn support_filter_pushdown(
155 db_type: RemoteDbType,
156 sql: &str,
157 filter: &Expr,
158) -> TableProviderFilterPushDown {
159 if !db_type.support_rewrite_with_filters_limit(sql) {
160 return TableProviderFilterPushDown::Unsupported;
161 }
162 let unparser = match db_type {
163 RemoteDbType::Mysql => Unparser::new(&MySqlDialect {}),
164 RemoteDbType::Postgres => Unparser::new(&PostgreSqlDialect {}),
165 RemoteDbType::Sqlite => Unparser::new(&SqliteDialect {}),
166 RemoteDbType::Oracle => return TableProviderFilterPushDown::Unsupported,
167 RemoteDbType::Dm => return TableProviderFilterPushDown::Unsupported,
168 };
169 if unparser.expr_to_sql(filter).is_err() {
170 return TableProviderFilterPushDown::Unsupported;
171 }
172
173 let mut pushdown = TableProviderFilterPushDown::Exact;
174 filter
175 .apply(|e| {
176 if matches!(e, Expr::ScalarFunction(_)) {
177 pushdown = TableProviderFilterPushDown::Unsupported;
178 }
179 Ok(TreeNodeRecursion::Continue)
180 })
181 .expect("won't fail");
182
183 pushdown
184}