1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::tree_node::{Transformed, TreeNode},
15 error::{DataFusionError, Result},
16 execution::{context::SessionState, TaskContext},
17 logical_expr::{Extension, LogicalPlan},
18 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
19 physical_expr::EquivalenceProperties,
20 physical_plan::{
21 execution_plan::{Boundedness, EmissionType},
22 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
23 SendableRecordBatchStream,
24 },
25 sql::{sqlparser::ast::Statement, unparser::Unparser},
26};
27
28pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
29pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
30pub use table::{RemoteTable, SQLTableSource};
31pub use table_reference::RemoteTableRef;
32
33use crate::{
34 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
35};
36
37#[derive(Debug)]
39pub struct SQLFederationProvider {
40 optimizer: Arc<Optimizer>,
41 executor: Arc<dyn SQLExecutor>,
42}
43
44impl SQLFederationProvider {
45 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
46 Self {
47 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
48 SQLFederationOptimizerRule::new(executor.clone()),
49 )])),
50 executor,
51 }
52 }
53}
54
55impl FederationProvider for SQLFederationProvider {
56 fn name(&self) -> &str {
57 "sql_federation_provider"
58 }
59
60 fn compute_context(&self) -> Option<String> {
61 self.executor.compute_context()
62 }
63
64 fn optimizer(&self) -> Option<Arc<Optimizer>> {
65 Some(self.optimizer.clone())
66 }
67}
68
69#[derive(Debug)]
70struct SQLFederationOptimizerRule {
71 planner: Arc<SQLFederationPlanner>,
72}
73
74impl SQLFederationOptimizerRule {
75 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
76 Self {
77 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
78 }
79 }
80}
81
82impl OptimizerRule for SQLFederationOptimizerRule {
83 fn rewrite(
89 &self,
90 plan: LogicalPlan,
91 _config: &dyn OptimizerConfig,
92 ) -> Result<Transformed<LogicalPlan>> {
93 if let LogicalPlan::Extension(Extension { ref node }) = plan {
94 if node.name() == "Federated" {
95 return Ok(Transformed::no(plan));
97 }
98 }
99
100 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
101 let ext_node = Extension {
102 node: Arc::new(fed_plan),
103 };
104
105 let mut plan = LogicalPlan::Extension(ext_node);
106 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
107 plan = rewriter(plan)?;
108 }
109
110 Ok(Transformed::yes(plan))
111 }
112
113 fn name(&self) -> &str {
115 "federate_sql"
116 }
117
118 fn supports_rewrite(&self) -> bool {
120 true
121 }
122}
123
124#[derive(Debug)]
125struct SQLFederationPlanner {
126 executor: Arc<dyn SQLExecutor>,
127}
128
129impl SQLFederationPlanner {
130 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
131 Self { executor }
132 }
133}
134
135#[async_trait]
136impl FederationPlanner for SQLFederationPlanner {
137 async fn plan_federation(
138 &self,
139 node: &FederatedPlanNode,
140 _session_state: &SessionState,
141 ) -> Result<Arc<dyn ExecutionPlan>> {
142 let schema = Arc::new(node.plan().schema().as_arrow().clone());
143 let input = Arc::new(VirtualExecutionPlan::new(
144 node.plan().clone(),
145 Arc::clone(&self.executor),
146 ));
147 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
148 Ok(Arc::new(schema_cast_exec))
149 }
150}
151
152#[derive(Debug, Clone)]
153struct VirtualExecutionPlan {
154 plan: LogicalPlan,
155 executor: Arc<dyn SQLExecutor>,
156 props: PlanProperties,
157}
158
159impl VirtualExecutionPlan {
160 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
161 let schema: Schema = plan.schema().as_ref().into();
162 let props = PlanProperties::new(
163 EquivalenceProperties::new(Arc::new(schema)),
164 Partitioning::UnknownPartitioning(1),
165 EmissionType::Incremental,
166 Boundedness::Bounded,
167 );
168 Self {
169 plan,
170 executor,
171 props,
172 }
173 }
174
175 fn schema(&self) -> SchemaRef {
176 let df_schema = self.plan.schema().as_ref();
177 Arc::new(Schema::from(df_schema))
178 }
179
180 fn final_sql(&self) -> Result<String> {
181 let plan = self.plan.clone();
182 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
183 let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
184 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
185 let ast = self.plan_to_statement(&plan)?;
186 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
187 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
188 Ok(ast.to_string())
189 }
190
191 fn rewrite_with_executor_ast_analyzer(
192 &self,
193 ast: Statement,
194 ) -> Result<Statement, datafusion::error::DataFusionError> {
195 if let Some(mut analyzer) = self.executor.ast_analyzer() {
196 Ok(analyzer(ast)?)
197 } else {
198 Ok(ast)
199 }
200 }
201
202 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
203 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
204 }
205}
206
207fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
208 let mut logical_optimizers = vec![];
209 let mut ast_analyzers = vec![];
210
211 plan.apply(|node| {
212 if let LogicalPlan::TableScan(table) = node {
213 let provider = get_table_source(&table.source)
214 .expect("caller is virtual exec so this is valid")
215 .expect("caller is virtual exec so this is valid");
216 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
217 if let Some(analyzer) = source.table.logical_optimizer() {
218 logical_optimizers.push(analyzer);
219 }
220 if let Some(analyzer) = source.table.ast_analyzer() {
221 ast_analyzers.push(analyzer);
222 }
223 }
224 }
225 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
226 })?;
227
228 Ok((logical_optimizers, ast_analyzers))
229}
230
231fn apply_logical_optimizers(
232 mut plan: LogicalPlan,
233 analyzers: Vec<LogicalOptimizer>,
234) -> Result<LogicalPlan> {
235 for mut analyzer in analyzers {
236 let old_schema = plan.schema().clone();
237 plan = analyzer(plan)?;
238 let new_schema = plan.schema();
239 if &old_schema != new_schema {
240 return Err(DataFusionError::Execution(format!(
241 "Schema altered during logical analysis, expected: {}, found: {}",
242 old_schema, new_schema
243 )));
244 }
245 }
246 Ok(plan)
247}
248
249fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
250 for mut analyzer in analyzers {
251 statement = analyzer(statement)?;
252 }
253 Ok(statement)
254}
255
256impl DisplayAs for VirtualExecutionPlan {
257 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
258 write!(f, "VirtualExecutionPlan")?;
259 write!(f, " name={}", self.executor.name())?;
260 if let Some(ctx) = self.executor.compute_context() {
261 write!(f, " compute_context={ctx}")?;
262 };
263 let mut plan = self.plan.clone();
264 if let Ok(statement) = self.plan_to_statement(&plan) {
265 write!(f, " initial_sql={statement}")?;
266 }
267
268 let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
269 Ok(analyzers) => analyzers,
270 Err(_) => return Ok(()),
271 };
272
273 let old_plan = plan.clone();
274
275 plan = match apply_logical_optimizers(plan, logical_optimizers) {
276 Ok(plan) => plan,
277 _ => return Ok(()),
278 };
279
280 let statement = match self.plan_to_statement(&plan) {
281 Ok(statement) => statement,
282 _ => return Ok(()),
283 };
284
285 if plan != old_plan {
286 write!(f, " rewritten_logical_sql={statement}")?;
287 }
288
289 let old_statement = statement.clone();
290 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
291 Ok(statement) => statement,
292 _ => return Ok(()),
293 };
294 if old_statement != statement {
295 write!(f, " rewritten_executor_sql={statement}")?;
296 }
297
298 let old_statement = statement.clone();
299 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
300 Ok(statement) => statement,
301 _ => return Ok(()),
302 };
303 if old_statement != statement {
304 write!(f, " rewritten_ast_analyzer={statement}")?;
305 }
306
307 Ok(())
308 }
309}
310
311impl ExecutionPlan for VirtualExecutionPlan {
312 fn name(&self) -> &str {
313 "sql_federation_exec"
314 }
315
316 fn as_any(&self) -> &dyn Any {
317 self
318 }
319
320 fn schema(&self) -> SchemaRef {
321 self.schema()
322 }
323
324 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
325 vec![]
326 }
327
328 fn with_new_children(
329 self: Arc<Self>,
330 _: Vec<Arc<dyn ExecutionPlan>>,
331 ) -> Result<Arc<dyn ExecutionPlan>> {
332 Ok(self)
333 }
334
335 fn execute(
336 &self,
337 _partition: usize,
338 _context: Arc<TaskContext>,
339 ) -> Result<SendableRecordBatchStream> {
340 self.executor.execute(&self.final_sql()?, self.schema())
341 }
342
343 fn properties(&self) -> &PlanProperties {
344 &self.props
345 }
346}
347
348#[cfg(test)]
349mod tests {
350
351 use std::collections::HashSet;
352 use std::sync::Arc;
353
354 use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
355 use crate::FederatedTableProviderAdaptor;
356 use async_trait::async_trait;
357 use datafusion::arrow::datatypes::{Schema, SchemaRef};
358 use datafusion::common::tree_node::TreeNodeRecursion;
359 use datafusion::execution::SendableRecordBatchStream;
360 use datafusion::sql::unparser::dialect::Dialect;
361 use datafusion::sql::unparser::{self};
362 use datafusion::{
363 arrow::datatypes::{DataType, Field},
364 datasource::TableProvider,
365 execution::context::SessionContext,
366 };
367
368 use super::table::RemoteTable;
369 use super::*;
370
371 #[derive(Debug, Clone)]
372 struct TestExecutor {
373 compute_context: String,
374 }
375
376 #[async_trait]
377 impl SQLExecutor for TestExecutor {
378 fn name(&self) -> &str {
379 "TestExecutor"
380 }
381
382 fn compute_context(&self) -> Option<String> {
383 Some(self.compute_context.clone())
384 }
385
386 fn dialect(&self) -> Arc<dyn Dialect> {
387 Arc::new(unparser::dialect::DefaultDialect {})
388 }
389
390 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
391 unimplemented!()
392 }
393
394 async fn table_names(&self) -> Result<Vec<String>> {
395 unimplemented!()
396 }
397
398 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
399 unimplemented!()
400 }
401 }
402
403 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
404 let schema = Arc::new(Schema::new(vec![
405 Field::new("a", DataType::Int64, false),
406 Field::new("b", DataType::Utf8, false),
407 Field::new("c", DataType::Date32, false),
408 ]));
409 let table_ref = RemoteTableRef::try_from(name).unwrap();
410 let table = Arc::new(RemoteTable::new(table_ref, schema));
411 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
412 let table_source = Arc::new(SQLTableSource { provider, table });
413 Arc::new(FederatedTableProviderAdaptor::new(table_source))
414 }
415
416 #[tokio::test]
417 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
418 let test_executor_a = TestExecutor {
419 compute_context: "a".into(),
420 };
421
422 let test_executor_b = TestExecutor {
423 compute_context: "b".into(),
424 };
425
426 let table_a1_ref = "table_a1".to_string();
427 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
428
429 let table_a2_ref = "table_a2".to_string();
430 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
431
432 let table_b1_ref = "table_b1(1)".to_string();
433 let table_b1_df_ref = "table_local_b1".to_string();
434
435 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
436
437 let state = crate::default_session_state();
439 let ctx = SessionContext::new_with_state(state);
440
441 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
442 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
443 ctx.register_table(table_b1_df_ref.clone(), table_b1)
444 .unwrap();
445
446 let query = r#"
447 SELECT * FROM table_a1
448 UNION ALL
449 SELECT * FROM table_a2
450 UNION ALL
451 SELECT * FROM table_local_b1;
452 "#;
453
454 let df = ctx.sql(query).await?;
455
456 let logical_plan = df.into_optimized_plan()?;
457
458 let mut table_a1_federated = false;
459 let mut table_a2_federated = false;
460 let mut table_b1_federated = false;
461
462 let _ = logical_plan.apply(|node| {
463 if let LogicalPlan::Extension(node) = node {
464 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
465 let _ = node.plan().apply(|node| {
466 if let LogicalPlan::TableScan(table) = node {
467 if table.table_name.table() == table_a1_ref {
468 table_a1_federated = true;
469 }
470 if table.table_name.table() == table_a2_ref {
471 table_a2_federated = true;
472 }
473 if table.table_name.table() == table_b1_df_ref {
475 table_b1_federated = true;
476 }
477 }
478 Ok(TreeNodeRecursion::Continue)
479 });
480 }
481 }
482 Ok(TreeNodeRecursion::Continue)
483 });
484
485 assert!(table_a1_federated);
486 assert!(table_a2_federated);
487 assert!(table_b1_federated);
488
489 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
490
491 let mut final_queries = vec![];
492
493 let _ = physical_plan.apply(|node| {
494 if node.name() == "sql_federation_exec" {
495 let node = node
496 .as_any()
497 .downcast_ref::<VirtualExecutionPlan>()
498 .unwrap();
499
500 final_queries.push(node.final_sql()?);
501 }
502 Ok(TreeNodeRecursion::Continue)
503 });
504
505 let expected = vec![
506 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
507 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
508 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
509 ];
510
511 assert_eq!(
512 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
513 HashSet::from_iter(expected)
514 );
515
516 Ok(())
517 }
518
519 #[tokio::test]
520 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
521 let test_executor_a = TestExecutor {
522 compute_context: "test".into(),
523 };
524
525 let lowercase_table_ref = "default.table".to_string();
526 let lowercase_local_table_ref = "dftable".to_string();
527 let lowercase_table =
528 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
529
530 let capitalized_table_ref = "default.Table(1)".to_string();
531 let capitalized_local_table_ref = "dfview".to_string();
532 let capitalized_table =
533 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
534
535 let state = crate::default_session_state();
537 let ctx = SessionContext::new_with_state(state);
538
539 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
540 .unwrap();
541 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
542 .unwrap();
543
544 let query = r#"
545 SELECT * FROM dftable
546 UNION ALL
547 SELECT * FROM dfview;
548 "#;
549
550 let df = ctx.sql(query).await?;
551
552 let logical_plan = df.into_optimized_plan()?;
553
554 let mut lowercase_table = false;
555 let mut capitalized_table = false;
556
557 let _ = logical_plan.apply(|node| {
558 if let LogicalPlan::Extension(node) = node {
559 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
560 let _ = node.plan().apply(|node| {
561 if let LogicalPlan::TableScan(table) = node {
562 if table.table_name.table() == lowercase_local_table_ref {
563 lowercase_table = true;
564 }
565 if table.table_name.table() == capitalized_local_table_ref {
566 capitalized_table = true;
567 }
568 }
569 Ok(TreeNodeRecursion::Continue)
570 });
571 }
572 }
573 Ok(TreeNodeRecursion::Continue)
574 });
575
576 assert!(lowercase_table);
577 assert!(capitalized_table);
578
579 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
580
581 let mut final_queries = vec![];
582
583 let _ = physical_plan.apply(|node| {
584 if node.name() == "sql_federation_exec" {
585 let node = node
586 .as_any()
587 .downcast_ref::<VirtualExecutionPlan>()
588 .unwrap();
589
590 final_queries.push(node.final_sql()?);
591 }
592 Ok(TreeNodeRecursion::Continue)
593 });
594
595 let expected = vec![
596 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
597 ];
598
599 assert_eq!(
600 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
601 HashSet::from_iter(expected)
602 );
603
604 Ok(())
605 }
606}