1use crate::sql::db_connection_pool::{
8 self,
9 dbconnection::{get_schema, query_arrow},
10 DbConnectionPool,
11};
12use async_trait::async_trait;
13use datafusion::{
14 catalog::Session,
15 physical_plan::execution_plan::{Boundedness, EmissionType},
16 sql::unparser::dialect::{DefaultDialect, Dialect},
17};
18use futures::TryStreamExt;
19use snafu::prelude::*;
20use std::fmt::{Display, Formatter};
21use std::{any::Any, fmt, sync::Arc};
22
23use datafusion::{
24 arrow::datatypes::SchemaRef,
25 datasource::TableProvider,
26 error::{DataFusionError, Result as DataFusionResult},
27 execution::TaskContext,
28 logical_expr::{
29 logical_plan::builder::LogicalTableSource, Expr, LogicalPlan, LogicalPlanBuilder,
30 TableProviderFilterPushDown, TableType,
31 },
32 physical_expr::EquivalenceProperties,
33 physical_plan::{
34 stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
35 Partitioning, PlanProperties, SendableRecordBatchStream,
36 },
37 sql::{unparser::Unparser, TableReference},
38};
39
40#[cfg(feature = "federation")]
41pub mod federation;
42
43#[derive(Debug, Snafu)]
44pub enum Error {
45 #[snafu(display("Unable to get a DB connection from the pool: {source}"))]
46 UnableToGetConnectionFromPool { source: db_connection_pool::Error },
47
48 #[snafu(display("Unable to get schema: {source}"))]
49 UnableToGetSchema {
50 source: db_connection_pool::dbconnection::Error,
51 },
52
53 #[snafu(display("Unable to generate SQL: {source}"))]
54 UnableToGenerateSQL { source: DataFusionError },
55}
56
57pub type Result<T, E = Error> = std::result::Result<T, E>;
58
59#[derive(Clone)]
60pub struct SqlTable<T: 'static, P: 'static> {
61 name: String,
62 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
63 schema: SchemaRef,
64 pub table_reference: TableReference,
65 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
66}
67
68impl<T, P> fmt::Debug for SqlTable<T, P> {
69 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
70 f.debug_struct("SqlTable")
71 .field("name", &self.name)
72 .field("schema", &self.schema)
73 .field("table_reference", &self.table_reference)
74 .finish()
75 }
76}
77
78impl<T, P> SqlTable<T, P> {
79 pub async fn new(
80 name: &str,
81 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
82 table_reference: impl Into<TableReference>,
83 ) -> Result<Self> {
84 let table_reference = table_reference.into();
85 let conn = pool
86 .connect()
87 .await
88 .context(UnableToGetConnectionFromPoolSnafu)?;
89
90 let schema = get_schema(conn, &table_reference)
91 .await
92 .context(UnableToGetSchemaSnafu)?;
93
94 Ok(Self::new_with_schema(name, pool, schema, table_reference))
95 }
96
97 pub fn new_with_schema(
98 name: &str,
99 pool: &Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
100 schema: impl Into<SchemaRef>,
101 table_reference: impl Into<TableReference>,
102 ) -> Self {
103 Self {
104 name: name.to_owned(),
105 pool: Arc::clone(pool),
106 schema: schema.into(),
107 table_reference: table_reference.into(),
108 dialect: None,
109 }
110 }
111
112 pub fn scan_to_sql(
113 &self,
114 projection: Option<&Vec<usize>>,
115 filters: &[Expr],
116 limit: Option<usize>,
117 ) -> DataFusionResult<String> {
118 let logical_plan = self.create_logical_plan(projection, filters, limit)?;
119 let sql = Unparser::new(self.dialect())
120 .plan_to_sql(&logical_plan)?
121 .to_string();
122
123 Ok(sql)
124 }
125
126 fn create_logical_plan(
127 &self,
128 projection: Option<&Vec<usize>>,
129 filters: &[Expr],
130 limit: Option<usize>,
131 ) -> DataFusionResult<LogicalPlan> {
132 let table_source = LogicalTableSource::new(self.schema());
133 LogicalPlanBuilder::scan_with_filters(
134 self.table_reference.clone(),
135 Arc::new(table_source),
136 projection.cloned(),
137 filters.to_vec(),
138 )?
139 .limit(0, limit)?
140 .build()
141 }
142
143 fn create_physical_plan(
144 &self,
145 projection: Option<&Vec<usize>>,
146 sql: String,
147 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
148 Ok(Arc::new(SqlExec::new(
149 projection,
150 &self.schema(),
151 Arc::clone(&self.pool),
152 sql,
153 )?))
154 }
155
156 fn unique_id(&self) -> usize {
158 std::ptr::from_ref(self) as usize
159 }
160
161 #[must_use]
162 pub fn with_dialect(self, dialect: Arc<dyn Dialect + Send + Sync>) -> Self {
163 Self {
164 dialect: Some(dialect),
165 ..self
166 }
167 }
168
169 #[must_use]
170 pub fn name(&self) -> &str {
171 &self.name
172 }
173
174 #[must_use]
175 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
176 Arc::clone(&self.pool)
177 }
178
179 fn dialect(&self) -> &(dyn Dialect + Send + Sync) {
180 match &self.dialect {
181 Some(dialect) => dialect.as_ref(),
182 None => &DefaultDialect {},
183 }
184 }
185
186 fn arc_dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
187 match &self.dialect {
188 Some(dialect) => Arc::clone(dialect),
189 None => Arc::new(DefaultDialect {}),
190 }
191 }
192}
193
194#[async_trait]
195impl<T, P> TableProvider for SqlTable<T, P> {
196 fn as_any(&self) -> &dyn Any {
197 self
198 }
199
200 fn schema(&self) -> SchemaRef {
201 Arc::clone(&self.schema)
202 }
203
204 fn table_type(&self) -> TableType {
205 TableType::Base
206 }
207
208 fn supports_filters_pushdown(
209 &self,
210 filters: &[&Expr],
211 ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
212 let filter_push_down: Vec<TableProviderFilterPushDown> = filters
213 .iter()
214 .map(|f| match Unparser::new(self.dialect()).expr_to_sql(f) {
215 Ok(_) => TableProviderFilterPushDown::Exact,
216 Err(_) => TableProviderFilterPushDown::Unsupported,
217 })
218 .collect();
219
220 Ok(filter_push_down)
221 }
222
223 async fn scan(
224 &self,
225 _state: &dyn Session,
226 projection: Option<&Vec<usize>>,
227 filters: &[Expr],
228 limit: Option<usize>,
229 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
230 let sql = self.scan_to_sql(projection, filters, limit)?;
231 return self.create_physical_plan(projection, sql);
232 }
233}
234
235impl<T, P> Display for SqlTable<T, P> {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 write!(f, "SqlTable {}", self.name)
238 }
239}
240
241pub fn project_schema_safe(
242 schema: &SchemaRef,
243 projection: Option<&Vec<usize>>,
244) -> DataFusionResult<SchemaRef> {
245 let schema = match projection {
246 Some(columns) => {
247 if columns.is_empty() {
248 Arc::clone(schema)
249 } else {
250 Arc::new(schema.project(columns)?)
251 }
252 }
253 None => Arc::clone(schema),
254 };
255 Ok(schema)
256}
257
258#[derive(Clone)]
259pub struct SqlExec<T, P> {
260 projected_schema: SchemaRef,
261 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
262 sql: String,
263 properties: PlanProperties,
264}
265
266impl<T, P> SqlExec<T, P> {
267 pub fn new(
268 projection: Option<&Vec<usize>>,
269 schema: &SchemaRef,
270 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
271 sql: String,
272 ) -> DataFusionResult<Self> {
273 let projected_schema = project_schema_safe(schema, projection)?;
274
275 Ok(Self {
276 projected_schema: Arc::clone(&projected_schema),
277 pool,
278 sql,
279 properties: PlanProperties::new(
280 EquivalenceProperties::new(projected_schema),
281 Partitioning::UnknownPartitioning(1),
282 EmissionType::Incremental,
283 Boundedness::Bounded,
284 ),
285 })
286 }
287
288 #[must_use]
289 pub fn clone_pool(&self) -> Arc<dyn DbConnectionPool<T, P> + Send + Sync> {
290 Arc::clone(&self.pool)
291 }
292
293 pub fn sql(&self) -> Result<String> {
294 Ok(self.sql.clone())
295 }
296}
297
298impl<T, P> std::fmt::Debug for SqlExec<T, P> {
299 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
300 let sql = self.sql().unwrap_or_default();
301 write!(f, "SqlExec sql={sql}")
302 }
303}
304
305impl<T, P> DisplayAs for SqlExec<T, P> {
306 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
307 let sql = self.sql().unwrap_or_default();
308 write!(f, "SqlExec sql={sql}")
309 }
310}
311
312impl<T: 'static, P: 'static> ExecutionPlan for SqlExec<T, P> {
313 fn name(&self) -> &'static str {
314 "SqlExec"
315 }
316
317 fn as_any(&self) -> &dyn Any {
318 self
319 }
320
321 fn schema(&self) -> SchemaRef {
322 Arc::clone(&self.projected_schema)
323 }
324
325 fn properties(&self) -> &PlanProperties {
326 &self.properties
327 }
328
329 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
330 vec![]
331 }
332
333 fn with_new_children(
334 self: Arc<Self>,
335 _children: Vec<Arc<dyn ExecutionPlan>>,
336 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
337 Ok(self)
338 }
339
340 fn execute(
341 &self,
342 _partition: usize,
343 _context: Arc<TaskContext>,
344 ) -> DataFusionResult<SendableRecordBatchStream> {
345 let sql = self.sql().map_err(to_execution_error)?;
346 tracing::debug!("SqlExec sql: {sql}");
347
348 let schema = self.schema();
349
350 let fut = get_stream(Arc::clone(&self.pool), sql, Arc::clone(&schema));
351
352 let stream = futures::stream::once(fut).try_flatten();
353 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
354 }
355}
356
357pub async fn get_stream<T: 'static, P: 'static>(
358 pool: Arc<dyn DbConnectionPool<T, P> + Send + Sync>,
359 sql: String,
360 projected_schema: SchemaRef,
361) -> DataFusionResult<SendableRecordBatchStream> {
362 let conn = pool.connect().await.map_err(to_execution_error)?;
363
364 query_arrow(conn, sql, Some(projected_schema))
365 .await
366 .map_err(to_execution_error)
367}
368
369#[allow(clippy::needless_pass_by_value)]
370pub fn to_execution_error(
371 e: impl Into<Box<dyn std::error::Error + Send + Sync>>,
372) -> DataFusionError {
373 DataFusionError::Execution(format!("{}", e.into()).to_string())
374}
375
376#[cfg(test)]
377mod tests {
378 use std::{error::Error, sync::Arc};
379
380 use datafusion::execution::context::SessionContext;
381 use datafusion::sql::TableReference;
382 use tracing::{level_filters::LevelFilter, subscriber::DefaultGuard, Dispatch};
383
384 use crate::sql::sql_provider_datafusion::SqlTable;
385
386 fn setup_tracing() -> DefaultGuard {
387 let subscriber: tracing_subscriber::FmtSubscriber = tracing_subscriber::fmt()
388 .with_max_level(LevelFilter::DEBUG)
389 .finish();
390
391 let dispatch = Dispatch::new(subscriber);
392 tracing::dispatcher::set_default(&dispatch)
393 }
394
395 mod sql_table_plan_to_sql_tests {
396 use std::any::Any;
397
398 use async_trait::async_trait;
399 use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
400 use datafusion::sql::unparser::dialect::{Dialect, SqliteDialect};
401 use datafusion::{
402 logical_expr::{col, lit},
403 sql::TableReference,
404 };
405
406 use crate::sql::db_connection_pool::{
407 dbconnection::DbConnection, DbConnectionPool, JoinPushDown,
408 };
409
410 use super::*;
411
412 struct MockConn {}
413
414 impl DbConnection<(), &'static dyn ToString> for MockConn {
415 fn as_any(&self) -> &dyn Any {
416 self
417 }
418
419 fn as_any_mut(&mut self) -> &mut dyn Any {
420 self
421 }
422 }
423
424 struct MockDBPool {}
425
426 #[async_trait]
427 impl DbConnectionPool<(), &'static dyn ToString> for MockDBPool {
428 async fn connect(
429 &self,
430 ) -> Result<
431 Box<dyn DbConnection<(), &'static dyn ToString>>,
432 Box<dyn Error + Send + Sync>,
433 > {
434 Ok(Box::new(MockConn {}))
435 }
436
437 fn join_push_down(&self) -> JoinPushDown {
438 JoinPushDown::Disallow
439 }
440 }
441
442 fn new_sql_table(
443 table_reference: &'static str,
444 dialect: Option<Arc<dyn Dialect + Send + Sync>>,
445 ) -> Result<SqlTable<(), &'static dyn ToString>, Box<dyn Error + Send + Sync>> {
446 let fields = vec![
447 Field::new("name", DataType::Utf8, false),
448 Field::new("age", DataType::Int16, false),
449 Field::new(
450 "createdDate",
451 DataType::Timestamp(TimeUnit::Millisecond, None),
452 false,
453 ),
454 Field::new("userId", DataType::LargeUtf8, false),
455 Field::new("active", DataType::Boolean, false),
456 Field::new("5e48", DataType::LargeUtf8, false),
457 ];
458 let schema = Arc::new(Schema::new(fields));
459 let pool = Arc::new(MockDBPool {})
460 as Arc<dyn DbConnectionPool<(), &'static dyn ToString> + Send + Sync>;
461 let table_ref = TableReference::parse_str(table_reference);
462
463 let sql_table = SqlTable::new_with_schema(table_reference, &pool, schema, table_ref);
464 if let Some(dialect) = dialect {
465 Ok(sql_table.with_dialect(dialect))
466 } else {
467 Ok(sql_table)
468 }
469 }
470
471 #[tokio::test]
472 async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
473 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
474 let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
475 assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
476 Ok(())
477 }
478
479 #[tokio::test]
480 async fn test_sql_to_string_with_filters_and_limit(
481 ) -> Result<(), Box<dyn Error + Send + Sync>> {
482 let filters = vec![col("age").gt_eq(lit(30)).and(col("name").eq(lit("x")))];
483 let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
484 let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
485 assert_eq!(
486 result,
487 r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
488 );
489 Ok(())
490 }
491 }
492
493 #[test]
494 fn test_references() {
495 let table_ref = TableReference::bare("test");
496 assert_eq!(format!("{table_ref}"), "test");
497 }
498
499 #[cfg(feature = "duckdb")]
501 mod duckdb_tests {
502 use super::*;
503 use crate::sql::db_connection_pool::dbconnection::duckdbconn::{
504 DuckDBSyncParameter, DuckDbConnection,
505 };
506 use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool};
507 use duckdb::DuckdbConnectionManager;
508
509 #[tokio::test]
510 async fn test_duckdb_table() -> Result<(), Box<dyn Error + Send + Sync>> {
511 let t = setup_tracing();
512 let ctx = SessionContext::new();
513 let pool: Arc<
514 dyn DbConnectionPool<
515 r2d2::PooledConnection<DuckdbConnectionManager>,
516 Box<dyn DuckDBSyncParameter>,
517 > + Send
518 + Sync,
519 > = Arc::new(DuckDbConnectionPool::new_memory()?)
520 as Arc<
521 dyn DbConnectionPool<
522 r2d2::PooledConnection<DuckdbConnectionManager>,
523 Box<dyn DuckDBSyncParameter>,
524 > + Send
525 + Sync,
526 >;
527 let conn = pool.connect().await?;
528 let db_conn = conn
529 .as_any()
530 .downcast_ref::<DuckDbConnection>()
531 .expect("Unable to downcast to DuckDbConnection");
532 db_conn.conn.execute_batch(
533 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
534 )?;
535 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
536 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
537 let sql = "SELECT * FROM test_datafusion limit 1";
538 let df = ctx.sql(sql).await?;
539 df.show().await?;
540 drop(t);
541 Ok(())
542 }
543
544 #[tokio::test]
545 async fn test_duckdb_table_filter() -> Result<(), Box<dyn Error + Send + Sync>> {
546 let t = setup_tracing();
547 let ctx = SessionContext::new();
548 let pool: Arc<
549 dyn DbConnectionPool<
550 r2d2::PooledConnection<DuckdbConnectionManager>,
551 Box<dyn DuckDBSyncParameter>,
552 > + Send
553 + Sync,
554 > = Arc::new(DuckDbConnectionPool::new_memory()?)
555 as Arc<
556 dyn DbConnectionPool<
557 r2d2::PooledConnection<DuckdbConnectionManager>,
558 Box<dyn DuckDBSyncParameter>,
559 > + Send
560 + Sync,
561 >;
562 let conn = pool.connect().await?;
563 let db_conn = conn
564 .as_any()
565 .downcast_ref::<DuckDbConnection>()
566 .expect("Unable to downcast to DuckDbConnection");
567 db_conn.conn.execute_batch(
568 "CREATE TABLE test (a INTEGER, b VARCHAR); INSERT INTO test VALUES (3, 'bar');",
569 )?;
570 let duckdb_table = SqlTable::new("duckdb", &pool, "test").await?;
571 ctx.register_table("test_datafusion", Arc::new(duckdb_table))?;
572 let sql = "SELECT * FROM test_datafusion where a > 1 and b = 'bar' limit 1";
573 let df = ctx.sql(sql).await?;
574 df.show().await?;
575 drop(t);
576 Ok(())
577 }
578 }
579}