Skip to main content

datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::{
15        tree_node::{Transformed, TreeNode},
16        Statistics,
17    },
18    config::ConfigOptions,
19    error::{DataFusionError, Result},
20    execution::{context::SessionState, TaskContext},
21    logical_expr::{Extension, LogicalPlan},
22    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
23    physical_expr::EquivalenceProperties,
24    physical_plan::{
25        execution_plan::{Boundedness, EmissionType},
26        filter_pushdown::{
27            ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown,
28        },
29        metrics::MetricsSet,
30        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties,
31        SendableRecordBatchStream,
32    },
33    sql::{sqlparser::ast::Statement, unparser::Unparser},
34};
35
36pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
37pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
38pub use table::{RemoteTable, SQLTable, SQLTableSource};
39pub use table_reference::RemoteTableRef;
40
41use crate::{
42    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
43};
44
45// SQLFederationProvider provides federation to SQL DMBSs.
46#[derive(Debug)]
47pub struct SQLFederationProvider {
48    pub optimizer: Arc<Optimizer>,
49    pub executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54        Self {
55            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56                SQLFederationOptimizerRule::new(executor.clone()),
57            )])),
58            executor,
59        }
60    }
61}
62
63impl FederationProvider for SQLFederationProvider {
64    fn name(&self) -> &str {
65        "sql_federation_provider"
66    }
67
68    fn compute_context(&self) -> Option<String> {
69        self.executor.compute_context()
70    }
71
72    fn optimizer(&self) -> Option<Arc<Optimizer>> {
73        Some(self.optimizer.clone())
74    }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79    planner: Arc<SQLFederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84        Self {
85            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86        }
87    }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
92    /// if the plan was rewritten and `Transformed::no` if it was not.
93    ///
94    /// Note: this function is only called if [`Self::supports_rewrite`] returns
95    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
96    fn rewrite(
97        &self,
98        plan: LogicalPlan,
99        _config: &dyn OptimizerConfig,
100    ) -> Result<Transformed<LogicalPlan>> {
101        if let LogicalPlan::Extension(Extension { ref node }) = plan {
102            if node.name() == "Federated" {
103                // Avoid attempting double federation
104                return Ok(Transformed::no(plan));
105            }
106        }
107
108        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109        let ext_node = Extension {
110            node: Arc::new(fed_plan),
111        };
112
113        let mut plan = LogicalPlan::Extension(ext_node);
114        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
115            plan = rewriter(plan)?;
116        }
117
118        Ok(Transformed::yes(plan))
119    }
120
121    /// A human readable name for this analyzer rule
122    fn name(&self) -> &str {
123        "federate_sql"
124    }
125
126    /// Does this rule support rewriting owned plans (rather than by reference)?
127    fn supports_rewrite(&self) -> bool {
128        true
129    }
130}
131
132#[derive(Debug)]
133pub struct SQLFederationPlanner {
134    pub executor: Arc<dyn SQLExecutor>,
135}
136
137impl SQLFederationPlanner {
138    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
139        Self { executor }
140    }
141}
142
143#[async_trait]
144impl FederationPlanner for SQLFederationPlanner {
145    async fn plan_federation(
146        &self,
147        node: &FederatedPlanNode,
148        _session_state: &SessionState,
149    ) -> Result<Arc<dyn ExecutionPlan>> {
150        let schema = Arc::new(node.plan().schema().as_arrow().clone());
151        let plan = node.plan().clone();
152        let statistics = self.executor.statistics(&plan).await?;
153        let input = Arc::new(VirtualExecutionPlan::new(
154            plan,
155            Arc::clone(&self.executor),
156            statistics,
157        ));
158        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
159        Ok(Arc::new(schema_cast_exec))
160    }
161}
162
163#[derive(Debug, Clone)]
164pub struct VirtualExecutionPlan {
165    plan: LogicalPlan,
166    executor: Arc<dyn SQLExecutor>,
167    props: PlanProperties,
168    statistics: Statistics,
169    filters: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172impl VirtualExecutionPlan {
173    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
174        let schema: Schema = plan.schema().as_arrow().clone();
175        let props = PlanProperties::new(
176            EquivalenceProperties::new(Arc::new(schema)),
177            Partitioning::UnknownPartitioning(1),
178            EmissionType::Incremental,
179            Boundedness::Bounded,
180        );
181        Self {
182            plan,
183            executor,
184            props,
185            statistics,
186            filters: Vec::new(),
187        }
188    }
189
190    pub fn plan(&self) -> &LogicalPlan {
191        &self.plan
192    }
193
194    pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
195        &self.executor
196    }
197
198    pub fn statistics(&self) -> &Statistics {
199        &self.statistics
200    }
201
202    fn schema(&self) -> SchemaRef {
203        let df_schema = self.plan.schema().as_arrow().clone();
204        Arc::new(df_schema)
205    }
206
207    fn final_sql(&self) -> Result<String> {
208        let plan = self.plan.clone();
209        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
210        let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
211        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
212        let ast = self.plan_to_statement(&plan)?;
213        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
214        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
215        Ok(ast.to_string())
216    }
217
218    fn rewrite_with_executor_ast_analyzer(
219        &self,
220        ast: Statement,
221    ) -> Result<Statement, datafusion::error::DataFusionError> {
222        if let Some(mut analyzer) = self.executor.ast_analyzer() {
223            Ok(analyzer(ast)?)
224        } else {
225            Ok(ast)
226        }
227    }
228
229    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
230        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
231    }
232}
233
234fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
235    let mut logical_optimizers = vec![];
236    let mut ast_analyzers = vec![];
237
238    plan.apply(|node| {
239        if let LogicalPlan::TableScan(table) = node {
240            let provider = get_table_source(&table.source)
241                .expect("caller is virtual exec so this is valid")
242                .expect("caller is virtual exec so this is valid");
243            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
244                if let Some(analyzer) = source.table.logical_optimizer() {
245                    logical_optimizers.push(analyzer);
246                }
247                if let Some(analyzer) = source.table.ast_analyzer() {
248                    ast_analyzers.push(analyzer);
249                }
250            }
251        }
252        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
253    })?;
254
255    Ok((logical_optimizers, ast_analyzers))
256}
257
258fn apply_logical_optimizers(
259    mut plan: LogicalPlan,
260    analyzers: Vec<LogicalOptimizer>,
261) -> Result<LogicalPlan> {
262    for mut analyzer in analyzers {
263        let old_schema = plan.schema().clone();
264        plan = analyzer(plan)?;
265        let new_schema = plan.schema();
266        if &old_schema != new_schema {
267            return Err(DataFusionError::Execution(format!(
268                "Schema altered during logical analysis, expected: {}, found: {}",
269                old_schema, new_schema
270            )));
271        }
272    }
273    Ok(plan)
274}
275
276fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
277    for mut analyzer in analyzers {
278        statement = analyzer(statement)?;
279    }
280    Ok(statement)
281}
282
283impl DisplayAs for VirtualExecutionPlan {
284    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
285        write!(f, "VirtualExecutionPlan")?;
286        write!(f, " name={}", self.executor.name())?;
287        if let Some(ctx) = self.executor.compute_context() {
288            write!(f, " compute_context={ctx}")?;
289        };
290        let mut plan = self.plan.clone();
291        if let Ok(statement) = self.plan_to_statement(&plan) {
292            write!(f, " initial_sql={statement}")?;
293        }
294
295        let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
296            Ok(analyzers) => analyzers,
297            Err(_) => return Ok(()),
298        };
299
300        let old_plan = plan.clone();
301
302        plan = match apply_logical_optimizers(plan, logical_optimizers) {
303            Ok(plan) => plan,
304            _ => return Ok(()),
305        };
306
307        let statement = match self.plan_to_statement(&plan) {
308            Ok(statement) => statement,
309            _ => return Ok(()),
310        };
311
312        if plan != old_plan {
313            write!(f, " rewritten_logical_sql={statement}")?;
314        }
315
316        let old_statement = statement.clone();
317        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
318            Ok(statement) => statement,
319            _ => return Ok(()),
320        };
321        if old_statement != statement {
322            write!(f, " rewritten_executor_sql={statement}")?;
323        }
324
325        let old_statement = statement.clone();
326        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
327            Ok(statement) => statement,
328            _ => return Ok(()),
329        };
330        if old_statement != statement {
331            write!(f, " rewritten_ast_analyzer={statement}")?;
332        }
333
334        Ok(())
335    }
336}
337
338impl ExecutionPlan for VirtualExecutionPlan {
339    fn name(&self) -> &str {
340        "sql_federation_exec"
341    }
342
343    fn as_any(&self) -> &dyn Any {
344        self
345    }
346
347    fn schema(&self) -> SchemaRef {
348        self.schema()
349    }
350
351    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
352        vec![]
353    }
354
355    fn with_new_children(
356        self: Arc<Self>,
357        _: Vec<Arc<dyn ExecutionPlan>>,
358    ) -> Result<Arc<dyn ExecutionPlan>> {
359        Ok(self)
360    }
361
362    fn execute(
363        &self,
364        _partition: usize,
365        _context: Arc<TaskContext>,
366    ) -> Result<SendableRecordBatchStream> {
367        self.executor
368            .execute(&self.final_sql()?, self.schema(), &self.filters)
369    }
370
371    fn properties(&self) -> &PlanProperties {
372        &self.props
373    }
374
375    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
376        Ok(self.statistics.clone())
377    }
378
379    fn metrics(&self) -> Option<MetricsSet> {
380        self.executor.metrics()
381    }
382
383    fn handle_child_pushdown_result(
384        &self,
385        _phase: FilterPushdownPhase,
386        child_pushdown_result: ChildPushdownResult,
387        _config: &ConfigOptions,
388    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
389        let parent_filters: Vec<_> = child_pushdown_result
390            .clone()
391            .parent_filters
392            .into_iter()
393            .map(|f| f.filter)
394            .collect();
395
396        if parent_filters.is_empty() {
397            return Ok(FilterPushdownPropagation {
398                filters: vec![],
399                updated_node: None,
400            });
401        }
402
403        let filters_pushed_down = vec![PushedDown::Yes; parent_filters.len()];
404        let mut node = self.clone();
405        node.filters = parent_filters;
406
407        Ok(FilterPushdownPropagation {
408            filters: filters_pushed_down,
409            updated_node: Some(Arc::new(node)),
410        })
411    }
412}
413
414#[cfg(test)]
415mod tests {
416
417    use std::collections::HashSet;
418    use std::sync::Arc;
419
420    use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
421    use crate::FederatedTableProviderAdaptor;
422    use async_trait::async_trait;
423    use datafusion::arrow::datatypes::{Schema, SchemaRef};
424    use datafusion::common::tree_node::TreeNodeRecursion;
425    use datafusion::execution::SendableRecordBatchStream;
426    use datafusion::sql::unparser::dialect::Dialect;
427    use datafusion::sql::unparser::{self};
428    use datafusion::{
429        arrow::datatypes::{DataType, Field},
430        datasource::TableProvider,
431        execution::context::SessionContext,
432    };
433
434    use super::table::RemoteTable;
435    use super::*;
436
437    #[derive(Debug, Clone)]
438    struct TestExecutor {
439        compute_context: String,
440    }
441
442    #[async_trait]
443    impl SQLExecutor for TestExecutor {
444        fn name(&self) -> &str {
445            "TestExecutor"
446        }
447
448        fn compute_context(&self) -> Option<String> {
449            Some(self.compute_context.clone())
450        }
451
452        fn dialect(&self) -> Arc<dyn Dialect> {
453            Arc::new(unparser::dialect::DefaultDialect {})
454        }
455
456        fn execute(
457            &self,
458            _query: &str,
459            _schema: SchemaRef,
460            _filters: &[Arc<dyn PhysicalExpr>],
461        ) -> Result<SendableRecordBatchStream> {
462            unimplemented!()
463        }
464
465        async fn table_names(&self) -> Result<Vec<String>> {
466            unimplemented!()
467        }
468
469        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
470            unimplemented!()
471        }
472    }
473
474    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
475        let schema = Arc::new(Schema::new(vec![
476            Field::new("a", DataType::Int64, false),
477            Field::new("b", DataType::Utf8, false),
478            Field::new("c", DataType::Date32, false),
479        ]));
480        let table_ref = RemoteTableRef::try_from(name).unwrap();
481        let table = Arc::new(RemoteTable::new(table_ref, schema));
482        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
483        let table_source = Arc::new(SQLTableSource { provider, table });
484        Arc::new(FederatedTableProviderAdaptor::new(table_source))
485    }
486
487    #[tokio::test]
488    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
489        let test_executor_a = TestExecutor {
490            compute_context: "a".into(),
491        };
492
493        let test_executor_b = TestExecutor {
494            compute_context: "b".into(),
495        };
496
497        let table_a1_ref = "table_a1".to_string();
498        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
499
500        let table_a2_ref = "table_a2".to_string();
501        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
502
503        let table_b1_ref = "table_b1(1)".to_string();
504        let table_b1_df_ref = "table_local_b1".to_string();
505
506        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
507
508        // Create a new SessionState with the optimizer rule we created above
509        let state = crate::default_session_state();
510        let ctx = SessionContext::new_with_state(state);
511
512        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
513        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
514        ctx.register_table(table_b1_df_ref.clone(), table_b1)
515            .unwrap();
516
517        let query = r#"
518            SELECT * FROM table_a1
519            UNION ALL
520            SELECT * FROM table_a2
521            UNION ALL
522            SELECT * FROM table_local_b1;
523        "#;
524
525        let df = ctx.sql(query).await?;
526
527        let logical_plan = df.into_optimized_plan()?;
528
529        let mut table_a1_federated = false;
530        let mut table_a2_federated = false;
531        let mut table_b1_federated = false;
532
533        let _ = logical_plan.apply(|node| {
534            if let LogicalPlan::Extension(node) = node {
535                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
536                    let _ = node.plan().apply(|node| {
537                        if let LogicalPlan::TableScan(table) = node {
538                            if table.table_name.table() == table_a1_ref {
539                                table_a1_federated = true;
540                            }
541                            if table.table_name.table() == table_a2_ref {
542                                table_a2_federated = true;
543                            }
544                            // assuming table name is rewritten via analyzer
545                            if table.table_name.table() == table_b1_df_ref {
546                                table_b1_federated = true;
547                            }
548                        }
549                        Ok(TreeNodeRecursion::Continue)
550                    });
551                }
552            }
553            Ok(TreeNodeRecursion::Continue)
554        });
555
556        assert!(table_a1_federated);
557        assert!(table_a2_federated);
558        assert!(table_b1_federated);
559
560        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
561
562        let mut final_queries = vec![];
563
564        let _ = physical_plan.apply(|node| {
565            if node.name() == "sql_federation_exec" {
566                let node = node
567                    .as_any()
568                    .downcast_ref::<VirtualExecutionPlan>()
569                    .unwrap();
570
571                final_queries.push(node.final_sql()?);
572            }
573            Ok(TreeNodeRecursion::Continue)
574        });
575
576        let expected = vec![
577            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
578            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
579            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
580        ];
581
582        assert_eq!(
583            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
584            HashSet::from_iter(expected)
585        );
586
587        Ok(())
588    }
589
590    #[tokio::test]
591    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
592        let test_executor_a = TestExecutor {
593            compute_context: "test".into(),
594        };
595
596        let lowercase_table_ref = "default.table".to_string();
597        let lowercase_local_table_ref = "dftable".to_string();
598        let lowercase_table =
599            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
600
601        let capitalized_table_ref = "default.Table(1)".to_string();
602        let capitalized_local_table_ref = "dfview".to_string();
603        let capitalized_table =
604            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
605
606        // Create a new SessionState with the optimizer rule we created above
607        let state = crate::default_session_state();
608        let ctx = SessionContext::new_with_state(state);
609
610        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
611            .unwrap();
612        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
613            .unwrap();
614
615        let query = r#"
616                SELECT * FROM dftable
617                UNION ALL
618                SELECT * FROM dfview;
619            "#;
620
621        let df = ctx.sql(query).await?;
622
623        let logical_plan = df.into_optimized_plan()?;
624
625        let mut lowercase_table = false;
626        let mut capitalized_table = false;
627
628        let _ = logical_plan.apply(|node| {
629            if let LogicalPlan::Extension(node) = node {
630                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
631                    let _ = node.plan().apply(|node| {
632                        if let LogicalPlan::TableScan(table) = node {
633                            if table.table_name.table() == lowercase_local_table_ref {
634                                lowercase_table = true;
635                            }
636                            if table.table_name.table() == capitalized_local_table_ref {
637                                capitalized_table = true;
638                            }
639                        }
640                        Ok(TreeNodeRecursion::Continue)
641                    });
642                }
643            }
644            Ok(TreeNodeRecursion::Continue)
645        });
646
647        assert!(lowercase_table);
648        assert!(capitalized_table);
649
650        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
651
652        let mut final_queries = vec![];
653
654        let _ = physical_plan.apply(|node| {
655            if node.name() == "sql_federation_exec" {
656                let node = node
657                    .as_any()
658                    .downcast_ref::<VirtualExecutionPlan>()
659                    .unwrap();
660
661                final_queries.push(node.final_sql()?);
662            }
663            Ok(TreeNodeRecursion::Continue)
664        });
665
666        let expected = vec![
667            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
668        ];
669
670        assert_eq!(
671            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
672            HashSet::from_iter(expected)
673        );
674
675        Ok(())
676    }
677}