datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::tree_node::{Transformed, TreeNode},
15    common::Statistics,
16    error::{DataFusionError, Result},
17    execution::{context::SessionState, TaskContext},
18    logical_expr::{Extension, LogicalPlan},
19    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
20    physical_expr::EquivalenceProperties,
21    physical_plan::{
22        execution_plan::{Boundedness, EmissionType},
23        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
24        SendableRecordBatchStream,
25    },
26    sql::{sqlparser::ast::Statement, unparser::Unparser},
27};
28
29pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
30pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
31pub use table::{RemoteTable, SQLTable, SQLTableSource};
32pub use table_reference::RemoteTableRef;
33
34use crate::{
35    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
36};
37
38// SQLFederationProvider provides federation to SQL DMBSs.
39#[derive(Debug)]
40pub struct SQLFederationProvider {
41    pub optimizer: Arc<Optimizer>,
42    pub executor: Arc<dyn SQLExecutor>,
43}
44
45impl SQLFederationProvider {
46    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
47        Self {
48            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
49                SQLFederationOptimizerRule::new(executor.clone()),
50            )])),
51            executor,
52        }
53    }
54}
55
56impl FederationProvider for SQLFederationProvider {
57    fn name(&self) -> &str {
58        "sql_federation_provider"
59    }
60
61    fn compute_context(&self) -> Option<String> {
62        self.executor.compute_context()
63    }
64
65    fn optimizer(&self) -> Option<Arc<Optimizer>> {
66        Some(self.optimizer.clone())
67    }
68}
69
70#[derive(Debug)]
71struct SQLFederationOptimizerRule {
72    planner: Arc<SQLFederationPlanner>,
73}
74
75impl SQLFederationOptimizerRule {
76    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
77        Self {
78            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
79        }
80    }
81}
82
83impl OptimizerRule for SQLFederationOptimizerRule {
84    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
85    /// if the plan was rewritten and `Transformed::no` if it was not.
86    ///
87    /// Note: this function is only called if [`Self::supports_rewrite`] returns
88    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
89    fn rewrite(
90        &self,
91        plan: LogicalPlan,
92        _config: &dyn OptimizerConfig,
93    ) -> Result<Transformed<LogicalPlan>> {
94        if let LogicalPlan::Extension(Extension { ref node }) = plan {
95            if node.name() == "Federated" {
96                // Avoid attempting double federation
97                return Ok(Transformed::no(plan));
98            }
99        }
100
101        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
102        let ext_node = Extension {
103            node: Arc::new(fed_plan),
104        };
105
106        let mut plan = LogicalPlan::Extension(ext_node);
107        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
108            plan = rewriter(plan)?;
109        }
110
111        Ok(Transformed::yes(plan))
112    }
113
114    /// A human readable name for this analyzer rule
115    fn name(&self) -> &str {
116        "federate_sql"
117    }
118
119    /// Does this rule support rewriting owned plans (rather than by reference)?
120    fn supports_rewrite(&self) -> bool {
121        true
122    }
123}
124
125#[derive(Debug)]
126pub struct SQLFederationPlanner {
127    pub executor: Arc<dyn SQLExecutor>,
128}
129
130impl SQLFederationPlanner {
131    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
132        Self { executor }
133    }
134}
135
136#[async_trait]
137impl FederationPlanner for SQLFederationPlanner {
138    async fn plan_federation(
139        &self,
140        node: &FederatedPlanNode,
141        _session_state: &SessionState,
142    ) -> Result<Arc<dyn ExecutionPlan>> {
143        let schema = Arc::new(node.plan().schema().as_arrow().clone());
144        let plan = node.plan().clone();
145        let statistics = self.executor.statistics(&plan).await?;
146        let input = Arc::new(VirtualExecutionPlan::new(
147            plan,
148            Arc::clone(&self.executor),
149            statistics,
150        ));
151        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
152        Ok(Arc::new(schema_cast_exec))
153    }
154}
155
156#[derive(Debug, Clone)]
157pub struct VirtualExecutionPlan {
158    plan: LogicalPlan,
159    executor: Arc<dyn SQLExecutor>,
160    props: PlanProperties,
161    statistics: Statistics,
162}
163
164impl VirtualExecutionPlan {
165    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
166        let schema: Schema = plan.schema().as_ref().into();
167        let props = PlanProperties::new(
168            EquivalenceProperties::new(Arc::new(schema)),
169            Partitioning::UnknownPartitioning(1),
170            EmissionType::Incremental,
171            Boundedness::Bounded,
172        );
173        Self {
174            plan,
175            executor,
176            props,
177            statistics,
178        }
179    }
180
181    pub fn plan(&self) -> &LogicalPlan {
182        &self.plan
183    }
184
185    pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
186        &self.executor
187    }
188
189    pub fn statistics(&self) -> &Statistics {
190        &self.statistics
191    }
192
193    fn schema(&self) -> SchemaRef {
194        let df_schema = self.plan.schema().as_ref();
195        Arc::new(Schema::from(df_schema))
196    }
197
198    fn final_sql(&self) -> Result<String> {
199        let plan = self.plan.clone();
200        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
201        let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
202        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
203        let ast = self.plan_to_statement(&plan)?;
204        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
205        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
206        Ok(ast.to_string())
207    }
208
209    fn rewrite_with_executor_ast_analyzer(
210        &self,
211        ast: Statement,
212    ) -> Result<Statement, datafusion::error::DataFusionError> {
213        if let Some(mut analyzer) = self.executor.ast_analyzer() {
214            Ok(analyzer(ast)?)
215        } else {
216            Ok(ast)
217        }
218    }
219
220    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
221        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
222    }
223}
224
225fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
226    let mut logical_optimizers = vec![];
227    let mut ast_analyzers = vec![];
228
229    plan.apply(|node| {
230        if let LogicalPlan::TableScan(table) = node {
231            let provider = get_table_source(&table.source)
232                .expect("caller is virtual exec so this is valid")
233                .expect("caller is virtual exec so this is valid");
234            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
235                if let Some(analyzer) = source.table.logical_optimizer() {
236                    logical_optimizers.push(analyzer);
237                }
238                if let Some(analyzer) = source.table.ast_analyzer() {
239                    ast_analyzers.push(analyzer);
240                }
241            }
242        }
243        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
244    })?;
245
246    Ok((logical_optimizers, ast_analyzers))
247}
248
249fn apply_logical_optimizers(
250    mut plan: LogicalPlan,
251    analyzers: Vec<LogicalOptimizer>,
252) -> Result<LogicalPlan> {
253    for mut analyzer in analyzers {
254        let old_schema = plan.schema().clone();
255        plan = analyzer(plan)?;
256        let new_schema = plan.schema();
257        if &old_schema != new_schema {
258            return Err(DataFusionError::Execution(format!(
259                "Schema altered during logical analysis, expected: {}, found: {}",
260                old_schema, new_schema
261            )));
262        }
263    }
264    Ok(plan)
265}
266
267fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
268    for mut analyzer in analyzers {
269        statement = analyzer(statement)?;
270    }
271    Ok(statement)
272}
273
274impl DisplayAs for VirtualExecutionPlan {
275    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
276        write!(f, "VirtualExecutionPlan")?;
277        write!(f, " name={}", self.executor.name())?;
278        if let Some(ctx) = self.executor.compute_context() {
279            write!(f, " compute_context={ctx}")?;
280        };
281        let mut plan = self.plan.clone();
282        if let Ok(statement) = self.plan_to_statement(&plan) {
283            write!(f, " initial_sql={statement}")?;
284        }
285
286        let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
287            Ok(analyzers) => analyzers,
288            Err(_) => return Ok(()),
289        };
290
291        let old_plan = plan.clone();
292
293        plan = match apply_logical_optimizers(plan, logical_optimizers) {
294            Ok(plan) => plan,
295            _ => return Ok(()),
296        };
297
298        let statement = match self.plan_to_statement(&plan) {
299            Ok(statement) => statement,
300            _ => return Ok(()),
301        };
302
303        if plan != old_plan {
304            write!(f, " rewritten_logical_sql={statement}")?;
305        }
306
307        let old_statement = statement.clone();
308        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
309            Ok(statement) => statement,
310            _ => return Ok(()),
311        };
312        if old_statement != statement {
313            write!(f, " rewritten_executor_sql={statement}")?;
314        }
315
316        let old_statement = statement.clone();
317        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
318            Ok(statement) => statement,
319            _ => return Ok(()),
320        };
321        if old_statement != statement {
322            write!(f, " rewritten_ast_analyzer={statement}")?;
323        }
324
325        Ok(())
326    }
327}
328
329impl ExecutionPlan for VirtualExecutionPlan {
330    fn name(&self) -> &str {
331        "sql_federation_exec"
332    }
333
334    fn as_any(&self) -> &dyn Any {
335        self
336    }
337
338    fn schema(&self) -> SchemaRef {
339        self.schema()
340    }
341
342    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
343        vec![]
344    }
345
346    fn with_new_children(
347        self: Arc<Self>,
348        _: Vec<Arc<dyn ExecutionPlan>>,
349    ) -> Result<Arc<dyn ExecutionPlan>> {
350        Ok(self)
351    }
352
353    fn execute(
354        &self,
355        _partition: usize,
356        _context: Arc<TaskContext>,
357    ) -> Result<SendableRecordBatchStream> {
358        self.executor.execute(&self.final_sql()?, self.schema())
359    }
360
361    fn properties(&self) -> &PlanProperties {
362        &self.props
363    }
364
365    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
366        Ok(self.statistics.clone())
367    }
368}
369
370#[cfg(test)]
371mod tests {
372
373    use std::collections::HashSet;
374    use std::sync::Arc;
375
376    use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
377    use crate::FederatedTableProviderAdaptor;
378    use async_trait::async_trait;
379    use datafusion::arrow::datatypes::{Schema, SchemaRef};
380    use datafusion::common::tree_node::TreeNodeRecursion;
381    use datafusion::execution::SendableRecordBatchStream;
382    use datafusion::sql::unparser::dialect::Dialect;
383    use datafusion::sql::unparser::{self};
384    use datafusion::{
385        arrow::datatypes::{DataType, Field},
386        datasource::TableProvider,
387        execution::context::SessionContext,
388    };
389
390    use super::table::RemoteTable;
391    use super::*;
392
393    #[derive(Debug, Clone)]
394    struct TestExecutor {
395        compute_context: String,
396    }
397
398    #[async_trait]
399    impl SQLExecutor for TestExecutor {
400        fn name(&self) -> &str {
401            "TestExecutor"
402        }
403
404        fn compute_context(&self) -> Option<String> {
405            Some(self.compute_context.clone())
406        }
407
408        fn dialect(&self) -> Arc<dyn Dialect> {
409            Arc::new(unparser::dialect::DefaultDialect {})
410        }
411
412        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
413            unimplemented!()
414        }
415
416        async fn table_names(&self) -> Result<Vec<String>> {
417            unimplemented!()
418        }
419
420        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
421            unimplemented!()
422        }
423    }
424
425    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
426        let schema = Arc::new(Schema::new(vec![
427            Field::new("a", DataType::Int64, false),
428            Field::new("b", DataType::Utf8, false),
429            Field::new("c", DataType::Date32, false),
430        ]));
431        let table_ref = RemoteTableRef::try_from(name).unwrap();
432        let table = Arc::new(RemoteTable::new(table_ref, schema));
433        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
434        let table_source = Arc::new(SQLTableSource { provider, table });
435        Arc::new(FederatedTableProviderAdaptor::new(table_source))
436    }
437
438    #[tokio::test]
439    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
440        let test_executor_a = TestExecutor {
441            compute_context: "a".into(),
442        };
443
444        let test_executor_b = TestExecutor {
445            compute_context: "b".into(),
446        };
447
448        let table_a1_ref = "table_a1".to_string();
449        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
450
451        let table_a2_ref = "table_a2".to_string();
452        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
453
454        let table_b1_ref = "table_b1(1)".to_string();
455        let table_b1_df_ref = "table_local_b1".to_string();
456
457        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
458
459        // Create a new SessionState with the optimizer rule we created above
460        let state = crate::default_session_state();
461        let ctx = SessionContext::new_with_state(state);
462
463        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
464        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
465        ctx.register_table(table_b1_df_ref.clone(), table_b1)
466            .unwrap();
467
468        let query = r#"
469            SELECT * FROM table_a1
470            UNION ALL
471            SELECT * FROM table_a2
472            UNION ALL
473            SELECT * FROM table_local_b1;
474        "#;
475
476        let df = ctx.sql(query).await?;
477
478        let logical_plan = df.into_optimized_plan()?;
479
480        let mut table_a1_federated = false;
481        let mut table_a2_federated = false;
482        let mut table_b1_federated = false;
483
484        let _ = logical_plan.apply(|node| {
485            if let LogicalPlan::Extension(node) = node {
486                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
487                    let _ = node.plan().apply(|node| {
488                        if let LogicalPlan::TableScan(table) = node {
489                            if table.table_name.table() == table_a1_ref {
490                                table_a1_federated = true;
491                            }
492                            if table.table_name.table() == table_a2_ref {
493                                table_a2_federated = true;
494                            }
495                            // assuming table name is rewritten via analyzer
496                            if table.table_name.table() == table_b1_df_ref {
497                                table_b1_federated = true;
498                            }
499                        }
500                        Ok(TreeNodeRecursion::Continue)
501                    });
502                }
503            }
504            Ok(TreeNodeRecursion::Continue)
505        });
506
507        assert!(table_a1_federated);
508        assert!(table_a2_federated);
509        assert!(table_b1_federated);
510
511        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
512
513        let mut final_queries = vec![];
514
515        let _ = physical_plan.apply(|node| {
516            if node.name() == "sql_federation_exec" {
517                let node = node
518                    .as_any()
519                    .downcast_ref::<VirtualExecutionPlan>()
520                    .unwrap();
521
522                final_queries.push(node.final_sql()?);
523            }
524            Ok(TreeNodeRecursion::Continue)
525        });
526
527        let expected = vec![
528            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
529            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
530            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
531        ];
532
533        assert_eq!(
534            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
535            HashSet::from_iter(expected)
536        );
537
538        Ok(())
539    }
540
541    #[tokio::test]
542    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
543        let test_executor_a = TestExecutor {
544            compute_context: "test".into(),
545        };
546
547        let lowercase_table_ref = "default.table".to_string();
548        let lowercase_local_table_ref = "dftable".to_string();
549        let lowercase_table =
550            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
551
552        let capitalized_table_ref = "default.Table(1)".to_string();
553        let capitalized_local_table_ref = "dfview".to_string();
554        let capitalized_table =
555            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
556
557        // Create a new SessionState with the optimizer rule we created above
558        let state = crate::default_session_state();
559        let ctx = SessionContext::new_with_state(state);
560
561        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
562            .unwrap();
563        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
564            .unwrap();
565
566        let query = r#"
567                SELECT * FROM dftable
568                UNION ALL
569                SELECT * FROM dfview;
570            "#;
571
572        let df = ctx.sql(query).await?;
573
574        let logical_plan = df.into_optimized_plan()?;
575
576        let mut lowercase_table = false;
577        let mut capitalized_table = false;
578
579        let _ = logical_plan.apply(|node| {
580            if let LogicalPlan::Extension(node) = node {
581                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
582                    let _ = node.plan().apply(|node| {
583                        if let LogicalPlan::TableScan(table) = node {
584                            if table.table_name.table() == lowercase_local_table_ref {
585                                lowercase_table = true;
586                            }
587                            if table.table_name.table() == capitalized_local_table_ref {
588                                capitalized_table = true;
589                            }
590                        }
591                        Ok(TreeNodeRecursion::Continue)
592                    });
593                }
594            }
595            Ok(TreeNodeRecursion::Continue)
596        });
597
598        assert!(lowercase_table);
599        assert!(capitalized_table);
600
601        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
602
603        let mut final_queries = vec![];
604
605        let _ = physical_plan.apply(|node| {
606            if node.name() == "sql_federation_exec" {
607                let node = node
608                    .as_any()
609                    .downcast_ref::<VirtualExecutionPlan>()
610                    .unwrap();
611
612                final_queries.push(node.final_sql()?);
613            }
614            Ok(TreeNodeRecursion::Continue)
615        });
616
617        let expected = vec![
618            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
619        ];
620
621        assert_eq!(
622            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
623            HashSet::from_iter(expected)
624        );
625
626        Ok(())
627    }
628}