Skip to main content

datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::{
15        tree_node::{Transformed, TreeNode},
16        Statistics,
17    },
18    config::ConfigOptions,
19    error::{DataFusionError, Result},
20    execution::{context::SessionState, TaskContext},
21    logical_expr::{Extension, LogicalPlan},
22    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
23    physical_expr::EquivalenceProperties,
24    physical_plan::{
25        execution_plan::{Boundedness, EmissionType},
26        filter_pushdown::{
27            ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown,
28        },
29        metrics::MetricsSet,
30        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties,
31        SendableRecordBatchStream,
32    },
33    sql::{sqlparser::ast::Statement, unparser::Unparser},
34};
35
36pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef, SqlQueryRewriter};
37pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
38pub use table::{RemoteTable, SQLTable, SQLTableSource};
39pub use table_reference::RemoteTableRef;
40
41use crate::{
42    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
43};
44
45// SQLFederationProvider provides federation to SQL DMBSs.
46#[derive(Debug)]
47pub struct SQLFederationProvider {
48    pub optimizer: Arc<Optimizer>,
49    pub executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54        Self {
55            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56                SQLFederationOptimizerRule::new(executor.clone()),
57            )])),
58            executor,
59        }
60    }
61}
62
63impl FederationProvider for SQLFederationProvider {
64    fn name(&self) -> &str {
65        "sql_federation_provider"
66    }
67
68    fn compute_context(&self) -> Option<String> {
69        self.executor.compute_context()
70    }
71
72    fn optimizer(&self) -> Option<Arc<Optimizer>> {
73        Some(self.optimizer.clone())
74    }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79    planner: Arc<SQLFederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84        Self {
85            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86        }
87    }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
92    /// if the plan was rewritten and `Transformed::no` if it was not.
93    ///
94    /// Note: this function is only called if [`Self::supports_rewrite`] returns
95    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
96    fn rewrite(
97        &self,
98        plan: LogicalPlan,
99        _config: &dyn OptimizerConfig,
100    ) -> Result<Transformed<LogicalPlan>> {
101        if let LogicalPlan::Extension(Extension { ref node }) = plan {
102            if node.name() == "Federated" {
103                // Avoid attempting double federation
104                return Ok(Transformed::no(plan));
105            }
106        }
107
108        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109        let ext_node = Extension {
110            node: Arc::new(fed_plan),
111        };
112
113        let mut plan = LogicalPlan::Extension(ext_node);
114        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
115            plan = rewriter(plan)?;
116        }
117
118        Ok(Transformed::yes(plan))
119    }
120
121    /// A human readable name for this analyzer rule
122    fn name(&self) -> &str {
123        "federate_sql"
124    }
125
126    /// Does this rule support rewriting owned plans (rather than by reference)?
127    fn supports_rewrite(&self) -> bool {
128        true
129    }
130}
131
132#[derive(Debug)]
133pub struct SQLFederationPlanner {
134    pub executor: Arc<dyn SQLExecutor>,
135}
136
137impl SQLFederationPlanner {
138    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
139        Self { executor }
140    }
141}
142
143#[async_trait]
144impl FederationPlanner for SQLFederationPlanner {
145    async fn plan_federation(
146        &self,
147        node: &FederatedPlanNode,
148        _session_state: &SessionState,
149    ) -> Result<Arc<dyn ExecutionPlan>> {
150        let schema = Arc::new(node.plan().schema().as_arrow().clone());
151        let plan = node.plan().clone();
152        let statistics = self.executor.statistics(&plan).await?;
153        let input = Arc::new(VirtualExecutionPlan::new(
154            plan,
155            Arc::clone(&self.executor),
156            statistics,
157        ));
158        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
159        Ok(Arc::new(schema_cast_exec))
160    }
161}
162
163#[derive(Debug, Clone)]
164pub struct VirtualExecutionPlan {
165    plan: LogicalPlan,
166    executor: Arc<dyn SQLExecutor>,
167    props: Arc<PlanProperties>,
168    statistics: Statistics,
169    filters: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172impl VirtualExecutionPlan {
173    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
174        let schema: Schema = plan.schema().as_arrow().clone();
175        let props = Arc::new(PlanProperties::new(
176            EquivalenceProperties::new(Arc::new(schema)),
177            Partitioning::UnknownPartitioning(1),
178            EmissionType::Incremental,
179            Boundedness::Bounded,
180        ));
181        Self {
182            plan,
183            executor,
184            props,
185            statistics,
186            filters: Vec::new(),
187        }
188    }
189
190    pub fn plan(&self) -> &LogicalPlan {
191        &self.plan
192    }
193
194    pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
195        &self.executor
196    }
197
198    pub fn statistics(&self) -> &Statistics {
199        &self.statistics
200    }
201
202    fn schema(&self) -> SchemaRef {
203        let df_schema = self.plan.schema().as_arrow().clone();
204        Arc::new(df_schema)
205    }
206
207    fn final_sql(&self) -> Result<String> {
208        let plan = self.plan.clone();
209        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
210        let (logical_optimizers, ast_analyzers, sql_query_rewriters) = gather_analyzers(&plan)?;
211        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
212        let ast = self.plan_to_statement(&plan)?;
213        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
214        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
215        apply_sql_query_rewriters(ast.to_string(), sql_query_rewriters)
216    }
217
218    fn rewrite_with_executor_ast_analyzer(
219        &self,
220        ast: Statement,
221    ) -> Result<Statement, datafusion::error::DataFusionError> {
222        if let Some(mut analyzer) = self.executor.ast_analyzer() {
223            Ok(analyzer(ast)?)
224        } else {
225            Ok(ast)
226        }
227    }
228
229    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
230        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
231    }
232}
233
234fn gather_analyzers(
235    plan: &LogicalPlan,
236) -> Result<(
237    Vec<LogicalOptimizer>,
238    Vec<AstAnalyzer>,
239    Vec<SqlQueryRewriter>,
240)> {
241    let mut logical_optimizers = vec![];
242    let mut ast_analyzers = vec![];
243    let mut sql_query_rewriters = vec![];
244
245    plan.apply(|node| {
246        if let LogicalPlan::TableScan(table) = node {
247            let provider = get_table_source(&table.source)
248                .expect("caller is virtual exec so this is valid")
249                .expect("caller is virtual exec so this is valid");
250            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
251                if let Some(analyzer) = source.table.logical_optimizer() {
252                    logical_optimizers.push(analyzer);
253                }
254                if let Some(analyzer) = source.table.ast_analyzer() {
255                    ast_analyzers.push(analyzer);
256                }
257                if let Some(rewriter) = source.table.sql_query_rewriter() {
258                    sql_query_rewriters.push(rewriter);
259                }
260            }
261        }
262        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
263    })?;
264
265    Ok((logical_optimizers, ast_analyzers, sql_query_rewriters))
266}
267
268fn apply_logical_optimizers(
269    mut plan: LogicalPlan,
270    analyzers: Vec<LogicalOptimizer>,
271) -> Result<LogicalPlan> {
272    for mut analyzer in analyzers {
273        let old_schema = plan.schema().clone();
274        plan = analyzer(plan)?;
275        let new_schema = plan.schema();
276        if &old_schema != new_schema {
277            return Err(DataFusionError::Execution(format!(
278                "Schema altered during logical analysis, expected: {}, found: {}",
279                old_schema, new_schema
280            )));
281        }
282    }
283    Ok(plan)
284}
285
286fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
287    for mut analyzer in analyzers {
288        statement = analyzer(statement)?;
289    }
290    Ok(statement)
291}
292
293fn apply_sql_query_rewriters(
294    mut query: String,
295    rewriters: Vec<SqlQueryRewriter>,
296) -> Result<String> {
297    for mut rewriter in rewriters {
298        query = rewriter(query)?;
299    }
300    Ok(query)
301}
302
303impl DisplayAs for VirtualExecutionPlan {
304    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
305        write!(f, "VirtualExecutionPlan")?;
306        write!(f, " name={}", self.executor.name())?;
307        if let Some(ctx) = self.executor.compute_context() {
308            write!(f, " compute_context={ctx}")?;
309        };
310        let mut plan = match RewriteTableScanAnalyzer::rewrite(self.plan.clone()) {
311            Ok(plan) => plan,
312            Err(_) => self.plan.clone(),
313        };
314        if let Ok(statement) = self.plan_to_statement(&plan) {
315            write!(f, " base_sql={statement}")?;
316        }
317
318        let (logical_optimizers, ast_analyzers, sql_query_rewriters) = match gather_analyzers(&plan)
319        {
320            Ok(analyzers) => analyzers,
321            Err(_) => return Ok(()),
322        };
323
324        let old_plan = plan.clone();
325
326        plan = match apply_logical_optimizers(plan, logical_optimizers) {
327            Ok(plan) => plan,
328            _ => return Ok(()),
329        };
330
331        let statement = match self.plan_to_statement(&plan) {
332            Ok(statement) => statement,
333            _ => return Ok(()),
334        };
335
336        if plan != old_plan {
337            write!(f, " rewritten_logical_sql={statement}")?;
338        }
339
340        let old_statement = statement.clone();
341        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
342            Ok(statement) => statement,
343            _ => return Ok(()),
344        };
345        if old_statement != statement {
346            write!(f, " rewritten_executor_sql={statement}")?;
347        }
348
349        let old_statement = statement.clone();
350        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
351            Ok(statement) => statement,
352            _ => return Ok(()),
353        };
354        if old_statement != statement {
355            write!(f, " rewritten_ast_analyzer={statement}")?;
356        }
357
358        let sql = statement.to_string();
359        let rewritten_sql = match apply_sql_query_rewriters(sql.clone(), sql_query_rewriters) {
360            Ok(sql) => sql,
361            _ => return Ok(()),
362        };
363        if sql != rewritten_sql {
364            write!(f, " rewritten_sql_query={rewritten_sql}")?;
365        }
366
367        Ok(())
368    }
369}
370
371impl ExecutionPlan for VirtualExecutionPlan {
372    fn name(&self) -> &str {
373        "sql_federation_exec"
374    }
375
376    fn as_any(&self) -> &dyn Any {
377        self
378    }
379
380    fn schema(&self) -> SchemaRef {
381        self.schema()
382    }
383
384    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
385        vec![]
386    }
387
388    fn with_new_children(
389        self: Arc<Self>,
390        _: Vec<Arc<dyn ExecutionPlan>>,
391    ) -> Result<Arc<dyn ExecutionPlan>> {
392        Ok(self)
393    }
394
395    fn execute(
396        &self,
397        _partition: usize,
398        _context: Arc<TaskContext>,
399    ) -> Result<SendableRecordBatchStream> {
400        self.executor
401            .execute(&self.final_sql()?, self.schema(), &self.filters)
402    }
403
404    fn properties(&self) -> &Arc<PlanProperties> {
405        &self.props
406    }
407
408    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
409        Ok(self.statistics.clone())
410    }
411
412    fn metrics(&self) -> Option<MetricsSet> {
413        self.executor.metrics()
414    }
415
416    fn handle_child_pushdown_result(
417        &self,
418        _phase: FilterPushdownPhase,
419        child_pushdown_result: ChildPushdownResult,
420        _config: &ConfigOptions,
421    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
422        let parent_filters: Vec<_> = child_pushdown_result
423            .clone()
424            .parent_filters
425            .into_iter()
426            .map(|f| f.filter)
427            .collect();
428
429        if parent_filters.is_empty() {
430            return Ok(FilterPushdownPropagation {
431                filters: vec![],
432                updated_node: None,
433            });
434        }
435
436        let filters_pushed_down = vec![PushedDown::Yes; parent_filters.len()];
437        let mut node = self.clone();
438        node.filters = parent_filters;
439
440        Ok(FilterPushdownPropagation {
441            filters: filters_pushed_down,
442            updated_node: Some(Arc::new(node)),
443        })
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use std::any::Any;
450    use std::collections::HashSet;
451    use std::sync::atomic::{AtomicUsize, Ordering};
452    use std::sync::Arc;
453
454    use crate::sql::{
455        RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTable, SQLTableSource,
456    };
457    use crate::FederatedTableProviderAdaptor;
458    use async_trait::async_trait;
459    use datafusion::arrow::datatypes::{Schema, SchemaRef};
460    use datafusion::common::tree_node::TreeNodeRecursion;
461    use datafusion::execution::SendableRecordBatchStream;
462    use datafusion::sql::unparser::dialect::Dialect;
463    use datafusion::sql::unparser::{self};
464    use datafusion::sql::TableReference;
465    use datafusion::{
466        arrow::datatypes::{DataType, Field},
467        datasource::TableProvider,
468        execution::context::SessionContext,
469    };
470
471    use super::table::RemoteTable;
472    use super::*;
473
474    #[derive(Debug, Clone)]
475    struct TestExecutor {
476        compute_context: String,
477    }
478
479    #[async_trait]
480    impl SQLExecutor for TestExecutor {
481        fn name(&self) -> &str {
482            "TestExecutor"
483        }
484
485        fn compute_context(&self) -> Option<String> {
486            Some(self.compute_context.clone())
487        }
488
489        fn dialect(&self) -> Arc<dyn Dialect> {
490            Arc::new(unparser::dialect::DefaultDialect {})
491        }
492
493        fn execute(
494            &self,
495            _query: &str,
496            _schema: SchemaRef,
497            _filters: &[Arc<dyn PhysicalExpr>],
498        ) -> Result<SendableRecordBatchStream> {
499            unimplemented!()
500        }
501
502        async fn table_names(&self) -> Result<Vec<String>> {
503            unimplemented!()
504        }
505
506        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
507            unimplemented!()
508        }
509    }
510
511    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
512        let schema = Arc::new(Schema::new(vec![
513            Field::new("a", DataType::Int64, false),
514            Field::new("b", DataType::Utf8, false),
515            Field::new("c", DataType::Date32, false),
516        ]));
517        let table_ref = RemoteTableRef::try_from(name).unwrap();
518        let table = Arc::new(RemoteTable::new(table_ref, schema));
519        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
520        let table_source = Arc::new(SQLTableSource { provider, table });
521        Arc::new(FederatedTableProviderAdaptor::new(table_source))
522    }
523
524    fn get_test_table_provider_with_table(
525        table: Arc<dyn SQLTable>,
526        executor: TestExecutor,
527    ) -> Arc<dyn TableProvider> {
528        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
529        let table_source = Arc::new(SQLTableSource::new_with_table(provider, table));
530        Arc::new(FederatedTableProviderAdaptor::new(table_source))
531    }
532
533    #[derive(Debug)]
534    struct SqlRewriteTable {
535        table: RemoteTable,
536        rewrite_calls: Arc<AtomicUsize>,
537        suffix: String,
538    }
539
540    impl SqlRewriteTable {
541        fn new(
542            table_ref: RemoteTableRef,
543            schema: SchemaRef,
544            rewrite_calls: Arc<AtomicUsize>,
545            suffix: impl Into<String>,
546        ) -> Self {
547            Self {
548                table: RemoteTable::new(table_ref, schema),
549                rewrite_calls,
550                suffix: suffix.into(),
551            }
552        }
553    }
554
555    impl SQLTable for SqlRewriteTable {
556        fn as_any(&self) -> &dyn Any {
557            self
558        }
559
560        fn table_reference(&self) -> TableReference {
561            self.table.table_reference().clone()
562        }
563
564        fn schema(&self) -> SchemaRef {
565            Arc::clone(self.table.schema())
566        }
567
568        fn sql_query_rewriter(&self) -> Option<SqlQueryRewriter> {
569            let rewrite_calls = Arc::clone(&self.rewrite_calls);
570            let suffix = self.suffix.clone();
571            Some(Box::new(move |sql| {
572                rewrite_calls.fetch_add(1, Ordering::SeqCst);
573                Ok(format!("{sql} {suffix}"))
574            }))
575        }
576    }
577
578    #[tokio::test]
579    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
580        let test_executor_a = TestExecutor {
581            compute_context: "a".into(),
582        };
583
584        let test_executor_b = TestExecutor {
585            compute_context: "b".into(),
586        };
587
588        let table_a1_ref = "table_a1".to_string();
589        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
590
591        let table_a2_ref = "table_a2".to_string();
592        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
593
594        let table_b1_ref = "table_b1(1)".to_string();
595        let table_b1_df_ref = "table_local_b1".to_string();
596
597        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
598
599        // Create a new SessionState with the optimizer rule we created above
600        let state = crate::default_session_state();
601        let ctx = SessionContext::new_with_state(state);
602
603        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
604        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
605        ctx.register_table(table_b1_df_ref.clone(), table_b1)
606            .unwrap();
607
608        let query = r#"
609            SELECT * FROM table_a1
610            UNION ALL
611            SELECT * FROM table_a2
612            UNION ALL
613            SELECT * FROM table_local_b1;
614        "#;
615
616        let df = ctx.sql(query).await?;
617
618        let logical_plan = df.into_optimized_plan()?;
619
620        let mut table_a1_federated = false;
621        let mut table_a2_federated = false;
622        let mut table_b1_federated = false;
623
624        let _ = logical_plan.apply(|node| {
625            if let LogicalPlan::Extension(node) = node {
626                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
627                    let _ = node.plan().apply(|node| {
628                        if let LogicalPlan::TableScan(table) = node {
629                            if table.table_name.table() == table_a1_ref {
630                                table_a1_federated = true;
631                            }
632                            if table.table_name.table() == table_a2_ref {
633                                table_a2_federated = true;
634                            }
635                            // assuming table name is rewritten via analyzer
636                            if table.table_name.table() == table_b1_df_ref {
637                                table_b1_federated = true;
638                            }
639                        }
640                        Ok(TreeNodeRecursion::Continue)
641                    });
642                }
643            }
644            Ok(TreeNodeRecursion::Continue)
645        });
646
647        assert!(table_a1_federated);
648        assert!(table_a2_federated);
649        assert!(table_b1_federated);
650
651        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
652
653        let mut final_queries = vec![];
654
655        let _ = physical_plan.apply(|node| {
656            if node.name() == "sql_federation_exec" {
657                let node = node
658                    .as_any()
659                    .downcast_ref::<VirtualExecutionPlan>()
660                    .unwrap();
661
662                final_queries.push(node.final_sql()?);
663            }
664            Ok(TreeNodeRecursion::Continue)
665        });
666
667        let expected = vec![
668            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
669            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
670            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
671        ];
672
673        assert_eq!(
674            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
675            HashSet::from_iter(expected)
676        );
677
678        Ok(())
679    }
680
681    #[tokio::test]
682    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
683        let test_executor_a = TestExecutor {
684            compute_context: "test".into(),
685        };
686
687        let lowercase_table_ref = "default.table".to_string();
688        let lowercase_local_table_ref = "dftable".to_string();
689        let lowercase_table =
690            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
691
692        let capitalized_table_ref = "default.Table(1)".to_string();
693        let capitalized_local_table_ref = "dfview".to_string();
694        let capitalized_table =
695            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
696
697        // Create a new SessionState with the optimizer rule we created above
698        let state = crate::default_session_state();
699        let ctx = SessionContext::new_with_state(state);
700
701        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
702            .unwrap();
703        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
704            .unwrap();
705
706        let query = r#"
707                SELECT * FROM dftable
708                UNION ALL
709                SELECT * FROM dfview;
710            "#;
711
712        let df = ctx.sql(query).await?;
713
714        let logical_plan = df.into_optimized_plan()?;
715
716        let mut lowercase_table = false;
717        let mut capitalized_table = false;
718
719        let _ = logical_plan.apply(|node| {
720            if let LogicalPlan::Extension(node) = node {
721                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
722                    let _ = node.plan().apply(|node| {
723                        if let LogicalPlan::TableScan(table) = node {
724                            if table.table_name.table() == lowercase_local_table_ref {
725                                lowercase_table = true;
726                            }
727                            if table.table_name.table() == capitalized_local_table_ref {
728                                capitalized_table = true;
729                            }
730                        }
731                        Ok(TreeNodeRecursion::Continue)
732                    });
733                }
734            }
735            Ok(TreeNodeRecursion::Continue)
736        });
737
738        assert!(lowercase_table);
739        assert!(capitalized_table);
740
741        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
742
743        let mut final_queries = vec![];
744
745        let _ = physical_plan.apply(|node| {
746            if node.name() == "sql_federation_exec" {
747                let node = node
748                    .as_any()
749                    .downcast_ref::<VirtualExecutionPlan>()
750                    .unwrap();
751
752                final_queries.push(node.final_sql()?);
753            }
754            Ok(TreeNodeRecursion::Continue)
755        });
756
757        let expected = vec![
758            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
759        ];
760
761        assert_eq!(
762            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
763            HashSet::from_iter(expected)
764        );
765
766        Ok(())
767    }
768
769    #[tokio::test]
770    async fn sql_query_rewriter_hook_invoked_and_rewrites_sql() -> Result<(), DataFusionError> {
771        let executor = TestExecutor {
772            compute_context: "rewrite".into(),
773        };
774        let rewrite_calls = Arc::new(AtomicUsize::new(0));
775        let table_ref = "table_with_rewriter".to_string();
776        let table = Arc::new(SqlRewriteTable::new(
777            table_ref.clone().try_into().unwrap(),
778            Arc::new(Schema::new(vec![
779                Field::new("a", DataType::Int64, false),
780                Field::new("b", DataType::Utf8, false),
781                Field::new("c", DataType::Date32, false),
782            ])),
783            Arc::clone(&rewrite_calls),
784            "/* rewritten by sql_query_rewriter */",
785        ));
786        let table_provider = get_test_table_provider_with_table(table, executor);
787
788        let state = crate::default_session_state();
789        let ctx = SessionContext::new_with_state(state);
790        ctx.register_table(table_ref.clone(), table_provider)
791            .unwrap();
792
793        let query = format!("SELECT * FROM {table_ref}");
794        let df = ctx.sql(&query).await?;
795        let logical_plan = df.into_optimized_plan()?;
796        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
797
798        let mut final_queries = vec![];
799        let _ = physical_plan.apply(|node| {
800            if node.name() == "sql_federation_exec" {
801                let node = node
802                    .as_any()
803                    .downcast_ref::<VirtualExecutionPlan>()
804                    .unwrap();
805                final_queries.push(node.final_sql()?);
806            }
807            Ok(TreeNodeRecursion::Continue)
808        });
809
810        let [final_query] = final_queries.as_slice() else {
811            panic!("expected a single federated SQL query");
812        };
813
814        assert!(final_query.ends_with("/* rewritten by sql_query_rewriter */"));
815        assert_eq!(rewrite_calls.load(Ordering::SeqCst), 1);
816
817        Ok(())
818    }
819}