1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::{
15 tree_node::{Transformed, TreeNode},
16 Statistics,
17 },
18 error::{DataFusionError, Result},
19 execution::{context::SessionState, TaskContext},
20 logical_expr::{Extension, LogicalPlan},
21 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
22 physical_expr::EquivalenceProperties,
23 physical_plan::{
24 execution_plan::{Boundedness, EmissionType},
25 metrics::MetricsSet,
26 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
27 SendableRecordBatchStream,
28 },
29 sql::{sqlparser::ast::Statement, unparser::Unparser},
30};
31
32pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
33pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
34pub use table::{RemoteTable, SQLTable, SQLTableSource};
35pub use table_reference::RemoteTableRef;
36
37use crate::{
38 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
39};
40
41#[derive(Debug)]
43pub struct SQLFederationProvider {
44 pub optimizer: Arc<Optimizer>,
45 pub executor: Arc<dyn SQLExecutor>,
46}
47
48impl SQLFederationProvider {
49 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
50 Self {
51 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
52 SQLFederationOptimizerRule::new(executor.clone()),
53 )])),
54 executor,
55 }
56 }
57}
58
59impl FederationProvider for SQLFederationProvider {
60 fn name(&self) -> &str {
61 "sql_federation_provider"
62 }
63
64 fn compute_context(&self) -> Option<String> {
65 self.executor.compute_context()
66 }
67
68 fn optimizer(&self) -> Option<Arc<Optimizer>> {
69 Some(self.optimizer.clone())
70 }
71}
72
73#[derive(Debug)]
74struct SQLFederationOptimizerRule {
75 planner: Arc<SQLFederationPlanner>,
76}
77
78impl SQLFederationOptimizerRule {
79 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
80 Self {
81 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
82 }
83 }
84}
85
86impl OptimizerRule for SQLFederationOptimizerRule {
87 fn rewrite(
93 &self,
94 plan: LogicalPlan,
95 _config: &dyn OptimizerConfig,
96 ) -> Result<Transformed<LogicalPlan>> {
97 if let LogicalPlan::Extension(Extension { ref node }) = plan {
98 if node.name() == "Federated" {
99 return Ok(Transformed::no(plan));
101 }
102 }
103
104 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
105 let ext_node = Extension {
106 node: Arc::new(fed_plan),
107 };
108
109 let mut plan = LogicalPlan::Extension(ext_node);
110 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
111 plan = rewriter(plan)?;
112 }
113
114 Ok(Transformed::yes(plan))
115 }
116
117 fn name(&self) -> &str {
119 "federate_sql"
120 }
121
122 fn supports_rewrite(&self) -> bool {
124 true
125 }
126}
127
128#[derive(Debug)]
129pub struct SQLFederationPlanner {
130 pub executor: Arc<dyn SQLExecutor>,
131}
132
133impl SQLFederationPlanner {
134 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
135 Self { executor }
136 }
137}
138
139#[async_trait]
140impl FederationPlanner for SQLFederationPlanner {
141 async fn plan_federation(
142 &self,
143 node: &FederatedPlanNode,
144 _session_state: &SessionState,
145 ) -> Result<Arc<dyn ExecutionPlan>> {
146 let schema = Arc::new(node.plan().schema().as_arrow().clone());
147 let plan = node.plan().clone();
148 let statistics = self.executor.statistics(&plan).await?;
149 let input = Arc::new(VirtualExecutionPlan::new(
150 plan,
151 Arc::clone(&self.executor),
152 statistics,
153 ));
154 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
155 Ok(Arc::new(schema_cast_exec))
156 }
157}
158
159#[derive(Debug, Clone)]
160pub struct VirtualExecutionPlan {
161 plan: LogicalPlan,
162 executor: Arc<dyn SQLExecutor>,
163 props: PlanProperties,
164 statistics: Statistics,
165}
166
167impl VirtualExecutionPlan {
168 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
169 let schema: Schema = plan.schema().as_arrow().clone();
170 let props = PlanProperties::new(
171 EquivalenceProperties::new(Arc::new(schema)),
172 Partitioning::UnknownPartitioning(1),
173 EmissionType::Incremental,
174 Boundedness::Bounded,
175 );
176 Self {
177 plan,
178 executor,
179 props,
180 statistics,
181 }
182 }
183
184 pub fn plan(&self) -> &LogicalPlan {
185 &self.plan
186 }
187
188 pub fn executor(&self) -> &Arc<dyn SQLExecutor> {
189 &self.executor
190 }
191
192 pub fn statistics(&self) -> &Statistics {
193 &self.statistics
194 }
195
196 fn schema(&self) -> SchemaRef {
197 let df_schema = self.plan.schema().as_arrow().clone();
198 Arc::new(df_schema)
199 }
200
201 fn final_sql(&self) -> Result<String> {
202 let plan = self.plan.clone();
203 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
204 let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
205 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
206 let ast = self.plan_to_statement(&plan)?;
207 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
208 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
209 Ok(ast.to_string())
210 }
211
212 fn rewrite_with_executor_ast_analyzer(
213 &self,
214 ast: Statement,
215 ) -> Result<Statement, datafusion::error::DataFusionError> {
216 if let Some(mut analyzer) = self.executor.ast_analyzer() {
217 Ok(analyzer(ast)?)
218 } else {
219 Ok(ast)
220 }
221 }
222
223 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
224 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
225 }
226}
227
228fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
229 let mut logical_optimizers = vec![];
230 let mut ast_analyzers = vec![];
231
232 plan.apply(|node| {
233 if let LogicalPlan::TableScan(table) = node {
234 let provider = get_table_source(&table.source)
235 .expect("caller is virtual exec so this is valid")
236 .expect("caller is virtual exec so this is valid");
237 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
238 if let Some(analyzer) = source.table.logical_optimizer() {
239 logical_optimizers.push(analyzer);
240 }
241 if let Some(analyzer) = source.table.ast_analyzer() {
242 ast_analyzers.push(analyzer);
243 }
244 }
245 }
246 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
247 })?;
248
249 Ok((logical_optimizers, ast_analyzers))
250}
251
252fn apply_logical_optimizers(
253 mut plan: LogicalPlan,
254 analyzers: Vec<LogicalOptimizer>,
255) -> Result<LogicalPlan> {
256 for mut analyzer in analyzers {
257 let old_schema = plan.schema().clone();
258 plan = analyzer(plan)?;
259 let new_schema = plan.schema();
260 if &old_schema != new_schema {
261 return Err(DataFusionError::Execution(format!(
262 "Schema altered during logical analysis, expected: {}, found: {}",
263 old_schema, new_schema
264 )));
265 }
266 }
267 Ok(plan)
268}
269
270fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
271 for mut analyzer in analyzers {
272 statement = analyzer(statement)?;
273 }
274 Ok(statement)
275}
276
277impl DisplayAs for VirtualExecutionPlan {
278 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
279 write!(f, "VirtualExecutionPlan")?;
280 write!(f, " name={}", self.executor.name())?;
281 if let Some(ctx) = self.executor.compute_context() {
282 write!(f, " compute_context={ctx}")?;
283 };
284 let mut plan = self.plan.clone();
285 if let Ok(statement) = self.plan_to_statement(&plan) {
286 write!(f, " initial_sql={statement}")?;
287 }
288
289 let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
290 Ok(analyzers) => analyzers,
291 Err(_) => return Ok(()),
292 };
293
294 let old_plan = plan.clone();
295
296 plan = match apply_logical_optimizers(plan, logical_optimizers) {
297 Ok(plan) => plan,
298 _ => return Ok(()),
299 };
300
301 let statement = match self.plan_to_statement(&plan) {
302 Ok(statement) => statement,
303 _ => return Ok(()),
304 };
305
306 if plan != old_plan {
307 write!(f, " rewritten_logical_sql={statement}")?;
308 }
309
310 let old_statement = statement.clone();
311 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
312 Ok(statement) => statement,
313 _ => return Ok(()),
314 };
315 if old_statement != statement {
316 write!(f, " rewritten_executor_sql={statement}")?;
317 }
318
319 let old_statement = statement.clone();
320 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
321 Ok(statement) => statement,
322 _ => return Ok(()),
323 };
324 if old_statement != statement {
325 write!(f, " rewritten_ast_analyzer={statement}")?;
326 }
327
328 Ok(())
329 }
330}
331
332impl ExecutionPlan for VirtualExecutionPlan {
333 fn name(&self) -> &str {
334 "sql_federation_exec"
335 }
336
337 fn as_any(&self) -> &dyn Any {
338 self
339 }
340
341 fn schema(&self) -> SchemaRef {
342 self.schema()
343 }
344
345 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
346 vec![]
347 }
348
349 fn with_new_children(
350 self: Arc<Self>,
351 _: Vec<Arc<dyn ExecutionPlan>>,
352 ) -> Result<Arc<dyn ExecutionPlan>> {
353 Ok(self)
354 }
355
356 fn execute(
357 &self,
358 _partition: usize,
359 _context: Arc<TaskContext>,
360 ) -> Result<SendableRecordBatchStream> {
361 self.executor.execute(&self.final_sql()?, self.schema())
362 }
363
364 fn properties(&self) -> &PlanProperties {
365 &self.props
366 }
367
368 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
369 Ok(self.statistics.clone())
370 }
371
372 fn metrics(&self) -> Option<MetricsSet> {
373 self.executor.metrics()
374 }
375}
376
377#[cfg(test)]
378mod tests {
379
380 use std::collections::HashSet;
381 use std::sync::Arc;
382
383 use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
384 use crate::FederatedTableProviderAdaptor;
385 use async_trait::async_trait;
386 use datafusion::arrow::datatypes::{Schema, SchemaRef};
387 use datafusion::common::tree_node::TreeNodeRecursion;
388 use datafusion::execution::SendableRecordBatchStream;
389 use datafusion::sql::unparser::dialect::Dialect;
390 use datafusion::sql::unparser::{self};
391 use datafusion::{
392 arrow::datatypes::{DataType, Field},
393 datasource::TableProvider,
394 execution::context::SessionContext,
395 };
396
397 use super::table::RemoteTable;
398 use super::*;
399
400 #[derive(Debug, Clone)]
401 struct TestExecutor {
402 compute_context: String,
403 }
404
405 #[async_trait]
406 impl SQLExecutor for TestExecutor {
407 fn name(&self) -> &str {
408 "TestExecutor"
409 }
410
411 fn compute_context(&self) -> Option<String> {
412 Some(self.compute_context.clone())
413 }
414
415 fn dialect(&self) -> Arc<dyn Dialect> {
416 Arc::new(unparser::dialect::DefaultDialect {})
417 }
418
419 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
420 unimplemented!()
421 }
422
423 async fn table_names(&self) -> Result<Vec<String>> {
424 unimplemented!()
425 }
426
427 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
428 unimplemented!()
429 }
430 }
431
432 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
433 let schema = Arc::new(Schema::new(vec![
434 Field::new("a", DataType::Int64, false),
435 Field::new("b", DataType::Utf8, false),
436 Field::new("c", DataType::Date32, false),
437 ]));
438 let table_ref = RemoteTableRef::try_from(name).unwrap();
439 let table = Arc::new(RemoteTable::new(table_ref, schema));
440 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
441 let table_source = Arc::new(SQLTableSource { provider, table });
442 Arc::new(FederatedTableProviderAdaptor::new(table_source))
443 }
444
445 #[tokio::test]
446 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
447 let test_executor_a = TestExecutor {
448 compute_context: "a".into(),
449 };
450
451 let test_executor_b = TestExecutor {
452 compute_context: "b".into(),
453 };
454
455 let table_a1_ref = "table_a1".to_string();
456 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
457
458 let table_a2_ref = "table_a2".to_string();
459 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
460
461 let table_b1_ref = "table_b1(1)".to_string();
462 let table_b1_df_ref = "table_local_b1".to_string();
463
464 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
465
466 let state = crate::default_session_state();
468 let ctx = SessionContext::new_with_state(state);
469
470 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
471 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
472 ctx.register_table(table_b1_df_ref.clone(), table_b1)
473 .unwrap();
474
475 let query = r#"
476 SELECT * FROM table_a1
477 UNION ALL
478 SELECT * FROM table_a2
479 UNION ALL
480 SELECT * FROM table_local_b1;
481 "#;
482
483 let df = ctx.sql(query).await?;
484
485 let logical_plan = df.into_optimized_plan()?;
486
487 let mut table_a1_federated = false;
488 let mut table_a2_federated = false;
489 let mut table_b1_federated = false;
490
491 let _ = logical_plan.apply(|node| {
492 if let LogicalPlan::Extension(node) = node {
493 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
494 let _ = node.plan().apply(|node| {
495 if let LogicalPlan::TableScan(table) = node {
496 if table.table_name.table() == table_a1_ref {
497 table_a1_federated = true;
498 }
499 if table.table_name.table() == table_a2_ref {
500 table_a2_federated = true;
501 }
502 if table.table_name.table() == table_b1_df_ref {
504 table_b1_federated = true;
505 }
506 }
507 Ok(TreeNodeRecursion::Continue)
508 });
509 }
510 }
511 Ok(TreeNodeRecursion::Continue)
512 });
513
514 assert!(table_a1_federated);
515 assert!(table_a2_federated);
516 assert!(table_b1_federated);
517
518 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
519
520 let mut final_queries = vec![];
521
522 let _ = physical_plan.apply(|node| {
523 if node.name() == "sql_federation_exec" {
524 let node = node
525 .as_any()
526 .downcast_ref::<VirtualExecutionPlan>()
527 .unwrap();
528
529 final_queries.push(node.final_sql()?);
530 }
531 Ok(TreeNodeRecursion::Continue)
532 });
533
534 let expected = vec![
535 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
536 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
537 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
538 ];
539
540 assert_eq!(
541 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
542 HashSet::from_iter(expected)
543 );
544
545 Ok(())
546 }
547
548 #[tokio::test]
549 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
550 let test_executor_a = TestExecutor {
551 compute_context: "test".into(),
552 };
553
554 let lowercase_table_ref = "default.table".to_string();
555 let lowercase_local_table_ref = "dftable".to_string();
556 let lowercase_table =
557 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
558
559 let capitalized_table_ref = "default.Table(1)".to_string();
560 let capitalized_local_table_ref = "dfview".to_string();
561 let capitalized_table =
562 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
563
564 let state = crate::default_session_state();
566 let ctx = SessionContext::new_with_state(state);
567
568 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
569 .unwrap();
570 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
571 .unwrap();
572
573 let query = r#"
574 SELECT * FROM dftable
575 UNION ALL
576 SELECT * FROM dfview;
577 "#;
578
579 let df = ctx.sql(query).await?;
580
581 let logical_plan = df.into_optimized_plan()?;
582
583 let mut lowercase_table = false;
584 let mut capitalized_table = false;
585
586 let _ = logical_plan.apply(|node| {
587 if let LogicalPlan::Extension(node) = node {
588 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
589 let _ = node.plan().apply(|node| {
590 if let LogicalPlan::TableScan(table) = node {
591 if table.table_name.table() == lowercase_local_table_ref {
592 lowercase_table = true;
593 }
594 if table.table_name.table() == capitalized_local_table_ref {
595 capitalized_table = true;
596 }
597 }
598 Ok(TreeNodeRecursion::Continue)
599 });
600 }
601 }
602 Ok(TreeNodeRecursion::Continue)
603 });
604
605 assert!(lowercase_table);
606 assert!(capitalized_table);
607
608 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
609
610 let mut final_queries = vec![];
611
612 let _ = physical_plan.apply(|node| {
613 if node.name() == "sql_federation_exec" {
614 let node = node
615 .as_any()
616 .downcast_ref::<VirtualExecutionPlan>()
617 .unwrap();
618
619 final_queries.push(node.final_sql()?);
620 }
621 Ok(TreeNodeRecursion::Continue)
622 });
623
624 let expected = vec![
625 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
626 ];
627
628 assert_eq!(
629 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
630 HashSet::from_iter(expected)
631 );
632
633 Ok(())
634 }
635}