1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::{
15 tree_node::{Transformed, TreeNode},
16 Statistics,
17 },
18 config::ConfigOptions,
19 error::{DataFusionError, Result},
20 execution::{context::SessionState, TaskContext},
21 logical_expr::{Extension, LogicalPlan},
22 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
23 physical_expr::EquivalenceProperties,
24 physical_plan::{
25 execution_plan::{Boundedness, EmissionType},
26 filter_pushdown::{
27 ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown,
28 },
29 metrics::MetricsSet,
30 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties,
31 SendableRecordBatchStream,
32 },
33 sql::{sqlparser::ast::Statement, unparser::Unparser},
34};
35
36pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef, SqlQueryRewriter};
37pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
38pub use table::{RemoteTable, SQLTable, SQLTableSource};
39pub use table_reference::RemoteTableRef;
40
41use crate::{
42 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
43};
44
45#[derive(Debug)]
47pub struct SQLFederationProvider {
48 pub optimizer: Arc<Optimizer>,
49 pub executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54 Self {
55 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56 SQLFederationOptimizerRule::new(executor.clone()),
57 )])),
58 executor,
59 }
60 }
61}
62
63impl FederationProvider for SQLFederationProvider {
64 fn name(&self) -> &str {
65 "sql_federation_provider"
66 }
67
68 fn compute_context(&self) -> Option<String> {
69 self.executor.compute_context()
70 }
71
72 fn optimizer(&self) -> Option<Arc<Optimizer>> {
73 Some(self.optimizer.clone())
74 }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79 planner: Arc<SQLFederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84 Self {
85 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86 }
87 }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91 fn rewrite(
97 &self,
98 plan: LogicalPlan,
99 _config: &dyn OptimizerConfig,
100 ) -> Result<Transformed<LogicalPlan>> {
101 if let LogicalPlan::Extension(Extension { ref node }) = plan {
102 if node.name() == "Federated" {
103 return Ok(Transformed::no(plan));
105 }
106 }
107
108 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109 let ext_node = Extension {
110 node: Arc::new(fed_plan),
111 };
112
113 let mut plan = LogicalPlan::Extension(ext_node);
114 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
115 plan = rewriter(plan)?;
116 }
117
118 Ok(Transformed::yes(plan))
119 }
120
121 fn name(&self) -> &str {
123 "federate_sql"
124 }
125
126 fn supports_rewrite(&self) -> bool {
128 true
129 }
130}
131
132#[derive(Debug)]
133pub struct SQLFederationPlanner {
134 pub executor: Arc<dyn SQLExecutor>,
135}
136
137impl SQLFederationPlanner {
138 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
139 Self { executor }
140 }
141}
142
143#[async_trait]
144impl FederationPlanner for SQLFederationPlanner {
145 async fn plan_federation(
146 &self,
147 node: &FederatedPlanNode,
148 _session_state: &SessionState,
149 ) -> Result<Arc<dyn ExecutionPlan>> {
150 let schema = Arc::new(node.plan().schema().as_arrow().clone());
151 let plan = node.plan().clone();
152 let statistics = self.executor.statistics(&plan).await?;
153 let input = Arc::new(VirtualExecutionPlan::new(
154 plan,
155 Arc::clone(&self.executor),
156 statistics,
157 ));
158 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
159 Ok(Arc::new(schema_cast_exec))
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct VirtualExecutionPlan {
165 plan: LogicalPlan,
166 executor: Arc<dyn SQLExecutor>,
167 props: Arc<PlanProperties>,
168 statistics: Statistics,
169 filters: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172impl VirtualExecutionPlan {
173 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
174 let schema: Schema = plan.schema().as_arrow().clone();
175 let props = Arc::new(PlanProperties::new(
176 EquivalenceProperties::new(Arc::new(schema)),
177 Partitioning::UnknownPartitioning(1),
178 EmissionType::Incremental,
179 Boundedness::Bounded,
180 ));
181 Self {
182 plan,
183 executor,
184 props,
185 statistics,
186 filters: Vec::new(),
187 }
188 }
189
190 pub fn plan(&self) -> &LogicalPlan {
191 &self.plan
192 }
193
194 pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
195 &self.executor
196 }
197
198 pub fn statistics(&self) -> &Statistics {
199 &self.statistics
200 }
201
202 fn schema(&self) -> SchemaRef {
203 let df_schema = self.plan.schema().as_arrow().clone();
204 Arc::new(df_schema)
205 }
206
207 fn final_sql(&self) -> Result<String> {
208 let plan = self.plan.clone();
209 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
210 let (logical_optimizers, ast_analyzers, sql_query_rewriters) = gather_analyzers(&plan)?;
211 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
212 let ast = self.plan_to_statement(&plan)?;
213 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
214 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
215 apply_sql_query_rewriters(ast.to_string(), sql_query_rewriters)
216 }
217
218 fn rewrite_with_executor_ast_analyzer(
219 &self,
220 ast: Statement,
221 ) -> Result<Statement, datafusion::error::DataFusionError> {
222 if let Some(mut analyzer) = self.executor.ast_analyzer() {
223 Ok(analyzer(ast)?)
224 } else {
225 Ok(ast)
226 }
227 }
228
229 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
230 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
231 }
232}
233
234fn gather_analyzers(
235 plan: &LogicalPlan,
236) -> Result<(
237 Vec<LogicalOptimizer>,
238 Vec<AstAnalyzer>,
239 Vec<SqlQueryRewriter>,
240)> {
241 let mut logical_optimizers = vec![];
242 let mut ast_analyzers = vec![];
243 let mut sql_query_rewriters = vec![];
244
245 plan.apply(|node| {
246 if let LogicalPlan::TableScan(table) = node {
247 let provider = get_table_source(&table.source)
248 .expect("caller is virtual exec so this is valid")
249 .expect("caller is virtual exec so this is valid");
250 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
251 if let Some(analyzer) = source.table.logical_optimizer() {
252 logical_optimizers.push(analyzer);
253 }
254 if let Some(analyzer) = source.table.ast_analyzer() {
255 ast_analyzers.push(analyzer);
256 }
257 if let Some(rewriter) = source.table.sql_query_rewriter() {
258 sql_query_rewriters.push(rewriter);
259 }
260 }
261 }
262 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
263 })?;
264
265 Ok((logical_optimizers, ast_analyzers, sql_query_rewriters))
266}
267
268fn apply_logical_optimizers(
269 mut plan: LogicalPlan,
270 analyzers: Vec<LogicalOptimizer>,
271) -> Result<LogicalPlan> {
272 for mut analyzer in analyzers {
273 let old_schema = plan.schema().clone();
274 plan = analyzer(plan)?;
275 let new_schema = plan.schema();
276 if &old_schema != new_schema {
277 return Err(DataFusionError::Execution(format!(
278 "Schema altered during logical analysis, expected: {}, found: {}",
279 old_schema, new_schema
280 )));
281 }
282 }
283 Ok(plan)
284}
285
286fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
287 for mut analyzer in analyzers {
288 statement = analyzer(statement)?;
289 }
290 Ok(statement)
291}
292
293fn apply_sql_query_rewriters(
294 mut query: String,
295 rewriters: Vec<SqlQueryRewriter>,
296) -> Result<String> {
297 for mut rewriter in rewriters {
298 query = rewriter(query)?;
299 }
300 Ok(query)
301}
302
303impl DisplayAs for VirtualExecutionPlan {
304 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
305 write!(f, "VirtualExecutionPlan")?;
306 write!(f, " name={}", self.executor.name())?;
307 if let Some(ctx) = self.executor.compute_context() {
308 write!(f, " compute_context={ctx}")?;
309 };
310 let mut plan = match RewriteTableScanAnalyzer::rewrite(self.plan.clone()) {
311 Ok(plan) => plan,
312 Err(_) => self.plan.clone(),
313 };
314 if let Ok(statement) = self.plan_to_statement(&plan) {
315 write!(f, " base_sql={statement}")?;
316 }
317
318 let (logical_optimizers, ast_analyzers, sql_query_rewriters) = match gather_analyzers(&plan)
319 {
320 Ok(analyzers) => analyzers,
321 Err(_) => return Ok(()),
322 };
323
324 let old_plan = plan.clone();
325
326 plan = match apply_logical_optimizers(plan, logical_optimizers) {
327 Ok(plan) => plan,
328 _ => return Ok(()),
329 };
330
331 let statement = match self.plan_to_statement(&plan) {
332 Ok(statement) => statement,
333 _ => return Ok(()),
334 };
335
336 if plan != old_plan {
337 write!(f, " rewritten_logical_sql={statement}")?;
338 }
339
340 let old_statement = statement.clone();
341 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
342 Ok(statement) => statement,
343 _ => return Ok(()),
344 };
345 if old_statement != statement {
346 write!(f, " rewritten_executor_sql={statement}")?;
347 }
348
349 let old_statement = statement.clone();
350 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
351 Ok(statement) => statement,
352 _ => return Ok(()),
353 };
354 if old_statement != statement {
355 write!(f, " rewritten_ast_analyzer={statement}")?;
356 }
357
358 let sql = statement.to_string();
359 let rewritten_sql = match apply_sql_query_rewriters(sql.clone(), sql_query_rewriters) {
360 Ok(sql) => sql,
361 _ => return Ok(()),
362 };
363 if sql != rewritten_sql {
364 write!(f, " rewritten_sql_query={rewritten_sql}")?;
365 }
366
367 Ok(())
368 }
369}
370
371impl ExecutionPlan for VirtualExecutionPlan {
372 fn name(&self) -> &str {
373 "sql_federation_exec"
374 }
375
376 fn as_any(&self) -> &dyn Any {
377 self
378 }
379
380 fn schema(&self) -> SchemaRef {
381 self.schema()
382 }
383
384 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
385 vec![]
386 }
387
388 fn with_new_children(
389 self: Arc<Self>,
390 _: Vec<Arc<dyn ExecutionPlan>>,
391 ) -> Result<Arc<dyn ExecutionPlan>> {
392 Ok(self)
393 }
394
395 fn execute(
396 &self,
397 _partition: usize,
398 _context: Arc<TaskContext>,
399 ) -> Result<SendableRecordBatchStream> {
400 self.executor
401 .execute(&self.final_sql()?, self.schema(), &self.filters)
402 }
403
404 fn properties(&self) -> &Arc<PlanProperties> {
405 &self.props
406 }
407
408 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
409 Ok(self.statistics.clone())
410 }
411
412 fn metrics(&self) -> Option<MetricsSet> {
413 self.executor.metrics()
414 }
415
416 fn handle_child_pushdown_result(
417 &self,
418 _phase: FilterPushdownPhase,
419 child_pushdown_result: ChildPushdownResult,
420 _config: &ConfigOptions,
421 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
422 let parent_filters: Vec<_> = child_pushdown_result
423 .clone()
424 .parent_filters
425 .into_iter()
426 .map(|f| f.filter)
427 .collect();
428
429 if parent_filters.is_empty() {
430 return Ok(FilterPushdownPropagation {
431 filters: vec![],
432 updated_node: None,
433 });
434 }
435
436 let filters_pushed_down = vec![PushedDown::Yes; parent_filters.len()];
437 let mut node = self.clone();
438 node.filters = parent_filters;
439
440 Ok(FilterPushdownPropagation {
441 filters: filters_pushed_down,
442 updated_node: Some(Arc::new(node)),
443 })
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use std::any::Any;
450 use std::collections::HashSet;
451 use std::sync::atomic::{AtomicUsize, Ordering};
452 use std::sync::Arc;
453
454 use crate::sql::{
455 RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTable, SQLTableSource,
456 };
457 use crate::FederatedTableProviderAdaptor;
458 use async_trait::async_trait;
459 use datafusion::arrow::datatypes::{Schema, SchemaRef};
460 use datafusion::common::tree_node::TreeNodeRecursion;
461 use datafusion::execution::SendableRecordBatchStream;
462 use datafusion::sql::unparser::dialect::Dialect;
463 use datafusion::sql::unparser::{self};
464 use datafusion::sql::TableReference;
465 use datafusion::{
466 arrow::datatypes::{DataType, Field},
467 datasource::TableProvider,
468 execution::context::SessionContext,
469 };
470
471 use super::table::RemoteTable;
472 use super::*;
473
474 #[derive(Debug, Clone)]
475 struct TestExecutor {
476 compute_context: String,
477 }
478
479 #[async_trait]
480 impl SQLExecutor for TestExecutor {
481 fn name(&self) -> &str {
482 "TestExecutor"
483 }
484
485 fn compute_context(&self) -> Option<String> {
486 Some(self.compute_context.clone())
487 }
488
489 fn dialect(&self) -> Arc<dyn Dialect> {
490 Arc::new(unparser::dialect::DefaultDialect {})
491 }
492
493 fn execute(
494 &self,
495 _query: &str,
496 _schema: SchemaRef,
497 _filters: &[Arc<dyn PhysicalExpr>],
498 ) -> Result<SendableRecordBatchStream> {
499 unimplemented!()
500 }
501
502 async fn table_names(&self) -> Result<Vec<String>> {
503 unimplemented!()
504 }
505
506 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
507 unimplemented!()
508 }
509 }
510
511 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
512 let schema = Arc::new(Schema::new(vec![
513 Field::new("a", DataType::Int64, false),
514 Field::new("b", DataType::Utf8, false),
515 Field::new("c", DataType::Date32, false),
516 ]));
517 let table_ref = RemoteTableRef::try_from(name).unwrap();
518 let table = Arc::new(RemoteTable::new(table_ref, schema));
519 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
520 let table_source = Arc::new(SQLTableSource { provider, table });
521 Arc::new(FederatedTableProviderAdaptor::new(table_source))
522 }
523
524 fn get_test_table_provider_with_table(
525 table: Arc<dyn SQLTable>,
526 executor: TestExecutor,
527 ) -> Arc<dyn TableProvider> {
528 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
529 let table_source = Arc::new(SQLTableSource::new_with_table(provider, table));
530 Arc::new(FederatedTableProviderAdaptor::new(table_source))
531 }
532
533 #[derive(Debug)]
534 struct SqlRewriteTable {
535 table: RemoteTable,
536 rewrite_calls: Arc<AtomicUsize>,
537 suffix: String,
538 }
539
540 impl SqlRewriteTable {
541 fn new(
542 table_ref: RemoteTableRef,
543 schema: SchemaRef,
544 rewrite_calls: Arc<AtomicUsize>,
545 suffix: impl Into<String>,
546 ) -> Self {
547 Self {
548 table: RemoteTable::new(table_ref, schema),
549 rewrite_calls,
550 suffix: suffix.into(),
551 }
552 }
553 }
554
555 impl SQLTable for SqlRewriteTable {
556 fn as_any(&self) -> &dyn Any {
557 self
558 }
559
560 fn table_reference(&self) -> TableReference {
561 self.table.table_reference().clone()
562 }
563
564 fn schema(&self) -> SchemaRef {
565 Arc::clone(self.table.schema())
566 }
567
568 fn sql_query_rewriter(&self) -> Option<SqlQueryRewriter> {
569 let rewrite_calls = Arc::clone(&self.rewrite_calls);
570 let suffix = self.suffix.clone();
571 Some(Box::new(move |sql| {
572 rewrite_calls.fetch_add(1, Ordering::SeqCst);
573 Ok(format!("{sql} {suffix}"))
574 }))
575 }
576 }
577
578 #[tokio::test]
579 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
580 let test_executor_a = TestExecutor {
581 compute_context: "a".into(),
582 };
583
584 let test_executor_b = TestExecutor {
585 compute_context: "b".into(),
586 };
587
588 let table_a1_ref = "table_a1".to_string();
589 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
590
591 let table_a2_ref = "table_a2".to_string();
592 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
593
594 let table_b1_ref = "table_b1(1)".to_string();
595 let table_b1_df_ref = "table_local_b1".to_string();
596
597 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
598
599 let state = crate::default_session_state();
601 let ctx = SessionContext::new_with_state(state);
602
603 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
604 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
605 ctx.register_table(table_b1_df_ref.clone(), table_b1)
606 .unwrap();
607
608 let query = r#"
609 SELECT * FROM table_a1
610 UNION ALL
611 SELECT * FROM table_a2
612 UNION ALL
613 SELECT * FROM table_local_b1;
614 "#;
615
616 let df = ctx.sql(query).await?;
617
618 let logical_plan = df.into_optimized_plan()?;
619
620 let mut table_a1_federated = false;
621 let mut table_a2_federated = false;
622 let mut table_b1_federated = false;
623
624 let _ = logical_plan.apply(|node| {
625 if let LogicalPlan::Extension(node) = node {
626 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
627 let _ = node.plan().apply(|node| {
628 if let LogicalPlan::TableScan(table) = node {
629 if table.table_name.table() == table_a1_ref {
630 table_a1_federated = true;
631 }
632 if table.table_name.table() == table_a2_ref {
633 table_a2_federated = true;
634 }
635 if table.table_name.table() == table_b1_df_ref {
637 table_b1_federated = true;
638 }
639 }
640 Ok(TreeNodeRecursion::Continue)
641 });
642 }
643 }
644 Ok(TreeNodeRecursion::Continue)
645 });
646
647 assert!(table_a1_federated);
648 assert!(table_a2_federated);
649 assert!(table_b1_federated);
650
651 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
652
653 let mut final_queries = vec![];
654
655 let _ = physical_plan.apply(|node| {
656 if node.name() == "sql_federation_exec" {
657 let node = node
658 .as_any()
659 .downcast_ref::<VirtualExecutionPlan>()
660 .unwrap();
661
662 final_queries.push(node.final_sql()?);
663 }
664 Ok(TreeNodeRecursion::Continue)
665 });
666
667 let expected = vec![
668 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
669 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
670 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
671 ];
672
673 assert_eq!(
674 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
675 HashSet::from_iter(expected)
676 );
677
678 Ok(())
679 }
680
681 #[tokio::test]
682 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
683 let test_executor_a = TestExecutor {
684 compute_context: "test".into(),
685 };
686
687 let lowercase_table_ref = "default.table".to_string();
688 let lowercase_local_table_ref = "dftable".to_string();
689 let lowercase_table =
690 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
691
692 let capitalized_table_ref = "default.Table(1)".to_string();
693 let capitalized_local_table_ref = "dfview".to_string();
694 let capitalized_table =
695 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
696
697 let state = crate::default_session_state();
699 let ctx = SessionContext::new_with_state(state);
700
701 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
702 .unwrap();
703 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
704 .unwrap();
705
706 let query = r#"
707 SELECT * FROM dftable
708 UNION ALL
709 SELECT * FROM dfview;
710 "#;
711
712 let df = ctx.sql(query).await?;
713
714 let logical_plan = df.into_optimized_plan()?;
715
716 let mut lowercase_table = false;
717 let mut capitalized_table = false;
718
719 let _ = logical_plan.apply(|node| {
720 if let LogicalPlan::Extension(node) = node {
721 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
722 let _ = node.plan().apply(|node| {
723 if let LogicalPlan::TableScan(table) = node {
724 if table.table_name.table() == lowercase_local_table_ref {
725 lowercase_table = true;
726 }
727 if table.table_name.table() == capitalized_local_table_ref {
728 capitalized_table = true;
729 }
730 }
731 Ok(TreeNodeRecursion::Continue)
732 });
733 }
734 }
735 Ok(TreeNodeRecursion::Continue)
736 });
737
738 assert!(lowercase_table);
739 assert!(capitalized_table);
740
741 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
742
743 let mut final_queries = vec![];
744
745 let _ = physical_plan.apply(|node| {
746 if node.name() == "sql_federation_exec" {
747 let node = node
748 .as_any()
749 .downcast_ref::<VirtualExecutionPlan>()
750 .unwrap();
751
752 final_queries.push(node.final_sql()?);
753 }
754 Ok(TreeNodeRecursion::Continue)
755 });
756
757 let expected = vec![
758 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
759 ];
760
761 assert_eq!(
762 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
763 HashSet::from_iter(expected)
764 );
765
766 Ok(())
767 }
768
769 #[tokio::test]
770 async fn sql_query_rewriter_hook_invoked_and_rewrites_sql() -> Result<(), DataFusionError> {
771 let executor = TestExecutor {
772 compute_context: "rewrite".into(),
773 };
774 let rewrite_calls = Arc::new(AtomicUsize::new(0));
775 let table_ref = "table_with_rewriter".to_string();
776 let table = Arc::new(SqlRewriteTable::new(
777 table_ref.clone().try_into().unwrap(),
778 Arc::new(Schema::new(vec![
779 Field::new("a", DataType::Int64, false),
780 Field::new("b", DataType::Utf8, false),
781 Field::new("c", DataType::Date32, false),
782 ])),
783 Arc::clone(&rewrite_calls),
784 "/* rewritten by sql_query_rewriter */",
785 ));
786 let table_provider = get_test_table_provider_with_table(table, executor);
787
788 let state = crate::default_session_state();
789 let ctx = SessionContext::new_with_state(state);
790 ctx.register_table(table_ref.clone(), table_provider)
791 .unwrap();
792
793 let query = format!("SELECT * FROM {table_ref}");
794 let df = ctx.sql(&query).await?;
795 let logical_plan = df.into_optimized_plan()?;
796 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
797
798 let mut final_queries = vec![];
799 let _ = physical_plan.apply(|node| {
800 if node.name() == "sql_federation_exec" {
801 let node = node
802 .as_any()
803 .downcast_ref::<VirtualExecutionPlan>()
804 .unwrap();
805 final_queries.push(node.final_sql()?);
806 }
807 Ok(TreeNodeRecursion::Continue)
808 });
809
810 let [final_query] = final_queries.as_slice() else {
811 panic!("expected a single federated SQL query");
812 };
813
814 assert!(final_query.ends_with("/* rewritten by sql_query_rewriter */"));
815 assert_eq!(rewrite_calls.load(Ordering::SeqCst), 1);
816
817 Ok(())
818 }
819}