1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::{
15 tree_node::{Transformed, TreeNode},
16 Statistics,
17 },
18 config::ConfigOptions,
19 error::{DataFusionError, Result},
20 execution::{context::SessionState, TaskContext},
21 logical_expr::{Extension, LogicalPlan},
22 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
23 physical_expr::EquivalenceProperties,
24 physical_plan::{
25 execution_plan::{Boundedness, EmissionType},
26 filter_pushdown::{
27 ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown,
28 },
29 metrics::MetricsSet,
30 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties,
31 SendableRecordBatchStream,
32 },
33 sql::{sqlparser::ast::Statement, unparser::Unparser},
34};
35
36pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
37pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
38pub use table::{RemoteTable, SQLTable, SQLTableSource};
39pub use table_reference::RemoteTableRef;
40
41use crate::{
42 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
43};
44
45#[derive(Debug)]
47pub struct SQLFederationProvider {
48 pub optimizer: Arc<Optimizer>,
49 pub executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54 Self {
55 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56 SQLFederationOptimizerRule::new(executor.clone()),
57 )])),
58 executor,
59 }
60 }
61}
62
63impl FederationProvider for SQLFederationProvider {
64 fn name(&self) -> &str {
65 "sql_federation_provider"
66 }
67
68 fn compute_context(&self) -> Option<String> {
69 self.executor.compute_context()
70 }
71
72 fn optimizer(&self) -> Option<Arc<Optimizer>> {
73 Some(self.optimizer.clone())
74 }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79 planner: Arc<SQLFederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84 Self {
85 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86 }
87 }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91 fn rewrite(
97 &self,
98 plan: LogicalPlan,
99 _config: &dyn OptimizerConfig,
100 ) -> Result<Transformed<LogicalPlan>> {
101 if let LogicalPlan::Extension(Extension { ref node }) = plan {
102 if node.name() == "Federated" {
103 return Ok(Transformed::no(plan));
105 }
106 }
107
108 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109 let ext_node = Extension {
110 node: Arc::new(fed_plan),
111 };
112
113 let mut plan = LogicalPlan::Extension(ext_node);
114 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
115 plan = rewriter(plan)?;
116 }
117
118 Ok(Transformed::yes(plan))
119 }
120
121 fn name(&self) -> &str {
123 "federate_sql"
124 }
125
126 fn supports_rewrite(&self) -> bool {
128 true
129 }
130}
131
132#[derive(Debug)]
133pub struct SQLFederationPlanner {
134 pub executor: Arc<dyn SQLExecutor>,
135}
136
137impl SQLFederationPlanner {
138 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
139 Self { executor }
140 }
141}
142
143#[async_trait]
144impl FederationPlanner for SQLFederationPlanner {
145 async fn plan_federation(
146 &self,
147 node: &FederatedPlanNode,
148 _session_state: &SessionState,
149 ) -> Result<Arc<dyn ExecutionPlan>> {
150 let schema = Arc::new(node.plan().schema().as_arrow().clone());
151 let plan = node.plan().clone();
152 let statistics = self.executor.statistics(&plan).await?;
153 let input = Arc::new(VirtualExecutionPlan::new(
154 plan,
155 Arc::clone(&self.executor),
156 statistics,
157 ));
158 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
159 Ok(Arc::new(schema_cast_exec))
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct VirtualExecutionPlan {
165 plan: LogicalPlan,
166 executor: Arc<dyn SQLExecutor>,
167 props: PlanProperties,
168 statistics: Statistics,
169 filters: Vec<Arc<dyn PhysicalExpr>>,
170}
171
172impl VirtualExecutionPlan {
173 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
174 let schema: Schema = plan.schema().as_arrow().clone();
175 let props = PlanProperties::new(
176 EquivalenceProperties::new(Arc::new(schema)),
177 Partitioning::UnknownPartitioning(1),
178 EmissionType::Incremental,
179 Boundedness::Bounded,
180 );
181 Self {
182 plan,
183 executor,
184 props,
185 statistics,
186 filters: Vec::new(),
187 }
188 }
189
190 pub fn plan(&self) -> &LogicalPlan {
191 &self.plan
192 }
193
194 pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
195 &self.executor
196 }
197
198 pub fn statistics(&self) -> &Statistics {
199 &self.statistics
200 }
201
202 fn schema(&self) -> SchemaRef {
203 let df_schema = self.plan.schema().as_arrow().clone();
204 Arc::new(df_schema)
205 }
206
207 fn final_sql(&self) -> Result<String> {
208 let plan = self.plan.clone();
209 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
210 let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
211 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
212 let ast = self.plan_to_statement(&plan)?;
213 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
214 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
215 Ok(ast.to_string())
216 }
217
218 fn rewrite_with_executor_ast_analyzer(
219 &self,
220 ast: Statement,
221 ) -> Result<Statement, datafusion::error::DataFusionError> {
222 if let Some(mut analyzer) = self.executor.ast_analyzer() {
223 Ok(analyzer(ast)?)
224 } else {
225 Ok(ast)
226 }
227 }
228
229 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
230 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
231 }
232}
233
234fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
235 let mut logical_optimizers = vec![];
236 let mut ast_analyzers = vec![];
237
238 plan.apply(|node| {
239 if let LogicalPlan::TableScan(table) = node {
240 let provider = get_table_source(&table.source)
241 .expect("caller is virtual exec so this is valid")
242 .expect("caller is virtual exec so this is valid");
243 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
244 if let Some(analyzer) = source.table.logical_optimizer() {
245 logical_optimizers.push(analyzer);
246 }
247 if let Some(analyzer) = source.table.ast_analyzer() {
248 ast_analyzers.push(analyzer);
249 }
250 }
251 }
252 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
253 })?;
254
255 Ok((logical_optimizers, ast_analyzers))
256}
257
258fn apply_logical_optimizers(
259 mut plan: LogicalPlan,
260 analyzers: Vec<LogicalOptimizer>,
261) -> Result<LogicalPlan> {
262 for mut analyzer in analyzers {
263 let old_schema = plan.schema().clone();
264 plan = analyzer(plan)?;
265 let new_schema = plan.schema();
266 if &old_schema != new_schema {
267 return Err(DataFusionError::Execution(format!(
268 "Schema altered during logical analysis, expected: {}, found: {}",
269 old_schema, new_schema
270 )));
271 }
272 }
273 Ok(plan)
274}
275
276fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
277 for mut analyzer in analyzers {
278 statement = analyzer(statement)?;
279 }
280 Ok(statement)
281}
282
283impl DisplayAs for VirtualExecutionPlan {
284 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
285 write!(f, "VirtualExecutionPlan")?;
286 write!(f, " name={}", self.executor.name())?;
287 if let Some(ctx) = self.executor.compute_context() {
288 write!(f, " compute_context={ctx}")?;
289 };
290 let mut plan = self.plan.clone();
291 if let Ok(statement) = self.plan_to_statement(&plan) {
292 write!(f, " initial_sql={statement}")?;
293 }
294
295 let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
296 Ok(analyzers) => analyzers,
297 Err(_) => return Ok(()),
298 };
299
300 let old_plan = plan.clone();
301
302 plan = match apply_logical_optimizers(plan, logical_optimizers) {
303 Ok(plan) => plan,
304 _ => return Ok(()),
305 };
306
307 let statement = match self.plan_to_statement(&plan) {
308 Ok(statement) => statement,
309 _ => return Ok(()),
310 };
311
312 if plan != old_plan {
313 write!(f, " rewritten_logical_sql={statement}")?;
314 }
315
316 let old_statement = statement.clone();
317 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
318 Ok(statement) => statement,
319 _ => return Ok(()),
320 };
321 if old_statement != statement {
322 write!(f, " rewritten_executor_sql={statement}")?;
323 }
324
325 let old_statement = statement.clone();
326 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
327 Ok(statement) => statement,
328 _ => return Ok(()),
329 };
330 if old_statement != statement {
331 write!(f, " rewritten_ast_analyzer={statement}")?;
332 }
333
334 Ok(())
335 }
336}
337
338impl ExecutionPlan for VirtualExecutionPlan {
339 fn name(&self) -> &str {
340 "sql_federation_exec"
341 }
342
343 fn as_any(&self) -> &dyn Any {
344 self
345 }
346
347 fn schema(&self) -> SchemaRef {
348 self.schema()
349 }
350
351 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
352 vec![]
353 }
354
355 fn with_new_children(
356 self: Arc<Self>,
357 _: Vec<Arc<dyn ExecutionPlan>>,
358 ) -> Result<Arc<dyn ExecutionPlan>> {
359 Ok(self)
360 }
361
362 fn execute(
363 &self,
364 _partition: usize,
365 _context: Arc<TaskContext>,
366 ) -> Result<SendableRecordBatchStream> {
367 self.executor
368 .execute(&self.final_sql()?, self.schema(), &self.filters)
369 }
370
371 fn properties(&self) -> &PlanProperties {
372 &self.props
373 }
374
375 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
376 Ok(self.statistics.clone())
377 }
378
379 fn metrics(&self) -> Option<MetricsSet> {
380 self.executor.metrics()
381 }
382
383 fn handle_child_pushdown_result(
384 &self,
385 _phase: FilterPushdownPhase,
386 child_pushdown_result: ChildPushdownResult,
387 _config: &ConfigOptions,
388 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
389 let parent_filters: Vec<_> = child_pushdown_result
390 .clone()
391 .parent_filters
392 .into_iter()
393 .map(|f| f.filter)
394 .collect();
395
396 if parent_filters.is_empty() {
397 return Ok(FilterPushdownPropagation {
398 filters: vec![],
399 updated_node: None,
400 });
401 }
402
403 let filters_pushed_down = vec![PushedDown::Yes; parent_filters.len()];
404 let mut node = self.clone();
405 node.filters = parent_filters;
406
407 Ok(FilterPushdownPropagation {
408 filters: filters_pushed_down,
409 updated_node: Some(Arc::new(node)),
410 })
411 }
412}
413
414#[cfg(test)]
415mod tests {
416
417 use std::collections::HashSet;
418 use std::sync::Arc;
419
420 use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
421 use crate::FederatedTableProviderAdaptor;
422 use async_trait::async_trait;
423 use datafusion::arrow::datatypes::{Schema, SchemaRef};
424 use datafusion::common::tree_node::TreeNodeRecursion;
425 use datafusion::execution::SendableRecordBatchStream;
426 use datafusion::sql::unparser::dialect::Dialect;
427 use datafusion::sql::unparser::{self};
428 use datafusion::{
429 arrow::datatypes::{DataType, Field},
430 datasource::TableProvider,
431 execution::context::SessionContext,
432 };
433
434 use super::table::RemoteTable;
435 use super::*;
436
437 #[derive(Debug, Clone)]
438 struct TestExecutor {
439 compute_context: String,
440 }
441
442 #[async_trait]
443 impl SQLExecutor for TestExecutor {
444 fn name(&self) -> &str {
445 "TestExecutor"
446 }
447
448 fn compute_context(&self) -> Option<String> {
449 Some(self.compute_context.clone())
450 }
451
452 fn dialect(&self) -> Arc<dyn Dialect> {
453 Arc::new(unparser::dialect::DefaultDialect {})
454 }
455
456 fn execute(
457 &self,
458 _query: &str,
459 _schema: SchemaRef,
460 _filters: &[Arc<dyn PhysicalExpr>],
461 ) -> Result<SendableRecordBatchStream> {
462 unimplemented!()
463 }
464
465 async fn table_names(&self) -> Result<Vec<String>> {
466 unimplemented!()
467 }
468
469 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
470 unimplemented!()
471 }
472 }
473
474 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
475 let schema = Arc::new(Schema::new(vec![
476 Field::new("a", DataType::Int64, false),
477 Field::new("b", DataType::Utf8, false),
478 Field::new("c", DataType::Date32, false),
479 ]));
480 let table_ref = RemoteTableRef::try_from(name).unwrap();
481 let table = Arc::new(RemoteTable::new(table_ref, schema));
482 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
483 let table_source = Arc::new(SQLTableSource { provider, table });
484 Arc::new(FederatedTableProviderAdaptor::new(table_source))
485 }
486
487 #[tokio::test]
488 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
489 let test_executor_a = TestExecutor {
490 compute_context: "a".into(),
491 };
492
493 let test_executor_b = TestExecutor {
494 compute_context: "b".into(),
495 };
496
497 let table_a1_ref = "table_a1".to_string();
498 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
499
500 let table_a2_ref = "table_a2".to_string();
501 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
502
503 let table_b1_ref = "table_b1(1)".to_string();
504 let table_b1_df_ref = "table_local_b1".to_string();
505
506 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
507
508 let state = crate::default_session_state();
510 let ctx = SessionContext::new_with_state(state);
511
512 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
513 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
514 ctx.register_table(table_b1_df_ref.clone(), table_b1)
515 .unwrap();
516
517 let query = r#"
518 SELECT * FROM table_a1
519 UNION ALL
520 SELECT * FROM table_a2
521 UNION ALL
522 SELECT * FROM table_local_b1;
523 "#;
524
525 let df = ctx.sql(query).await?;
526
527 let logical_plan = df.into_optimized_plan()?;
528
529 let mut table_a1_federated = false;
530 let mut table_a2_federated = false;
531 let mut table_b1_federated = false;
532
533 let _ = logical_plan.apply(|node| {
534 if let LogicalPlan::Extension(node) = node {
535 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
536 let _ = node.plan().apply(|node| {
537 if let LogicalPlan::TableScan(table) = node {
538 if table.table_name.table() == table_a1_ref {
539 table_a1_federated = true;
540 }
541 if table.table_name.table() == table_a2_ref {
542 table_a2_federated = true;
543 }
544 if table.table_name.table() == table_b1_df_ref {
546 table_b1_federated = true;
547 }
548 }
549 Ok(TreeNodeRecursion::Continue)
550 });
551 }
552 }
553 Ok(TreeNodeRecursion::Continue)
554 });
555
556 assert!(table_a1_federated);
557 assert!(table_a2_federated);
558 assert!(table_b1_federated);
559
560 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
561
562 let mut final_queries = vec![];
563
564 let _ = physical_plan.apply(|node| {
565 if node.name() == "sql_federation_exec" {
566 let node = node
567 .as_any()
568 .downcast_ref::<VirtualExecutionPlan>()
569 .unwrap();
570
571 final_queries.push(node.final_sql()?);
572 }
573 Ok(TreeNodeRecursion::Continue)
574 });
575
576 let expected = vec![
577 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
578 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
579 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
580 ];
581
582 assert_eq!(
583 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
584 HashSet::from_iter(expected)
585 );
586
587 Ok(())
588 }
589
590 #[tokio::test]
591 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
592 let test_executor_a = TestExecutor {
593 compute_context: "test".into(),
594 };
595
596 let lowercase_table_ref = "default.table".to_string();
597 let lowercase_local_table_ref = "dftable".to_string();
598 let lowercase_table =
599 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
600
601 let capitalized_table_ref = "default.Table(1)".to_string();
602 let capitalized_local_table_ref = "dfview".to_string();
603 let capitalized_table =
604 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
605
606 let state = crate::default_session_state();
608 let ctx = SessionContext::new_with_state(state);
609
610 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
611 .unwrap();
612 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
613 .unwrap();
614
615 let query = r#"
616 SELECT * FROM dftable
617 UNION ALL
618 SELECT * FROM dfview;
619 "#;
620
621 let df = ctx.sql(query).await?;
622
623 let logical_plan = df.into_optimized_plan()?;
624
625 let mut lowercase_table = false;
626 let mut capitalized_table = false;
627
628 let _ = logical_plan.apply(|node| {
629 if let LogicalPlan::Extension(node) = node {
630 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
631 let _ = node.plan().apply(|node| {
632 if let LogicalPlan::TableScan(table) = node {
633 if table.table_name.table() == lowercase_local_table_ref {
634 lowercase_table = true;
635 }
636 if table.table_name.table() == capitalized_local_table_ref {
637 capitalized_table = true;
638 }
639 }
640 Ok(TreeNodeRecursion::Continue)
641 });
642 }
643 }
644 Ok(TreeNodeRecursion::Continue)
645 });
646
647 assert!(lowercase_table);
648 assert!(capitalized_table);
649
650 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
651
652 let mut final_queries = vec![];
653
654 let _ = physical_plan.apply(|node| {
655 if node.name() == "sql_federation_exec" {
656 let node = node
657 .as_any()
658 .downcast_ref::<VirtualExecutionPlan>()
659 .unwrap();
660
661 final_queries.push(node.final_sql()?);
662 }
663 Ok(TreeNodeRecursion::Continue)
664 });
665
666 let expected = vec![
667 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
668 ];
669
670 assert_eq!(
671 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
672 HashSet::from_iter(expected)
673 );
674
675 Ok(())
676 }
677}