datafusion_federation/sql/
mod.rs

1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13    arrow::datatypes::{Schema, SchemaRef},
14    common::tree_node::{Transformed, TreeNode},
15    common::Statistics,
16    error::{DataFusionError, Result},
17    execution::{context::SessionState, TaskContext},
18    logical_expr::{Extension, LogicalPlan},
19    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
20    physical_expr::EquivalenceProperties,
21    physical_plan::{
22        execution_plan::{Boundedness, EmissionType},
23        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
24        SendableRecordBatchStream,
25    },
26    sql::{sqlparser::ast::Statement, unparser::Unparser},
27};
28
29pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
30pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
31pub use table::{RemoteTable, SQLTable, SQLTableSource};
32pub use table_reference::RemoteTableRef;
33
34use crate::{
35    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
36};
37
38// SQLFederationProvider provides federation to SQL DMBSs.
39#[derive(Debug)]
40pub struct SQLFederationProvider {
41    pub optimizer: Arc<Optimizer>,
42    pub executor: Arc<dyn SQLExecutor>,
43}
44
45impl SQLFederationProvider {
46    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
47        Self {
48            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
49                SQLFederationOptimizerRule::new(executor.clone()),
50            )])),
51            executor,
52        }
53    }
54}
55
56impl FederationProvider for SQLFederationProvider {
57    fn name(&self) -> &str {
58        "sql_federation_provider"
59    }
60
61    fn compute_context(&self) -> Option<String> {
62        self.executor.compute_context()
63    }
64
65    fn optimizer(&self) -> Option<Arc<Optimizer>> {
66        Some(self.optimizer.clone())
67    }
68}
69
70#[derive(Debug)]
71struct SQLFederationOptimizerRule {
72    planner: Arc<SQLFederationPlanner>,
73}
74
75impl SQLFederationOptimizerRule {
76    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
77        Self {
78            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
79        }
80    }
81}
82
83impl OptimizerRule for SQLFederationOptimizerRule {
84    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
85    /// if the plan was rewritten and `Transformed::no` if it was not.
86    ///
87    /// Note: this function is only called if [`Self::supports_rewrite`] returns
88    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
89    fn rewrite(
90        &self,
91        plan: LogicalPlan,
92        _config: &dyn OptimizerConfig,
93    ) -> Result<Transformed<LogicalPlan>> {
94        if let LogicalPlan::Extension(Extension { ref node }) = plan {
95            if node.name() == "Federated" {
96                // Avoid attempting double federation
97                return Ok(Transformed::no(plan));
98            }
99        }
100
101        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
102        let ext_node = Extension {
103            node: Arc::new(fed_plan),
104        };
105
106        let mut plan = LogicalPlan::Extension(ext_node);
107        if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
108            plan = rewriter(plan)?;
109        }
110
111        Ok(Transformed::yes(plan))
112    }
113
114    /// A human readable name for this analyzer rule
115    fn name(&self) -> &str {
116        "federate_sql"
117    }
118
119    /// Does this rule support rewriting owned plans (rather than by reference)?
120    fn supports_rewrite(&self) -> bool {
121        true
122    }
123}
124
125#[derive(Debug)]
126pub struct SQLFederationPlanner {
127    pub executor: Arc<dyn SQLExecutor>,
128}
129
130impl SQLFederationPlanner {
131    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
132        Self { executor }
133    }
134}
135
136#[async_trait]
137impl FederationPlanner for SQLFederationPlanner {
138    async fn plan_federation(
139        &self,
140        node: &FederatedPlanNode,
141        _session_state: &SessionState,
142    ) -> Result<Arc<dyn ExecutionPlan>> {
143        let schema = Arc::new(node.plan().schema().as_arrow().clone());
144        let plan = node.plan().clone();
145        let statistics = self.executor.statistics(&plan).await?;
146        let input = Arc::new(VirtualExecutionPlan::new(
147            plan,
148            Arc::clone(&self.executor),
149            statistics,
150        ));
151        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
152        Ok(Arc::new(schema_cast_exec))
153    }
154}
155
156#[derive(Debug, Clone)]
157struct VirtualExecutionPlan {
158    plan: LogicalPlan,
159    executor: Arc<dyn SQLExecutor>,
160    props: PlanProperties,
161    statistics: Statistics,
162}
163
164impl VirtualExecutionPlan {
165    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
166        let schema: Schema = plan.schema().as_ref().into();
167        let props = PlanProperties::new(
168            EquivalenceProperties::new(Arc::new(schema)),
169            Partitioning::UnknownPartitioning(1),
170            EmissionType::Incremental,
171            Boundedness::Bounded,
172        );
173        Self {
174            plan,
175            executor,
176            props,
177            statistics,
178        }
179    }
180
181    fn schema(&self) -> SchemaRef {
182        let df_schema = self.plan.schema().as_ref();
183        Arc::new(Schema::from(df_schema))
184    }
185
186    fn final_sql(&self) -> Result<String> {
187        let plan = self.plan.clone();
188        let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
189        let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
190        let plan = apply_logical_optimizers(plan, logical_optimizers)?;
191        let ast = self.plan_to_statement(&plan)?;
192        let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
193        let ast = apply_ast_analyzers(ast, ast_analyzers)?;
194        Ok(ast.to_string())
195    }
196
197    fn rewrite_with_executor_ast_analyzer(
198        &self,
199        ast: Statement,
200    ) -> Result<Statement, datafusion::error::DataFusionError> {
201        if let Some(mut analyzer) = self.executor.ast_analyzer() {
202            Ok(analyzer(ast)?)
203        } else {
204            Ok(ast)
205        }
206    }
207
208    fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
209        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
210    }
211}
212
213fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
214    let mut logical_optimizers = vec![];
215    let mut ast_analyzers = vec![];
216
217    plan.apply(|node| {
218        if let LogicalPlan::TableScan(table) = node {
219            let provider = get_table_source(&table.source)
220                .expect("caller is virtual exec so this is valid")
221                .expect("caller is virtual exec so this is valid");
222            if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
223                if let Some(analyzer) = source.table.logical_optimizer() {
224                    logical_optimizers.push(analyzer);
225                }
226                if let Some(analyzer) = source.table.ast_analyzer() {
227                    ast_analyzers.push(analyzer);
228                }
229            }
230        }
231        Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
232    })?;
233
234    Ok((logical_optimizers, ast_analyzers))
235}
236
237fn apply_logical_optimizers(
238    mut plan: LogicalPlan,
239    analyzers: Vec<LogicalOptimizer>,
240) -> Result<LogicalPlan> {
241    for mut analyzer in analyzers {
242        let old_schema = plan.schema().clone();
243        plan = analyzer(plan)?;
244        let new_schema = plan.schema();
245        if &old_schema != new_schema {
246            return Err(DataFusionError::Execution(format!(
247                "Schema altered during logical analysis, expected: {}, found: {}",
248                old_schema, new_schema
249            )));
250        }
251    }
252    Ok(plan)
253}
254
255fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
256    for mut analyzer in analyzers {
257        statement = analyzer(statement)?;
258    }
259    Ok(statement)
260}
261
262impl DisplayAs for VirtualExecutionPlan {
263    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
264        write!(f, "VirtualExecutionPlan")?;
265        write!(f, " name={}", self.executor.name())?;
266        if let Some(ctx) = self.executor.compute_context() {
267            write!(f, " compute_context={ctx}")?;
268        };
269        let mut plan = self.plan.clone();
270        if let Ok(statement) = self.plan_to_statement(&plan) {
271            write!(f, " initial_sql={statement}")?;
272        }
273
274        let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
275            Ok(analyzers) => analyzers,
276            Err(_) => return Ok(()),
277        };
278
279        let old_plan = plan.clone();
280
281        plan = match apply_logical_optimizers(plan, logical_optimizers) {
282            Ok(plan) => plan,
283            _ => return Ok(()),
284        };
285
286        let statement = match self.plan_to_statement(&plan) {
287            Ok(statement) => statement,
288            _ => return Ok(()),
289        };
290
291        if plan != old_plan {
292            write!(f, " rewritten_logical_sql={statement}")?;
293        }
294
295        let old_statement = statement.clone();
296        let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
297            Ok(statement) => statement,
298            _ => return Ok(()),
299        };
300        if old_statement != statement {
301            write!(f, " rewritten_executor_sql={statement}")?;
302        }
303
304        let old_statement = statement.clone();
305        let statement = match apply_ast_analyzers(statement, ast_analyzers) {
306            Ok(statement) => statement,
307            _ => return Ok(()),
308        };
309        if old_statement != statement {
310            write!(f, " rewritten_ast_analyzer={statement}")?;
311        }
312
313        Ok(())
314    }
315}
316
317impl ExecutionPlan for VirtualExecutionPlan {
318    fn name(&self) -> &str {
319        "sql_federation_exec"
320    }
321
322    fn as_any(&self) -> &dyn Any {
323        self
324    }
325
326    fn schema(&self) -> SchemaRef {
327        self.schema()
328    }
329
330    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
331        vec![]
332    }
333
334    fn with_new_children(
335        self: Arc<Self>,
336        _: Vec<Arc<dyn ExecutionPlan>>,
337    ) -> Result<Arc<dyn ExecutionPlan>> {
338        Ok(self)
339    }
340
341    fn execute(
342        &self,
343        _partition: usize,
344        _context: Arc<TaskContext>,
345    ) -> Result<SendableRecordBatchStream> {
346        self.executor.execute(&self.final_sql()?, self.schema())
347    }
348
349    fn properties(&self) -> &PlanProperties {
350        &self.props
351    }
352
353    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
354        Ok(self.statistics.clone())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360
361    use std::collections::HashSet;
362    use std::sync::Arc;
363
364    use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
365    use crate::FederatedTableProviderAdaptor;
366    use async_trait::async_trait;
367    use datafusion::arrow::datatypes::{Schema, SchemaRef};
368    use datafusion::common::tree_node::TreeNodeRecursion;
369    use datafusion::execution::SendableRecordBatchStream;
370    use datafusion::sql::unparser::dialect::Dialect;
371    use datafusion::sql::unparser::{self};
372    use datafusion::{
373        arrow::datatypes::{DataType, Field},
374        datasource::TableProvider,
375        execution::context::SessionContext,
376    };
377
378    use super::table::RemoteTable;
379    use super::*;
380
381    #[derive(Debug, Clone)]
382    struct TestExecutor {
383        compute_context: String,
384    }
385
386    #[async_trait]
387    impl SQLExecutor for TestExecutor {
388        fn name(&self) -> &str {
389            "TestExecutor"
390        }
391
392        fn compute_context(&self) -> Option<String> {
393            Some(self.compute_context.clone())
394        }
395
396        fn dialect(&self) -> Arc<dyn Dialect> {
397            Arc::new(unparser::dialect::DefaultDialect {})
398        }
399
400        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
401            unimplemented!()
402        }
403
404        async fn table_names(&self) -> Result<Vec<String>> {
405            unimplemented!()
406        }
407
408        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
409            unimplemented!()
410        }
411    }
412
413    fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
414        let schema = Arc::new(Schema::new(vec![
415            Field::new("a", DataType::Int64, false),
416            Field::new("b", DataType::Utf8, false),
417            Field::new("c", DataType::Date32, false),
418        ]));
419        let table_ref = RemoteTableRef::try_from(name).unwrap();
420        let table = Arc::new(RemoteTable::new(table_ref, schema));
421        let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
422        let table_source = Arc::new(SQLTableSource { provider, table });
423        Arc::new(FederatedTableProviderAdaptor::new(table_source))
424    }
425
426    #[tokio::test]
427    async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
428        let test_executor_a = TestExecutor {
429            compute_context: "a".into(),
430        };
431
432        let test_executor_b = TestExecutor {
433            compute_context: "b".into(),
434        };
435
436        let table_a1_ref = "table_a1".to_string();
437        let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
438
439        let table_a2_ref = "table_a2".to_string();
440        let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
441
442        let table_b1_ref = "table_b1(1)".to_string();
443        let table_b1_df_ref = "table_local_b1".to_string();
444
445        let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
446
447        // Create a new SessionState with the optimizer rule we created above
448        let state = crate::default_session_state();
449        let ctx = SessionContext::new_with_state(state);
450
451        ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
452        ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
453        ctx.register_table(table_b1_df_ref.clone(), table_b1)
454            .unwrap();
455
456        let query = r#"
457            SELECT * FROM table_a1
458            UNION ALL
459            SELECT * FROM table_a2
460            UNION ALL
461            SELECT * FROM table_local_b1;
462        "#;
463
464        let df = ctx.sql(query).await?;
465
466        let logical_plan = df.into_optimized_plan()?;
467
468        let mut table_a1_federated = false;
469        let mut table_a2_federated = false;
470        let mut table_b1_federated = false;
471
472        let _ = logical_plan.apply(|node| {
473            if let LogicalPlan::Extension(node) = node {
474                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
475                    let _ = node.plan().apply(|node| {
476                        if let LogicalPlan::TableScan(table) = node {
477                            if table.table_name.table() == table_a1_ref {
478                                table_a1_federated = true;
479                            }
480                            if table.table_name.table() == table_a2_ref {
481                                table_a2_federated = true;
482                            }
483                            // assuming table name is rewritten via analyzer
484                            if table.table_name.table() == table_b1_df_ref {
485                                table_b1_federated = true;
486                            }
487                        }
488                        Ok(TreeNodeRecursion::Continue)
489                    });
490                }
491            }
492            Ok(TreeNodeRecursion::Continue)
493        });
494
495        assert!(table_a1_federated);
496        assert!(table_a2_federated);
497        assert!(table_b1_federated);
498
499        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
500
501        let mut final_queries = vec![];
502
503        let _ = physical_plan.apply(|node| {
504            if node.name() == "sql_federation_exec" {
505                let node = node
506                    .as_any()
507                    .downcast_ref::<VirtualExecutionPlan>()
508                    .unwrap();
509
510                final_queries.push(node.final_sql()?);
511            }
512            Ok(TreeNodeRecursion::Continue)
513        });
514
515        let expected = vec![
516            "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
517            "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
518            "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
519        ];
520
521        assert_eq!(
522            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
523            HashSet::from_iter(expected)
524        );
525
526        Ok(())
527    }
528
529    #[tokio::test]
530    async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
531        let test_executor_a = TestExecutor {
532            compute_context: "test".into(),
533        };
534
535        let lowercase_table_ref = "default.table".to_string();
536        let lowercase_local_table_ref = "dftable".to_string();
537        let lowercase_table =
538            get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
539
540        let capitalized_table_ref = "default.Table(1)".to_string();
541        let capitalized_local_table_ref = "dfview".to_string();
542        let capitalized_table =
543            get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
544
545        // Create a new SessionState with the optimizer rule we created above
546        let state = crate::default_session_state();
547        let ctx = SessionContext::new_with_state(state);
548
549        ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
550            .unwrap();
551        ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
552            .unwrap();
553
554        let query = r#"
555                SELECT * FROM dftable
556                UNION ALL
557                SELECT * FROM dfview;
558            "#;
559
560        let df = ctx.sql(query).await?;
561
562        let logical_plan = df.into_optimized_plan()?;
563
564        let mut lowercase_table = false;
565        let mut capitalized_table = false;
566
567        let _ = logical_plan.apply(|node| {
568            if let LogicalPlan::Extension(node) = node {
569                if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
570                    let _ = node.plan().apply(|node| {
571                        if let LogicalPlan::TableScan(table) = node {
572                            if table.table_name.table() == lowercase_local_table_ref {
573                                lowercase_table = true;
574                            }
575                            if table.table_name.table() == capitalized_local_table_ref {
576                                capitalized_table = true;
577                            }
578                        }
579                        Ok(TreeNodeRecursion::Continue)
580                    });
581                }
582            }
583            Ok(TreeNodeRecursion::Continue)
584        });
585
586        assert!(lowercase_table);
587        assert!(capitalized_table);
588
589        let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
590
591        let mut final_queries = vec![];
592
593        let _ = physical_plan.apply(|node| {
594            if node.name() == "sql_federation_exec" {
595                let node = node
596                    .as_any()
597                    .downcast_ref::<VirtualExecutionPlan>()
598                    .unwrap();
599
600                final_queries.push(node.final_sql()?);
601            }
602            Ok(TreeNodeRecursion::Continue)
603        });
604
605        let expected = vec![
606            r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
607        ];
608
609        assert_eq!(
610            HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
611            HashSet::from_iter(expected)
612        );
613
614        Ok(())
615    }
616}