1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::tree_node::{Transformed, TreeNode},
15 common::Statistics,
16 error::{DataFusionError, Result},
17 execution::{context::SessionState, TaskContext},
18 logical_expr::{Extension, LogicalPlan},
19 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
20 physical_expr::EquivalenceProperties,
21 physical_plan::{
22 execution_plan::{Boundedness, EmissionType},
23 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
24 SendableRecordBatchStream,
25 },
26 sql::{sqlparser::ast::Statement, unparser::Unparser},
27};
28
29pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
30pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
31pub use table::{RemoteTable, SQLTable, SQLTableSource};
32pub use table_reference::RemoteTableRef;
33
34use crate::{
35 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
36};
37
38#[derive(Debug)]
40pub struct SQLFederationProvider {
41 pub optimizer: Arc<Optimizer>,
42 pub executor: Arc<dyn SQLExecutor>,
43}
44
45impl SQLFederationProvider {
46 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
47 Self {
48 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
49 SQLFederationOptimizerRule::new(executor.clone()),
50 )])),
51 executor,
52 }
53 }
54}
55
56impl FederationProvider for SQLFederationProvider {
57 fn name(&self) -> &str {
58 "sql_federation_provider"
59 }
60
61 fn compute_context(&self) -> Option<String> {
62 self.executor.compute_context()
63 }
64
65 fn optimizer(&self) -> Option<Arc<Optimizer>> {
66 Some(self.optimizer.clone())
67 }
68}
69
70#[derive(Debug)]
71struct SQLFederationOptimizerRule {
72 planner: Arc<SQLFederationPlanner>,
73}
74
75impl SQLFederationOptimizerRule {
76 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
77 Self {
78 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
79 }
80 }
81}
82
83impl OptimizerRule for SQLFederationOptimizerRule {
84 fn rewrite(
90 &self,
91 plan: LogicalPlan,
92 _config: &dyn OptimizerConfig,
93 ) -> Result<Transformed<LogicalPlan>> {
94 if let LogicalPlan::Extension(Extension { ref node }) = plan {
95 if node.name() == "Federated" {
96 return Ok(Transformed::no(plan));
98 }
99 }
100
101 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
102 let ext_node = Extension {
103 node: Arc::new(fed_plan),
104 };
105
106 let mut plan = LogicalPlan::Extension(ext_node);
107 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
108 plan = rewriter(plan)?;
109 }
110
111 Ok(Transformed::yes(plan))
112 }
113
114 fn name(&self) -> &str {
116 "federate_sql"
117 }
118
119 fn supports_rewrite(&self) -> bool {
121 true
122 }
123}
124
125#[derive(Debug)]
126pub struct SQLFederationPlanner {
127 pub executor: Arc<dyn SQLExecutor>,
128}
129
130impl SQLFederationPlanner {
131 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
132 Self { executor }
133 }
134}
135
136#[async_trait]
137impl FederationPlanner for SQLFederationPlanner {
138 async fn plan_federation(
139 &self,
140 node: &FederatedPlanNode,
141 _session_state: &SessionState,
142 ) -> Result<Arc<dyn ExecutionPlan>> {
143 let schema = Arc::new(node.plan().schema().as_arrow().clone());
144 let plan = node.plan().clone();
145 let statistics = self.executor.statistics(&plan).await?;
146 let input = Arc::new(VirtualExecutionPlan::new(
147 plan,
148 Arc::clone(&self.executor),
149 statistics,
150 ));
151 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
152 Ok(Arc::new(schema_cast_exec))
153 }
154}
155
156#[derive(Debug, Clone)]
157pub struct VirtualExecutionPlan {
158 plan: LogicalPlan,
159 executor: Arc<dyn SQLExecutor>,
160 props: PlanProperties,
161 statistics: Statistics,
162}
163
164impl VirtualExecutionPlan {
165 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
166 let schema: Schema = plan.schema().as_ref().into();
167 let props = PlanProperties::new(
168 EquivalenceProperties::new(Arc::new(schema)),
169 Partitioning::UnknownPartitioning(1),
170 EmissionType::Incremental,
171 Boundedness::Bounded,
172 );
173 Self {
174 plan,
175 executor,
176 props,
177 statistics,
178 }
179 }
180
181 pub fn plan(&self) -> &LogicalPlan {
182 &self.plan
183 }
184
185 pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
186 &self.executor
187 }
188
189 pub fn statistics(&self) -> &Statistics {
190 &self.statistics
191 }
192
193 fn schema(&self) -> SchemaRef {
194 let df_schema = self.plan.schema().as_ref();
195 Arc::new(Schema::from(df_schema))
196 }
197
198 fn final_sql(&self) -> Result<String> {
199 let plan = self.plan.clone();
200 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
201 let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
202 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
203 let ast = self.plan_to_statement(&plan)?;
204 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
205 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
206 Ok(ast.to_string())
207 }
208
209 fn rewrite_with_executor_ast_analyzer(
210 &self,
211 ast: Statement,
212 ) -> Result<Statement, datafusion::error::DataFusionError> {
213 if let Some(mut analyzer) = self.executor.ast_analyzer() {
214 Ok(analyzer(ast)?)
215 } else {
216 Ok(ast)
217 }
218 }
219
220 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
221 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
222 }
223}
224
225fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
226 let mut logical_optimizers = vec![];
227 let mut ast_analyzers = vec![];
228
229 plan.apply(|node| {
230 if let LogicalPlan::TableScan(table) = node {
231 let provider = get_table_source(&table.source)
232 .expect("caller is virtual exec so this is valid")
233 .expect("caller is virtual exec so this is valid");
234 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
235 if let Some(analyzer) = source.table.logical_optimizer() {
236 logical_optimizers.push(analyzer);
237 }
238 if let Some(analyzer) = source.table.ast_analyzer() {
239 ast_analyzers.push(analyzer);
240 }
241 }
242 }
243 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
244 })?;
245
246 Ok((logical_optimizers, ast_analyzers))
247}
248
249fn apply_logical_optimizers(
250 mut plan: LogicalPlan,
251 analyzers: Vec<LogicalOptimizer>,
252) -> Result<LogicalPlan> {
253 for mut analyzer in analyzers {
254 let old_schema = plan.schema().clone();
255 plan = analyzer(plan)?;
256 let new_schema = plan.schema();
257 if &old_schema != new_schema {
258 return Err(DataFusionError::Execution(format!(
259 "Schema altered during logical analysis, expected: {}, found: {}",
260 old_schema, new_schema
261 )));
262 }
263 }
264 Ok(plan)
265}
266
267fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
268 for mut analyzer in analyzers {
269 statement = analyzer(statement)?;
270 }
271 Ok(statement)
272}
273
274impl DisplayAs for VirtualExecutionPlan {
275 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
276 write!(f, "VirtualExecutionPlan")?;
277 write!(f, " name={}", self.executor.name())?;
278 if let Some(ctx) = self.executor.compute_context() {
279 write!(f, " compute_context={ctx}")?;
280 };
281 let mut plan = self.plan.clone();
282 if let Ok(statement) = self.plan_to_statement(&plan) {
283 write!(f, " initial_sql={statement}")?;
284 }
285
286 let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
287 Ok(analyzers) => analyzers,
288 Err(_) => return Ok(()),
289 };
290
291 let old_plan = plan.clone();
292
293 plan = match apply_logical_optimizers(plan, logical_optimizers) {
294 Ok(plan) => plan,
295 _ => return Ok(()),
296 };
297
298 let statement = match self.plan_to_statement(&plan) {
299 Ok(statement) => statement,
300 _ => return Ok(()),
301 };
302
303 if plan != old_plan {
304 write!(f, " rewritten_logical_sql={statement}")?;
305 }
306
307 let old_statement = statement.clone();
308 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
309 Ok(statement) => statement,
310 _ => return Ok(()),
311 };
312 if old_statement != statement {
313 write!(f, " rewritten_executor_sql={statement}")?;
314 }
315
316 let old_statement = statement.clone();
317 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
318 Ok(statement) => statement,
319 _ => return Ok(()),
320 };
321 if old_statement != statement {
322 write!(f, " rewritten_ast_analyzer={statement}")?;
323 }
324
325 Ok(())
326 }
327}
328
329impl ExecutionPlan for VirtualExecutionPlan {
330 fn name(&self) -> &str {
331 "sql_federation_exec"
332 }
333
334 fn as_any(&self) -> &dyn Any {
335 self
336 }
337
338 fn schema(&self) -> SchemaRef {
339 self.schema()
340 }
341
342 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
343 vec![]
344 }
345
346 fn with_new_children(
347 self: Arc<Self>,
348 _: Vec<Arc<dyn ExecutionPlan>>,
349 ) -> Result<Arc<dyn ExecutionPlan>> {
350 Ok(self)
351 }
352
353 fn execute(
354 &self,
355 _partition: usize,
356 _context: Arc<TaskContext>,
357 ) -> Result<SendableRecordBatchStream> {
358 self.executor.execute(&self.final_sql()?, self.schema())
359 }
360
361 fn properties(&self) -> &PlanProperties {
362 &self.props
363 }
364
365 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
366 Ok(self.statistics.clone())
367 }
368}
369
370#[cfg(test)]
371mod tests {
372
373 use std::collections::HashSet;
374 use std::sync::Arc;
375
376 use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
377 use crate::FederatedTableProviderAdaptor;
378 use async_trait::async_trait;
379 use datafusion::arrow::datatypes::{Schema, SchemaRef};
380 use datafusion::common::tree_node::TreeNodeRecursion;
381 use datafusion::execution::SendableRecordBatchStream;
382 use datafusion::sql::unparser::dialect::Dialect;
383 use datafusion::sql::unparser::{self};
384 use datafusion::{
385 arrow::datatypes::{DataType, Field},
386 datasource::TableProvider,
387 execution::context::SessionContext,
388 };
389
390 use super::table::RemoteTable;
391 use super::*;
392
393 #[derive(Debug, Clone)]
394 struct TestExecutor {
395 compute_context: String,
396 }
397
398 #[async_trait]
399 impl SQLExecutor for TestExecutor {
400 fn name(&self) -> &str {
401 "TestExecutor"
402 }
403
404 fn compute_context(&self) -> Option<String> {
405 Some(self.compute_context.clone())
406 }
407
408 fn dialect(&self) -> Arc<dyn Dialect> {
409 Arc::new(unparser::dialect::DefaultDialect {})
410 }
411
412 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
413 unimplemented!()
414 }
415
416 async fn table_names(&self) -> Result<Vec<String>> {
417 unimplemented!()
418 }
419
420 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
421 unimplemented!()
422 }
423 }
424
425 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
426 let schema = Arc::new(Schema::new(vec![
427 Field::new("a", DataType::Int64, false),
428 Field::new("b", DataType::Utf8, false),
429 Field::new("c", DataType::Date32, false),
430 ]));
431 let table_ref = RemoteTableRef::try_from(name).unwrap();
432 let table = Arc::new(RemoteTable::new(table_ref, schema));
433 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
434 let table_source = Arc::new(SQLTableSource { provider, table });
435 Arc::new(FederatedTableProviderAdaptor::new(table_source))
436 }
437
438 #[tokio::test]
439 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
440 let test_executor_a = TestExecutor {
441 compute_context: "a".into(),
442 };
443
444 let test_executor_b = TestExecutor {
445 compute_context: "b".into(),
446 };
447
448 let table_a1_ref = "table_a1".to_string();
449 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
450
451 let table_a2_ref = "table_a2".to_string();
452 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
453
454 let table_b1_ref = "table_b1(1)".to_string();
455 let table_b1_df_ref = "table_local_b1".to_string();
456
457 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
458
459 let state = crate::default_session_state();
461 let ctx = SessionContext::new_with_state(state);
462
463 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
464 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
465 ctx.register_table(table_b1_df_ref.clone(), table_b1)
466 .unwrap();
467
468 let query = r#"
469 SELECT * FROM table_a1
470 UNION ALL
471 SELECT * FROM table_a2
472 UNION ALL
473 SELECT * FROM table_local_b1;
474 "#;
475
476 let df = ctx.sql(query).await?;
477
478 let logical_plan = df.into_optimized_plan()?;
479
480 let mut table_a1_federated = false;
481 let mut table_a2_federated = false;
482 let mut table_b1_federated = false;
483
484 let _ = logical_plan.apply(|node| {
485 if let LogicalPlan::Extension(node) = node {
486 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
487 let _ = node.plan().apply(|node| {
488 if let LogicalPlan::TableScan(table) = node {
489 if table.table_name.table() == table_a1_ref {
490 table_a1_federated = true;
491 }
492 if table.table_name.table() == table_a2_ref {
493 table_a2_federated = true;
494 }
495 if table.table_name.table() == table_b1_df_ref {
497 table_b1_federated = true;
498 }
499 }
500 Ok(TreeNodeRecursion::Continue)
501 });
502 }
503 }
504 Ok(TreeNodeRecursion::Continue)
505 });
506
507 assert!(table_a1_federated);
508 assert!(table_a2_federated);
509 assert!(table_b1_federated);
510
511 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
512
513 let mut final_queries = vec![];
514
515 let _ = physical_plan.apply(|node| {
516 if node.name() == "sql_federation_exec" {
517 let node = node
518 .as_any()
519 .downcast_ref::<VirtualExecutionPlan>()
520 .unwrap();
521
522 final_queries.push(node.final_sql()?);
523 }
524 Ok(TreeNodeRecursion::Continue)
525 });
526
527 let expected = vec![
528 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
529 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
530 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
531 ];
532
533 assert_eq!(
534 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
535 HashSet::from_iter(expected)
536 );
537
538 Ok(())
539 }
540
541 #[tokio::test]
542 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
543 let test_executor_a = TestExecutor {
544 compute_context: "test".into(),
545 };
546
547 let lowercase_table_ref = "default.table".to_string();
548 let lowercase_local_table_ref = "dftable".to_string();
549 let lowercase_table =
550 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
551
552 let capitalized_table_ref = "default.Table(1)".to_string();
553 let capitalized_local_table_ref = "dfview".to_string();
554 let capitalized_table =
555 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
556
557 let state = crate::default_session_state();
559 let ctx = SessionContext::new_with_state(state);
560
561 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
562 .unwrap();
563 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
564 .unwrap();
565
566 let query = r#"
567 SELECT * FROM dftable
568 UNION ALL
569 SELECT * FROM dfview;
570 "#;
571
572 let df = ctx.sql(query).await?;
573
574 let logical_plan = df.into_optimized_plan()?;
575
576 let mut lowercase_table = false;
577 let mut capitalized_table = false;
578
579 let _ = logical_plan.apply(|node| {
580 if let LogicalPlan::Extension(node) = node {
581 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
582 let _ = node.plan().apply(|node| {
583 if let LogicalPlan::TableScan(table) = node {
584 if table.table_name.table() == lowercase_local_table_ref {
585 lowercase_table = true;
586 }
587 if table.table_name.table() == capitalized_local_table_ref {
588 capitalized_table = true;
589 }
590 }
591 Ok(TreeNodeRecursion::Continue)
592 });
593 }
594 }
595 Ok(TreeNodeRecursion::Continue)
596 });
597
598 assert!(lowercase_table);
599 assert!(capitalized_table);
600
601 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
602
603 let mut final_queries = vec![];
604
605 let _ = physical_plan.apply(|node| {
606 if node.name() == "sql_federation_exec" {
607 let node = node
608 .as_any()
609 .downcast_ref::<VirtualExecutionPlan>()
610 .unwrap();
611
612 final_queries.push(node.final_sql()?);
613 }
614 Ok(TreeNodeRecursion::Continue)
615 });
616
617 let expected = vec![
618 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
619 ];
620
621 assert_eq!(
622 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
623 HashSet::from_iter(expected)
624 );
625
626 Ok(())
627 }
628}