datafusion_table_providers/sql/sql_provider_datafusion/
mod.rs

1//! # SQL DataFusion TableProvider
2//!
3//! This module implements a SQL TableProvider for DataFusion.
4//!
5//! This is used as a fallback if the `datafusion-federation` optimizer is not enabled.
6
7use crate::sql::db_connection_pool::{
8    self,
9    dbconnection::{get_schema, query_arrow},
10    DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14    catalog::Session,
15    physical_plan::execution_plan::{Boundedness, EmissionType},
16    sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::{any::Any, fmt, sync::Arc};
21use std::{
22    fmt::{Display, Formatter},
23    sync::LazyLock,
24};
25
26use datafusion::{
27    arrow::datatypes::{DataType, Field, Schema, SchemaRef},
28    datasource::TableProvider,
29    error::{DataFusionError, Result as DataFusionResult},
30    execution::TaskContext,
31    logical_expr::{
32        logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
33        TableProviderFilterPushDown, TableType,
34    },
35    physical_expr::EquivalenceProperties,
36    physical_plan::{
37        stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
38        Partitioning, PlanProperties, SendableRecordBatchStream,
39    },
40    sql::{unparser::Unparser, TableReference},
41};
42
43mod expr;
44#[cfg(feature = "federation")]
45pub mod federation;
46
47#[derive(Debug, Snafu)]
48pub enum Error {
49    #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
50    UnableToGetConnectionFromPool { source: db_connection_pool::Error },
51
52    #[snafu(display("Unable to get schema: {source}"))]
53    UnableToGetSchema {
54        source: db_connection_pool::dbconnection::Error,
55    },
56
57    #[snafu(display("Unable to generate SQL: {source}"))]
58    UnableToGenerateSQL { source: DataFusionError },
59}
60
61pub type Result<T, E = Error> = std::result::Result<T, E>;
62
63#[derive(Clone)]
64pub struct SqlTable<T: 'static, P: 'static> {
65    name: String,
66    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
67    schema: SchemaRef,
68    pub table_reference: TableReference,
69    dialect: Option<Arc<dyn Dialect + Send + Sync>>,
70}
71
72impl<T, P> fmt::Debug for SqlTable<T, P> {
73    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74        f.debug_struct("SqlTable")
75            .field("name", &self.name)
76            .field("schema", &self.schema)
77            .field("table_reference", &self.table_reference)
78            .finish()
79    }
80}
81
82impl<T, P> SqlTable<T, P> {
83    pub async fn new(
84        name: &str,
85        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
86        table_reference: impl Into<TableReference>,
87    ) -> Result<Self> {
88        let table_reference = table_reference.into();
89        let conn = pool
90            .connect()
91            .await
92            .context(UnableToGetConnectionFromPoolSnafu)?;
93
94        let schema = get_schema(conn, &table_reference)
95            .await
96            .context(UnableToGetSchemaSnafu)?;
97
98        Ok(Self::new_with_schema(name, pool, schema, table_reference))
99    }
100
101    pub fn new_with_schema(
102        name: &str,
103        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
104        schema: impl Into<SchemaRef>,
105        table_reference: impl Into<TableReference>,
106    ) -> Self {
107        Self {
108            name: name.to_owned(),
109            pool: Arc::clone(pool),
110            schema: schema.into(),
111            table_reference: table_reference.into(),
112            dialect: None,
113        }
114    }
115
116    pub fn scan_to_sql(
117        &self,
118        projection: Option<&Vec<usize>>,
119        filters: &[Expr],
120        limit: Option<usize>,
121    ) -> DataFusionResult<String> {
122        let logical_plan = self.create_logical_plan(projection, filters, limit)?;
123        let sql = Unparser::new(self.dialect())
124            .plan_to_sql(&logical_plan)?
125            .to_string();
126
127        Ok(sql)
128    }
129
130    fn create_logical_plan(
131        &self,
132        projection: Option<&Vec<usize>>,
133        filters: &[Expr],
134        limit: Option<usize>,
135    ) -> DataFusionResult<LogicalPlan> {
136        let table_source = LogicalTableSource::new(self.schema());
137        LogicalPlanBuilder::scan_with_filters(
138            self.table_reference.clone(),
139            Arc::new(table_source),
140            projection.cloned(),
141            filters.to_vec(),
142        )?
143        .limit(0, limit)?
144        .build()
145    }
146
147    fn create_physical_plan(
148        &self,
149        projection: Option<&Vec<usize>>,
150        sql: String,
151    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
152        Ok(Arc::new(SqlExec::new(
153            projection,
154            &self.schema(),
155            Arc::clone(&self.pool),
156            sql,
157        )?))
158    }
159
160    #[must_use]
161    pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
162        Self {
163            dialect: Some(dialect),
164            ..self
165        }
166    }
167
168    #[must_use]
169    pub fn name(&self) -> &str {
170        &self.name
171    }
172
173    #[must_use]
174    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
175        Arc::clone(&self.pool)
176    }
177
178    fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
179        match &self.dialect {
180            Some(dialect) => dialect.as_ref(),
181            None => &DefaultDialect {},
182        }
183    }
184}
185
186#[async_trait]
187impl<T, P> TableProvider for SqlTable<T, P> {
188    fn as_any(&self) -> &dyn Any {
189        self
190    }
191
192    fn schema(&self) -> SchemaRef {
193        Arc::clone(&self.schema)
194    }
195
196    fn table_type(&self) -> TableType {
197        TableType::Base
198    }
199
200    fn supports_filters_pushdown(
201        &self,
202        filters: &[&Expr],
203    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
204        let filter_push_down = default_filter_pushdown(filters, self.dialect());
205        Ok(filter_push_down)
206    }
207
208    async fn scan(
209        &self,
210        _state: &dyn Session,
211        projection: Option<&Vec<usize>>,
212        filters: &[Expr],
213        limit: Option<usize>,
214    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
215        let sql = self.scan_to_sql(projection, filters, limit)?;
216        return self.create_physical_plan(projection, sql);
217    }
218}
219
220pub fn default_filter_pushdown(
221    filters: &[&Expr],
222    dialect: &dyn Dialect,
223) -> Vec<TableProviderFilterPushDown> {
224    filters
225        .iter()
226        .map(|f| match Unparser::new(dialect).expr_to_sql(f) {
227            // The DataFusion unparser currently does not correctly handle unparsing subquery expressions on TableScan filters.
228            Ok(_) => match expr::expr_contains_subquery(f) {
229                Ok(true) => TableProviderFilterPushDown::Unsupported,
230                Ok(false) => TableProviderFilterPushDown::Exact,
231                Err(_) => TableProviderFilterPushDown::Unsupported,
232            },
233            Err(_) => TableProviderFilterPushDown::Unsupported,
234        })
235        .collect()
236}
237
238impl<T, P> Display for SqlTable<T, P> {
239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240        write!(f, "SqlTable {}", self.name)
241    }
242}
243
244static ONE_COLUMN_SCHEMA: LazyLock<SchemaRef> =
245    LazyLock::new(|| Arc::new(Schema::new(vec![Field::new("1", DataType::Int64, true)])));
246
247pub fn project_schema_safe(
248    schema: &SchemaRef,
249    projection: Option<&Vec<usize>>,
250) -> DataFusionResult<SchemaRef> {
251    let schema = match projection {
252        Some(columns) => {
253            if columns.is_empty() {
254                // If the projection is Some([]) then it gets unparsed as `SELECT 1`, so return a schema with a single Int64 column.
255                //
256                // See: <https://github.com/apache/datafusion/blob/83ce79c39412a4f150167d00e40ea05948c4870f/datafusion/sql/src/unparser/plan.rs#L998>
257                Arc::clone(&ONE_COLUMN_SCHEMA)
258            } else {
259                Arc::new(schema.project(columns)?)
260            }
261        }
262        None => Arc::clone(schema),
263    };
264    Ok(schema)
265}
266
267#[derive(Clone)]
268pub struct SqlExec<T, P> {
269    projected_schema: SchemaRef,
270    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
271    sql: String,
272    properties: PlanProperties,
273}
274
275impl<T, P> SqlExec<T, P> {
276    pub fn new(
277        projection: Option<&Vec<usize>>,
278        schema: &SchemaRef,
279        pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
280        sql: String,
281    ) -> DataFusionResult<Self> {
282        let projected_schema = project_schema_safe(schema, projection)?;
283
284        Ok(Self {
285            projected_schema: Arc::clone(&projected_schema),
286            pool,
287            sql,
288            properties: PlanProperties::new(
289                EquivalenceProperties::new(projected_schema),
290                Partitioning::UnknownPartitioning(1),
291                EmissionType::Incremental,
292                Boundedness::Bounded,
293            ),
294        })
295    }
296
297    #[must_use]
298    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
299        Arc::clone(&self.pool)
300    }
301
302    pub fn sql(&self) -> Result<String> {
303        Ok(self.sql.clone())
304    }
305}
306
307impl<T, P> std::fmt::Debug for SqlExec<T, P> {
308    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
309        let sql = self.sql().unwrap_or_default();
310        write!(f, "SqlExec sql={sql}")
311    }
312}
313
314impl<T, P> DisplayAs for SqlExec<T, P> {
315    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
316        let sql = self.sql().unwrap_or_default();
317        write!(f, "SqlExec sql={sql}")
318    }
319}
320
321impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
322    fn name(&self) -> &'static str {
323        "SqlExec"
324    }
325
326    fn as_any(&self) -> &dyn Any {
327        self
328    }
329
330    fn schema(&self) -> SchemaRef {
331        Arc::clone(&self.projected_schema)
332    }
333
334    fn properties(&self) -> &PlanProperties {
335        &self.properties
336    }
337
338    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
339        vec![]
340    }
341
342    fn with_new_children(
343        self: Arc<Self>,
344        _children: Vec<Arc<dyn ExecutionPlan>>,
345    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
346        Ok(self)
347    }
348
349    fn execute(
350        &self,
351        _partition: usize,
352        _context: Arc<TaskContext>,
353    ) -> DataFusionResult<SendableRecordBatchStream> {
354        let sql = self.sql().map_err(to_execution_error)?;
355        tracing::debug!("SqlExec sql: {sql}");
356
357        let schema = self.schema();
358
359        let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
360
361        let stream = futures::stream::once(fut).try_flatten();
362        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
363    }
364}
365
366pub async fn get_stream<T: 'static, P: 'static>(
367    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
368    sql: String,
369    projected_schema: SchemaRef,
370) -> DataFusionResult<SendableRecordBatchStream> {
371    let conn = pool.connect().await.map_err(to_execution_error)?;
372
373    query_arrow(conn, sql, Some(projected_schema))
374        .await
375        .map_err(to_execution_error)
376}
377
378#[allow(clippy::needless_pass_by_value)]
379pub fn to_execution_error(
380    e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
381) -> DataFusionError {
382    DataFusionError::Execution(format!("{}", e.into()).to_string())
383}
384
385#[cfg(test)]
386mod tests {
387    use std::{error::Error, sync::Arc};
388
389    use datafusion::execution::context::SessionContext;
390    use datafusion::sql::TableReference;
391    use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
392
393    use crate::sql::sql_provider_datafusion::SqlTable;
394
395    fn setup_tracing() -> DefaultGuard {
396        let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
397            .with_max_level(LevelFilter::DEBUG)
398            .finish();
399
400        let dispatch = Dispatch::new(subscriber);
401        tracing::dispatcher::set_default(&dispatch)
402    }
403
404    mod sql_table_plan_to_sql_tests {
405        use std::any::Any;
406
407        use async_trait::async_trait;
408        use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
409        use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
410        use datafusion::{
411            logical_expr::{col, lit},
412            sql::TableReference,
413        };
414
415        use crate::sql::db_connection_pool::{
416            dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
417        };
418
419        use super::*;
420
421        struct MockConn {}
422
423        impl DbConnection<(), &'static dyn ToString> for MockConn {
424            fn as_any(&self) -> &dyn Any {
425                self
426            }
427
428            fn as_any_mut(&mut self) -> &mut dyn Any {
429                self
430            }
431        }
432
433        struct MockDBPool {}
434
435        #[async_trait]
436        impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
437            async fn connect(
438                &self,
439            ) -> Result<
440                Box<dyn DbConnection<(), &'static dyn ToString>>,
441                Box<dyn Error + Send + Sync>,
442            > {
443                Ok(Box::new(MockConn {}))
444            }
445
446            fn join_push_down(&self) -> JoinPushDown {
447                JoinPushDown::Disallow
448            }
449        }
450
451        fn new_sql_table(
452            table_reference: &'static str,
453            dialect: Option<Arc<dyn Dialect + Send + Sync>>,
454        ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
455            let fields = vec![
456                Field::new("name", DataType::Utf8, false),
457                Field::new("age", DataType::Int16, false),
458                Field::new(
459                    "createdDate",
460                    DataType::Timestamp(TimeUnit::Millisecond, None),
461                    false,
462                ),
463                Field::new("userId", DataType::LargeUtf8, false),
464                Field::new("active", DataType::Boolean, false),
465                Field::new("5e48", DataType::LargeUtf8, false),
466            ];
467            let schema = Arc::new(Schema::new(fields));
468            let pool = Arc::new(MockDBPool {})
469                as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
470            let table_ref = TableReference::parse_str(table_reference);
471
472            let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
473            if let Some(dialect) = dialect {
474                Ok(sql_table.with_dialect(dialect))
475            } else {
476                Ok(sql_table)
477            }
478        }
479
480        #[tokio::test]
481        async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
482            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
483            let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
484            assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
485            Ok(())
486        }
487
488        #[tokio::test]
489        async fn test_sql_to_string_with_filters_and_limit(
490        ) -> Result<(), Box<dyn Error + Send + Sync>> {
491            let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
492            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
493            let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
494            assert_eq!(
495                result,
496                r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
497            );
498            Ok(())
499        }
500    }
501
502    #[test]
503    fn test_references() {
504        let table_ref = TableReference::bare("test");
505        assert_eq!(format!("{table_ref}"), "test");
506    }
507
508    // XXX move this to duckdb mod??
509    #[cfg(feature = "duckdb")]
510    mod duckdb_tests {
511        use super::*;
512        use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
513            DuckDBSyncParameter, DuckDbConnection,
514        };
515        use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
516        use duckdb::DuckdbConnectionManager;
517
518        #[tokio::test]
519        async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
520            let t = setup_tracing();
521            let ctx = SessionContext::new();
522            let pool: Arc<
523                dyn DbConnectionPool<
524                        r2d2::PooledConnection<DuckdbConnectionManager>,
525                        Box<dyn DuckDBSyncParameter>,
526                    > + Send
527                    + Sync,
528            > = Arc::new(DuckDbConnectionPool::new_memory()?)
529                as Arc<
530                    dyn DbConnectionPool<
531                            r2d2::PooledConnection<DuckdbConnectionManager>,
532                            Box<dyn DuckDBSyncParameter>,
533                        > + Send
534                        + Sync,
535                >;
536            let conn = pool.connect().await?;
537            let db_conn = conn
538                .as_any()
539                .downcast_ref::<DuckDbConnection>()
540                .expect("Unable to downcast to DuckDbConnection");
541            db_conn.conn.execute_batch(
542                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
543            )?;
544            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
545            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
546            let sql = "SELECT * FROM test_datafusion limit 1";
547            let df = ctx.sql(sql).await?;
548            df.show().await?;
549            drop(t);
550            Ok(())
551        }
552
553        #[tokio::test]
554        async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
555            let t = setup_tracing();
556            let ctx = SessionContext::new();
557            let pool: Arc<
558                dyn DbConnectionPool<
559                        r2d2::PooledConnection<DuckdbConnectionManager>,
560                        Box<dyn DuckDBSyncParameter>,
561                    > + Send
562                    + Sync,
563            > = Arc::new(DuckDbConnectionPool::new_memory()?)
564                as Arc<
565                    dyn DbConnectionPool<
566                            r2d2::PooledConnection<DuckdbConnectionManager>,
567                            Box<dyn DuckDBSyncParameter>,
568                        > + Send
569                        + Sync,
570                >;
571            let conn = pool.connect().await?;
572            let db_conn = conn
573                .as_any()
574                .downcast_ref::<DuckDbConnection>()
575                .expect("Unable to downcast to DuckDbConnection");
576            db_conn.conn.execute_batch(
577                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
578            )?;
579            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
580            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
581            let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
582            let df = ctx.sql(sql).await?;
583            df.show().await?;
584            drop(t);
585            Ok(())
586        }
587    }
588}