datafusion_table_providers/sql/sql_provider_datafusion/
mod.rs

1//! # SQL DataFusion TableProvider
2//!
3//! This module implements a SQL TableProvider for DataFusion.
4//!
5//! This is used as a fallback if the `datafusion-federation` optimizer is not enabled.
6
7use crate::sql::db_connection_pool::{
8    self,
9    dbconnection::{get_schema, query_arrow},
10    DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14    catalog::Session,
15    physical_plan::execution_plan::{Boundedness, EmissionType},
16    sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::fmt::{Display, Formatter};
21use std::{any::Any, fmt, sync::Arc};
22
23use datafusion::{
24    arrow::datatypes::SchemaRef,
25    datasource::TableProvider,
26    error::{DataFusionError, Result as DataFusionResult},
27    execution::TaskContext,
28    logical_expr::{
29        logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
30        TableProviderFilterPushDown, TableType,
31    },
32    physical_expr::EquivalenceProperties,
33    physical_plan::{
34        stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
35        Partitioning, PlanProperties, SendableRecordBatchStream,
36    },
37    sql::{unparser::Unparser, TableReference},
38};
39
40#[cfg(feature = "federation")]
41pub mod federation;
42
43#[derive(Debug, Snafu)]
44pub enum Error {
45    #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
46    UnableToGetConnectionFromPool { source: db_connection_pool::Error },
47
48    #[snafu(display("Unable to get schema: {source}"))]
49    UnableToGetSchema {
50        source: db_connection_pool::dbconnection::Error,
51    },
52
53    #[snafu(display("Unable to generate SQL: {source}"))]
54    UnableToGenerateSQL { source: DataFusionError },
55}
56
57pub type Result<T, E = Error> = std::result::Result<T, E>;
58
59#[derive(Clone)]
60pub struct SqlTable<T: 'static, P: 'static> {
61    name: String,
62    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
63    schema: SchemaRef,
64    pub table_reference: TableReference,
65    dialect: Option<Arc<dyn Dialect + Send + Sync>>,
66}
67
68impl<T, P> fmt::Debug for SqlTable<T, P> {
69    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
70        f.debug_struct("SqlTable")
71            .field("name", &self.name)
72            .field("schema", &self.schema)
73            .field("table_reference", &self.table_reference)
74            .finish()
75    }
76}
77
78impl<T, P> SqlTable<T, P> {
79    pub async fn new(
80        name: &str,
81        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
82        table_reference: impl Into<TableReference>,
83    ) -> Result<Self> {
84        let table_reference = table_reference.into();
85        let conn = pool
86            .connect()
87            .await
88            .context(UnableToGetConnectionFromPoolSnafu)?;
89
90        let schema = get_schema(conn, &table_reference)
91            .await
92            .context(UnableToGetSchemaSnafu)?;
93
94        Ok(Self::new_with_schema(name, pool, schema, table_reference))
95    }
96
97    pub fn new_with_schema(
98        name: &str,
99        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
100        schema: impl Into<SchemaRef>,
101        table_reference: impl Into<TableReference>,
102    ) -> Self {
103        Self {
104            name: name.to_owned(),
105            pool: Arc::clone(pool),
106            schema: schema.into(),
107            table_reference: table_reference.into(),
108            dialect: None,
109        }
110    }
111
112    pub fn scan_to_sql(
113        &self,
114        projection: Option<&Vec<usize>>,
115        filters: &[Expr],
116        limit: Option<usize>,
117    ) -> DataFusionResult<String> {
118        let logical_plan = self.create_logical_plan(projection, filters, limit)?;
119        let sql = Unparser::new(self.dialect())
120            .plan_to_sql(&logical_plan)?
121            .to_string();
122
123        Ok(sql)
124    }
125
126    fn create_logical_plan(
127        &self,
128        projection: Option<&Vec<usize>>,
129        filters: &[Expr],
130        limit: Option<usize>,
131    ) -> DataFusionResult<LogicalPlan> {
132        let table_source = LogicalTableSource::new(self.schema());
133        LogicalPlanBuilder::scan_with_filters(
134            self.table_reference.clone(),
135            Arc::new(table_source),
136            projection.cloned(),
137            filters.to_vec(),
138        )?
139        .limit(0, limit)?
140        .build()
141    }
142
143    fn create_physical_plan(
144        &self,
145        projection: Option<&Vec<usize>>,
146        sql: String,
147    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
148        Ok(Arc::new(SqlExec::new(
149            projection,
150            &self.schema(),
151            Arc::clone(&self.pool),
152            sql,
153        )?))
154    }
155
156    // Return the current memory location of the object as a unique identifier
157    fn unique_id(&self) -> usize {
158        std::ptr::from_ref(self) as usize
159    }
160
161    #[must_use]
162    pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
163        Self {
164            dialect: Some(dialect),
165            ..self
166        }
167    }
168
169    #[must_use]
170    pub fn name(&self) -> &str {
171        &self.name
172    }
173
174    #[must_use]
175    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
176        Arc::clone(&self.pool)
177    }
178
179    fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
180        match &self.dialect {
181            Some(dialect) => dialect.as_ref(),
182            None => &DefaultDialect {},
183        }
184    }
185
186    fn arc_dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
187        match &self.dialect {
188            Some(dialect) => Arc::clone(dialect),
189            None => Arc::new(DefaultDialect {}),
190        }
191    }
192}
193
194#[async_trait]
195impl<T, P> TableProvider for SqlTable<T, P> {
196    fn as_any(&self) -> &dyn Any {
197        self
198    }
199
200    fn schema(&self) -> SchemaRef {
201        Arc::clone(&self.schema)
202    }
203
204    fn table_type(&self) -> TableType {
205        TableType::Base
206    }
207
208    fn supports_filters_pushdown(
209        &self,
210        filters: &[&Expr],
211    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
212        let filter_push_down: Vec<TableProviderFilterPushDown> = filters
213            .iter()
214            .map(|f| match Unparser::new(self.dialect()).expr_to_sql(f) {
215                Ok(_) => TableProviderFilterPushDown::Exact,
216                Err(_) => TableProviderFilterPushDown::Unsupported,
217            })
218            .collect();
219
220        Ok(filter_push_down)
221    }
222
223    async fn scan(
224        &self,
225        _state: &dyn Session,
226        projection: Option<&Vec<usize>>,
227        filters: &[Expr],
228        limit: Option<usize>,
229    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
230        let sql = self.scan_to_sql(projection, filters, limit)?;
231        return self.create_physical_plan(projection, sql);
232    }
233}
234
235impl<T, P> Display for SqlTable<T, P> {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        write!(f, "SqlTable {}", self.name)
238    }
239}
240
241pub fn project_schema_safe(
242    schema: &SchemaRef,
243    projection: Option<&Vec<usize>>,
244) -> DataFusionResult<SchemaRef> {
245    let schema = match projection {
246        Some(columns) => {
247            if columns.is_empty() {
248                Arc::clone(schema)
249            } else {
250                Arc::new(schema.project(columns)?)
251            }
252        }
253        None => Arc::clone(schema),
254    };
255    Ok(schema)
256}
257
258#[derive(Clone)]
259pub struct SqlExec<T, P> {
260    projected_schema: SchemaRef,
261    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
262    sql: String,
263    properties: PlanProperties,
264}
265
266impl<T, P> SqlExec<T, P> {
267    pub fn new(
268        projection: Option<&Vec<usize>>,
269        schema: &SchemaRef,
270        pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
271        sql: String,
272    ) -> DataFusionResult<Self> {
273        let projected_schema = project_schema_safe(schema, projection)?;
274
275        Ok(Self {
276            projected_schema: Arc::clone(&projected_schema),
277            pool,
278            sql,
279            properties: PlanProperties::new(
280                EquivalenceProperties::new(projected_schema),
281                Partitioning::UnknownPartitioning(1),
282                EmissionType::Incremental,
283                Boundedness::Bounded,
284            ),
285        })
286    }
287
288    #[must_use]
289    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
290        Arc::clone(&self.pool)
291    }
292
293    pub fn sql(&self) -> Result<String> {
294        Ok(self.sql.clone())
295    }
296}
297
298impl<T, P> std::fmt::Debug for SqlExec<T, P> {
299    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
300        let sql = self.sql().unwrap_or_default();
301        write!(f, "SqlExec sql={sql}")
302    }
303}
304
305impl<T, P> DisplayAs for SqlExec<T, P> {
306    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
307        let sql = self.sql().unwrap_or_default();
308        write!(f, "SqlExec sql={sql}")
309    }
310}
311
312impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
313    fn name(&self) -> &'static str {
314        "SqlExec"
315    }
316
317    fn as_any(&self) -> &dyn Any {
318        self
319    }
320
321    fn schema(&self) -> SchemaRef {
322        Arc::clone(&self.projected_schema)
323    }
324
325    fn properties(&self) -> &PlanProperties {
326        &self.properties
327    }
328
329    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
330        vec![]
331    }
332
333    fn with_new_children(
334        self: Arc<Self>,
335        _children: Vec<Arc<dyn ExecutionPlan>>,
336    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
337        Ok(self)
338    }
339
340    fn execute(
341        &self,
342        _partition: usize,
343        _context: Arc<TaskContext>,
344    ) -> DataFusionResult<SendableRecordBatchStream> {
345        let sql = self.sql().map_err(to_execution_error)?;
346        tracing::debug!("SqlExec sql: {sql}");
347
348        let schema = self.schema();
349
350        let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
351
352        let stream = futures::stream::once(fut).try_flatten();
353        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
354    }
355}
356
357pub async fn get_stream<T: 'static, P: 'static>(
358    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
359    sql: String,
360    projected_schema: SchemaRef,
361) -> DataFusionResult<SendableRecordBatchStream> {
362    let conn = pool.connect().await.map_err(to_execution_error)?;
363
364    query_arrow(conn, sql, Some(projected_schema))
365        .await
366        .map_err(to_execution_error)
367}
368
369#[allow(clippy::needless_pass_by_value)]
370pub fn to_execution_error(
371    e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
372) -> DataFusionError {
373    DataFusionError::Execution(format!("{}", e.into()).to_string())
374}
375
376#[cfg(test)]
377mod tests {
378    use std::{error::Error, sync::Arc};
379
380    use datafusion::execution::context::SessionContext;
381    use datafusion::sql::TableReference;
382    use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
383
384    use crate::sql::sql_provider_datafusion::SqlTable;
385
386    fn setup_tracing() -> DefaultGuard {
387        let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
388            .with_max_level(LevelFilter::DEBUG)
389            .finish();
390
391        let dispatch = Dispatch::new(subscriber);
392        tracing::dispatcher::set_default(&dispatch)
393    }
394
395    mod sql_table_plan_to_sql_tests {
396        use std::any::Any;
397
398        use async_trait::async_trait;
399        use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
400        use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
401        use datafusion::{
402            logical_expr::{col, lit},
403            sql::TableReference,
404        };
405
406        use crate::sql::db_connection_pool::{
407            dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
408        };
409
410        use super::*;
411
412        struct MockConn {}
413
414        impl DbConnection<(), &'static dyn ToString> for MockConn {
415            fn as_any(&self) -> &dyn Any {
416                self
417            }
418
419            fn as_any_mut(&mut self) -> &mut dyn Any {
420                self
421            }
422        }
423
424        struct MockDBPool {}
425
426        #[async_trait]
427        impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
428            async fn connect(
429                &self,
430            ) -> Result<
431                Box<dyn DbConnection<(), &'static dyn ToString>>,
432                Box<dyn Error + Send + Sync>,
433            > {
434                Ok(Box::new(MockConn {}))
435            }
436
437            fn join_push_down(&self) -> JoinPushDown {
438                JoinPushDown::Disallow
439            }
440        }
441
442        fn new_sql_table(
443            table_reference: &'static str,
444            dialect: Option<Arc<dyn Dialect + Send + Sync>>,
445        ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
446            let fields = vec![
447                Field::new("name", DataType::Utf8, false),
448                Field::new("age", DataType::Int16, false),
449                Field::new(
450                    "createdDate",
451                    DataType::Timestamp(TimeUnit::Millisecond, None),
452                    false,
453                ),
454                Field::new("userId", DataType::LargeUtf8, false),
455                Field::new("active", DataType::Boolean, false),
456                Field::new("5e48", DataType::LargeUtf8, false),
457            ];
458            let schema = Arc::new(Schema::new(fields));
459            let pool = Arc::new(MockDBPool {})
460                as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
461            let table_ref = TableReference::parse_str(table_reference);
462
463            let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
464            if let Some(dialect) = dialect {
465                Ok(sql_table.with_dialect(dialect))
466            } else {
467                Ok(sql_table)
468            }
469        }
470
471        #[tokio::test]
472        async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
473            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
474            let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
475            assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
476            Ok(())
477        }
478
479        #[tokio::test]
480        async fn test_sql_to_string_with_filters_and_limit(
481        ) -> Result<(), Box<dyn Error + Send + Sync>> {
482            let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
483            let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
484            let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
485            assert_eq!(
486                result,
487                r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
488            );
489            Ok(())
490        }
491    }
492
493    #[test]
494    fn test_references() {
495        let table_ref = TableReference::bare("test");
496        assert_eq!(format!("{table_ref}"), "test");
497    }
498
499    // XXX move this to duckdb mod??
500    #[cfg(feature = "duckdb")]
501    mod duckdb_tests {
502        use super::*;
503        use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
504            DuckDBSyncParameter, DuckDbConnection,
505        };
506        use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
507        use duckdb::DuckdbConnectionManager;
508
509        #[tokio::test]
510        async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
511            let t = setup_tracing();
512            let ctx = SessionContext::new();
513            let pool: Arc<
514                dyn DbConnectionPool<
515                        r2d2::PooledConnection<DuckdbConnectionManager>,
516                        Box<dyn DuckDBSyncParameter>,
517                    > + Send
518                    + Sync,
519            > = Arc::new(DuckDbConnectionPool::new_memory()?)
520                as Arc<
521                    dyn DbConnectionPool<
522                            r2d2::PooledConnection<DuckdbConnectionManager>,
523                            Box<dyn DuckDBSyncParameter>,
524                        > + Send
525                        + Sync,
526                >;
527            let conn = pool.connect().await?;
528            let db_conn = conn
529                .as_any()
530                .downcast_ref::<DuckDbConnection>()
531                .expect("Unable to downcast to DuckDbConnection");
532            db_conn.conn.execute_batch(
533                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
534            )?;
535            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
536            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
537            let sql = "SELECT * FROM test_datafusion limit 1";
538            let df = ctx.sql(sql).await?;
539            df.show().await?;
540            drop(t);
541            Ok(())
542        }
543
544        #[tokio::test]
545        async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
546            let t = setup_tracing();
547            let ctx = SessionContext::new();
548            let pool: Arc<
549                dyn DbConnectionPool<
550                        r2d2::PooledConnection<DuckdbConnectionManager>,
551                        Box<dyn DuckDBSyncParameter>,
552                    > + Send
553                    + Sync,
554            > = Arc::new(DuckDbConnectionPool::new_memory()?)
555                as Arc<
556                    dyn DbConnectionPool<
557                            r2d2::PooledConnection<DuckdbConnectionManager>,
558                            Box<dyn DuckDBSyncParameter>,
559                        > + Send
560                        + Sync,
561                >;
562            let conn = pool.connect().await?;
563            let db_conn = conn
564                .as_any()
565                .downcast_ref::<DuckDbConnection>()
566                .expect("Unable to downcast to DuckDbConnection");
567            db_conn.conn.execute_batch(
568                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
569            )?;
570            let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
571            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
572            let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
573            let df = ctx.sql(sql).await?;
574            df.show().await?;
575            drop(t);
576            Ok(())
577        }
578    }
579}