1use crate::sql::db_connection_pool::{
8 self,
9 dbconnection::{get_schema, query_arrow},
10 DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14 catalog::Session,
15 physical_plan::execution_plan::{Boundedness, EmissionType},
16 sql::unparser::dialect::{
17 DefaultDialect, Dialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
18 },
19};
20use futures::TryStreamExt;
21use snafu::prelude::*;
22use std::fmt::{Display, Formatter};
23use std::{any::Any, fmt, sync::Arc};
24
25use datafusion::{
26 arrow::datatypes::SchemaRef,
27 datasource::TableProvider,
28 error::{DataFusionError, Result as DataFusionResult},
29 execution::TaskContext,
30 logical_expr::{
31 logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
32 TableProviderFilterPushDown, TableType,
33 },
34 physical_expr::EquivalenceProperties,
35 physical_plan::{
36 stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
37 Partitioning, PlanProperties, SendableRecordBatchStream,
38 },
39 sql::{unparser::Unparser, TableReference},
40};
41
42#[cfg(feature = "federation")]
43pub mod federation;
44
45#[derive(Debug, Snafu)]
46pub enum Error {
47 #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
48 UnableToGetConnectionFromPool { source: db_connection_pool::Error },
49
50 #[snafu(display("Unable to get schema: {source}"))]
51 UnableToGetSchema {
52 source: db_connection_pool::dbconnection::Error,
53 },
54
55 #[snafu(display("Unable to generate SQL: {source}"))]
56 UnableToGenerateSQL { source: DataFusionError },
57}
58
59#[derive(Clone, Copy, Debug)]
60pub enum Engine {
61 Spark,
62 SQLite,
63 DuckDB,
64 ODBC,
65 Postgres,
66 MySQL,
67 Default,
68}
69
70impl Engine {
71 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
73 match self {
74 Engine::SQLite => Arc::new(SqliteDialect {}),
75 Engine::Postgres => Arc::new(PostgreSqlDialect {}),
76 Engine::MySQL => Arc::new(MySqlDialect {}),
77 Engine::Spark | Engine::DuckDB | Engine::ODBC | Engine::Default => {
78 Arc::new(DefaultDialect {})
79 }
80 }
81 }
82}
83
84pub type Result<T, E = Error> = std::result::Result<T, E>;
85
86#[derive(Clone)]
87pub struct SqlTable<T: 'static, P: 'static> {
88 name: String,
89 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
90 schema: SchemaRef,
91 pub table_reference: TableReference,
92 engine: Engine,
93}
94
95impl<T, P> fmt::Debug for SqlTable<T, P> {
96 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
97 f.debug_struct("SqlTable")
98 .field("name", &self.name)
99 .field("schema", &self.schema)
100 .field("table_reference", &self.table_reference)
101 .field("engine", &self.engine)
102 .finish()
103 }
104}
105
106impl<T, P> SqlTable<T, P> {
107 pub async fn new(
108 name: &str,
109 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
110 table_reference: impl Into<TableReference>,
111 engine: Option<Engine>,
112 ) -> Result<Self> {
113 let table_reference = table_reference.into();
114 let conn = pool
115 .connect()
116 .await
117 .context(UnableToGetConnectionFromPoolSnafu)?;
118
119 let schema = get_schema(conn, &table_reference)
120 .await
121 .context(UnableToGetSchemaSnafu)?;
122
123 Ok(Self::new_with_schema(
124 name,
125 pool,
126 schema,
127 table_reference,
128 engine,
129 ))
130 }
131
132 pub fn new_with_schema(
133 name: &str,
134 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
135 schema: impl Into<SchemaRef>,
136 table_reference: impl Into<TableReference>,
137 engine: Option<Engine>,
138 ) -> Self {
139 let engine = engine.unwrap_or(Engine::Default);
140 Self {
141 name: name.to_owned(),
142 pool: Arc::clone(pool),
143 schema: schema.into(),
144 table_reference: table_reference.into(),
145 engine,
146 }
147 }
148
149 pub fn scan_to_sql(
150 &self,
151 projection: Option<&Vec<usize>>,
152 filters: &[Expr],
153 limit: Option<usize>,
154 ) -> DataFusionResult<String> {
155 let logical_plan = self.create_logical_plan(projection, filters, limit)?;
156 let sql = Unparser::new(self.engine.dialect().as_ref())
157 .plan_to_sql(&logical_plan)?
158 .to_string();
159
160 Ok(sql)
161 }
162
163 fn create_logical_plan(
164 &self,
165 projection: Option<&Vec<usize>>,
166 filters: &[Expr],
167 limit: Option<usize>,
168 ) -> DataFusionResult<LogicalPlan> {
169 let table_source = LogicalTableSource::new(self.schema());
170 LogicalPlanBuilder::scan_with_filters(
171 self.table_reference.clone(),
172 Arc::new(table_source),
173 projection.cloned(),
174 filters.to_vec(),
175 )?
176 .limit(0, limit)?
177 .build()
178 }
179
180 fn create_physical_plan(
181 &self,
182 projection: Option<&Vec<usize>>,
183 sql: String,
184 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
185 Ok(Arc::new(SqlExec::new(
186 projection,
187 &self.schema(),
188 Arc::clone(&self.pool),
189 sql,
190 )?))
191 }
192
193 fn unique_id(&self) -> usize {
195 std::ptr::from_ref(self) as usize
196 }
197
198 #[must_use]
199 pub fn name(&self) -> &str {
200 &self.name
201 }
202
203 #[must_use]
204 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
205 Arc::clone(&self.pool)
206 }
207}
208
209#[async_trait]
210impl<T, P> TableProvider for SqlTable<T, P> {
211 fn as_any(&self) -> &dyn Any {
212 self
213 }
214
215 fn schema(&self) -> SchemaRef {
216 Arc::clone(&self.schema)
217 }
218
219 fn table_type(&self) -> TableType {
220 TableType::Base
221 }
222
223 fn supports_filters_pushdown(
224 &self,
225 filters: &[&Expr],
226 ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
227 let filter_push_down: Vec<TableProviderFilterPushDown> = filters
228 .iter()
229 .map(
230 |f| match Unparser::new(self.engine.dialect().as_ref()).expr_to_sql(f) {
231 Ok(_) => TableProviderFilterPushDown::Exact,
232 Err(_) => TableProviderFilterPushDown::Unsupported,
233 },
234 )
235 .collect();
236
237 Ok(filter_push_down)
238 }
239
240 async fn scan(
241 &self,
242 _state: &dyn Session,
243 projection: Option<&Vec<usize>>,
244 filters: &[Expr],
245 limit: Option<usize>,
246 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
247 let sql = self.scan_to_sql(projection, filters, limit)?;
248 return self.create_physical_plan(projection, sql);
249 }
250}
251
252impl<T, P> Display for SqlTable<T, P> {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 write!(f, "SqlTable {}", self.name)
255 }
256}
257
258pub fn project_schema_safe(
259 schema: &SchemaRef,
260 projection: Option<&Vec<usize>>,
261) -> DataFusionResult<SchemaRef> {
262 let schema = match projection {
263 Some(columns) => {
264 if columns.is_empty() {
265 Arc::clone(schema)
266 } else {
267 Arc::new(schema.project(columns)?)
268 }
269 }
270 None => Arc::clone(schema),
271 };
272 Ok(schema)
273}
274
275#[derive(Clone)]
276pub struct SqlExec<T, P> {
277 projected_schema: SchemaRef,
278 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
279 sql: String,
280 properties: PlanProperties,
281}
282
283impl<T, P> SqlExec<T, P> {
284 pub fn new(
285 projection: Option<&Vec<usize>>,
286 schema: &SchemaRef,
287 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
288 sql: String,
289 ) -> DataFusionResult<Self> {
290 let projected_schema = project_schema_safe(schema, projection)?;
291
292 Ok(Self {
293 projected_schema: Arc::clone(&projected_schema),
294 pool,
295 sql,
296 properties: PlanProperties::new(
297 EquivalenceProperties::new(projected_schema),
298 Partitioning::UnknownPartitioning(1),
299 EmissionType::Incremental,
300 Boundedness::Bounded,
301 ),
302 })
303 }
304
305 #[must_use]
306 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
307 Arc::clone(&self.pool)
308 }
309
310 pub fn sql(&self) -> Result<String> {
311 Ok(self.sql.clone())
312 }
313}
314
315impl<T, P> std::fmt::Debug for SqlExec<T, P> {
316 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317 let sql = self.sql().unwrap_or_default();
318 write!(f, "SqlExec sql={sql}")
319 }
320}
321
322impl<T, P> DisplayAs for SqlExec<T, P> {
323 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
324 let sql = self.sql().unwrap_or_default();
325 write!(f, "SqlExec sql={sql}")
326 }
327}
328
329impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
330 fn name(&self) -> &'static str {
331 "SqlExec"
332 }
333
334 fn as_any(&self) -> &dyn Any {
335 self
336 }
337
338 fn schema(&self) -> SchemaRef {
339 Arc::clone(&self.projected_schema)
340 }
341
342 fn properties(&self) -> &PlanProperties {
343 &self.properties
344 }
345
346 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
347 vec![]
348 }
349
350 fn with_new_children(
351 self: Arc<Self>,
352 _children: Vec<Arc<dyn ExecutionPlan>>,
353 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
354 Ok(self)
355 }
356
357 fn execute(
358 &self,
359 _partition: usize,
360 _context: Arc<TaskContext>,
361 ) -> DataFusionResult<SendableRecordBatchStream> {
362 let sql = self.sql().map_err(to_execution_error)?;
363 tracing::debug!("SqlExec sql: {sql}");
364
365 let schema = self.schema();
366
367 let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
368
369 let stream = futures::stream::once(fut).try_flatten();
370 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
371 }
372}
373
374pub async fn get_stream<T: 'static, P: 'static>(
375 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
376 sql: String,
377 projected_schema: SchemaRef,
378) -> DataFusionResult<SendableRecordBatchStream> {
379 let conn = pool.connect().await.map_err(to_execution_error)?;
380
381 query_arrow(conn, sql, Some(projected_schema))
382 .await
383 .map_err(to_execution_error)
384}
385
386#[allow(clippy::needless_pass_by_value)]
387pub fn to_execution_error(
388 e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
389) -> DataFusionError {
390 DataFusionError::Execution(format!("{}", e.into()).to_string())
391}
392
393#[cfg(test)]
394mod tests {
395 use std::{error::Error, sync::Arc};
396
397 use datafusion::execution::context::SessionContext;
398 use datafusion::sql::TableReference;
399 use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
400
401 use crate::sql::sql_provider_datafusion::SqlTable;
402
403 fn setup_tracing() -> DefaultGuard {
404 let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
405 .with_max_level(LevelFilter::DEBUG)
406 .finish();
407
408 let dispatch = Dispatch::new(subscriber);
409 tracing::dispatcher::set_default(&dispatch)
410 }
411
412 mod sql_table_plan_to_sql_tests {
413 use std::any::Any;
414
415 use async_trait::async_trait;
416 use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
417 use datafusion::{
418 logical_expr::{col, lit},
419 sql::TableReference,
420 };
421
422 use crate::sql::db_connection_pool::{
423 dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
424 };
425
426 use super::super::Engine;
427 use super::*;
428
429 struct MockConn {}
430
431 impl DbConnection<(), &'static dyn ToString> for MockConn {
432 fn as_any(&self) -> &dyn Any {
433 self
434 }
435
436 fn as_any_mut(&mut self) -> &mut dyn Any {
437 self
438 }
439 }
440
441 struct MockDBPool {}
442
443 #[async_trait]
444 impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
445 async fn connect(
446 &self,
447 ) -> Result<
448 Box<dyn DbConnection<(), &'static dyn ToString>>,
449 Box<dyn Error + Send + Sync>,
450 > {
451 Ok(Box::new(MockConn {}))
452 }
453
454 fn join_push_down(&self) -> JoinPushDown {
455 JoinPushDown::Disallow
456 }
457 }
458
459 fn new_sql_table(
460 table_reference: &'static str,
461 engine: Option<Engine>,
462 ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
463 let fields = vec![
464 Field::new("name", DataType::Utf8, false),
465 Field::new("age", DataType::Int16, false),
466 Field::new(
467 "createdDate",
468 DataType::Timestamp(TimeUnit::Millisecond, None),
469 false,
470 ),
471 Field::new("userId", DataType::LargeUtf8, false),
472 Field::new("active", DataType::Boolean, false),
473 Field::new("5e48", DataType::LargeUtf8, false),
474 ];
475 let schema = Arc::new(Schema::new(fields));
476 let pool = Arc::new(MockDBPool {})
477 as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
478 let table_ref = TableReference::parse_str(table_reference);
479
480 Ok(SqlTable::new_with_schema(
481 table_reference,
482 &pool,
483 schema,
484 table_ref,
485 engine,
486 ))
487 }
488
489 #[tokio::test]
490 async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
491 let sql_table = new_sql_table("users", Some(Engine::SQLite))?;
492 let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
493 assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
494 Ok(())
495 }
496
497 #[tokio::test]
498 async fn test_sql_to_string_with_filters_and_limit(
499 ) -> Result<(), Box<dyn Error + Send + Sync>> {
500 let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
501 let sql_table = new_sql_table("users", Some(Engine::SQLite))?;
502 let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
503 assert_eq!(
504 result,
505 r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
506 );
507 Ok(())
508 }
509
510 }
512
513 #[test]
514 fn test_references() {
515 let table_ref = TableReference::bare("test");
516 assert_eq!(format!("{table_ref}"), "test");
517 }
518
519 #[cfg(feature = "duckdb")]
521 mod duckdb_tests {
522 use super::*;
523 use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
524 DuckDBSyncParameter, DuckDbConnection,
525 };
526 use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
527 use duckdb::DuckdbConnectionManager;
528
529 #[tokio::test]
530 async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
531 let t = setup_tracing();
532 let ctx = SessionContext::new();
533 let pool: Arc<
534 dyn DbConnectionPool<
535 r2d2::PooledConnection<DuckdbConnectionManager>,
536 Box<dyn DuckDBSyncParameter>,
537 > + Send
538 + Sync,
539 > = Arc::new(DuckDbConnectionPool::new_memory()?)
540 as Arc<
541 dyn DbConnectionPool<
542 r2d2::PooledConnection<DuckdbConnectionManager>,
543 Box<dyn DuckDBSyncParameter>,
544 > + Send
545 + Sync,
546 >;
547 let conn = pool.connect().await?;
548 let db_conn = conn
549 .as_any()
550 .downcast_ref::<DuckDbConnection>()
551 .expect("Unable to downcast to DuckDbConnection");
552 db_conn.conn.execute_batch(
553 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
554 )?;
555 let duckdb_table = SqlTable::new("duckdb", &pool, "test", None).await?;
556 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
557 let sql = "SELECT * FROM test_datafusion limit 1";
558 let df = ctx.sql(sql).await?;
559 df.show().await?;
560 drop(t);
561 Ok(())
562 }
563
564 #[tokio::test]
565 async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
566 let t = setup_tracing();
567 let ctx = SessionContext::new();
568 let pool: Arc<
569 dyn DbConnectionPool<
570 r2d2::PooledConnection<DuckdbConnectionManager>,
571 Box<dyn DuckDBSyncParameter>,
572 > + Send
573 + Sync,
574 > = Arc::new(DuckDbConnectionPool::new_memory()?)
575 as Arc<
576 dyn DbConnectionPool<
577 r2d2::PooledConnection<DuckdbConnectionManager>,
578 Box<dyn DuckDBSyncParameter>,
579 > + Send
580 + Sync,
581 >;
582 let conn = pool.connect().await?;
583 let db_conn = conn
584 .as_any()
585 .downcast_ref::<DuckDbConnection>()
586 .expect("Unable to downcast to DuckDbConnection");
587 db_conn.conn.execute_batch(
588 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
589 )?;
590 let duckdb_table = SqlTable::new("duckdb", &pool, "test", None).await?;
591 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
592 let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
593 let df = ctx.sql(sql).await?;
594 df.show().await?;
595 drop(t);
596 Ok(())
597 }
598 }
599}