datafusion_remote_table/
table.rs

1use crate::{
2    Connection, ConnectionOptions, DFResult, DefaultLiteralizer, DefaultTransform, Literalize,
3    Pool, RemoteDbType, RemoteSchema, RemoteSchemaRef, RemoteTableInsertExec, RemoteTableScanExec,
4    Transform, TransformArgs, connect, transform_schema,
5};
6use arrow::datatypes::SchemaRef;
7use datafusion_catalog::{Session, TableProvider};
8use datafusion_common::DataFusionError;
9use datafusion_common::Statistics;
10use datafusion_common::stats::Precision;
11use datafusion_expr::TableType;
12use datafusion_expr::dml::InsertOp;
13use datafusion_expr::{Expr, TableProviderFilterPushDown};
14use datafusion_physical_plan::ExecutionPlan;
15use log::debug;
16use std::any::Any;
17use std::sync::Arc;
18use tokio::sync::OnceCell;
19
20#[derive(Debug, Clone)]
21pub enum RemoteSource {
22    Query(String),
23    Table(Vec<String>),
24}
25
26impl RemoteSource {
27    pub fn query(&self, db_type: RemoteDbType) -> String {
28        match self {
29            RemoteSource::Query(query) => query.clone(),
30            RemoteSource::Table(table_identifiers) => db_type.select_all_query(table_identifiers),
31        }
32    }
33}
34
35impl std::fmt::Display for RemoteSource {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            RemoteSource::Query(query) => write!(f, "{query}"),
39            RemoteSource::Table(table) => write!(f, "{}", table.join(".")),
40        }
41    }
42}
43
44impl From<String> for RemoteSource {
45    fn from(query: String) -> Self {
46        RemoteSource::Query(query)
47    }
48}
49
50impl From<&String> for RemoteSource {
51    fn from(query: &String) -> Self {
52        RemoteSource::Query(query.clone())
53    }
54}
55
56impl From<&str> for RemoteSource {
57    fn from(query: &str) -> Self {
58        RemoteSource::Query(query.to_string())
59    }
60}
61
62impl From<Vec<String>> for RemoteSource {
63    fn from(table_identifiers: Vec<String>) -> Self {
64        RemoteSource::Table(table_identifiers)
65    }
66}
67
68impl From<Vec<&str>> for RemoteSource {
69    fn from(table_identifiers: Vec<&str>) -> Self {
70        RemoteSource::Table(
71            table_identifiers
72                .into_iter()
73                .map(|s| s.to_string())
74                .collect(),
75        )
76    }
77}
78
79impl From<Vec<&String>> for RemoteSource {
80    fn from(table_identifiers: Vec<&String>) -> Self {
81        RemoteSource::Table(table_identifiers.into_iter().cloned().collect())
82    }
83}
84
85#[derive(Debug)]
86pub struct RemoteTable {
87    pub(crate) conn_options: Arc<ConnectionOptions>,
88    pub(crate) pool: LazyPool,
89    pub(crate) source: RemoteSource,
90    pub(crate) table_schema: SchemaRef,
91    pub(crate) transformed_table_schema: SchemaRef,
92    pub(crate) remote_schema: Option<RemoteSchemaRef>,
93    pub(crate) transform: Arc<dyn Transform>,
94    pub(crate) literalizer: Arc<dyn Literalize>,
95    pub(crate) row_count: Option<usize>,
96}
97
98impl RemoteTable {
99    pub async fn try_new(
100        conn_options: impl Into<ConnectionOptions>,
101        source: impl Into<RemoteSource>,
102    ) -> DFResult<Self> {
103        Self::try_new_with_schema_transform_literalizer(
104            conn_options,
105            source,
106            None,
107            None,
108            Arc::new(DefaultTransform {}),
109            Arc::new(DefaultLiteralizer {}),
110            false,
111        )
112        .await
113    }
114
115    pub async fn try_new_with_schema(
116        conn_options: impl Into<ConnectionOptions>,
117        source: impl Into<RemoteSource>,
118        table_schema: SchemaRef,
119    ) -> DFResult<Self> {
120        Self::try_new_with_schema_transform_literalizer(
121            conn_options,
122            source,
123            Some(table_schema),
124            None,
125            Arc::new(DefaultTransform {}),
126            Arc::new(DefaultLiteralizer {}),
127            false,
128        )
129        .await
130    }
131
132    pub async fn try_new_with_remote_schema(
133        conn_options: impl Into<ConnectionOptions>,
134        source: impl Into<RemoteSource>,
135        remote_schema: RemoteSchemaRef,
136    ) -> DFResult<Self> {
137        Self::try_new_with_schema_transform_literalizer(
138            conn_options,
139            source,
140            None,
141            Some(remote_schema),
142            Arc::new(DefaultTransform {}),
143            Arc::new(DefaultLiteralizer {}),
144            false,
145        )
146        .await
147    }
148
149    pub async fn try_new_with_transform(
150        conn_options: impl Into<ConnectionOptions>,
151        source: impl Into<RemoteSource>,
152        transform: Arc<dyn Transform>,
153    ) -> DFResult<Self> {
154        Self::try_new_with_schema_transform_literalizer(
155            conn_options,
156            source,
157            None,
158            None,
159            transform,
160            Arc::new(DefaultLiteralizer {}),
161            false,
162        )
163        .await
164    }
165
166    pub async fn try_new_with_schema_transform(
167        conn_options: impl Into<ConnectionOptions>,
168        source: impl Into<RemoteSource>,
169        table_schema: SchemaRef,
170        transform: Arc<dyn Transform>,
171    ) -> DFResult<Self> {
172        Self::try_new_with_schema_transform_literalizer(
173            conn_options,
174            source,
175            Some(table_schema),
176            None,
177            transform,
178            Arc::new(DefaultLiteralizer {}),
179            false,
180        )
181        .await
182    }
183
184    pub async fn try_new_with_remote_schema_transform(
185        conn_options: impl Into<ConnectionOptions>,
186        source: impl Into<RemoteSource>,
187        remote_schema: RemoteSchemaRef,
188        transform: Arc<dyn Transform>,
189    ) -> DFResult<Self> {
190        Self::try_new_with_schema_transform_literalizer(
191            conn_options,
192            source,
193            None,
194            Some(remote_schema),
195            transform,
196            Arc::new(DefaultLiteralizer {}),
197            false,
198        )
199        .await
200    }
201
202    pub async fn try_new_with_schema_transform_literalizer(
203        conn_options: impl Into<ConnectionOptions>,
204        source: impl Into<RemoteSource>,
205        table_schema: Option<SchemaRef>,
206        remote_schema: Option<RemoteSchemaRef>,
207        transform: Arc<dyn Transform>,
208        literalizer: Arc<dyn Literalize>,
209        enable_table_statistics: bool,
210    ) -> DFResult<Self> {
211        let conn_options = Arc::new(conn_options.into());
212        let source = source.into();
213
214        if let RemoteSource::Table(table) = &source
215            && table.is_empty()
216        {
217            return Err(DataFusionError::Plan(
218                "Table source is empty vec".to_string(),
219            ));
220        }
221
222        let pool = LazyPool::new(conn_options.clone());
223
224        let infer_schema_fn =
225            async |pool: &LazyPool, source: &RemoteSource| -> DFResult<RemoteSchemaRef> {
226                let now = std::time::Instant::now();
227                let conn = pool.get().await?;
228                let remote_schema = conn.infer_schema(source).await?;
229                debug!(
230                    "[remote-table] Inferring remote schema cost: {}ms",
231                    now.elapsed().as_millis()
232                );
233                Ok(remote_schema)
234            };
235
236        let (table_schema, remote_schema_opt): (SchemaRef, Option<RemoteSchemaRef>) =
237            match (table_schema, remote_schema) {
238                (Some(table_schema), Some(remote_schema)) => (table_schema, Some(remote_schema)),
239                (Some(table_schema), None) => {
240                    let remote_schema = if transform.as_any().is::<DefaultTransform>()
241                        && matches!(source, RemoteSource::Query(_))
242                    {
243                        None
244                    } else {
245                        // Infer remote schema
246                        let remote_schema = infer_schema_fn(&pool, &source).await?;
247                        Some(remote_schema)
248                    };
249                    (table_schema, remote_schema)
250                }
251                (None, Some(remote_schema)) => (
252                    Arc::new(remote_schema.to_arrow_schema()),
253                    Some(remote_schema),
254                ),
255                (None, None) => {
256                    // Infer table schema
257                    let remote_schema = infer_schema_fn(&pool, &source).await?;
258                    let inferred_table_schema = Arc::new(remote_schema.to_arrow_schema());
259                    (inferred_table_schema, Some(remote_schema))
260                }
261            };
262
263        if let Some(remote_schema) = &remote_schema_opt
264            && table_schema.fields.len() != remote_schema.fields.len()
265        {
266            return Err(DataFusionError::Plan(format!(
267                "fields length of table schema is not matched with remote schema. table schema: {table_schema}, remote schema: {remote_schema:?}"
268            )));
269        }
270
271        let transformed_table_schema = transform_schema(
272            transform.as_ref(),
273            table_schema.clone(),
274            remote_schema_opt.as_ref(),
275            conn_options.db_type(),
276        )?;
277
278        let row_count = if enable_table_statistics {
279            fetch_row_count(&pool, &conn_options, &source, &[], None).await?
280        } else {
281            None
282        };
283
284        Ok(RemoteTable {
285            conn_options,
286            pool,
287            source,
288            table_schema,
289            transformed_table_schema,
290            remote_schema: remote_schema_opt,
291            transform,
292            literalizer,
293            row_count,
294        })
295    }
296
297    pub fn remote_schema(&self) -> Option<RemoteSchemaRef> {
298        self.remote_schema.clone()
299    }
300
301    pub async fn pool(&self) -> DFResult<&Arc<dyn Pool>> {
302        self.pool.get_or_init_pool().await
303    }
304}
305
306#[async_trait::async_trait]
307impl TableProvider for RemoteTable {
308    fn as_any(&self) -> &dyn Any {
309        self
310    }
311
312    fn schema(&self) -> SchemaRef {
313        self.transformed_table_schema.clone()
314    }
315
316    fn table_type(&self) -> TableType {
317        TableType::Base
318    }
319
320    async fn scan(
321        &self,
322        _state: &dyn Session,
323        projection: Option<&Vec<usize>>,
324        filters: &[Expr],
325        limit: Option<usize>,
326    ) -> DFResult<Arc<dyn ExecutionPlan>> {
327        let remote_schema = if self.transform.as_any().is::<DefaultTransform>() {
328            Arc::new(RemoteSchema::empty())
329        } else {
330            let Some(remote_schema) = &self.remote_schema else {
331                return Err(DataFusionError::Plan(
332                    "remote schema is none but transform is not DefaultTransform".to_string(),
333                ));
334            };
335            remote_schema.clone()
336        };
337        let mut unparsed_filters = vec![];
338        for filter in filters {
339            let args = TransformArgs {
340                db_type: self.conn_options.db_type(),
341                table_schema: &self.table_schema,
342                remote_schema: &remote_schema,
343            };
344            unparsed_filters.push(self.transform.unparse_filter(filter, args)?);
345        }
346
347        let row_count = fetch_row_count(
348            &self.pool,
349            &self.conn_options,
350            &self.source,
351            &unparsed_filters,
352            None,
353        )
354        .await?;
355
356        Ok(Arc::new(RemoteTableScanExec::try_new(
357            self.conn_options.clone(),
358            self.pool.clone(),
359            self.source.clone(),
360            self.table_schema.clone(),
361            self.remote_schema.clone(),
362            projection.cloned(),
363            unparsed_filters,
364            limit,
365            self.transform.clone(),
366            row_count,
367        )?))
368    }
369
370    fn supports_filters_pushdown(
371        &self,
372        filters: &[&Expr],
373    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
374        let db_type = self.conn_options.db_type();
375        if !db_type.support_rewrite_with_filters_limit(&self.source) {
376            return Ok(vec![
377                TableProviderFilterPushDown::Unsupported;
378                filters.len()
379            ]);
380        }
381
382        let remote_schema = if self.transform.as_any().is::<DefaultTransform>() {
383            Arc::new(RemoteSchema::empty())
384        } else {
385            let Some(remote_schema) = &self.remote_schema else {
386                return Err(DataFusionError::Plan(
387                    "remote schema is none but transform is not DefaultTransform".to_string(),
388                ));
389            };
390            remote_schema.clone()
391        };
392
393        let mut pushdown = vec![];
394        for filter in filters {
395            let args = TransformArgs {
396                db_type: self.conn_options.db_type(),
397                table_schema: &self.table_schema,
398                remote_schema: &remote_schema,
399            };
400            pushdown.push(self.transform.support_filter_pushdown(filter, args)?);
401        }
402        Ok(pushdown)
403    }
404
405    fn statistics(&self) -> Option<Statistics> {
406        self.row_count.map(|count| {
407            let column_stat = Statistics::unknown_column(self.transformed_table_schema.as_ref());
408            Statistics {
409                num_rows: Precision::Exact(count),
410                total_byte_size: Precision::Absent,
411                column_statistics: column_stat,
412            }
413        })
414    }
415
416    async fn insert_into(
417        &self,
418        _state: &dyn Session,
419        input: Arc<dyn ExecutionPlan>,
420        insert_op: InsertOp,
421    ) -> DFResult<Arc<dyn ExecutionPlan>> {
422        match insert_op {
423            InsertOp::Append => {}
424            InsertOp::Overwrite | InsertOp::Replace => {
425                return Err(DataFusionError::Execution(
426                    "Only support append insert operation".to_string(),
427                ));
428            }
429        }
430
431        let remote_schema = self
432            .remote_schema
433            .as_ref()
434            .ok_or(DataFusionError::Execution(
435                "Remote schema is not available".to_string(),
436            ))?
437            .clone();
438
439        let RemoteSource::Table(table) = &self.source else {
440            return Err(DataFusionError::Execution(
441                "Only support insert operation for table".to_string(),
442            ));
443        };
444
445        let exec = RemoteTableInsertExec::new(
446            input,
447            self.conn_options.clone(),
448            self.pool.clone(),
449            self.literalizer.clone(),
450            table.clone(),
451            remote_schema,
452        );
453        Ok(Arc::new(exec))
454    }
455}
456
457#[derive(Debug, Clone)]
458pub struct LazyPool {
459    pub conn_options: Arc<ConnectionOptions>,
460    pub pool: Arc<OnceCell<Arc<dyn Pool>>>,
461}
462
463impl LazyPool {
464    pub fn new(conn_options: Arc<ConnectionOptions>) -> Self {
465        Self {
466            conn_options,
467            pool: Arc::new(OnceCell::new()),
468        }
469    }
470
471    pub async fn get_or_init_pool(&self) -> DFResult<&Arc<dyn Pool>> {
472        self.pool
473            .get_or_try_init(|| async { connect(&self.conn_options).await })
474            .await
475    }
476
477    pub async fn get(&self) -> DFResult<Arc<dyn Connection>> {
478        let pool = self.get_or_init_pool().await?;
479        pool.get().await
480    }
481}
482
483pub(crate) async fn fetch_row_count(
484    pool: &LazyPool,
485    conn_options: &ConnectionOptions,
486    source: &RemoteSource,
487    unparsed_filters: &[String],
488    limit: Option<usize>,
489) -> DFResult<Option<usize>> {
490    let db_type = conn_options.db_type();
491    let count1_query = if unparsed_filters.is_empty() && limit.is_none() {
492        db_type.try_count1_query(source)
493    } else {
494        let real_sql = db_type.rewrite_query(source, unparsed_filters, limit);
495        db_type.try_count1_query(&RemoteSource::Query(real_sql))
496    };
497
498    if let Some(count1_query) = count1_query {
499        debug!("[remote-table] fetching row count with query: {count1_query}");
500        let conn = pool.get().await?;
501        let row_count = db_type
502            .fetch_count(conn, conn_options, &count1_query)
503            .await?;
504        Ok(Some(row_count))
505    } else {
506        Ok(None)
507    }
508}