datafusion_remote_table/
table.rs

1use crate::{
2    ConnectionOptions, DFResult, DefaultLiteralizer, DefaultTransform, Literalize, Pool,
3    RemoteDbType, RemoteSchema, RemoteSchemaRef, RemoteTableInsertExec, RemoteTableScanExec,
4    Transform, TransformArgs, connect, transform_schema,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::catalog::{Session, TableProvider};
8use datafusion::common::Statistics;
9use datafusion::common::stats::Precision;
10use datafusion::datasource::TableType;
11use datafusion::error::DataFusionError;
12use datafusion::logical_expr::dml::InsertOp;
13use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
14use datafusion::physical_plan::ExecutionPlan;
15use log::{debug, warn};
16use std::any::Any;
17use std::sync::Arc;
18
19#[derive(Debug, Clone)]
20pub enum RemoteSource {
21    Query(String),
22    Table(Vec<String>),
23}
24
25impl RemoteSource {
26    pub fn query(&self, db_type: RemoteDbType) -> String {
27        match self {
28            RemoteSource::Query(query) => query.clone(),
29            RemoteSource::Table(table_identifiers) => db_type.select_all_query(table_identifiers),
30        }
31    }
32}
33
34impl std::fmt::Display for RemoteSource {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            RemoteSource::Query(query) => write!(f, "{query}"),
38            RemoteSource::Table(table) => write!(f, "{}", table.join(".")),
39        }
40    }
41}
42
43impl From<String> for RemoteSource {
44    fn from(query: String) -> Self {
45        RemoteSource::Query(query)
46    }
47}
48
49impl From<&String> for RemoteSource {
50    fn from(query: &String) -> Self {
51        RemoteSource::Query(query.clone())
52    }
53}
54
55impl From<&str> for RemoteSource {
56    fn from(query: &str) -> Self {
57        RemoteSource::Query(query.to_string())
58    }
59}
60
61impl From<Vec<String>> for RemoteSource {
62    fn from(table_identifiers: Vec<String>) -> Self {
63        RemoteSource::Table(table_identifiers)
64    }
65}
66
67impl From<Vec<&str>> for RemoteSource {
68    fn from(table_identifiers: Vec<&str>) -> Self {
69        RemoteSource::Table(
70            table_identifiers
71                .into_iter()
72                .map(|s| s.to_string())
73                .collect(),
74        )
75    }
76}
77
78impl From<Vec<&String>> for RemoteSource {
79    fn from(table_identifiers: Vec<&String>) -> Self {
80        RemoteSource::Table(table_identifiers.into_iter().cloned().collect())
81    }
82}
83
84#[derive(Debug)]
85pub struct RemoteTable {
86    pub(crate) conn_options: Arc<ConnectionOptions>,
87    pub(crate) source: RemoteSource,
88    pub(crate) table_schema: SchemaRef,
89    pub(crate) transformed_table_schema: SchemaRef,
90    pub(crate) remote_schema: Option<RemoteSchemaRef>,
91    pub(crate) transform: Arc<dyn Transform>,
92    pub(crate) literalizer: Arc<dyn Literalize>,
93    pub(crate) pool: Arc<dyn Pool>,
94}
95
96impl RemoteTable {
97    pub async fn try_new(
98        conn_options: impl Into<ConnectionOptions>,
99        source: impl Into<RemoteSource>,
100    ) -> DFResult<Self> {
101        Self::try_new_with_schema_transform_literalizer(
102            conn_options,
103            source,
104            None,
105            None,
106            Arc::new(DefaultTransform {}),
107            Arc::new(DefaultLiteralizer {}),
108        )
109        .await
110    }
111
112    pub async fn try_new_with_schema(
113        conn_options: impl Into<ConnectionOptions>,
114        source: impl Into<RemoteSource>,
115        table_schema: SchemaRef,
116    ) -> DFResult<Self> {
117        Self::try_new_with_schema_transform_literalizer(
118            conn_options,
119            source,
120            Some(table_schema),
121            None,
122            Arc::new(DefaultTransform {}),
123            Arc::new(DefaultLiteralizer {}),
124        )
125        .await
126    }
127
128    pub async fn try_new_with_remote_schema(
129        conn_options: impl Into<ConnectionOptions>,
130        source: impl Into<RemoteSource>,
131        remote_schema: RemoteSchemaRef,
132    ) -> DFResult<Self> {
133        Self::try_new_with_schema_transform_literalizer(
134            conn_options,
135            source,
136            None,
137            Some(remote_schema),
138            Arc::new(DefaultTransform {}),
139            Arc::new(DefaultLiteralizer {}),
140        )
141        .await
142    }
143
144    pub async fn try_new_with_transform(
145        conn_options: impl Into<ConnectionOptions>,
146        source: impl Into<RemoteSource>,
147        transform: Arc<dyn Transform>,
148    ) -> DFResult<Self> {
149        Self::try_new_with_schema_transform_literalizer(
150            conn_options,
151            source,
152            None,
153            None,
154            transform,
155            Arc::new(DefaultLiteralizer {}),
156        )
157        .await
158    }
159
160    pub async fn try_new_with_schema_transform_literalizer(
161        conn_options: impl Into<ConnectionOptions>,
162        source: impl Into<RemoteSource>,
163        table_schema: Option<SchemaRef>,
164        remote_schema: Option<RemoteSchemaRef>,
165        transform: Arc<dyn Transform>,
166        literalizer: Arc<dyn Literalize>,
167    ) -> DFResult<Self> {
168        let conn_options = conn_options.into();
169        let source = source.into();
170
171        if let RemoteSource::Table(table) = &source
172            && table.is_empty()
173        {
174            return Err(DataFusionError::Plan(
175                "Table source is empty vec".to_string(),
176            ));
177        }
178
179        let now = std::time::Instant::now();
180        let pool = connect(&conn_options).await?;
181        debug!(
182            "[remote-table] Creating connection pool cost: {}ms",
183            now.elapsed().as_millis()
184        );
185
186        let infer_schema_fn =
187            async |pool: &Arc<dyn Pool>, source: &RemoteSource| -> DFResult<RemoteSchemaRef> {
188                let now = std::time::Instant::now();
189                let conn = pool.get().await?;
190                let remote_schema = conn.infer_schema(source).await?;
191                debug!(
192                    "[remote-table] Inferring remote schema cost: {}ms",
193                    now.elapsed().as_millis()
194                );
195                Ok(remote_schema)
196            };
197
198        let (table_schema, remote_schema_opt): (SchemaRef, Option<RemoteSchemaRef>) =
199            match (table_schema, remote_schema) {
200                (Some(table_schema), Some(remote_schema)) => (table_schema, Some(remote_schema)),
201                (Some(table_schema), None) => {
202                    let remote_schema = if transform.as_any().is::<DefaultTransform>()
203                        && matches!(source, RemoteSource::Query(_))
204                    {
205                        None
206                    } else {
207                        // Infer remote schema
208                        let remote_schema = infer_schema_fn(&pool, &source).await?;
209                        Some(remote_schema)
210                    };
211                    (table_schema, remote_schema)
212                }
213                (None, Some(remote_schema)) => (
214                    Arc::new(remote_schema.to_arrow_schema()),
215                    Some(remote_schema),
216                ),
217                (None, None) => {
218                    // Infer table schema
219                    let remote_schema = infer_schema_fn(&pool, &source).await?;
220                    let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
221                    (inferred_table_schema, Some(remote_schema))
222                }
223            };
224
225        let transformed_table_schema = transform_schema(
226            transform.as_ref(),
227            table_schema.clone(),
228            remote_schema_opt.as_ref(),
229            conn_options.db_type(),
230        )?;
231
232        Ok(RemoteTable {
233            conn_options: Arc::new(conn_options),
234            source,
235            table_schema,
236            transformed_table_schema,
237            remote_schema: remote_schema_opt,
238            transform,
239            literalizer,
240            pool,
241        })
242    }
243
244    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
245        self.remote_schema.clone()
246    }
247}
248
249#[async_trait::async_trait]
250impl TableProvider for RemoteTable {
251    fn as_any(&self) -> &dyn Any {
252        self
253    }
254
255    fn schema(&self) -> SchemaRef {
256        self.transformed_table_schema.clone()
257    }
258
259    fn table_type(&self) -> TableType {
260        TableType::Base
261    }
262
263    async fn scan(
264        &self,
265        _state: &dyn Session,
266        projection: Option<&Vec<usize>>,
267        filters: &[Expr],
268        limit: Option<usize>,
269    ) -> DFResult<Arc<dyn ExecutionPlan>> {
270        let remote_schema = if self.transform.as_any().is::<DefaultTransform>() {
271            Arc::new(RemoteSchema::empty())
272        } else {
273            let Some(remote_schema) = &self.remote_schema else {
274                return Err(DataFusionError::Plan(
275                    "remote schema is none but transform is not DefaultTransform".to_string(),
276                ));
277            };
278            remote_schema.clone()
279        };
280        let mut unparsed_filters = vec![];
281        for filter in filters {
282            let args = TransformArgs {
283                db_type: self.conn_options.db_type(),
284                table_schema: &self.table_schema,
285                remote_schema: &remote_schema,
286            };
287            unparsed_filters.push(self.transform.unparse_filter(filter, args)?);
288        }
289
290        let now = std::time::Instant::now();
291        let conn = self.pool.get().await?;
292        debug!(
293            "[remote-table] Getting connection from pool cost: {}ms",
294            now.elapsed().as_millis()
295        );
296
297        Ok(Arc::new(RemoteTableScanExec::try_new(
298            self.conn_options.clone(),
299            self.source.clone(),
300            self.table_schema.clone(),
301            self.remote_schema.clone(),
302            projection.cloned(),
303            unparsed_filters,
304            limit,
305            self.transform.clone(),
306            conn,
307        )?))
308    }
309
310    fn supports_filters_pushdown(
311        &self,
312        filters: &[&Expr],
313    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
314        let db_type = self.conn_options.db_type();
315        if !db_type.support_rewrite_with_filters_limit(&self.source) {
316            return Ok(vec![
317                TableProviderFilterPushDown::Unsupported;
318                filters.len()
319            ]);
320        }
321
322        let remote_schema = if self.transform.as_any().is::<DefaultTransform>() {
323            Arc::new(RemoteSchema::empty())
324        } else {
325            let Some(remote_schema) = &self.remote_schema else {
326                return Err(DataFusionError::Plan(
327                    "remote schema is none but transform is not DefaultTransform".to_string(),
328                ));
329            };
330            remote_schema.clone()
331        };
332
333        let mut pushdown = vec![];
334        for filter in filters {
335            let args = TransformArgs {
336                db_type: self.conn_options.db_type(),
337                table_schema: &self.table_schema,
338                remote_schema: &remote_schema,
339            };
340            pushdown.push(self.transform.support_filter_pushdown(filter, args)?);
341        }
342        Ok(pushdown)
343    }
344
345    fn statistics(&self) -> Option<Statistics> {
346        let db_type = self.conn_options.db_type();
347        if let Some(count1_query) = db_type.try_count1_query(&self.source) {
348            let conn_options = self.conn_options.clone();
349            let row_count_result = tokio::task::block_in_place(|| {
350                tokio::runtime::Handle::current().block_on(async {
351                    let pool = connect(&conn_options).await?;
352                    let conn = pool.get().await?;
353                    conn_options
354                        .db_type()
355                        .fetch_count(conn, &conn_options, &count1_query)
356                        .await
357                })
358            });
359
360            match row_count_result {
361                Ok(row_count) => {
362                    let column_stat =
363                        Statistics::unknown_column(self.transformed_table_schema.as_ref());
364                    Some(Statistics {
365                        num_rows: Precision::Exact(row_count),
366                        total_byte_size: Precision::Absent,
367                        column_statistics: column_stat,
368                    })
369                }
370                Err(e) => {
371                    warn!("[remote-table] Failed to fetch table statistics: {e}");
372                    None
373                }
374            }
375        } else {
376            debug!(
377                "[remote-table] Query can not be rewritten as count1 query: {}",
378                self.source
379            );
380            None
381        }
382    }
383
384    async fn insert_into(
385        &self,
386        _state: &dyn Session,
387        input: Arc<dyn ExecutionPlan>,
388        insert_op: InsertOp,
389    ) -> DFResult<Arc<dyn ExecutionPlan>> {
390        match insert_op {
391            InsertOp::Append => {}
392            InsertOp::Overwrite | InsertOp::Replace => {
393                return Err(DataFusionError::Execution(
394                    "Only support append insert operation".to_string(),
395                ));
396            }
397        }
398
399        let remote_schema = self
400            .remote_schema
401            .as_ref()
402            .ok_or(DataFusionError::Execution(
403                "Remote schema is not available".to_string(),
404            ))?
405            .clone();
406
407        let RemoteSource::Table(table) = &self.source else {
408            return Err(DataFusionError::Execution(
409                "Only support insert operation for table".to_string(),
410            ));
411        };
412
413        let now = std::time::Instant::now();
414        let conn = self.pool.get().await?;
415        debug!(
416            "[remote-table] Getting connection from pool cost: {}ms",
417            now.elapsed().as_millis()
418        );
419
420        let exec = RemoteTableInsertExec::new(
421            input,
422            self.conn_options.clone(),
423            self.literalizer.clone(),
424            table.clone(),
425            remote_schema,
426            conn,
427        );
428        Ok(Arc::new(exec))
429    }
430}