datafusion_table_providers/sql/sql_provider_datafusion/
mod.rs

1//! # SQL DataFusion TableProvider
2//!
3//! This module implements a SQL TableProvider for DataFusion.
4//!
5//! This is used as a fallback if the `datafusion-federation` optimizer is not enabled.
6
7use crate::sql::db_connection_pool::{
8    self,
9    dbconnection::{get_schema, query_arrow},
10    DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14    catalog::Session,
15    physical_plan::execution_plan::{Boundedness, EmissionType},
16    sql::unparser::dialect::{
17        DefaultDialect, Dialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
18    },
19};
20use futures::TryStreamExt;
21use snafu::prelude::*;
22use std::fmt::{Display, Formatter};
23use std::{any::Any, fmt, sync::Arc};
24
25use datafusion::{
26    arrow::datatypes::SchemaRef,
27    datasource::TableProvider,
28    error::{DataFusionError, Result as DataFusionResult},
29    execution::TaskContext,
30    logical_expr::{
31        logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
32        TableProviderFilterPushDown, TableType,
33    },
34    physical_expr::EquivalenceProperties,
35    physical_plan::{
36        stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
37        Partitioning, PlanProperties, SendableRecordBatchStream,
38    },
39    sql::{unparser::Unparser, TableReference},
40};
41
42#[cfg(feature = "federation")]
43pub mod federation;
44
45#[derive(Debug, Snafu)]
46pub enum Error {
47    #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
48    UnableToGetConnectionFromPool { source: db_connection_pool::Error },
49
50    #[snafu(display("Unable to get schema: {source}"))]
51    UnableToGetSchema {
52        source: db_connection_pool::dbconnection::Error,
53    },
54
55    #[snafu(display("Unable to generate SQL: {source}"))]
56    UnableToGenerateSQL { source: DataFusionError },
57}
58
59#[derive(Clone, Copy, Debug)]
60pub enum Engine {
61    Spark,
62    SQLite,
63    DuckDB,
64    ODBC,
65    Postgres,
66    MySQL,
67    Default,
68}
69
70impl Engine {
71    /// Get the corresponding `Dialect` to use for unparsing
72    pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
73        match self {
74            Engine::SQLite => Arc::new(SqliteDialect {}),
75            Engine::Postgres => Arc::new(PostgreSqlDialect {}),
76            Engine::MySQL => Arc::new(MySqlDialect {}),
77            Engine::Spark | Engine::DuckDB | Engine::ODBC | Engine::Default => {
78                Arc::new(DefaultDialect {})
79            }
80        }
81    }
82}
83
84pub type Result<T, E = Error> = std::result::Result<T, E>;
85
86#[derive(Clone)]
87pub struct SqlTable<T: 'static, P: 'static> {
88    name: String,
89    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
90    schema: SchemaRef,
91    pub table_reference: TableReference,
92    engine: Engine,
93}
94
95impl<T, P> fmt::Debug for SqlTable<T, P> {
96    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
97        f.debug_struct("SqlTable")
98            .field("name", &self.name)
99            .field("schema", &self.schema)
100            .field("table_reference", &self.table_reference)
101            .field("engine", &self.engine)
102            .finish()
103    }
104}
105
106impl<T, P> SqlTable<T, P> {
107    pub async fn new(
108        name: &str,
109        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
110        table_reference: impl Into<TableReference>,
111        engine: Option<Engine>,
112    ) -> Result<Self> {
113        let table_reference = table_reference.into();
114        let conn = pool
115            .connect()
116            .await
117            .context(UnableToGetConnectionFromPoolSnafu)?;
118
119        let schema = get_schema(conn, &table_reference)
120            .await
121            .context(UnableToGetSchemaSnafu)?;
122
123        Ok(Self::new_with_schema(
124            name,
125            pool,
126            schema,
127            table_reference,
128            engine,
129        ))
130    }
131
132    pub fn new_with_schema(
133        name: &str,
134        pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
135        schema: impl Into<SchemaRef>,
136        table_reference: impl Into<TableReference>,
137        engine: Option<Engine>,
138    ) -> Self {
139        let engine = engine.unwrap_or(Engine::Default);
140        Self {
141            name: name.to_owned(),
142            pool: Arc::clone(pool),
143            schema: schema.into(),
144            table_reference: table_reference.into(),
145            engine,
146        }
147    }
148
149    pub fn scan_to_sql(
150        &self,
151        projection: Option<&Vec<usize>>,
152        filters: &[Expr],
153        limit: Option<usize>,
154    ) -> DataFusionResult<String> {
155        let logical_plan = self.create_logical_plan(projection, filters, limit)?;
156        let sql = Unparser::new(self.engine.dialect().as_ref())
157            .plan_to_sql(&logical_plan)?
158            .to_string();
159
160        Ok(sql)
161    }
162
163    fn create_logical_plan(
164        &self,
165        projection: Option<&Vec<usize>>,
166        filters: &[Expr],
167        limit: Option<usize>,
168    ) -> DataFusionResult<LogicalPlan> {
169        let table_source = LogicalTableSource::new(self.schema());
170        LogicalPlanBuilder::scan_with_filters(
171            self.table_reference.clone(),
172            Arc::new(table_source),
173            projection.cloned(),
174            filters.to_vec(),
175        )?
176        .limit(0, limit)?
177        .build()
178    }
179
180    fn create_physical_plan(
181        &self,
182        projection: Option<&Vec<usize>>,
183        sql: String,
184    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
185        Ok(Arc::new(SqlExec::new(
186            projection,
187            &self.schema(),
188            Arc::clone(&self.pool),
189            sql,
190        )?))
191    }
192
193    // Return the current memory location of the object as a unique identifier
194    fn unique_id(&self) -> usize {
195        std::ptr::from_ref(self) as usize
196    }
197
198    #[must_use]
199    pub fn name(&self) -> &str {
200        &self.name
201    }
202
203    #[must_use]
204    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
205        Arc::clone(&self.pool)
206    }
207}
208
209#[async_trait]
210impl<T, P> TableProvider for SqlTable<T, P> {
211    fn as_any(&self) -> &dyn Any {
212        self
213    }
214
215    fn schema(&self) -> SchemaRef {
216        Arc::clone(&self.schema)
217    }
218
219    fn table_type(&self) -> TableType {
220        TableType::Base
221    }
222
223    fn supports_filters_pushdown(
224        &self,
225        filters: &[&Expr],
226    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
227        let filter_push_down: Vec<TableProviderFilterPushDown> = filters
228            .iter()
229            .map(
230                |f| match Unparser::new(self.engine.dialect().as_ref()).expr_to_sql(f) {
231                    Ok(_) => TableProviderFilterPushDown::Exact,
232                    Err(_) => TableProviderFilterPushDown::Unsupported,
233                },
234            )
235            .collect();
236
237        Ok(filter_push_down)
238    }
239
240    async fn scan(
241        &self,
242        _state: &dyn Session,
243        projection: Option<&Vec<usize>>,
244        filters: &[Expr],
245        limit: Option<usize>,
246    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
247        let sql = self.scan_to_sql(projection, filters, limit)?;
248        return self.create_physical_plan(projection, sql);
249    }
250}
251
252impl<T, P> Display for SqlTable<T, P> {
253    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254        write!(f, "SqlTable {}", self.name)
255    }
256}
257
258pub fn project_schema_safe(
259    schema: &SchemaRef,
260    projection: Option<&Vec<usize>>,
261) -> DataFusionResult<SchemaRef> {
262    let schema = match projection {
263        Some(columns) => {
264            if columns.is_empty() {
265                Arc::clone(schema)
266            } else {
267                Arc::new(schema.project(columns)?)
268            }
269        }
270        None => Arc::clone(schema),
271    };
272    Ok(schema)
273}
274
275#[derive(Clone)]
276pub struct SqlExec<T, P> {
277    projected_schema: SchemaRef,
278    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
279    sql: String,
280    properties: PlanProperties,
281}
282
283impl<T, P> SqlExec<T, P> {
284    pub fn new(
285        projection: Option<&Vec<usize>>,
286        schema: &SchemaRef,
287        pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
288        sql: String,
289    ) -> DataFusionResult<Self> {
290        let projected_schema = project_schema_safe(schema, projection)?;
291
292        Ok(Self {
293            projected_schema: Arc::clone(&projected_schema),
294            pool,
295            sql,
296            properties: PlanProperties::new(
297                EquivalenceProperties::new(projected_schema),
298                Partitioning::UnknownPartitioning(1),
299                EmissionType::Incremental,
300                Boundedness::Bounded,
301            ),
302        })
303    }
304
305    #[must_use]
306    pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
307        Arc::clone(&self.pool)
308    }
309
310    pub fn sql(&self) -> Result<String> {
311        Ok(self.sql.clone())
312    }
313}
314
315impl<T, P> std::fmt::Debug for SqlExec<T, P> {
316    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317        let sql = self.sql().unwrap_or_default();
318        write!(f, "SqlExec sql={sql}")
319    }
320}
321
322impl<T, P> DisplayAs for SqlExec<T, P> {
323    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
324        let sql = self.sql().unwrap_or_default();
325        write!(f, "SqlExec sql={sql}")
326    }
327}
328
329impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
330    fn name(&self) -> &'static str {
331        "SqlExec"
332    }
333
334    fn as_any(&self) -> &dyn Any {
335        self
336    }
337
338    fn schema(&self) -> SchemaRef {
339        Arc::clone(&self.projected_schema)
340    }
341
342    fn properties(&self) -> &PlanProperties {
343        &self.properties
344    }
345
346    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
347        vec![]
348    }
349
350    fn with_new_children(
351        self: Arc<Self>,
352        _children: Vec<Arc<dyn ExecutionPlan>>,
353    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
354        Ok(self)
355    }
356
357    fn execute(
358        &self,
359        _partition: usize,
360        _context: Arc<TaskContext>,
361    ) -> DataFusionResult<SendableRecordBatchStream> {
362        let sql = self.sql().map_err(to_execution_error)?;
363        tracing::debug!("SqlExec sql: {sql}");
364
365        let schema = self.schema();
366
367        let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
368
369        let stream = futures::stream::once(fut).try_flatten();
370        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
371    }
372}
373
374pub async fn get_stream<T: 'static, P: 'static>(
375    pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
376    sql: String,
377    projected_schema: SchemaRef,
378) -> DataFusionResult<SendableRecordBatchStream> {
379    let conn = pool.connect().await.map_err(to_execution_error)?;
380
381    query_arrow(conn, sql, Some(projected_schema))
382        .await
383        .map_err(to_execution_error)
384}
385
386#[allow(clippy::needless_pass_by_value)]
387pub fn to_execution_error(
388    e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
389) -> DataFusionError {
390    DataFusionError::Execution(format!("{}", e.into()).to_string())
391}
392
393#[cfg(test)]
394mod tests {
395    use std::{error::Error, sync::Arc};
396
397    use datafusion::execution::context::SessionContext;
398    use datafusion::sql::TableReference;
399    use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
400
401    use crate::sql::sql_provider_datafusion::SqlTable;
402
403    fn setup_tracing() -> DefaultGuard {
404        let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
405            .with_max_level(LevelFilter::DEBUG)
406            .finish();
407
408        let dispatch = Dispatch::new(subscriber);
409        tracing::dispatcher::set_default(&dispatch)
410    }
411
412    mod sql_table_plan_to_sql_tests {
413        use std::any::Any;
414
415        use async_trait::async_trait;
416        use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
417        use datafusion::{
418            logical_expr::{col, lit},
419            sql::TableReference,
420        };
421
422        use crate::sql::db_connection_pool::{
423            dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
424        };
425
426        use super::super::Engine;
427        use super::*;
428
429        struct MockConn {}
430
431        impl DbConnection<(), &'static dyn ToString> for MockConn {
432            fn as_any(&self) -> &dyn Any {
433                self
434            }
435
436            fn as_any_mut(&mut self) -> &mut dyn Any {
437                self
438            }
439        }
440
441        struct MockDBPool {}
442
443        #[async_trait]
444        impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
445            async fn connect(
446                &self,
447            ) -> Result<
448                Box<dyn DbConnection<(), &'static dyn ToString>>,
449                Box<dyn Error + Send + Sync>,
450            > {
451                Ok(Box::new(MockConn {}))
452            }
453
454            fn join_push_down(&self) -> JoinPushDown {
455                JoinPushDown::Disallow
456            }
457        }
458
459        fn new_sql_table(
460            table_reference: &'static str,
461            engine: Option<Engine>,
462        ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
463            let fields = vec![
464                Field::new("name", DataType::Utf8, false),
465                Field::new("age", DataType::Int16, false),
466                Field::new(
467                    "createdDate",
468                    DataType::Timestamp(TimeUnit::Millisecond, None),
469                    false,
470                ),
471                Field::new("userId", DataType::LargeUtf8, false),
472                Field::new("active", DataType::Boolean, false),
473                Field::new("5e48", DataType::LargeUtf8, false),
474            ];
475            let schema = Arc::new(Schema::new(fields));
476            let pool = Arc::new(MockDBPool {})
477                as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
478            let table_ref = TableReference::parse_str(table_reference);
479
480            Ok(SqlTable::new_with_schema(
481                table_reference,
482                &pool,
483                schema,
484                table_ref,
485                engine,
486            ))
487        }
488
489        #[tokio::test]
490        async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
491            let sql_table = new_sql_table("users", Some(Engine::SQLite))?;
492            let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
493            assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
494            Ok(())
495        }
496
497        #[tokio::test]
498        async fn test_sql_to_string_with_filters_and_limit(
499        ) -> Result<(), Box<dyn Error + Send + Sync>> {
500            let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
501            let sql_table = new_sql_table("users", Some(Engine::SQLite))?;
502            let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
503            assert_eq!(
504                result,
505                r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
506            );
507            Ok(())
508        }
509
510        // TODO cover more test cases with different Engines & Dialects
511    }
512
513    #[test]
514    fn test_references() {
515        let table_ref = TableReference::bare("test");
516        assert_eq!(format!("{table_ref}"), "test");
517    }
518
519    // XXX move this to duckdb mod??
520    #[cfg(feature = "duckdb")]
521    mod duckdb_tests {
522        use super::*;
523        use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
524            DuckDBSyncParameter, DuckDbConnection,
525        };
526        use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
527        use duckdb::DuckdbConnectionManager;
528
529        #[tokio::test]
530        async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
531            let t = setup_tracing();
532            let ctx = SessionContext::new();
533            let pool: Arc<
534                dyn DbConnectionPool<
535                        r2d2::PooledConnection<DuckdbConnectionManager>,
536                        Box<dyn DuckDBSyncParameter>,
537                    > + Send
538                    + Sync,
539            > = Arc::new(DuckDbConnectionPool::new_memory()?)
540                as Arc<
541                    dyn DbConnectionPool<
542                            r2d2::PooledConnection<DuckdbConnectionManager>,
543                            Box<dyn DuckDBSyncParameter>,
544                        > + Send
545                        + Sync,
546                >;
547            let conn = pool.connect().await?;
548            let db_conn = conn
549                .as_any()
550                .downcast_ref::<DuckDbConnection>()
551                .expect("Unable to downcast to DuckDbConnection");
552            db_conn.conn.execute_batch(
553                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
554            )?;
555            let duckdb_table = SqlTable::new("duckdb", &pool, "test", None).await?;
556            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
557            let sql = "SELECT * FROM test_datafusion limit 1";
558            let df = ctx.sql(sql).await?;
559            df.show().await?;
560            drop(t);
561            Ok(())
562        }
563
564        #[tokio::test]
565        async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
566            let t = setup_tracing();
567            let ctx = SessionContext::new();
568            let pool: Arc<
569                dyn DbConnectionPool<
570                        r2d2::PooledConnection<DuckdbConnectionManager>,
571                        Box<dyn DuckDBSyncParameter>,
572                    > + Send
573                    + Sync,
574            > = Arc::new(DuckDbConnectionPool::new_memory()?)
575                as Arc<
576                    dyn DbConnectionPool<
577                            r2d2::PooledConnection<DuckdbConnectionManager>,
578                            Box<dyn DuckDBSyncParameter>,
579                        > + Send
580                        + Sync,
581                >;
582            let conn = pool.connect().await?;
583            let db_conn = conn
584                .as_any()
585                .downcast_ref::<DuckDbConnection>()
586                .expect("Unable to downcast to DuckDbConnection");
587            db_conn.conn.execute_batch(
588                "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
589            )?;
590            let duckdb_table = SqlTable::new("duckdb", &pool, "test", None).await?;
591            ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
592            let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
593            let df = ctx.sql(sql).await?;
594            df.show().await?;
595            drop(t);
596            Ok(())
597        }
598    }
599}