datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, DefaultTransform, DefaultUnparser, Pool, RemoteSchemaRef,
3    RemoteTableExec, Transform, Unparse, connect, transform_schema,
4};
5use datafusion::arrow::datatypes::SchemaRef;
6use datafusion::catalog::{Session, TableProvider};
7use datafusion::common::stats::Precision;
8use datafusion::common::tree_node::{Transformed, TreeNode};
9use datafusion::common::{Column, Statistics};
10use datafusion::datasource::TableType;
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
13use datafusion::physical_plan::ExecutionPlan;
14use log::{debug, warn};
15use std::any::Any;
16use std::sync::Arc;
17
18#[derive(Debug)]
19pub struct RemoteTable {
20    pub(crate) conn_options: ConnectionOptions,
21    pub(crate) sql: String,
22    pub(crate) table_schema: SchemaRef,
23    pub(crate) transformed_table_schema: SchemaRef,
24    pub(crate) remote_schema: Option<RemoteSchemaRef>,
25    pub(crate) transform: Arc<dyn Transform>,
26    pub(crate) unparser: Arc<dyn Unparse>,
27    pub(crate) pool: Arc<dyn Pool>,
28}
29
30impl RemoteTable {
31    pub async fn try_new(
32        conn_options: impl Into<ConnectionOptions>,
33        sql: impl Into<String>,
34    ) -> DFResult<Self> {
35        Self::try_new_with_schema_transform_unparser(
36            conn_options,
37            sql,
38            None,
39            Arc::new(DefaultTransform {}),
40            Arc::new(DefaultUnparser {}),
41        )
42        .await
43    }
44
45    pub async fn try_new_with_schema(
46        conn_options: impl Into<ConnectionOptions>,
47        sql: impl Into<String>,
48        table_schema: SchemaRef,
49    ) -> DFResult<Self> {
50        Self::try_new_with_schema_transform_unparser(
51            conn_options,
52            sql,
53            Some(table_schema),
54            Arc::new(DefaultTransform {}),
55            Arc::new(DefaultUnparser {}),
56        )
57        .await
58    }
59
60    pub async fn try_new_with_transform(
61        conn_options: impl Into<ConnectionOptions>,
62        sql: impl Into<String>,
63        transform: Arc<dyn Transform>,
64    ) -> DFResult<Self> {
65        Self::try_new_with_schema_transform_unparser(
66            conn_options,
67            sql,
68            None,
69            transform,
70            Arc::new(DefaultUnparser {}),
71        )
72        .await
73    }
74
75    pub async fn try_new_with_schema_transform_unparser(
76        conn_options: impl Into<ConnectionOptions>,
77        sql: impl Into<String>,
78        table_schema: Option<SchemaRef>,
79        transform: Arc<dyn Transform>,
80        unparser: Arc<dyn Unparse>,
81    ) -> DFResult<Self> {
82        let conn_options = conn_options.into();
83        let sql = sql.into();
84
85        let now = std::time::Instant::now();
86        let pool = connect(&conn_options).await?;
87        debug!(
88            "[remote-table] Creating connection pool cost: {}ms",
89            now.elapsed().as_millis()
90        );
91
92        let (table_schema, remote_schema) = if let Some(table_schema) = table_schema {
93            let remote_schema = if transform.as_any().is::<DefaultTransform>() {
94                None
95            } else {
96                // Infer remote schema
97                let now = std::time::Instant::now();
98                let conn = pool.get().await?;
99                let remote_schema_opt = conn.infer_schema(&sql).await.ok();
100                debug!(
101                    "[remote-table] Inferring remote schema cost: {}ms",
102                    now.elapsed().as_millis()
103                );
104                remote_schema_opt
105            };
106            (table_schema, remote_schema)
107        } else {
108            // Infer table schema
109            let now = std::time::Instant::now();
110            let conn = pool.get().await?;
111            match conn.infer_schema(&sql).await {
112                Ok(remote_schema) => {
113                    debug!(
114                        "[remote-table] Inferring table schema cost: {}ms",
115                        now.elapsed().as_millis()
116                    );
117                    let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
118                    (inferred_table_schema, Some(remote_schema))
119                }
120                Err(e) => {
121                    return Err(DataFusionError::Execution(format!(
122                        "Failed to infer schema: {e}"
123                    )));
124                }
125            }
126        };
127
128        let transformed_table_schema = transform_schema(
129            table_schema.clone(),
130            transform.as_ref(),
131            remote_schema.as_ref(),
132        )?;
133
134        Ok(RemoteTable {
135            conn_options,
136            sql,
137            table_schema,
138            transformed_table_schema,
139            remote_schema,
140            transform,
141            unparser,
142            pool,
143        })
144    }
145
146    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
147        self.remote_schema.clone()
148    }
149}
150
151#[async_trait::async_trait]
152impl TableProvider for RemoteTable {
153    fn as_any(&self) -> &dyn Any {
154        self
155    }
156
157    fn schema(&self) -> SchemaRef {
158        self.transformed_table_schema.clone()
159    }
160
161    fn table_type(&self) -> TableType {
162        TableType::Base
163    }
164
165    async fn scan(
166        &self,
167        _state: &dyn Session,
168        projection: Option<&Vec<usize>>,
169        filters: &[Expr],
170        limit: Option<usize>,
171    ) -> DFResult<Arc<dyn ExecutionPlan>> {
172        let transformed_table_schema = transform_schema(
173            self.table_schema.clone(),
174            self.transform.as_ref(),
175            self.remote_schema.as_ref(),
176        )?;
177        let rewritten_filters = rewrite_filters_column(
178            filters.to_vec(),
179            &self.table_schema,
180            &transformed_table_schema,
181        )?;
182        let mut unparsed_filters = vec![];
183        for filter in rewritten_filters {
184            unparsed_filters.push(
185                self.unparser
186                    .unparse_filter(&filter, self.conn_options.db_type())?,
187            );
188        }
189
190        let now = std::time::Instant::now();
191        let conn = self.pool.get().await?;
192        debug!(
193            "[remote-table] Getting connection from pool cost: {}ms",
194            now.elapsed().as_millis()
195        );
196
197        Ok(Arc::new(RemoteTableExec::try_new(
198            self.conn_options.clone(),
199            self.sql.clone(),
200            self.table_schema.clone(),
201            self.remote_schema.clone(),
202            projection.cloned(),
203            unparsed_filters,
204            limit,
205            self.transform.clone(),
206            conn,
207        )?))
208    }
209
210    fn supports_filters_pushdown(
211        &self,
212        filters: &[&Expr],
213    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
214        if !self
215            .conn_options
216            .db_type()
217            .support_rewrite_with_filters_limit(&self.sql)
218        {
219            return Ok(vec![
220                TableProviderFilterPushDown::Unsupported;
221                filters.len()
222            ]);
223        }
224        let mut pushdown = vec![];
225        for filter in filters {
226            pushdown.push(
227                self.unparser
228                    .support_filter_pushdown(filter, self.conn_options.db_type())?,
229            );
230        }
231        Ok(pushdown)
232    }
233
234    fn statistics(&self) -> Option<Statistics> {
235        if let Some(count1_query) = self.conn_options.db_type().try_count1_query(&self.sql) {
236            let conn_options = self.conn_options.clone();
237            let row_count_result = tokio::task::block_in_place(|| {
238                tokio::runtime::Handle::current().block_on(async {
239                    let pool = connect(&conn_options).await?;
240                    let conn = pool.get().await?;
241                    conn_options
242                        .db_type()
243                        .fetch_count(conn, &conn_options, &count1_query)
244                        .await
245                })
246            });
247
248            match row_count_result {
249                Ok(row_count) => {
250                    let column_stat =
251                        Statistics::unknown_column(self.transformed_table_schema.as_ref());
252                    Some(Statistics {
253                        num_rows: Precision::Exact(row_count),
254                        total_byte_size: Precision::Absent,
255                        column_statistics: column_stat,
256                    })
257                }
258                Err(e) => {
259                    warn!("[remote-table] Failed to fetch table statistics: {e}");
260                    None
261                }
262            }
263        } else {
264            debug!(
265                "[remote-table] Query can not be rewritten as count1 query: {}",
266                self.sql
267            );
268            None
269        }
270    }
271}
272
273pub(crate) fn rewrite_filters_column(
274    filters: Vec<Expr>,
275    table_schema: &SchemaRef,
276    transformed_table_schema: &SchemaRef,
277) -> DFResult<Vec<Expr>> {
278    filters
279        .into_iter()
280        .map(|f| {
281            f.transform_down(|e| {
282                if let Expr::Column(col) = e {
283                    let col_idx = transformed_table_schema.index_of(col.name())?;
284                    let row_name = table_schema.field(col_idx).name().to_string();
285                    Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
286                        row_name,
287                    ))))
288                } else {
289                    Ok(Transformed::no(e))
290                }
291            })
292            .map(|trans| trans.data)
293        })
294        .collect::<DFResult<Vec<_>>>()
295}