datafusion_table_providers/sql/sql_provider_datafusion/
mod.rs

1//! # SQL DataFusion TableProvider
2//!
3//! This module implements a SQL TableProvider for DataFusion.
4//!
5//! This is used as a fallback if the `datafusion-federation` optimizer is not enabled.
6
7use crate::sql::db_connection_pool::{
8    self,
9    dbconnection::{get_schema, query_arrow},
10    DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14    catalog::Session,
15    physical_plan::execution_plan::{Boundedness, EmissionType},
16    sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::{any::Any, fmt, sync::Arc};
21use std::{
22    fmt::{Display, Formatter},
23    sync::LazyLock,
24};
25
26use datafusion::{
27    arrow::datatypes::{DataType, Field, Schema, SchemaRef},
28    datasource::TableProvider,
29    error::{DataFusionError, Result as DataFusionResult},
30    execution::TaskContext,
31    logical_expr::{
32        logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
33        TableProviderFilterPushDown, TableType,
34    },
35    physical_expr::EquivalenceProperties,
36    physical_plan::{
37        stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
38        Partitioning, PlanProperties, SendableRecordBatchStream,
39    },
40    sql::{unparser::Unparser, TableReference},
41};
42
43mod expr;
44#[cfg(feature = "federation")]
45pub mod federation;
46
47#[derive(Debug, Snafu)]
48pub enum Error {
49    #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
50    UnableToGetConnectionFromPool { source: db_connection_pool::Error },
51
52    #[snafu(display("Unable to get schema: {source}"))]
53    UnableToGetSchema {
54        source: db_connection_pool::dbconnection::Error,
55    },
56
57    #[snafu(display("Unable to generate SQL: {source}"))]
58    UnableToGenerateSQL { source: DataFusionError },
59}
60
61pub type Result<T, E = Error> = std::result::Result<T, E>;
62
63#[derive(Clone)]
64pub struct SqlTable<T: 'static, P: 'static> {
65    name: String,
66    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
67    schema: SchemaRef,
68    pub table_reference: TableReference,
69    dialect: Option<Arc<dyn Dialect + Send + Sync>>,
70}
71
72impl<T, P> fmt::Debug for SqlTable<T, P> {
73    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74        f.debug_struct("SqlTable")
75            .field("name", &self.name)
76            .field("schema", &self.schema)
77            .field("table_reference", &self.table_reference)
78            .finish()
79    }
80}
81
82impl<T, P> SqlTable<T, P> {
83    pub async fn new(
84        name: &str,
85        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
86        table_reference: impl Into<TableReference>,
87    ) -> Result<Self> {
88        let table_reference = table_reference.into();
89        let conn = pool
90            .connect()
91            .await
92            .context(UnableToGetConnectionFromPoolSnafu)?;
93
94        let schema = get_schema(conn, &table_reference)
95            .await
96            .context(UnableToGetSchemaSnafu)?;
97
98        Ok(Self::new_with_schema(name, pool, schema, table_reference))
99    }
100
101    pub fn new_with_schema(
102        name: &str,
103        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
104        schema: impl Into<SchemaRef>,
105        table_reference: impl Into<TableReference>,
106    ) -> Self {
107        Self {
108            name: name.to_owned(),
109            pool: Arc::clone(pool),
110            schema: schema.into(),
111            table_reference: table_reference.into(),
112            dialect: None,
113        }
114    }
115
116    pub fn scan_to_sql(
117        &self,
118        projection: Option<&Vec<usize>>,
119        filters: &[Expr],
120        limit: Option<usize>,
121    ) -> DataFusionResult<String> {
122        let logical_plan = self.create_logical_plan(projection, filters, limit)?;
123        let sql = Unparser::new(self.dialect())
124            .plan_to_sql(&logical_plan)?
125            .to_string();
126
127        Ok(sql)
128    }
129
130    fn create_logical_plan(
131        &self,
132        projection: Option<&Vec<usize>>,
133        filters: &[Expr],
134        limit: Option<usize>,
135    ) -> DataFusionResult<LogicalPlan> {
136        let table_source = LogicalTableSource::new(self.schema());
137        LogicalPlanBuilder::scan_with_filters(
138            self.table_reference.clone(),
139            Arc::new(table_source),
140            projection.cloned(),
141            filters.to_vec(),
142        )?
143        .limit(0, limit)?
144        .build()
145    }
146
147    fn create_physical_plan(
148        &self,
149        projection: Option<&Vec<usize>>,
150        sql: String,
151    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
152        Ok(Arc::new(SqlExec::new(
153            projection,
154            &self.schema(),
155            Arc::clone(&self.pool),
156            sql,
157        )?))
158    }
159
160    #[must_use]
161    pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
162        Self {
163            dialect: Some(dialect),
164            ..self
165        }
166    }
167
168    #[must_use]
169    pub fn name(&self) -> &str {
170        &self.name
171    }
172
173    #[must_use]
174    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
175        Arc::clone(&self.pool)
176    }
177
178    fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
179        match &self.dialect {
180            Some(dialect) => dialect.as_ref(),
181            None => &DefaultDialect {},
182        }
183    }
184}
185
186#[async_trait]
187impl<T, P> TableProvider for SqlTable<T, P> {
188    fn as_any(&self) -> &dyn Any {
189        self
190    }
191
192    fn schema(&self) -> SchemaRef {
193        Arc::clone(&self.schema)
194    }
195
196    fn table_type(&self) -> TableType {
197        TableType::Base
198    }
199
200    fn supports_filters_pushdown(
201        &self,
202        filters: &[&Expr],
203    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
204        let filter_push_down: Vec<TableProviderFilterPushDown> = filters
205            .iter()
206            .map(|f| match Unparser::new(self.dialect()).expr_to_sql(f) {
207                // The DataFusion unparser currently does not correctly handle unparsing subquery expressions on TableScan filters.
208                Ok(_) => match expr::expr_contains_subquery(f) {
209                    Ok(true) => TableProviderFilterPushDown::Unsupported,
210                    Ok(false) => TableProviderFilterPushDown::Exact,
211                    Err(_) => TableProviderFilterPushDown::Unsupported,
212                },
213                Err(_) => TableProviderFilterPushDown::Unsupported,
214            })
215            .collect();
216
217        Ok(filter_push_down)
218    }
219
220    async fn scan(
221        &self,
222        _state: &dyn Session,
223        projection: Option<&Vec<usize>>,
224        filters: &[Expr],
225        limit: Option<usize>,
226    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
227        let sql = self.scan_to_sql(projection, filters, limit)?;
228        return self.create_physical_plan(projection, sql);
229    }
230}
231
232impl<T, P> Display for SqlTable<T, P> {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        write!(f, "SqlTable {}", self.name)
235    }
236}
237
238static ONE_COLUMN_SCHEMA: LazyLock<SchemaRef> =
239    LazyLock::new(|| Arc::new(Schema::new(vec![Field::new("1", DataType::Int64, true)])));
240
241pub fn project_schema_safe(
242    schema: &SchemaRef,
243    projection: Option<&Vec<usize>>,
244) -> DataFusionResult<SchemaRef> {
245    let schema = match projection {
246        Some(columns) => {
247            if columns.is_empty() {
248                // If the projection is Some([]) then it gets unparsed as `SELECT 1`, so return a schema with a single Int64 column.
249                //
250                // See: <https://github.com/apache/datafusion/blob/83ce79c39412a4f150167d00e40ea05948c4870f/datafusion/sql/src/unparser/plan.rs#L998>
251                Arc::clone(&ONE_COLUMN_SCHEMA)
252            } else {
253                Arc::new(schema.project(columns)?)
254            }
255        }
256        None => Arc::clone(schema),
257    };
258    Ok(schema)
259}
260
261#[derive(Clone)]
262pub struct SqlExec<T, P> {
263    projected_schema: SchemaRef,
264    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
265    sql: String,
266    properties: PlanProperties,
267}
268
269impl<T, P> SqlExec<T, P> {
270    pub fn new(
271        projection: Option<&Vec<usize>>,
272        schema: &SchemaRef,
273        pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
274        sql: String,
275    ) -> DataFusionResult<Self> {
276        let projected_schema = project_schema_safe(schema, projection)?;
277
278        Ok(Self {
279            projected_schema: Arc::clone(&projected_schema),
280            pool,
281            sql,
282            properties: PlanProperties::new(
283                EquivalenceProperties::new(projected_schema),
284                Partitioning::UnknownPartitioning(1),
285                EmissionType::Incremental,
286                Boundedness::Bounded,
287            ),
288        })
289    }
290
291    #[must_use]
292    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
293        Arc::clone(&self.pool)
294    }
295
296    pub fn sql(&self) -> Result<String> {
297        Ok(self.sql.clone())
298    }
299}
300
301impl<T, P> std::fmt::Debug for SqlExec<T, P> {
302    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303        let sql = self.sql().unwrap_or_default();
304        write!(f, "SqlExec sql={sql}")
305    }
306}
307
308impl<T, P> DisplayAs for SqlExec<T, P> {
309    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
310        let sql = self.sql().unwrap_or_default();
311        write!(f, "SqlExec sql={sql}")
312    }
313}
314
315impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
316    fn name(&self) -> &'static str {
317        "SqlExec"
318    }
319
320    fn as_any(&self) -> &dyn Any {
321        self
322    }
323
324    fn schema(&self) -> SchemaRef {
325        Arc::clone(&self.projected_schema)
326    }
327
328    fn properties(&self) -> &PlanProperties {
329        &self.properties
330    }
331
332    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
333        vec![]
334    }
335
336    fn with_new_children(
337        self: Arc<Self>,
338        _children: Vec<Arc<dyn ExecutionPlan>>,
339    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
340        Ok(self)
341    }
342
343    fn execute(
344        &self,
345        _partition: usize,
346        _context: Arc<TaskContext>,
347    ) -> DataFusionResult<SendableRecordBatchStream> {
348        let sql = self.sql().map_err(to_execution_error)?;
349        tracing::debug!("SqlExec sql: {sql}");
350
351        let schema = self.schema();
352
353        let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
354
355        let stream = futures::stream::once(fut).try_flatten();
356        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
357    }
358}
359
360pub async fn get_stream<T: 'static, P: 'static>(
361    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
362    sql: String,
363    projected_schema: SchemaRef,
364) -> DataFusionResult<SendableRecordBatchStream> {
365    let conn = pool.connect().await.map_err(to_execution_error)?;
366
367    query_arrow(conn, sql, Some(projected_schema))
368        .await
369        .map_err(to_execution_error)
370}
371
372#[allow(clippy::needless_pass_by_value)]
373pub fn to_execution_error(
374    e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
375) -> DataFusionError {
376    DataFusionError::Execution(format!("{}", e.into()).to_string())
377}
378
379#[cfg(test)]
380mod tests {
381    use std::{error::Error, sync::Arc};
382
383    use datafusion::execution::context::SessionContext;
384    use datafusion::sql::TableReference;
385    use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
386
387    use crate::sql::sql_provider_datafusion::SqlTable;
388
389    fn setup_tracing() -> DefaultGuard {
390        let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
391            .with_max_level(LevelFilter::DEBUG)
392            .finish();
393
394        let dispatch = Dispatch::new(subscriber);
395        tracing::dispatcher::set_default(&dispatch)
396    }
397
398    mod sql_table_plan_to_sql_tests {
399        use std::any::Any;
400
401        use async_trait::async_trait;
402        use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
403        use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
404        use datafusion::{
405            logical_expr::{col, lit},
406            sql::TableReference,
407        };
408
409        use crate::sql::db_connection_pool::{
410            dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
411        };
412
413        use super::*;
414
415        struct MockConn {}
416
417        impl DbConnection<(), &'static dyn ToString> for MockConn {
418            fn as_any(&self) -> &dyn Any {
419                self
420            }
421
422            fn as_any_mut(&mut self) -> &mut dyn Any {
423                self
424            }
425        }
426
427        struct MockDBPool {}
428
429        #[async_trait]
430        impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
431            async fn connect(
432                &self,
433            ) -> Result<
434                Box<dyn DbConnection<(), &'static dyn ToString>>,
435                Box<dyn Error + Send + Sync>,
436            > {
437                Ok(Box::new(MockConn {}))
438            }
439
440            fn join_push_down(&self) -> JoinPushDown {
441                JoinPushDown::Disallow
442            }
443        }
444
445        fn new_sql_table(
446            table_reference: &'static str,
447            dialect: Option<Arc<dyn Dialect + Send + Sync>>,
448        ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
449            let fields = vec![
450                Field::new("name", DataType::Utf8, false),
451                Field::new("age", DataType::Int16, false),
452                Field::new(
453                    "createdDate",
454                    DataType::Timestamp(TimeUnit::Millisecond, None),
455                    false,
456                ),
457                Field::new("userId", DataType::LargeUtf8, false),
458                Field::new("active", DataType::Boolean, false),
459                Field::new("5e48", DataType::LargeUtf8, false),
460            ];
461            let schema = Arc::new(Schema::new(fields));
462            let pool = Arc::new(MockDBPool {})
463                as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
464            let table_ref = TableReference::parse_str(table_reference);
465
466            let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
467            if let Some(dialect) = dialect {
468                Ok(sql_table.with_dialect(dialect))
469            } else {
470                Ok(sql_table)
471            }
472        }
473
474        #[tokio::test]
475        async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
476            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
477            let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
478            assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
479            Ok(())
480        }
481
482        #[tokio::test]
483        async fn test_sql_to_string_with_filters_and_limit(
484        ) -> Result<(), Box<dyn Error + Send + Sync>> {
485            let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
486            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
487            let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
488            assert_eq!(
489                result,
490                r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
491            );
492            Ok(())
493        }
494    }
495
496    #[test]
497    fn test_references() {
498        let table_ref = TableReference::bare("test");
499        assert_eq!(format!("{table_ref}"), "test");
500    }
501
502    // XXX move this to duckdb mod??
503    #[cfg(feature = "duckdb")]
504    mod duckdb_tests {
505        use super::*;
506        use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
507            DuckDBSyncParameter, DuckDbConnection,
508        };
509        use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
510        use duckdb::DuckdbConnectionManager;
511
512        #[tokio::test]
513        async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
514            let t = setup_tracing();
515            let ctx = SessionContext::new();
516            let pool: Arc<
517                dyn DbConnectionPool<
518                        r2d2::PooledConnection<DuckdbConnectionManager>,
519                        Box<dyn DuckDBSyncParameter>,
520                    > + Send
521                    + Sync,
522            > = Arc::new(DuckDbConnectionPool::new_memory()?)
523                as Arc<
524                    dyn DbConnectionPool<
525                            r2d2::PooledConnection<DuckdbConnectionManager>,
526                            Box<dyn DuckDBSyncParameter>,
527                        > + Send
528                        + Sync,
529                >;
530            let conn = pool.connect().await?;
531            let db_conn = conn
532                .as_any()
533                .downcast_ref::<DuckDbConnection>()
534                .expect("Unable to downcast to DuckDbConnection");
535            db_conn.conn.execute_batch(
536                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
537            )?;
538            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
539            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
540            let sql = "SELECT * FROM test_datafusion limit 1";
541            let df = ctx.sql(sql).await?;
542            df.show().await?;
543            drop(t);
544            Ok(())
545        }
546
547        #[tokio::test]
548        async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
549            let t = setup_tracing();
550            let ctx = SessionContext::new();
551            let pool: Arc<
552                dyn DbConnectionPool<
553                        r2d2::PooledConnection<DuckdbConnectionManager>,
554                        Box<dyn DuckDBSyncParameter>,
555                    > + Send
556                    + Sync,
557            > = Arc::new(DuckDbConnectionPool::new_memory()?)
558                as Arc<
559                    dyn DbConnectionPool<
560                            r2d2::PooledConnection<DuckdbConnectionManager>,
561                            Box<dyn DuckDBSyncParameter>,
562                        > + Send
563                        + Sync,
564                >;
565            let conn = pool.connect().await?;
566            let db_conn = conn
567                .as_any()
568                .downcast_ref::<DuckDbConnection>()
569                .expect("Unable to downcast to DuckDbConnection");
570            db_conn.conn.execute_batch(
571                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
572            )?;
573            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
574            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
575            let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
576            let df = ctx.sql(sql).await?;
577            df.show().await?;
578            drop(t);
579            Ok(())
580        }
581    }
582}