datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::tree_node::{Transformed, TreeNode},
15    error::{DataFusionError, Result},
16    execution::{context::SessionState, TaskContext},
17    logical_expr::{Extension, LogicalPlan},
18    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
19    physical_expr::EquivalenceProperties,
20    physical_plan::{
21        execution_plan::{Boundedness, EmissionType},
22        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
23        SendableRecordBatchStream,
24    },
25    sql::{sqlparser::ast::Statement, unparser::Unparser},
26};
27
28pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
29pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
30pub use table::{RemoteTable, SQLTableSource};
31pub use table_reference::RemoteTableRef;
32
33use crate::{
34    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
35};
36
37// SQLFederationProvider provides federation to SQL DMBSs.
38#[derive(Debug)]
39pub struct SQLFederationProvider {
40    optimizer: Arc<Optimizer>,
41    executor: Arc<dyn SQLExecutor>,
42}
43
44impl SQLFederationProvider {
45    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
46        Self {
47            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
48                SQLFederationOptimizerRule::new(executor.clone()),
49            )])),
50            executor,
51        }
52    }
53}
54
55impl FederationProvider for SQLFederationProvider {
56    fn name(&self) -> &str {
57        "sql_federation_provider"
58    }
59
60    fn compute_context(&self) -> Option<String> {
61        self.executor.compute_context()
62    }
63
64    fn optimizer(&self) -> Option<Arc<Optimizer>> {
65        Some(self.optimizer.clone())
66    }
67}
68
69#[derive(Debug)]
70struct SQLFederationOptimizerRule {
71    planner: Arc<SQLFederationPlanner>,
72}
73
74impl SQLFederationOptimizerRule {
75    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
76        Self {
77            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
78        }
79    }
80}
81
82impl OptimizerRule for SQLFederationOptimizerRule {
83    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
84    /// if the plan was rewritten and `Transformed::no` if it was not.
85    ///
86    /// Note: this function is only called if [`Self::supports_rewrite`] returns
87    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
88    fn rewrite(
89        &self,
90        plan: LogicalPlan,
91        _config: &dyn OptimizerConfig,
92    ) -> Result<Transformed<LogicalPlan>> {
93        if let LogicalPlan::Extension(Extension { ref node }) = plan {
94            if node.name() == "Federated" {
95                // Avoid attempting double federation
96                return Ok(Transformed::no(plan));
97            }
98        }
99
100        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
101        let ext_node = Extension {
102            node: Arc::new(fed_plan),
103        };
104
105        let mut plan = LogicalPlan::Extension(ext_node);
106        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
107            plan = rewriter(plan)?;
108        }
109
110        Ok(Transformed::yes(plan))
111    }
112
113    /// A human readable name for this analyzer rule
114    fn name(&self) -> &str {
115        "federate_sql"
116    }
117
118    /// Does this rule support rewriting owned plans (rather than by reference)?
119    fn supports_rewrite(&self) -> bool {
120        true
121    }
122}
123
124#[derive(Debug)]
125struct SQLFederationPlanner {
126    executor: Arc<dyn SQLExecutor>,
127}
128
129impl SQLFederationPlanner {
130    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
131        Self { executor }
132    }
133}
134
135#[async_trait]
136impl FederationPlanner for SQLFederationPlanner {
137    async fn plan_federation(
138        &self,
139        node: &FederatedPlanNode,
140        _session_state: &SessionState,
141    ) -> Result<Arc<dyn ExecutionPlan>> {
142        let schema = Arc::new(node.plan().schema().as_arrow().clone());
143        let input = Arc::new(VirtualExecutionPlan::new(
144            node.plan().clone(),
145            Arc::clone(&self.executor),
146        ));
147        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
148        Ok(Arc::new(schema_cast_exec))
149    }
150}
151
152#[derive(Debug, Clone)]
153struct VirtualExecutionPlan {
154    plan: LogicalPlan,
155    executor: Arc<dyn SQLExecutor>,
156    props: PlanProperties,
157}
158
159impl VirtualExecutionPlan {
160    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
161        let schema: Schema = plan.schema().as_ref().into();
162        let props = PlanProperties::new(
163            EquivalenceProperties::new(Arc::new(schema)),
164            Partitioning::UnknownPartitioning(1),
165            EmissionType::Incremental,
166            Boundedness::Bounded,
167        );
168        Self {
169            plan,
170            executor,
171            props,
172        }
173    }
174
175    fn schema(&self) -> SchemaRef {
176        let df_schema = self.plan.schema().as_ref();
177        Arc::new(Schema::from(df_schema))
178    }
179
180    fn final_sql(&self) -> Result<String> {
181        let plan = self.plan.clone();
182        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
183        let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
184        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
185        let ast = self.plan_to_statement(&plan)?;
186        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
187        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
188        Ok(ast.to_string())
189    }
190
191    fn rewrite_with_executor_ast_analyzer(
192        &self,
193        ast: Statement,
194    ) -> Result<Statement, datafusion::error::DataFusionError> {
195        if let Some(mut analyzer) = self.executor.ast_analyzer() {
196            Ok(analyzer(ast)?)
197        } else {
198            Ok(ast)
199        }
200    }
201
202    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
203        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
204    }
205}
206
207fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
208    let mut logical_optimizers = vec![];
209    let mut ast_analyzers = vec![];
210
211    plan.apply(|node| {
212        if let LogicalPlan::TableScan(table) = node {
213            let provider = get_table_source(&table.source)
214                .expect("caller is virtual exec so this is valid")
215                .expect("caller is virtual exec so this is valid");
216            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
217                if let Some(analyzer) = source.table.logical_optimizer() {
218                    logical_optimizers.push(analyzer);
219                }
220                if let Some(analyzer) = source.table.ast_analyzer() {
221                    ast_analyzers.push(analyzer);
222                }
223            }
224        }
225        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
226    })?;
227
228    Ok((logical_optimizers, ast_analyzers))
229}
230
231fn apply_logical_optimizers(
232    mut plan: LogicalPlan,
233    analyzers: Vec<LogicalOptimizer>,
234) -> Result<LogicalPlan> {
235    for mut analyzer in analyzers {
236        let old_schema = plan.schema().clone();
237        plan = analyzer(plan)?;
238        let new_schema = plan.schema();
239        if &old_schema != new_schema {
240            return Err(DataFusionError::Execution(format!(
241                "Schema altered during logical analysis, expected: {}, found: {}",
242                old_schema, new_schema
243            )));
244        }
245    }
246    Ok(plan)
247}
248
249fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
250    for mut analyzer in analyzers {
251        statement = analyzer(statement)?;
252    }
253    Ok(statement)
254}
255
256impl DisplayAs for VirtualExecutionPlan {
257    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
258        write!(f, "VirtualExecutionPlan")?;
259        write!(f, " name={}", self.executor.name())?;
260        if let Some(ctx) = self.executor.compute_context() {
261            write!(f, " compute_context={ctx}")?;
262        };
263        let mut plan = self.plan.clone();
264        if let Ok(statement) = self.plan_to_statement(&plan) {
265            write!(f, " initial_sql={statement}")?;
266        }
267
268        let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
269            Ok(analyzers) => analyzers,
270            Err(_) => return Ok(()),
271        };
272
273        let old_plan = plan.clone();
274
275        plan = match apply_logical_optimizers(plan, logical_optimizers) {
276            Ok(plan) => plan,
277            _ => return Ok(()),
278        };
279
280        let statement = match self.plan_to_statement(&plan) {
281            Ok(statement) => statement,
282            _ => return Ok(()),
283        };
284
285        if plan != old_plan {
286            write!(f, " rewritten_logical_sql={statement}")?;
287        }
288
289        let old_statement = statement.clone();
290        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
291            Ok(statement) => statement,
292            _ => return Ok(()),
293        };
294        if old_statement != statement {
295            write!(f, " rewritten_executor_sql={statement}")?;
296        }
297
298        let old_statement = statement.clone();
299        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
300            Ok(statement) => statement,
301            _ => return Ok(()),
302        };
303        if old_statement != statement {
304            write!(f, " rewritten_ast_analyzer={statement}")?;
305        }
306
307        Ok(())
308    }
309}
310
311impl ExecutionPlan for VirtualExecutionPlan {
312    fn name(&self) -> &str {
313        "sql_federation_exec"
314    }
315
316    fn as_any(&self) -> &dyn Any {
317        self
318    }
319
320    fn schema(&self) -> SchemaRef {
321        self.schema()
322    }
323
324    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
325        vec![]
326    }
327
328    fn with_new_children(
329        self: Arc<Self>,
330        _: Vec<Arc<dyn ExecutionPlan>>,
331    ) -> Result<Arc<dyn ExecutionPlan>> {
332        Ok(self)
333    }
334
335    fn execute(
336        &self,
337        _partition: usize,
338        _context: Arc<TaskContext>,
339    ) -> Result<SendableRecordBatchStream> {
340        self.executor.execute(&self.final_sql()?, self.schema())
341    }
342
343    fn properties(&self) -> &PlanProperties {
344        &self.props
345    }
346}
347
348#[cfg(test)]
349mod tests {
350
351    use std::collections::HashSet;
352    use std::sync::Arc;
353
354    use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
355    use crate::FederatedTableProviderAdaptor;
356    use async_trait::async_trait;
357    use datafusion::arrow::datatypes::{Schema, SchemaRef};
358    use datafusion::common::tree_node::TreeNodeRecursion;
359    use datafusion::execution::SendableRecordBatchStream;
360    use datafusion::sql::unparser::dialect::Dialect;
361    use datafusion::sql::unparser::{self};
362    use datafusion::{
363        arrow::datatypes::{DataType, Field},
364        datasource::TableProvider,
365        execution::context::SessionContext,
366    };
367
368    use super::table::RemoteTable;
369    use super::*;
370
371    #[derive(Debug, Clone)]
372    struct TestExecutor {
373        compute_context: String,
374    }
375
376    #[async_trait]
377    impl SQLExecutor for TestExecutor {
378        fn name(&self) -> &str {
379            "TestExecutor"
380        }
381
382        fn compute_context(&self) -> Option<String> {
383            Some(self.compute_context.clone())
384        }
385
386        fn dialect(&self) -> Arc<dyn Dialect> {
387            Arc::new(unparser::dialect::DefaultDialect {})
388        }
389
390        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
391            unimplemented!()
392        }
393
394        async fn table_names(&self) -> Result<Vec<String>> {
395            unimplemented!()
396        }
397
398        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
399            unimplemented!()
400        }
401    }
402
403    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
404        let schema = Arc::new(Schema::new(vec![
405            Field::new("a", DataType::Int64, false),
406            Field::new("b", DataType::Utf8, false),
407            Field::new("c", DataType::Date32, false),
408        ]));
409        let table_ref = RemoteTableRef::try_from(name).unwrap();
410        let table = Arc::new(RemoteTable::new(table_ref, schema));
411        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
412        let table_source = Arc::new(SQLTableSource { provider, table });
413        Arc::new(FederatedTableProviderAdaptor::new(table_source))
414    }
415
416    #[tokio::test]
417    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
418        let test_executor_a = TestExecutor {
419            compute_context: "a".into(),
420        };
421
422        let test_executor_b = TestExecutor {
423            compute_context: "b".into(),
424        };
425
426        let table_a1_ref = "table_a1".to_string();
427        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
428
429        let table_a2_ref = "table_a2".to_string();
430        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
431
432        let table_b1_ref = "table_b1(1)".to_string();
433        let table_b1_df_ref = "table_local_b1".to_string();
434
435        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
436
437        // Create a new SessionState with the optimizer rule we created above
438        let state = crate::default_session_state();
439        let ctx = SessionContext::new_with_state(state);
440
441        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
442        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
443        ctx.register_table(table_b1_df_ref.clone(), table_b1)
444            .unwrap();
445
446        let query = r#"
447            SELECT * FROM table_a1
448            UNION ALL
449            SELECT * FROM table_a2
450            UNION ALL
451            SELECT * FROM table_local_b1;
452        "#;
453
454        let df = ctx.sql(query).await?;
455
456        let logical_plan = df.into_optimized_plan()?;
457
458        let mut table_a1_federated = false;
459        let mut table_a2_federated = false;
460        let mut table_b1_federated = false;
461
462        let _ = logical_plan.apply(|node| {
463            if let LogicalPlan::Extension(node) = node {
464                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
465                    let _ = node.plan().apply(|node| {
466                        if let LogicalPlan::TableScan(table) = node {
467                            if table.table_name.table() == table_a1_ref {
468                                table_a1_federated = true;
469                            }
470                            if table.table_name.table() == table_a2_ref {
471                                table_a2_federated = true;
472                            }
473                            // assuming table name is rewritten via analyzer
474                            if table.table_name.table() == table_b1_df_ref {
475                                table_b1_federated = true;
476                            }
477                        }
478                        Ok(TreeNodeRecursion::Continue)
479                    });
480                }
481            }
482            Ok(TreeNodeRecursion::Continue)
483        });
484
485        assert!(table_a1_federated);
486        assert!(table_a2_federated);
487        assert!(table_b1_federated);
488
489        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
490
491        let mut final_queries = vec![];
492
493        let _ = physical_plan.apply(|node| {
494            if node.name() == "sql_federation_exec" {
495                let node = node
496                    .as_any()
497                    .downcast_ref::<VirtualExecutionPlan>()
498                    .unwrap();
499
500                final_queries.push(node.final_sql()?);
501            }
502            Ok(TreeNodeRecursion::Continue)
503        });
504
505        let expected = vec![
506            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
507            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
508            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
509        ];
510
511        assert_eq!(
512            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
513            HashSet::from_iter(expected)
514        );
515
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
521        let test_executor_a = TestExecutor {
522            compute_context: "test".into(),
523        };
524
525        let lowercase_table_ref = "default.table".to_string();
526        let lowercase_local_table_ref = "dftable".to_string();
527        let lowercase_table =
528            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
529
530        let capitalized_table_ref = "default.Table(1)".to_string();
531        let capitalized_local_table_ref = "dfview".to_string();
532        let capitalized_table =
533            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
534
535        // Create a new SessionState with the optimizer rule we created above
536        let state = crate::default_session_state();
537        let ctx = SessionContext::new_with_state(state);
538
539        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
540            .unwrap();
541        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
542            .unwrap();
543
544        let query = r#"
545                SELECT * FROM dftable
546                UNION ALL
547                SELECT * FROM dfview;
548            "#;
549
550        let df = ctx.sql(query).await?;
551
552        let logical_plan = df.into_optimized_plan()?;
553
554        let mut lowercase_table = false;
555        let mut capitalized_table = false;
556
557        let _ = logical_plan.apply(|node| {
558            if let LogicalPlan::Extension(node) = node {
559                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
560                    let _ = node.plan().apply(|node| {
561                        if let LogicalPlan::TableScan(table) = node {
562                            if table.table_name.table() == lowercase_local_table_ref {
563                                lowercase_table = true;
564                            }
565                            if table.table_name.table() == capitalized_local_table_ref {
566                                capitalized_table = true;
567                            }
568                        }
569                        Ok(TreeNodeRecursion::Continue)
570                    });
571                }
572            }
573            Ok(TreeNodeRecursion::Continue)
574        });
575
576        assert!(lowercase_table);
577        assert!(capitalized_table);
578
579        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
580
581        let mut final_queries = vec![];
582
583        let _ = physical_plan.apply(|node| {
584            if node.name() == "sql_federation_exec" {
585                let node = node
586                    .as_any()
587                    .downcast_ref::<VirtualExecutionPlan>()
588                    .unwrap();
589
590                final_queries.push(node.final_sql()?);
591            }
592            Ok(TreeNodeRecursion::Continue)
593        });
594
595        let expected = vec![
596            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
597        ];
598
599        assert_eq!(
600            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
601            HashSet::from_iter(expected)
602        );
603
604        Ok(())
605    }
606}