1use crate::sql::db_connection_pool::{
8 self,
9 dbconnection::{get_schema, query_arrow},
10 DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14 catalog::Session,
15 physical_plan::execution_plan::{Boundedness, EmissionType},
16 sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::{any::Any, fmt, sync::Arc};
21use std::{
22 fmt::{Display, Formatter},
23 sync::LazyLock,
24};
25
26use datafusion::{
27 arrow::datatypes::{DataType, Field, Schema, SchemaRef},
28 datasource::TableProvider,
29 error::{DataFusionError, Result as DataFusionResult},
30 execution::TaskContext,
31 logical_expr::{
32 logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
33 TableProviderFilterPushDown, TableType,
34 },
35 physical_expr::EquivalenceProperties,
36 physical_plan::{
37 stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
38 Partitioning, PlanProperties, SendableRecordBatchStream,
39 },
40 sql::{unparser::Unparser, TableReference},
41};
42
43mod expr;
44#[cfg(feature = "federation")]
45pub mod federation;
46
47#[derive(Debug, Snafu)]
48pub enum Error {
49 #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
50 UnableToGetConnectionFromPool { source: db_connection_pool::Error },
51
52 #[snafu(display("Unable to get schema: {source}"))]
53 UnableToGetSchema {
54 source: db_connection_pool::dbconnection::Error,
55 },
56
57 #[snafu(display("Unable to generate SQL: {source}"))]
58 UnableToGenerateSQL { source: DataFusionError },
59}
60
61pub type Result<T, E = Error> = std::result::Result<T, E>;
62
63#[derive(Clone)]
64pub struct SqlTable<T: 'static, P: 'static> {
65 name: String,
66 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
67 schema: SchemaRef,
68 pub table_reference: TableReference,
69 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
70}
71
72impl<T, P> fmt::Debug for SqlTable<T, P> {
73 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74 f.debug_struct("SqlTable")
75 .field("name", &self.name)
76 .field("schema", &self.schema)
77 .field("table_reference", &self.table_reference)
78 .finish()
79 }
80}
81
82impl<T, P> SqlTable<T, P> {
83 pub async fn new(
84 name: &str,
85 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
86 table_reference: impl Into<TableReference>,
87 ) -> Result<Self> {
88 let table_reference = table_reference.into();
89 let conn = pool
90 .connect()
91 .await
92 .context(UnableToGetConnectionFromPoolSnafu)?;
93
94 let schema = get_schema(conn, &table_reference)
95 .await
96 .context(UnableToGetSchemaSnafu)?;
97
98 Ok(Self::new_with_schema(name, pool, schema, table_reference))
99 }
100
101 pub fn new_with_schema(
102 name: &str,
103 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
104 schema: impl Into<SchemaRef>,
105 table_reference: impl Into<TableReference>,
106 ) -> Self {
107 Self {
108 name: name.to_owned(),
109 pool: Arc::clone(pool),
110 schema: schema.into(),
111 table_reference: table_reference.into(),
112 dialect: None,
113 }
114 }
115
116 pub fn scan_to_sql(
117 &self,
118 projection: Option<&Vec<usize>>,
119 filters: &[Expr],
120 limit: Option<usize>,
121 ) -> DataFusionResult<String> {
122 let logical_plan = self.create_logical_plan(projection, filters, limit)?;
123 let sql = Unparser::new(self.dialect())
124 .plan_to_sql(&logical_plan)?
125 .to_string();
126
127 Ok(sql)
128 }
129
130 fn create_logical_plan(
131 &self,
132 projection: Option<&Vec<usize>>,
133 filters: &[Expr],
134 limit: Option<usize>,
135 ) -> DataFusionResult<LogicalPlan> {
136 let table_source = LogicalTableSource::new(self.schema());
137 LogicalPlanBuilder::scan_with_filters(
138 self.table_reference.clone(),
139 Arc::new(table_source),
140 projection.cloned(),
141 filters.to_vec(),
142 )?
143 .limit(0, limit)?
144 .build()
145 }
146
147 fn create_physical_plan(
148 &self,
149 projection: Option<&Vec<usize>>,
150 sql: String,
151 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
152 Ok(Arc::new(SqlExec::new(
153 projection,
154 &self.schema(),
155 Arc::clone(&self.pool),
156 sql,
157 )?))
158 }
159
160 #[must_use]
161 pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
162 Self {
163 dialect: Some(dialect),
164 ..self
165 }
166 }
167
168 #[must_use]
169 pub fn name(&self) -> &str {
170 &self.name
171 }
172
173 #[must_use]
174 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
175 Arc::clone(&self.pool)
176 }
177
178 fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
179 match &self.dialect {
180 Some(dialect) => dialect.as_ref(),
181 None => &DefaultDialect {},
182 }
183 }
184}
185
186#[async_trait]
187impl<T, P> TableProvider for SqlTable<T, P> {
188 fn as_any(&self) -> &dyn Any {
189 self
190 }
191
192 fn schema(&self) -> SchemaRef {
193 Arc::clone(&self.schema)
194 }
195
196 fn table_type(&self) -> TableType {
197 TableType::Base
198 }
199
200 fn supports_filters_pushdown(
201 &self,
202 filters: &[&Expr],
203 ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
204 let filter_push_down = default_filter_pushdown(filters, self.dialect());
205 Ok(filter_push_down)
206 }
207
208 async fn scan(
209 &self,
210 _state: &dyn Session,
211 projection: Option<&Vec<usize>>,
212 filters: &[Expr],
213 limit: Option<usize>,
214 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
215 let sql = self.scan_to_sql(projection, filters, limit)?;
216 return self.create_physical_plan(projection, sql);
217 }
218}
219
220pub fn default_filter_pushdown(
221 filters: &[&Expr],
222 dialect: &dyn Dialect,
223) -> Vec<TableProviderFilterPushDown> {
224 filters
225 .iter()
226 .map(|f| match Unparser::new(dialect).expr_to_sql(f) {
227 Ok(_) => match expr::expr_contains_subquery(f) {
229 Ok(true) => TableProviderFilterPushDown::Unsupported,
230 Ok(false) => TableProviderFilterPushDown::Exact,
231 Err(_) => TableProviderFilterPushDown::Unsupported,
232 },
233 Err(_) => TableProviderFilterPushDown::Unsupported,
234 })
235 .collect()
236}
237
238impl<T, P> Display for SqlTable<T, P> {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 write!(f, "SqlTable {}", self.name)
241 }
242}
243
244static ONE_COLUMN_SCHEMA: LazyLock<SchemaRef> =
245 LazyLock::new(|| Arc::new(Schema::new(vec![Field::new("1", DataType::Int64, true)])));
246
247pub fn project_schema_safe(
248 schema: &SchemaRef,
249 projection: Option<&Vec<usize>>,
250) -> DataFusionResult<SchemaRef> {
251 let schema = match projection {
252 Some(columns) => {
253 if columns.is_empty() {
254 Arc::clone(&ONE_COLUMN_SCHEMA)
258 } else {
259 Arc::new(schema.project(columns)?)
260 }
261 }
262 None => Arc::clone(schema),
263 };
264 Ok(schema)
265}
266
267#[derive(Clone)]
268pub struct SqlExec<T, P> {
269 projected_schema: SchemaRef,
270 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
271 sql: String,
272 properties: PlanProperties,
273}
274
275impl<T, P> SqlExec<T, P> {
276 pub fn new(
277 projection: Option<&Vec<usize>>,
278 schema: &SchemaRef,
279 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
280 sql: String,
281 ) -> DataFusionResult<Self> {
282 let projected_schema = project_schema_safe(schema, projection)?;
283
284 Ok(Self {
285 projected_schema: Arc::clone(&projected_schema),
286 pool,
287 sql,
288 properties: PlanProperties::new(
289 EquivalenceProperties::new(projected_schema),
290 Partitioning::UnknownPartitioning(1),
291 EmissionType::Incremental,
292 Boundedness::Bounded,
293 ),
294 })
295 }
296
297 #[must_use]
298 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
299 Arc::clone(&self.pool)
300 }
301
302 pub fn sql(&self) -> Result<String> {
303 Ok(self.sql.clone())
304 }
305}
306
307impl<T, P> std::fmt::Debug for SqlExec<T, P> {
308 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
309 let sql = self.sql().unwrap_or_default();
310 write!(f, "SqlExec sql={sql}")
311 }
312}
313
314impl<T, P> DisplayAs for SqlExec<T, P> {
315 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
316 let sql = self.sql().unwrap_or_default();
317 write!(f, "SqlExec sql={sql}")
318 }
319}
320
321impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
322 fn name(&self) -> &'static str {
323 "SqlExec"
324 }
325
326 fn as_any(&self) -> &dyn Any {
327 self
328 }
329
330 fn schema(&self) -> SchemaRef {
331 Arc::clone(&self.projected_schema)
332 }
333
334 fn properties(&self) -> &PlanProperties {
335 &self.properties
336 }
337
338 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
339 vec![]
340 }
341
342 fn with_new_children(
343 self: Arc<Self>,
344 _children: Vec<Arc<dyn ExecutionPlan>>,
345 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
346 Ok(self)
347 }
348
349 fn execute(
350 &self,
351 _partition: usize,
352 _context: Arc<TaskContext>,
353 ) -> DataFusionResult<SendableRecordBatchStream> {
354 let sql = self.sql().map_err(to_execution_error)?;
355 tracing::debug!("SqlExec sql: {sql}");
356
357 let schema = self.schema();
358
359 let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
360
361 let stream = futures::stream::once(fut).try_flatten();
362 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
363 }
364}
365
366pub async fn get_stream<T: 'static, P: 'static>(
367 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
368 sql: String,
369 projected_schema: SchemaRef,
370) -> DataFusionResult<SendableRecordBatchStream> {
371 let conn = pool.connect().await.map_err(to_execution_error)?;
372
373 query_arrow(conn, sql, Some(projected_schema))
374 .await
375 .map_err(to_execution_error)
376}
377
378#[allow(clippy::needless_pass_by_value)]
379pub fn to_execution_error(
380 e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
381) -> DataFusionError {
382 DataFusionError::Execution(format!("{}", e.into()).to_string())
383}
384
385#[cfg(test)]
386mod tests {
387 use std::{error::Error, sync::Arc};
388
389 use datafusion::execution::context::SessionContext;
390 use datafusion::sql::TableReference;
391 use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
392
393 use crate::sql::sql_provider_datafusion::SqlTable;
394
395 fn setup_tracing() -> DefaultGuard {
396 let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
397 .with_max_level(LevelFilter::DEBUG)
398 .finish();
399
400 let dispatch = Dispatch::new(subscriber);
401 tracing::dispatcher::set_default(&dispatch)
402 }
403
404 mod sql_table_plan_to_sql_tests {
405 use std::any::Any;
406
407 use async_trait::async_trait;
408 use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
409 use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
410 use datafusion::{
411 logical_expr::{col, lit},
412 sql::TableReference,
413 };
414
415 use crate::sql::db_connection_pool::{
416 dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
417 };
418
419 use super::*;
420
421 struct MockConn {}
422
423 impl DbConnection<(), &'static dyn ToString> for MockConn {
424 fn as_any(&self) -> &dyn Any {
425 self
426 }
427
428 fn as_any_mut(&mut self) -> &mut dyn Any {
429 self
430 }
431 }
432
433 struct MockDBPool {}
434
435 #[async_trait]
436 impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
437 async fn connect(
438 &self,
439 ) -> Result<
440 Box<dyn DbConnection<(), &'static dyn ToString>>,
441 Box<dyn Error + Send + Sync>,
442 > {
443 Ok(Box::new(MockConn {}))
444 }
445
446 fn join_push_down(&self) -> JoinPushDown {
447 JoinPushDown::Disallow
448 }
449 }
450
451 fn new_sql_table(
452 table_reference: &'static str,
453 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
454 ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
455 let fields = vec![
456 Field::new("name", DataType::Utf8, false),
457 Field::new("age", DataType::Int16, false),
458 Field::new(
459 "createdDate",
460 DataType::Timestamp(TimeUnit::Millisecond, None),
461 false,
462 ),
463 Field::new("userId", DataType::LargeUtf8, false),
464 Field::new("active", DataType::Boolean, false),
465 Field::new("5e48", DataType::LargeUtf8, false),
466 ];
467 let schema = Arc::new(Schema::new(fields));
468 let pool = Arc::new(MockDBPool {})
469 as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
470 let table_ref = TableReference::parse_str(table_reference);
471
472 let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
473 if let Some(dialect) = dialect {
474 Ok(sql_table.with_dialect(dialect))
475 } else {
476 Ok(sql_table)
477 }
478 }
479
480 #[tokio::test]
481 async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
482 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
483 let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
484 assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
485 Ok(())
486 }
487
488 #[tokio::test]
489 async fn test_sql_to_string_with_filters_and_limit(
490 ) -> Result<(), Box<dyn Error + Send + Sync>> {
491 let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
492 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
493 let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
494 assert_eq!(
495 result,
496 r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
497 );
498 Ok(())
499 }
500 }
501
502 #[test]
503 fn test_references() {
504 let table_ref = TableReference::bare("test");
505 assert_eq!(format!("{table_ref}"), "test");
506 }
507
508 #[cfg(feature = "duckdb")]
510 mod duckdb_tests {
511 use super::*;
512 use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
513 DuckDBSyncParameter, DuckDbConnection,
514 };
515 use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
516 use duckdb::DuckdbConnectionManager;
517
518 #[tokio::test]
519 async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
520 let t = setup_tracing();
521 let ctx = SessionContext::new();
522 let pool: Arc<
523 dyn DbConnectionPool<
524 r2d2::PooledConnection<DuckdbConnectionManager>,
525 Box<dyn DuckDBSyncParameter>,
526 > + Send
527 + Sync,
528 > = Arc::new(DuckDbConnectionPool::new_memory()?)
529 as Arc<
530 dyn DbConnectionPool<
531 r2d2::PooledConnection<DuckdbConnectionManager>,
532 Box<dyn DuckDBSyncParameter>,
533 > + Send
534 + Sync,
535 >;
536 let conn = pool.connect().await?;
537 let db_conn = conn
538 .as_any()
539 .downcast_ref::<DuckDbConnection>()
540 .expect("Unable to downcast to DuckDbConnection");
541 db_conn.conn.execute_batch(
542 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
543 )?;
544 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
545 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
546 let sql = "SELECT * FROM test_datafusion limit 1";
547 let df = ctx.sql(sql).await?;
548 df.show().await?;
549 drop(t);
550 Ok(())
551 }
552
553 #[tokio::test]
554 async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
555 let t = setup_tracing();
556 let ctx = SessionContext::new();
557 let pool: Arc<
558 dyn DbConnectionPool<
559 r2d2::PooledConnection<DuckdbConnectionManager>,
560 Box<dyn DuckDBSyncParameter>,
561 > + Send
562 + Sync,
563 > = Arc::new(DuckDbConnectionPool::new_memory()?)
564 as Arc<
565 dyn DbConnectionPool<
566 r2d2::PooledConnection<DuckdbConnectionManager>,
567 Box<dyn DuckDBSyncParameter>,
568 > + Send
569 + Sync,
570 >;
571 let conn = pool.connect().await?;
572 let db_conn = conn
573 .as_any()
574 .downcast_ref::<DuckDbConnection>()
575 .expect("Unable to downcast to DuckDbConnection");
576 db_conn.conn.execute_batch(
577 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
578 )?;
579 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
580 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
581 let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
582 let df = ctx.sql(sql).await?;
583 df.show().await?;
584 drop(t);
585 Ok(())
586 }
587 }
588}