datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, DefaultTransform, DefaultUnparser, Pool, RemoteDbType,
3    RemoteSchemaRef, RemoteTableInsertExec, RemoteTableScanExec, Transform, Unparse, connect,
4    transform_schema,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::catalog::{Session, TableProvider};
8use datafusion::common::stats::Precision;
9use datafusion::common::tree_node::{Transformed, TreeNode};
10use datafusion::common::{Column, Statistics};
11use datafusion::datasource::TableType;
12use datafusion::error::DataFusionError;
13use datafusion::logical_expr::dml::InsertOp;
14use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
15use datafusion::physical_plan::ExecutionPlan;
16use log::{debug, warn};
17use std::any::Any;
18use std::sync::Arc;
19
20#[derive(Debug, Clone)]
21pub enum RemoteSource {
22    Query(String),
23    Table(Vec<String>),
24}
25
26impl RemoteSource {
27    pub fn query(&self, db_type: RemoteDbType) -> String {
28        match self {
29            RemoteSource::Query(query) => query.clone(),
30            RemoteSource::Table(table_identifiers) => db_type.select_all_query(table_identifiers),
31        }
32    }
33}
34
35impl std::fmt::Display for RemoteSource {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            RemoteSource::Query(query) => write!(f, "{query}"),
39            RemoteSource::Table(table) => write!(f, "{}", table.join(".")),
40        }
41    }
42}
43
44impl From<String> for RemoteSource {
45    fn from(query: String) -> Self {
46        RemoteSource::Query(query)
47    }
48}
49
50impl From<&String> for RemoteSource {
51    fn from(query: &String) -> Self {
52        RemoteSource::Query(query.clone())
53    }
54}
55
56impl From<&str> for RemoteSource {
57    fn from(query: &str) -> Self {
58        RemoteSource::Query(query.to_string())
59    }
60}
61
62impl From<Vec<String>> for RemoteSource {
63    fn from(table_identifiers: Vec<String>) -> Self {
64        RemoteSource::Table(table_identifiers)
65    }
66}
67
68impl From<Vec<&str>> for RemoteSource {
69    fn from(table_identifiers: Vec<&str>) -> Self {
70        RemoteSource::Table(
71            table_identifiers
72                .into_iter()
73                .map(|s| s.to_string())
74                .collect(),
75        )
76    }
77}
78
79impl From<Vec<&String>> for RemoteSource {
80    fn from(table_identifiers: Vec<&String>) -> Self {
81        RemoteSource::Table(table_identifiers.into_iter().cloned().collect())
82    }
83}
84
85#[derive(Debug)]
86pub struct RemoteTable {
87    pub(crate) conn_options: Arc<ConnectionOptions>,
88    pub(crate) source: RemoteSource,
89    pub(crate) table_schema: SchemaRef,
90    pub(crate) transformed_table_schema: SchemaRef,
91    pub(crate) remote_schema: Option<RemoteSchemaRef>,
92    pub(crate) transform: Arc<dyn Transform>,
93    pub(crate) unparser: Arc<dyn Unparse>,
94    pub(crate) pool: Arc<dyn Pool>,
95}
96
97impl RemoteTable {
98    pub async fn try_new(
99        conn_options: impl Into<ConnectionOptions>,
100        source: impl Into<RemoteSource>,
101    ) -> DFResult<Self> {
102        Self::try_new_with_schema_transform_unparser(
103            conn_options,
104            source,
105            None,
106            None,
107            Arc::new(DefaultTransform {}),
108            Arc::new(DefaultUnparser {}),
109        )
110        .await
111    }
112
113    pub async fn try_new_with_schema(
114        conn_options: impl Into<ConnectionOptions>,
115        source: impl Into<RemoteSource>,
116        table_schema: SchemaRef,
117    ) -> DFResult<Self> {
118        Self::try_new_with_schema_transform_unparser(
119            conn_options,
120            source,
121            Some(table_schema),
122            None,
123            Arc::new(DefaultTransform {}),
124            Arc::new(DefaultUnparser {}),
125        )
126        .await
127    }
128
129    pub async fn try_new_with_remote_schema(
130        conn_options: impl Into<ConnectionOptions>,
131        source: impl Into<RemoteSource>,
132        remote_schema: RemoteSchemaRef,
133    ) -> DFResult<Self> {
134        Self::try_new_with_schema_transform_unparser(
135            conn_options,
136            source,
137            None,
138            Some(remote_schema),
139            Arc::new(DefaultTransform {}),
140            Arc::new(DefaultUnparser {}),
141        )
142        .await
143    }
144
145    pub async fn try_new_with_transform(
146        conn_options: impl Into<ConnectionOptions>,
147        source: impl Into<RemoteSource>,
148        transform: Arc<dyn Transform>,
149    ) -> DFResult<Self> {
150        Self::try_new_with_schema_transform_unparser(
151            conn_options,
152            source,
153            None,
154            None,
155            transform,
156            Arc::new(DefaultUnparser {}),
157        )
158        .await
159    }
160
161    pub async fn try_new_with_schema_transform_unparser(
162        conn_options: impl Into<ConnectionOptions>,
163        source: impl Into<RemoteSource>,
164        table_schema: Option<SchemaRef>,
165        remote_schema: Option<RemoteSchemaRef>,
166        transform: Arc<dyn Transform>,
167        unparser: Arc<dyn Unparse>,
168    ) -> DFResult<Self> {
169        let conn_options = conn_options.into();
170        let source = source.into();
171
172        if let RemoteSource::Table(table) = &source
173            && table.is_empty()
174        {
175            return Err(DataFusionError::Plan(
176                "Table source is empty vec".to_string(),
177            ));
178        }
179
180        let now = std::time::Instant::now();
181        let pool = connect(&conn_options).await?;
182        debug!(
183            "[remote-table] Creating connection pool cost: {}ms",
184            now.elapsed().as_millis()
185        );
186
187        let (table_schema, remote_schema): (SchemaRef, Option<RemoteSchemaRef>) =
188            match (table_schema, remote_schema) {
189                (Some(table_schema), Some(remote_schema)) => (table_schema, Some(remote_schema)),
190                (Some(table_schema), None) => {
191                    let remote_schema = if transform.as_any().is::<DefaultTransform>()
192                        && matches!(source, RemoteSource::Query(_))
193                    {
194                        None
195                    } else {
196                        // Infer remote schema
197                        let now = std::time::Instant::now();
198                        let conn = pool.get().await?;
199                        let remote_schema_opt = conn.infer_schema(&source).await.ok();
200                        debug!(
201                            "[remote-table] Inferring remote schema cost: {}ms",
202                            now.elapsed().as_millis()
203                        );
204                        remote_schema_opt
205                    };
206                    (table_schema, remote_schema)
207                }
208                (None, Some(remote_schema)) => (
209                    Arc::new(remote_schema.to_arrow_schema()),
210                    Some(remote_schema),
211                ),
212                (None, None) => {
213                    // Infer table schema
214                    let now = std::time::Instant::now();
215                    let conn = pool.get().await?;
216                    match conn.infer_schema(&source).await {
217                        Ok(remote_schema) => {
218                            debug!(
219                                "[remote-table] Inferring table schema cost: {}ms",
220                                now.elapsed().as_millis()
221                            );
222                            let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
223                            (inferred_table_schema, Some(remote_schema))
224                        }
225                        Err(e) => {
226                            return Err(DataFusionError::Execution(format!(
227                                "Failed to infer schema: {e}"
228                            )));
229                        }
230                    }
231                }
232            };
233
234        let transformed_table_schema = transform_schema(
235            table_schema.clone(),
236            transform.as_ref(),
237            remote_schema.as_ref(),
238        )?;
239
240        Ok(RemoteTable {
241            conn_options: Arc::new(conn_options),
242            source,
243            table_schema,
244            transformed_table_schema,
245            remote_schema,
246            transform,
247            unparser,
248            pool,
249        })
250    }
251
252    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
253        self.remote_schema.clone()
254    }
255}
256
257#[async_trait::async_trait]
258impl TableProvider for RemoteTable {
259    fn as_any(&self) -> &dyn Any {
260        self
261    }
262
263    fn schema(&self) -> SchemaRef {
264        self.transformed_table_schema.clone()
265    }
266
267    fn table_type(&self) -> TableType {
268        TableType::Base
269    }
270
271    async fn scan(
272        &self,
273        _state: &dyn Session,
274        projection: Option<&Vec<usize>>,
275        filters: &[Expr],
276        limit: Option<usize>,
277    ) -> DFResult<Arc<dyn ExecutionPlan>> {
278        let transformed_table_schema = transform_schema(
279            self.table_schema.clone(),
280            self.transform.as_ref(),
281            self.remote_schema.as_ref(),
282        )?;
283        let rewritten_filters = rewrite_filters_column(
284            filters.to_vec(),
285            &self.table_schema,
286            &transformed_table_schema,
287        )?;
288        let mut unparsed_filters = vec![];
289        for filter in rewritten_filters {
290            unparsed_filters.push(
291                self.unparser
292                    .unparse_filter(&filter, self.conn_options.db_type())?,
293            );
294        }
295
296        let now = std::time::Instant::now();
297        let conn = self.pool.get().await?;
298        debug!(
299            "[remote-table] Getting connection from pool cost: {}ms",
300            now.elapsed().as_millis()
301        );
302
303        Ok(Arc::new(RemoteTableScanExec::try_new(
304            self.conn_options.clone(),
305            self.source.clone(),
306            self.table_schema.clone(),
307            self.remote_schema.clone(),
308            projection.cloned(),
309            unparsed_filters,
310            limit,
311            self.transform.clone(),
312            conn,
313        )?))
314    }
315
316    fn supports_filters_pushdown(
317        &self,
318        filters: &[&Expr],
319    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
320        let db_type = self.conn_options.db_type();
321        if !db_type.support_rewrite_with_filters_limit(&self.source) {
322            return Ok(vec![
323                TableProviderFilterPushDown::Unsupported;
324                filters.len()
325            ]);
326        }
327        let mut pushdown = vec![];
328        for filter in filters {
329            pushdown.push(
330                self.unparser
331                    .support_filter_pushdown(filter, self.conn_options.db_type())?,
332            );
333        }
334        Ok(pushdown)
335    }
336
337    fn statistics(&self) -> Option<Statistics> {
338        let db_type = self.conn_options.db_type();
339        if let Some(count1_query) = db_type.try_count1_query(&self.source) {
340            let conn_options = self.conn_options.clone();
341            let row_count_result = tokio::task::block_in_place(|| {
342                tokio::runtime::Handle::current().block_on(async {
343                    let pool = connect(&conn_options).await?;
344                    let conn = pool.get().await?;
345                    conn_options
346                        .db_type()
347                        .fetch_count(conn, &conn_options, &count1_query)
348                        .await
349                })
350            });
351
352            match row_count_result {
353                Ok(row_count) => {
354                    let column_stat =
355                        Statistics::unknown_column(self.transformed_table_schema.as_ref());
356                    Some(Statistics {
357                        num_rows: Precision::Exact(row_count),
358                        total_byte_size: Precision::Absent,
359                        column_statistics: column_stat,
360                    })
361                }
362                Err(e) => {
363                    warn!("[remote-table] Failed to fetch table statistics: {e}");
364                    None
365                }
366            }
367        } else {
368            debug!(
369                "[remote-table] Query can not be rewritten as count1 query: {}",
370                self.source
371            );
372            None
373        }
374    }
375
376    async fn insert_into(
377        &self,
378        _state: &dyn Session,
379        input: Arc<dyn ExecutionPlan>,
380        insert_op: InsertOp,
381    ) -> DFResult<Arc<dyn ExecutionPlan>> {
382        match insert_op {
383            InsertOp::Append => {}
384            InsertOp::Overwrite | InsertOp::Replace => {
385                return Err(DataFusionError::Execution(
386                    "Only support append insert operation".to_string(),
387                ));
388            }
389        }
390
391        let remote_schema = self
392            .remote_schema
393            .as_ref()
394            .ok_or(DataFusionError::Execution(
395                "Remote schema is not available".to_string(),
396            ))?
397            .clone();
398
399        let RemoteSource::Table(table) = &self.source else {
400            return Err(DataFusionError::Execution(
401                "Only support insert operation for table".to_string(),
402            ));
403        };
404
405        let now = std::time::Instant::now();
406        let conn = self.pool.get().await?;
407        debug!(
408            "[remote-table] Getting connection from pool cost: {}ms",
409            now.elapsed().as_millis()
410        );
411
412        let exec = RemoteTableInsertExec::new(
413            input,
414            self.conn_options.clone(),
415            self.unparser.clone(),
416            table.clone(),
417            remote_schema,
418            conn,
419        );
420        Ok(Arc::new(exec))
421    }
422}
423
424pub(crate) fn rewrite_filters_column(
425    filters: Vec<Expr>,
426    table_schema: &SchemaRef,
427    transformed_table_schema: &SchemaRef,
428) -> DFResult<Vec<Expr>> {
429    filters
430        .into_iter()
431        .map(|f| {
432            f.transform_down(|e| {
433                if let Expr::Column(col) = e {
434                    let col_idx = transformed_table_schema.index_of(col.name())?;
435                    let row_name = table_schema.field(col_idx).name().to_string();
436                    Ok(Transformed::yes(Expr::Column(Column::new_unqualified(
437                        row_name,
438                    ))))
439                } else {
440                    Ok(Transformed::no(e))
441                }
442            })
443            .map(|trans| trans.data)
444        })
445        .collect::<DFResult<Vec<_>>>()
446}