datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::{
15        tree_node::{Transformed, TreeNode},
16        Statistics,
17    },
18    error::{DataFusionError, Result},
19    execution::{context::SessionState, TaskContext},
20    logical_expr::{Extension, LogicalPlan},
21    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
22    physical_expr::EquivalenceProperties,
23    physical_plan::{
24        execution_plan::{Boundedness, EmissionType},
25        metrics::MetricsSet,
26        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
27        SendableRecordBatchStream,
28    },
29    sql::{sqlparser::ast::Statement, unparser::Unparser},
30};
31
32pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
33pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
34pub use table::{RemoteTable, SQLTable, SQLTableSource};
35pub use table_reference::RemoteTableRef;
36
37use crate::{
38    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
39};
40
41// SQLFederationProvider provides federation to SQL DMBSs.
42#[derive(Debug)]
43pub struct SQLFederationProvider {
44    pub optimizer: Arc<Optimizer>,
45    pub executor: Arc<dyn SQLExecutor>,
46}
47
48impl SQLFederationProvider {
49    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
50        Self {
51            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
52                SQLFederationOptimizerRule::new(executor.clone()),
53            )])),
54            executor,
55        }
56    }
57}
58
59impl FederationProvider for SQLFederationProvider {
60    fn name(&self) -> &str {
61        "sql_federation_provider"
62    }
63
64    fn compute_context(&self) -> Option<String> {
65        self.executor.compute_context()
66    }
67
68    fn optimizer(&self) -> Option<Arc<Optimizer>> {
69        Some(self.optimizer.clone())
70    }
71}
72
73#[derive(Debug)]
74struct SQLFederationOptimizerRule {
75    planner: Arc<SQLFederationPlanner>,
76}
77
78impl SQLFederationOptimizerRule {
79    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
80        Self {
81            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
82        }
83    }
84}
85
86impl OptimizerRule for SQLFederationOptimizerRule {
87    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
88    /// if the plan was rewritten and `Transformed::no` if it was not.
89    ///
90    /// Note: this function is only called if [`Self::supports_rewrite`] returns
91    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
92    fn rewrite(
93        &self,
94        plan: LogicalPlan,
95        _config: &dyn OptimizerConfig,
96    ) -> Result<Transformed<LogicalPlan>> {
97        if let LogicalPlan::Extension(Extension { ref node }) = plan {
98            if node.name() == "Federated" {
99                // Avoid attempting double federation
100                return Ok(Transformed::no(plan));
101            }
102        }
103
104        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
105        let ext_node = Extension {
106            node: Arc::new(fed_plan),
107        };
108
109        let mut plan = LogicalPlan::Extension(ext_node);
110        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
111            plan = rewriter(plan)?;
112        }
113
114        Ok(Transformed::yes(plan))
115    }
116
117    /// A human readable name for this analyzer rule
118    fn name(&self) -> &str {
119        "federate_sql"
120    }
121
122    /// Does this rule support rewriting owned plans (rather than by reference)?
123    fn supports_rewrite(&self) -> bool {
124        true
125    }
126}
127
128#[derive(Debug)]
129pub struct SQLFederationPlanner {
130    pub executor: Arc<dyn SQLExecutor>,
131}
132
133impl SQLFederationPlanner {
134    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
135        Self { executor }
136    }
137}
138
139#[async_trait]
140impl FederationPlanner for SQLFederationPlanner {
141    async fn plan_federation(
142        &self,
143        node: &FederatedPlanNode,
144        _session_state: &SessionState,
145    ) -> Result<Arc<dyn ExecutionPlan>> {
146        let schema = Arc::new(node.plan().schema().as_arrow().clone());
147        let plan = node.plan().clone();
148        let statistics = self.executor.statistics(&plan).await?;
149        let input = Arc::new(VirtualExecutionPlan::new(
150            plan,
151            Arc::clone(&self.executor),
152            statistics,
153        ));
154        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
155        Ok(Arc::new(schema_cast_exec))
156    }
157}
158
159#[derive(Debug, Clone)]
160pub struct VirtualExecutionPlan {
161    plan: LogicalPlan,
162    executor: Arc<dyn SQLExecutor>,
163    props: PlanProperties,
164    statistics: Statistics,
165}
166
167impl VirtualExecutionPlan {
168    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
169        let schema: Schema = plan.schema().as_arrow().clone();
170        let props = PlanProperties::new(
171            EquivalenceProperties::new(Arc::new(schema)),
172            Partitioning::UnknownPartitioning(1),
173            EmissionType::Incremental,
174            Boundedness::Bounded,
175        );
176        Self {
177            plan,
178            executor,
179            props,
180            statistics,
181        }
182    }
183
184    pub fn plan(&self) -> &LogicalPlan {
185        &self.plan
186    }
187
188    pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
189        &self.executor
190    }
191
192    pub fn statistics(&self) -> &Statistics {
193        &self.statistics
194    }
195
196    fn schema(&self) -> SchemaRef {
197        let df_schema = self.plan.schema().as_arrow().clone();
198        Arc::new(df_schema)
199    }
200
201    fn final_sql(&self) -> Result<String> {
202        let plan = self.plan.clone();
203        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
204        let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
205        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
206        let ast = self.plan_to_statement(&plan)?;
207        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
208        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
209        Ok(ast.to_string())
210    }
211
212    fn rewrite_with_executor_ast_analyzer(
213        &self,
214        ast: Statement,
215    ) -> Result<Statement, datafusion::error::DataFusionError> {
216        if let Some(mut analyzer) = self.executor.ast_analyzer() {
217            Ok(analyzer(ast)?)
218        } else {
219            Ok(ast)
220        }
221    }
222
223    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
224        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
225    }
226}
227
228fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
229    let mut logical_optimizers = vec![];
230    let mut ast_analyzers = vec![];
231
232    plan.apply(|node| {
233        if let LogicalPlan::TableScan(table) = node {
234            let provider = get_table_source(&table.source)
235                .expect("caller is virtual exec so this is valid")
236                .expect("caller is virtual exec so this is valid");
237            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
238                if let Some(analyzer) = source.table.logical_optimizer() {
239                    logical_optimizers.push(analyzer);
240                }
241                if let Some(analyzer) = source.table.ast_analyzer() {
242                    ast_analyzers.push(analyzer);
243                }
244            }
245        }
246        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
247    })?;
248
249    Ok((logical_optimizers, ast_analyzers))
250}
251
252fn apply_logical_optimizers(
253    mut plan: LogicalPlan,
254    analyzers: Vec<LogicalOptimizer>,
255) -> Result<LogicalPlan> {
256    for mut analyzer in analyzers {
257        let old_schema = plan.schema().clone();
258        plan = analyzer(plan)?;
259        let new_schema = plan.schema();
260        if &old_schema != new_schema {
261            return Err(DataFusionError::Execution(format!(
262                "Schema altered during logical analysis, expected: {}, found: {}",
263                old_schema, new_schema
264            )));
265        }
266    }
267    Ok(plan)
268}
269
270fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
271    for mut analyzer in analyzers {
272        statement = analyzer(statement)?;
273    }
274    Ok(statement)
275}
276
277impl DisplayAs for VirtualExecutionPlan {
278    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
279        write!(f, "VirtualExecutionPlan")?;
280        write!(f, " name={}", self.executor.name())?;
281        if let Some(ctx) = self.executor.compute_context() {
282            write!(f, " compute_context={ctx}")?;
283        };
284        let mut plan = self.plan.clone();
285        if let Ok(statement) = self.plan_to_statement(&plan) {
286            write!(f, " initial_sql={statement}")?;
287        }
288
289        let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
290            Ok(analyzers) => analyzers,
291            Err(_) => return Ok(()),
292        };
293
294        let old_plan = plan.clone();
295
296        plan = match apply_logical_optimizers(plan, logical_optimizers) {
297            Ok(plan) => plan,
298            _ => return Ok(()),
299        };
300
301        let statement = match self.plan_to_statement(&plan) {
302            Ok(statement) => statement,
303            _ => return Ok(()),
304        };
305
306        if plan != old_plan {
307            write!(f, " rewritten_logical_sql={statement}")?;
308        }
309
310        let old_statement = statement.clone();
311        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
312            Ok(statement) => statement,
313            _ => return Ok(()),
314        };
315        if old_statement != statement {
316            write!(f, " rewritten_executor_sql={statement}")?;
317        }
318
319        let old_statement = statement.clone();
320        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
321            Ok(statement) => statement,
322            _ => return Ok(()),
323        };
324        if old_statement != statement {
325            write!(f, " rewritten_ast_analyzer={statement}")?;
326        }
327
328        Ok(())
329    }
330}
331
332impl ExecutionPlan for VirtualExecutionPlan {
333    fn name(&self) -> &str {
334        "sql_federation_exec"
335    }
336
337    fn as_any(&self) -> &dyn Any {
338        self
339    }
340
341    fn schema(&self) -> SchemaRef {
342        self.schema()
343    }
344
345    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
346        vec![]
347    }
348
349    fn with_new_children(
350        self: Arc<Self>,
351        _: Vec<Arc<dyn ExecutionPlan>>,
352    ) -> Result<Arc<dyn ExecutionPlan>> {
353        Ok(self)
354    }
355
356    fn execute(
357        &self,
358        _partition: usize,
359        _context: Arc<TaskContext>,
360    ) -> Result<SendableRecordBatchStream> {
361        self.executor.execute(&self.final_sql()?, self.schema())
362    }
363
364    fn properties(&self) -> &PlanProperties {
365        &self.props
366    }
367
368    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
369        Ok(self.statistics.clone())
370    }
371
372    fn metrics(&self) -> Option<MetricsSet> {
373        self.executor.metrics()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379
380    use std::collections::HashSet;
381    use std::sync::Arc;
382
383    use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
384    use crate::FederatedTableProviderAdaptor;
385    use async_trait::async_trait;
386    use datafusion::arrow::datatypes::{Schema, SchemaRef};
387    use datafusion::common::tree_node::TreeNodeRecursion;
388    use datafusion::execution::SendableRecordBatchStream;
389    use datafusion::sql::unparser::dialect::Dialect;
390    use datafusion::sql::unparser::{self};
391    use datafusion::{
392        arrow::datatypes::{DataType, Field},
393        datasource::TableProvider,
394        execution::context::SessionContext,
395    };
396
397    use super::table::RemoteTable;
398    use super::*;
399
400    #[derive(Debug, Clone)]
401    struct TestExecutor {
402        compute_context: String,
403    }
404
405    #[async_trait]
406    impl SQLExecutor for TestExecutor {
407        fn name(&self) -> &str {
408            "TestExecutor"
409        }
410
411        fn compute_context(&self) -> Option<String> {
412            Some(self.compute_context.clone())
413        }
414
415        fn dialect(&self) -> Arc<dyn Dialect> {
416            Arc::new(unparser::dialect::DefaultDialect {})
417        }
418
419        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
420            unimplemented!()
421        }
422
423        async fn table_names(&self) -> Result<Vec<String>> {
424            unimplemented!()
425        }
426
427        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
428            unimplemented!()
429        }
430    }
431
432    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
433        let schema = Arc::new(Schema::new(vec![
434            Field::new("a", DataType::Int64, false),
435            Field::new("b", DataType::Utf8, false),
436            Field::new("c", DataType::Date32, false),
437        ]));
438        let table_ref = RemoteTableRef::try_from(name).unwrap();
439        let table = Arc::new(RemoteTable::new(table_ref, schema));
440        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
441        let table_source = Arc::new(SQLTableSource { provider, table });
442        Arc::new(FederatedTableProviderAdaptor::new(table_source))
443    }
444
445    #[tokio::test]
446    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
447        let test_executor_a = TestExecutor {
448            compute_context: "a".into(),
449        };
450
451        let test_executor_b = TestExecutor {
452            compute_context: "b".into(),
453        };
454
455        let table_a1_ref = "table_a1".to_string();
456        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
457
458        let table_a2_ref = "table_a2".to_string();
459        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
460
461        let table_b1_ref = "table_b1(1)".to_string();
462        let table_b1_df_ref = "table_local_b1".to_string();
463
464        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
465
466        // Create a new SessionState with the optimizer rule we created above
467        let state = crate::default_session_state();
468        let ctx = SessionContext::new_with_state(state);
469
470        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
471        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
472        ctx.register_table(table_b1_df_ref.clone(), table_b1)
473            .unwrap();
474
475        let query = r#"
476            SELECT * FROM table_a1
477            UNION ALL
478            SELECT * FROM table_a2
479            UNION ALL
480            SELECT * FROM table_local_b1;
481        "#;
482
483        let df = ctx.sql(query).await?;
484
485        let logical_plan = df.into_optimized_plan()?;
486
487        let mut table_a1_federated = false;
488        let mut table_a2_federated = false;
489        let mut table_b1_federated = false;
490
491        let _ = logical_plan.apply(|node| {
492            if let LogicalPlan::Extension(node) = node {
493                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
494                    let _ = node.plan().apply(|node| {
495                        if let LogicalPlan::TableScan(table) = node {
496                            if table.table_name.table() == table_a1_ref {
497                                table_a1_federated = true;
498                            }
499                            if table.table_name.table() == table_a2_ref {
500                                table_a2_federated = true;
501                            }
502                            // assuming table name is rewritten via analyzer
503                            if table.table_name.table() == table_b1_df_ref {
504                                table_b1_federated = true;
505                            }
506                        }
507                        Ok(TreeNodeRecursion::Continue)
508                    });
509                }
510            }
511            Ok(TreeNodeRecursion::Continue)
512        });
513
514        assert!(table_a1_federated);
515        assert!(table_a2_federated);
516        assert!(table_b1_federated);
517
518        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
519
520        let mut final_queries = vec![];
521
522        let _ = physical_plan.apply(|node| {
523            if node.name() == "sql_federation_exec" {
524                let node = node
525                    .as_any()
526                    .downcast_ref::<VirtualExecutionPlan>()
527                    .unwrap();
528
529                final_queries.push(node.final_sql()?);
530            }
531            Ok(TreeNodeRecursion::Continue)
532        });
533
534        let expected = vec![
535            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
536            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
537            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
538        ];
539
540        assert_eq!(
541            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
542            HashSet::from_iter(expected)
543        );
544
545        Ok(())
546    }
547
548    #[tokio::test]
549    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
550        let test_executor_a = TestExecutor {
551            compute_context: "test".into(),
552        };
553
554        let lowercase_table_ref = "default.table".to_string();
555        let lowercase_local_table_ref = "dftable".to_string();
556        let lowercase_table =
557            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
558
559        let capitalized_table_ref = "default.Table(1)".to_string();
560        let capitalized_local_table_ref = "dfview".to_string();
561        let capitalized_table =
562            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
563
564        // Create a new SessionState with the optimizer rule we created above
565        let state = crate::default_session_state();
566        let ctx = SessionContext::new_with_state(state);
567
568        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
569            .unwrap();
570        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
571            .unwrap();
572
573        let query = r#"
574                SELECT * FROM dftable
575                UNION ALL
576                SELECT * FROM dfview;
577            "#;
578
579        let df = ctx.sql(query).await?;
580
581        let logical_plan = df.into_optimized_plan()?;
582
583        let mut lowercase_table = false;
584        let mut capitalized_table = false;
585
586        let _ = logical_plan.apply(|node| {
587            if let LogicalPlan::Extension(node) = node {
588                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
589                    let _ = node.plan().apply(|node| {
590                        if let LogicalPlan::TableScan(table) = node {
591                            if table.table_name.table() == lowercase_local_table_ref {
592                                lowercase_table = true;
593                            }
594                            if table.table_name.table() == capitalized_local_table_ref {
595                                capitalized_table = true;
596                            }
597                        }
598                        Ok(TreeNodeRecursion::Continue)
599                    });
600                }
601            }
602            Ok(TreeNodeRecursion::Continue)
603        });
604
605        assert!(lowercase_table);
606        assert!(capitalized_table);
607
608        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
609
610        let mut final_queries = vec![];
611
612        let _ = physical_plan.apply(|node| {
613            if node.name() == "sql_federation_exec" {
614                let node = node
615                    .as_any()
616                    .downcast_ref::<VirtualExecutionPlan>()
617                    .unwrap();
618
619                final_queries.push(node.final_sql()?);
620            }
621            Ok(TreeNodeRecursion::Continue)
622        });
623
624        let expected = vec![
625            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
626        ];
627
628        assert_eq!(
629            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
630            HashSet::from_iter(expected)
631        );
632
633        Ok(())
634    }
635}