datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, DefaultTransform, DefaultUnparser, Pool, RemoteSchemaRef,
3    RemoteTableExec, Transform, Unparse, connect, transform_schema,
4};
5use datafusion::arrow::datatypes::SchemaRef;
6use datafusion::catalog::{Session, TableProvider};
7use datafusion::common::stats::Precision;
8use datafusion::common::tree_node::{Transformed, TreeNode};
9use datafusion::common::{Column, Statistics};
10use datafusion::datasource::TableType;
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
13use datafusion::physical_plan::ExecutionPlan;
14use log::{debug, warn};
15use std::any::Any;
16use std::sync::Arc;
17
18#[derive(Debug)]
19pub struct RemoteTable {
20    pub(crate) conn_options: ConnectionOptions,
21    pub(crate) sql: String,
22    pub(crate) table_schema: SchemaRef,
23    pub(crate) transformed_table_schema: SchemaRef,
24    pub(crate) remote_schema: Option<RemoteSchemaRef>,
25    pub(crate) transform: Arc<dyn Transform>,
26    pub(crate) unparser: Arc<dyn Unparse>,
27    pub(crate) pool: Arc<dyn Pool>,
28}
29
30impl RemoteTable {
31    pub async fn try_new(
32        conn_options: ConnectionOptions,
33        sql: impl Into<String>,
34    ) -> DFResult<Self> {
35        Self::try_new_with_schema_transform_unparser(
36            conn_options,
37            sql,
38            None,
39            Arc::new(DefaultTransform {}),
40            Arc::new(DefaultUnparser {}),
41        )
42        .await
43    }
44
45    pub async fn try_new_with_schema(
46        conn_options: ConnectionOptions,
47        sql: impl Into<String>,
48        table_schema: SchemaRef,
49    ) -> DFResult<Self> {
50        Self::try_new_with_schema_transform_unparser(
51            conn_options,
52            sql,
53            Some(table_schema),
54            Arc::new(DefaultTransform {}),
55            Arc::new(DefaultUnparser {}),
56        )
57        .await
58    }
59
60    pub async fn try_new_with_transform(
61        conn_options: ConnectionOptions,
62        sql: impl Into<String>,
63        transform: Arc<dyn Transform>,
64    ) -> DFResult<Self> {
65        Self::try_new_with_schema_transform_unparser(
66            conn_options,
67            sql,
68            None,
69            transform,
70            Arc::new(DefaultUnparser {}),
71        )
72        .await
73    }
74
75    pub async fn try_new_with_schema_transform_unparser(
76        conn_options: ConnectionOptions,
77        sql: impl Into<String>,
78        table_schema: Option<SchemaRef>,
79        transform: Arc<dyn Transform>,
80        unparser: Arc<dyn Unparse>,
81    ) -> DFResult<Self> {
82        let sql = sql.into();
83        let pool = connect(&conn_options).await?;
84
85        let (table_schema, remote_schema) = if let Some(table_schema) = table_schema {
86            let remote_schema = if transform.as_any().is::<DefaultTransform>() {
87                None
88            } else {
89                // Infer remote schema
90                let conn = pool.get().await?;
91                conn.infer_schema(&sql).await.ok()
92            };
93            (table_schema, remote_schema)
94        } else {
95            // Infer table schema
96            let conn = pool.get().await?;
97            match conn.infer_schema(&sql).await {
98                Ok(remote_schema) => {
99                    let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
100                    (inferred_table_schema, Some(remote_schema))
101                }
102                Err(e) => {
103                    return Err(DataFusionError::Execution(format!(
104                        "Failed to infer schema: {e}"
105                    )));
106                }
107            }
108        };
109
110        let transformed_table_schema = transform_schema(
111            table_schema.clone(),
112            transform.as_ref(),
113            remote_schema.as_ref(),
114        )?;
115
116        Ok(RemoteTable {
117            conn_options,
118            sql,
119            table_schema,
120            transformed_table_schema,
121            remote_schema,
122            transform,
123            unparser,
124            pool,
125        })
126    }
127
128    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
129        self.remote_schema.clone()
130    }
131}
132
133#[async_trait::async_trait]
134impl TableProvider for RemoteTable {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn schema(&self) -> SchemaRef {
140        self.transformed_table_schema.clone()
141    }
142
143    fn table_type(&self) -> TableType {
144        TableType::Base
145    }
146
147    async fn scan(
148        &self,
149        _state: &dyn Session,
150        projection: Option<&Vec<usize>>,
151        filters: &[Expr],
152        limit: Option<usize>,
153    ) -> DFResult<Arc<dyn ExecutionPlan>> {
154        let transformed_table_schema = transform_schema(
155            self.table_schema.clone(),
156            self.transform.as_ref(),
157            self.remote_schema.as_ref(),
158        )?;
159        let rewritten_filters = rewrite_filters_column(
160            filters.to_vec(),
161            &self.table_schema,
162            &transformed_table_schema,
163        )?;
164        let mut unparsed_filters = vec![];
165        for filter in rewritten_filters {
166            unparsed_filters.push(
167                self.unparser
168                    .unparse_filter(&filter, self.conn_options.db_type())?,
169            );
170        }
171
172        Ok(Arc::new(RemoteTableExec::try_new(
173            self.conn_options.clone(),
174            self.sql.clone(),
175            self.table_schema.clone(),
176            self.remote_schema.clone(),
177            projection.cloned(),
178            unparsed_filters,
179            limit,
180            self.transform.clone(),
181            self.pool.get().await?,
182        )?))
183    }
184
185    fn supports_filters_pushdown(
186        &self,
187        filters: &[&Expr],
188    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
189        if !self
190            .conn_options
191            .db_type()
192            .support_rewrite_with_filters_limit(&self.sql)
193        {
194            return Ok(vec![
195                TableProviderFilterPushDown::Unsupported;
196                filters.len()
197            ]);
198        }
199        let mut pushdown = vec![];
200        for filter in filters {
201            pushdown.push(
202                self.unparser
203                    .support_filter_pushdown(filter, self.conn_options.db_type())?,
204            );
205        }
206        Ok(pushdown)
207    }
208
209    fn statistics(&self) -> Option<Statistics> {
210        if let Some(count1_query) = self.conn_options.db_type().try_count1_query(&self.sql) {
211            let conn_options = self.conn_options.clone();
212            let row_count_result = tokio::task::block_in_place(|| {
213                tokio::runtime::Handle::current().block_on(async {
214                    let pool = connect(&conn_options).await?;
215                    let conn = pool.get().await?;
216                    conn_options
217                        .db_type()
218                        .fetch_count(conn, &conn_options, &count1_query)
219                        .await
220                })
221            });
222
223            match row_count_result {
224                Ok(row_count) => {
225                    let column_stat =
226                        Statistics::unknown_column(self.transformed_table_schema.as_ref());
227                    Some(Statistics {
228                        num_rows: Precision::Exact(row_count),
229                        total_byte_size: Precision::Absent,
230                        column_statistics: column_stat,
231                    })
232                }
233                Err(e) => {
234                    warn!("[remote-table] Failed to fetch table statistics: {e}");
235                    None
236                }
237            }
238        } else {
239            debug!(
240                "[remote-table] Query can not be rewritten as count1 query: {}",
241                self.sql
242            );
243            None
244        }
245    }
246}
247
248pub(crate) fn rewrite_filters_column(
249    filters: Vec<Expr>,
250    table_schema: &SchemaRef,
251    transformed_table_schema: &SchemaRef,
252) -> DFResult<Vec<Expr>> {
253    filters
254        .into_iter()
255        .map(|f| {
256            f.transform_down(|e| {
257                if let Expr::Column(col) = e {
258                    let col_idx = transformed_table_schema.index_of(col.name())?;
259                    let row_name = table_schema.field(col_idx).name().to_string();
260                    Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
261                        row_name,
262                    ))))
263                } else {
264                    Ok(Transformed::no(e))
265                }
266            })
267            .map(|trans| trans.data)
268        })
269        .collect::<DFResult<Vec<_>>>()
270}