1use crate::sql::db_connection_pool::{
8 self,
9 dbconnection::{get_schema, query_arrow},
10 DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14 catalog::Session,
15 physical_plan::execution_plan::{Boundedness, EmissionType},
16 sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::{any::Any, fmt, sync::Arc};
21use std::{
22 fmt::{Display, Formatter},
23 sync::LazyLock,
24};
25
26use datafusion::{
27 arrow::datatypes::{DataType, Field, Schema, SchemaRef},
28 datasource::TableProvider,
29 error::{DataFusionError, Result as DataFusionResult},
30 execution::TaskContext,
31 logical_expr::{
32 logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
33 TableProviderFilterPushDown, TableType,
34 },
35 physical_expr::EquivalenceProperties,
36 physical_plan::{
37 stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
38 Partitioning, PlanProperties, SendableRecordBatchStream,
39 },
40 sql::{unparser::Unparser, TableReference},
41};
42
43mod expr;
44#[cfg(feature = "federation")]
45pub mod federation;
46
47#[derive(Debug, Snafu)]
48pub enum Error {
49 #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
50 UnableToGetConnectionFromPool { source: db_connection_pool::Error },
51
52 #[snafu(display("Unable to get schema: {source}"))]
53 UnableToGetSchema {
54 source: db_connection_pool::dbconnection::Error,
55 },
56
57 #[snafu(display("Unable to generate SQL: {source}"))]
58 UnableToGenerateSQL { source: DataFusionError },
59}
60
61pub type Result<T, E = Error> = std::result::Result<T, E>;
62
63#[derive(Clone)]
64pub struct SqlTable<T: 'static, P: 'static> {
65 name: String,
66 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
67 schema: SchemaRef,
68 pub table_reference: TableReference,
69 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
70}
71
72impl<T, P> fmt::Debug for SqlTable<T, P> {
73 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74 f.debug_struct("SqlTable")
75 .field("name", &self.name)
76 .field("schema", &self.schema)
77 .field("table_reference", &self.table_reference)
78 .finish()
79 }
80}
81
82impl<T, P> SqlTable<T, P> {
83 pub async fn new(
84 name: &str,
85 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
86 table_reference: impl Into<TableReference>,
87 ) -> Result<Self> {
88 let table_reference = table_reference.into();
89 let conn = pool
90 .connect()
91 .await
92 .context(UnableToGetConnectionFromPoolSnafu)?;
93
94 let schema = get_schema(conn, &table_reference)
95 .await
96 .context(UnableToGetSchemaSnafu)?;
97
98 Ok(Self::new_with_schema(name, pool, schema, table_reference))
99 }
100
101 pub fn new_with_schema(
102 name: &str,
103 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
104 schema: impl Into<SchemaRef>,
105 table_reference: impl Into<TableReference>,
106 ) -> Self {
107 Self {
108 name: name.to_owned(),
109 pool: Arc::clone(pool),
110 schema: schema.into(),
111 table_reference: table_reference.into(),
112 dialect: None,
113 }
114 }
115
116 pub fn scan_to_sql(
117 &self,
118 projection: Option<&Vec<usize>>,
119 filters: &[Expr],
120 limit: Option<usize>,
121 ) -> DataFusionResult<String> {
122 let logical_plan = self.create_logical_plan(projection, filters, limit)?;
123 let sql = Unparser::new(self.dialect())
124 .plan_to_sql(&logical_plan)?
125 .to_string();
126
127 Ok(sql)
128 }
129
130 fn create_logical_plan(
131 &self,
132 projection: Option<&Vec<usize>>,
133 filters: &[Expr],
134 limit: Option<usize>,
135 ) -> DataFusionResult<LogicalPlan> {
136 let table_source = LogicalTableSource::new(self.schema());
137 LogicalPlanBuilder::scan_with_filters(
138 self.table_reference.clone(),
139 Arc::new(table_source),
140 projection.cloned(),
141 filters.to_vec(),
142 )?
143 .limit(0, limit)?
144 .build()
145 }
146
147 fn create_physical_plan(
148 &self,
149 projection: Option<&Vec<usize>>,
150 sql: String,
151 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
152 Ok(Arc::new(SqlExec::new(
153 projection,
154 &self.schema(),
155 Arc::clone(&self.pool),
156 sql,
157 )?))
158 }
159
160 #[must_use]
161 pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
162 Self {
163 dialect: Some(dialect),
164 ..self
165 }
166 }
167
168 #[must_use]
169 pub fn name(&self) -> &str {
170 &self.name
171 }
172
173 #[must_use]
174 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
175 Arc::clone(&self.pool)
176 }
177
178 fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
179 match &self.dialect {
180 Some(dialect) => dialect.as_ref(),
181 None => &DefaultDialect {},
182 }
183 }
184}
185
186#[async_trait]
187impl<T, P> TableProvider for SqlTable<T, P> {
188 fn as_any(&self) -> &dyn Any {
189 self
190 }
191
192 fn schema(&self) -> SchemaRef {
193 Arc::clone(&self.schema)
194 }
195
196 fn table_type(&self) -> TableType {
197 TableType::Base
198 }
199
200 fn supports_filters_pushdown(
201 &self,
202 filters: &[&Expr],
203 ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
204 let filter_push_down: Vec<TableProviderFilterPushDown> = filters
205 .iter()
206 .map(|f| match Unparser::new(self.dialect()).expr_to_sql(f) {
207 Ok(_) => match expr::expr_contains_subquery(f) {
209 Ok(true) => TableProviderFilterPushDown::Unsupported,
210 Ok(false) => TableProviderFilterPushDown::Exact,
211 Err(_) => TableProviderFilterPushDown::Unsupported,
212 },
213 Err(_) => TableProviderFilterPushDown::Unsupported,
214 })
215 .collect();
216
217 Ok(filter_push_down)
218 }
219
220 async fn scan(
221 &self,
222 _state: &dyn Session,
223 projection: Option<&Vec<usize>>,
224 filters: &[Expr],
225 limit: Option<usize>,
226 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
227 let sql = self.scan_to_sql(projection, filters, limit)?;
228 return self.create_physical_plan(projection, sql);
229 }
230}
231
232impl<T, P> Display for SqlTable<T, P> {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 write!(f, "SqlTable {}", self.name)
235 }
236}
237
238static ONE_COLUMN_SCHEMA: LazyLock<SchemaRef> =
239 LazyLock::new(|| Arc::new(Schema::new(vec![Field::new("1", DataType::Int64, true)])));
240
241pub fn project_schema_safe(
242 schema: &SchemaRef,
243 projection: Option<&Vec<usize>>,
244) -> DataFusionResult<SchemaRef> {
245 let schema = match projection {
246 Some(columns) => {
247 if columns.is_empty() {
248 Arc::clone(&ONE_COLUMN_SCHEMA)
252 } else {
253 Arc::new(schema.project(columns)?)
254 }
255 }
256 None => Arc::clone(schema),
257 };
258 Ok(schema)
259}
260
261#[derive(Clone)]
262pub struct SqlExec<T, P> {
263 projected_schema: SchemaRef,
264 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
265 sql: String,
266 properties: PlanProperties,
267}
268
269impl<T, P> SqlExec<T, P> {
270 pub fn new(
271 projection: Option<&Vec<usize>>,
272 schema: &SchemaRef,
273 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
274 sql: String,
275 ) -> DataFusionResult<Self> {
276 let projected_schema = project_schema_safe(schema, projection)?;
277
278 Ok(Self {
279 projected_schema: Arc::clone(&projected_schema),
280 pool,
281 sql,
282 properties: PlanProperties::new(
283 EquivalenceProperties::new(projected_schema),
284 Partitioning::UnknownPartitioning(1),
285 EmissionType::Incremental,
286 Boundedness::Bounded,
287 ),
288 })
289 }
290
291 #[must_use]
292 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
293 Arc::clone(&self.pool)
294 }
295
296 pub fn sql(&self) -> Result<String> {
297 Ok(self.sql.clone())
298 }
299}
300
301impl<T, P> std::fmt::Debug for SqlExec<T, P> {
302 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303 let sql = self.sql().unwrap_or_default();
304 write!(f, "SqlExec sql={sql}")
305 }
306}
307
308impl<T, P> DisplayAs for SqlExec<T, P> {
309 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
310 let sql = self.sql().unwrap_or_default();
311 write!(f, "SqlExec sql={sql}")
312 }
313}
314
315impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
316 fn name(&self) -> &'static str {
317 "SqlExec"
318 }
319
320 fn as_any(&self) -> &dyn Any {
321 self
322 }
323
324 fn schema(&self) -> SchemaRef {
325 Arc::clone(&self.projected_schema)
326 }
327
328 fn properties(&self) -> &PlanProperties {
329 &self.properties
330 }
331
332 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
333 vec![]
334 }
335
336 fn with_new_children(
337 self: Arc<Self>,
338 _children: Vec<Arc<dyn ExecutionPlan>>,
339 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
340 Ok(self)
341 }
342
343 fn execute(
344 &self,
345 _partition: usize,
346 _context: Arc<TaskContext>,
347 ) -> DataFusionResult<SendableRecordBatchStream> {
348 let sql = self.sql().map_err(to_execution_error)?;
349 tracing::debug!("SqlExec sql: {sql}");
350
351 let schema = self.schema();
352
353 let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
354
355 let stream = futures::stream::once(fut).try_flatten();
356 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
357 }
358}
359
360pub async fn get_stream<T: 'static, P: 'static>(
361 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
362 sql: String,
363 projected_schema: SchemaRef,
364) -> DataFusionResult<SendableRecordBatchStream> {
365 let conn = pool.connect().await.map_err(to_execution_error)?;
366
367 query_arrow(conn, sql, Some(projected_schema))
368 .await
369 .map_err(to_execution_error)
370}
371
372#[allow(clippy::needless_pass_by_value)]
373pub fn to_execution_error(
374 e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
375) -> DataFusionError {
376 DataFusionError::Execution(format!("{}", e.into()).to_string())
377}
378
379#[cfg(test)]
380mod tests {
381 use std::{error::Error, sync::Arc};
382
383 use datafusion::execution::context::SessionContext;
384 use datafusion::sql::TableReference;
385 use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
386
387 use crate::sql::sql_provider_datafusion::SqlTable;
388
389 fn setup_tracing() -> DefaultGuard {
390 let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
391 .with_max_level(LevelFilter::DEBUG)
392 .finish();
393
394 let dispatch = Dispatch::new(subscriber);
395 tracing::dispatcher::set_default(&dispatch)
396 }
397
398 mod sql_table_plan_to_sql_tests {
399 use std::any::Any;
400
401 use async_trait::async_trait;
402 use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
403 use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
404 use datafusion::{
405 logical_expr::{col, lit},
406 sql::TableReference,
407 };
408
409 use crate::sql::db_connection_pool::{
410 dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
411 };
412
413 use super::*;
414
415 struct MockConn {}
416
417 impl DbConnection<(), &'static dyn ToString> for MockConn {
418 fn as_any(&self) -> &dyn Any {
419 self
420 }
421
422 fn as_any_mut(&mut self) -> &mut dyn Any {
423 self
424 }
425 }
426
427 struct MockDBPool {}
428
429 #[async_trait]
430 impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
431 async fn connect(
432 &self,
433 ) -> Result<
434 Box<dyn DbConnection<(), &'static dyn ToString>>,
435 Box<dyn Error + Send + Sync>,
436 > {
437 Ok(Box::new(MockConn {}))
438 }
439
440 fn join_push_down(&self) -> JoinPushDown {
441 JoinPushDown::Disallow
442 }
443 }
444
445 fn new_sql_table(
446 table_reference: &'static str,
447 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
448 ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
449 let fields = vec![
450 Field::new("name", DataType::Utf8, false),
451 Field::new("age", DataType::Int16, false),
452 Field::new(
453 "createdDate",
454 DataType::Timestamp(TimeUnit::Millisecond, None),
455 false,
456 ),
457 Field::new("userId", DataType::LargeUtf8, false),
458 Field::new("active", DataType::Boolean, false),
459 Field::new("5e48", DataType::LargeUtf8, false),
460 ];
461 let schema = Arc::new(Schema::new(fields));
462 let pool = Arc::new(MockDBPool {})
463 as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
464 let table_ref = TableReference::parse_str(table_reference);
465
466 let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
467 if let Some(dialect) = dialect {
468 Ok(sql_table.with_dialect(dialect))
469 } else {
470 Ok(sql_table)
471 }
472 }
473
474 #[tokio::test]
475 async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
476 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
477 let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
478 assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
479 Ok(())
480 }
481
482 #[tokio::test]
483 async fn test_sql_to_string_with_filters_and_limit(
484 ) -> Result<(), Box<dyn Error + Send + Sync>> {
485 let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
486 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
487 let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
488 assert_eq!(
489 result,
490 r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
491 );
492 Ok(())
493 }
494 }
495
496 #[test]
497 fn test_references() {
498 let table_ref = TableReference::bare("test");
499 assert_eq!(format!("{table_ref}"), "test");
500 }
501
502 #[cfg(feature = "duckdb")]
504 mod duckdb_tests {
505 use super::*;
506 use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
507 DuckDBSyncParameter, DuckDbConnection,
508 };
509 use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
510 use duckdb::DuckdbConnectionManager;
511
512 #[tokio::test]
513 async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
514 let t = setup_tracing();
515 let ctx = SessionContext::new();
516 let pool: Arc<
517 dyn DbConnectionPool<
518 r2d2::PooledConnection<DuckdbConnectionManager>,
519 Box<dyn DuckDBSyncParameter>,
520 > + Send
521 + Sync,
522 > = Arc::new(DuckDbConnectionPool::new_memory()?)
523 as Arc<
524 dyn DbConnectionPool<
525 r2d2::PooledConnection<DuckdbConnectionManager>,
526 Box<dyn DuckDBSyncParameter>,
527 > + Send
528 + Sync,
529 >;
530 let conn = pool.connect().await?;
531 let db_conn = conn
532 .as_any()
533 .downcast_ref::<DuckDbConnection>()
534 .expect("Unable to downcast to DuckDbConnection");
535 db_conn.conn.execute_batch(
536 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
537 )?;
538 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
539 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
540 let sql = "SELECT * FROM test_datafusion limit 1";
541 let df = ctx.sql(sql).await?;
542 df.show().await?;
543 drop(t);
544 Ok(())
545 }
546
547 #[tokio::test]
548 async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
549 let t = setup_tracing();
550 let ctx = SessionContext::new();
551 let pool: Arc<
552 dyn DbConnectionPool<
553 r2d2::PooledConnection<DuckdbConnectionManager>,
554 Box<dyn DuckDBSyncParameter>,
555 > + Send
556 + Sync,
557 > = Arc::new(DuckDbConnectionPool::new_memory()?)
558 as Arc<
559 dyn DbConnectionPool<
560 r2d2::PooledConnection<DuckdbConnectionManager>,
561 Box<dyn DuckDBSyncParameter>,
562 > + Send
563 + Sync,
564 >;
565 let conn = pool.connect().await?;
566 let db_conn = conn
567 .as_any()
568 .downcast_ref::<DuckDbConnection>()
569 .expect("Unable to downcast to DuckDbConnection");
570 db_conn.conn.execute_batch(
571 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
572 )?;
573 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
574 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
575 let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
576 let df = ctx.sql(sql).await?;
577 df.show().await?;
578 drop(t);
579 Ok(())
580 }
581 }
582}