1use crate::{
2 ConnectionOptions, DFResult, DefaultTransform, DefaultUnparser, Pool, RemoteSchemaRef,
3 RemoteTableExec, Transform, Unparse, connect, transform_schema,
4};
5use datafusion::arrow::datatypes::SchemaRef;
6use datafusion::catalog::{Session, TableProvider};
7use datafusion::common::Column;
8use datafusion::common::tree_node::{Transformed, TreeNode};
9use datafusion::datasource::TableType;
10use datafusion::error::DataFusionError;
11use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
12use datafusion::physical_plan::ExecutionPlan;
13use std::any::Any;
14use std::sync::Arc;
15
16#[derive(Debug)]
17pub struct RemoteTable {
18 pub(crate) conn_options: ConnectionOptions,
19 pub(crate) sql: String,
20 pub(crate) table_schema: SchemaRef,
21 pub(crate) transformed_table_schema: SchemaRef,
22 pub(crate) remote_schema: Option<RemoteSchemaRef>,
23 pub(crate) transform: Arc<dyn Transform>,
24 pub(crate) unparser: Arc<dyn Unparse>,
25 pub(crate) pool: Arc<dyn Pool>,
26}
27
28impl RemoteTable {
29 pub async fn try_new(
30 conn_options: ConnectionOptions,
31 sql: impl Into<String>,
32 ) -> DFResult<Self> {
33 Self::try_new_with_schema_transform_unparser(
34 conn_options,
35 sql,
36 None,
37 Arc::new(DefaultTransform {}),
38 Arc::new(DefaultUnparser {}),
39 )
40 .await
41 }
42
43 pub async fn try_new_with_schema(
44 conn_options: ConnectionOptions,
45 sql: impl Into<String>,
46 table_schema: SchemaRef,
47 ) -> DFResult<Self> {
48 Self::try_new_with_schema_transform_unparser(
49 conn_options,
50 sql,
51 Some(table_schema),
52 Arc::new(DefaultTransform {}),
53 Arc::new(DefaultUnparser {}),
54 )
55 .await
56 }
57
58 pub async fn try_new_with_transform(
59 conn_options: ConnectionOptions,
60 sql: impl Into<String>,
61 transform: Arc<dyn Transform>,
62 ) -> DFResult<Self> {
63 Self::try_new_with_schema_transform_unparser(
64 conn_options,
65 sql,
66 None,
67 transform,
68 Arc::new(DefaultUnparser {}),
69 )
70 .await
71 }
72
73 pub async fn try_new_with_schema_transform_unparser(
74 conn_options: ConnectionOptions,
75 sql: impl Into<String>,
76 table_schema: Option<SchemaRef>,
77 transform: Arc<dyn Transform>,
78 unparser: Arc<dyn Unparse>,
79 ) -> DFResult<Self> {
80 let sql = sql.into();
81 let pool = connect(&conn_options).await?;
82
83 let (table_schema, remote_schema) = if let Some(table_schema) = table_schema {
84 let remote_schema = if transform.as_any().is::<DefaultTransform>() {
85 None
86 } else {
87 let conn = pool.get().await?;
89 conn.infer_schema(&sql).await.ok()
90 };
91 (table_schema, remote_schema)
92 } else {
93 let conn = pool.get().await?;
95 match conn.infer_schema(&sql).await {
96 Ok(remote_schema) => {
97 let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
98 (inferred_table_schema, Some(remote_schema))
99 }
100 Err(e) => {
101 return Err(DataFusionError::Execution(format!(
102 "Failed to infer schema: {e}"
103 )));
104 }
105 }
106 };
107
108 let transformed_table_schema = transform_schema(
109 table_schema.clone(),
110 transform.as_ref(),
111 remote_schema.as_ref(),
112 )?;
113
114 Ok(RemoteTable {
115 conn_options,
116 sql,
117 table_schema,
118 transformed_table_schema,
119 remote_schema,
120 transform,
121 unparser,
122 pool,
123 })
124 }
125
126 pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
127 self.remote_schema.clone()
128 }
129}
130
131#[async_trait::async_trait]
132impl TableProvider for RemoteTable {
133 fn as_any(&self) -> &dyn Any {
134 self
135 }
136
137 fn schema(&self) -> SchemaRef {
138 self.transformed_table_schema.clone()
139 }
140
141 fn table_type(&self) -> TableType {
142 TableType::Base
143 }
144
145 async fn scan(
146 &self,
147 _state: &dyn Session,
148 projection: Option<&Vec<usize>>,
149 filters: &[Expr],
150 limit: Option<usize>,
151 ) -> DFResult<Arc<dyn ExecutionPlan>> {
152 let transformed_table_schema = transform_schema(
153 self.table_schema.clone(),
154 self.transform.as_ref(),
155 self.remote_schema.as_ref(),
156 )?;
157 let rewritten_filters = rewrite_filters_column(
158 filters.to_vec(),
159 &self.table_schema,
160 &transformed_table_schema,
161 )?;
162 let mut unparsed_filters = vec![];
163 for filter in rewritten_filters {
164 unparsed_filters.push(
165 self.unparser
166 .unparse_filter(&filter, self.conn_options.db_type())?,
167 );
168 }
169
170 Ok(Arc::new(RemoteTableExec::try_new(
171 self.conn_options.clone(),
172 self.sql.clone(),
173 self.table_schema.clone(),
174 self.remote_schema.clone(),
175 projection.cloned(),
176 unparsed_filters,
177 limit,
178 self.transform.clone(),
179 self.pool.get().await?,
180 )?))
181 }
182
183 fn supports_filters_pushdown(
184 &self,
185 filters: &[&Expr],
186 ) -> DFResult<Vec<TableProviderFilterPushDown>> {
187 if !self
188 .conn_options
189 .db_type()
190 .support_rewrite_with_filters_limit(&self.sql)
191 {
192 return Ok(vec![
193 TableProviderFilterPushDown::Unsupported;
194 filters.len()
195 ]);
196 }
197 let mut pushdown = vec![];
198 for filter in filters {
199 pushdown.push(
200 self.unparser
201 .support_filter_pushdown(filter, self.conn_options.db_type())?,
202 );
203 }
204 Ok(pushdown)
205 }
206}
207
208pub(crate) fn rewrite_filters_column(
209 filters: Vec<Expr>,
210 table_schema: &SchemaRef,
211 transformed_table_schema: &SchemaRef,
212) -> DFResult<Vec<Expr>> {
213 filters
214 .into_iter()
215 .map(|f| {
216 f.transform_down(|e| {
217 if let Expr::Column(col) = e {
218 let col_idx = transformed_table_schema.index_of(col.name())?;
219 let row_name = table_schema.field(col_idx).name().to_string();
220 Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
221 row_name,
222 ))))
223 } else {
224 Ok(Transformed::no(e))
225 }
226 })
227 .map(|trans| trans.data)
228 })
229 .collect::<DFResult<Vec<_>>>()
230}