1mod analyzer;
2pub mod ast_analyzer;
3mod executor;
4mod schema;
5mod table;
6mod table_reference;
7
8use std::{any::Any, fmt, sync::Arc, vec};
9
10use analyzer::RewriteTableScanAnalyzer;
11use async_trait::async_trait;
12use datafusion::{
13 arrow::datatypes::{Schema, SchemaRef},
14 common::tree_node::{Transformed, TreeNode},
15 common::Statistics,
16 error::{DataFusionError, Result},
17 execution::{context::SessionState, TaskContext},
18 logical_expr::{Extension, LogicalPlan},
19 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
20 physical_expr::EquivalenceProperties,
21 physical_plan::{
22 execution_plan::{Boundedness, EmissionType},
23 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
24 SendableRecordBatchStream,
25 },
26 sql::{sqlparser::ast::Statement, unparser::Unparser},
27};
28
29pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
30pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
31pub use table::{RemoteTable, SQLTable, SQLTableSource};
32pub use table_reference::RemoteTableRef;
33
34use crate::{
35 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
36};
37
38#[derive(Debug)]
40pub struct SQLFederationProvider {
41 pub optimizer: Arc<Optimizer>,
42 pub executor: Arc<dyn SQLExecutor>,
43}
44
45impl SQLFederationProvider {
46 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
47 Self {
48 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
49 SQLFederationOptimizerRule::new(executor.clone()),
50 )])),
51 executor,
52 }
53 }
54}
55
56impl FederationProvider for SQLFederationProvider {
57 fn name(&self) -> &str {
58 "sql_federation_provider"
59 }
60
61 fn compute_context(&self) -> Option<String> {
62 self.executor.compute_context()
63 }
64
65 fn optimizer(&self) -> Option<Arc<Optimizer>> {
66 Some(self.optimizer.clone())
67 }
68}
69
70#[derive(Debug)]
71struct SQLFederationOptimizerRule {
72 planner: Arc<SQLFederationPlanner>,
73}
74
75impl SQLFederationOptimizerRule {
76 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
77 Self {
78 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
79 }
80 }
81}
82
83impl OptimizerRule for SQLFederationOptimizerRule {
84 fn rewrite(
90 &self,
91 plan: LogicalPlan,
92 _config: &dyn OptimizerConfig,
93 ) -> Result<Transformed<LogicalPlan>> {
94 if let LogicalPlan::Extension(Extension { ref node }) = plan {
95 if node.name() == "Federated" {
96 return Ok(Transformed::no(plan));
98 }
99 }
100
101 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
102 let ext_node = Extension {
103 node: Arc::new(fed_plan),
104 };
105
106 let mut plan = LogicalPlan::Extension(ext_node);
107 if let Some(mut rewriter) = self.planner.executor.logical_optimizer() {
108 plan = rewriter(plan)?;
109 }
110
111 Ok(Transformed::yes(plan))
112 }
113
114 fn name(&self) -> &str {
116 "federate_sql"
117 }
118
119 fn supports_rewrite(&self) -> bool {
121 true
122 }
123}
124
125#[derive(Debug)]
126pub struct SQLFederationPlanner {
127 pub executor: Arc<dyn SQLExecutor>,
128}
129
130impl SQLFederationPlanner {
131 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
132 Self { executor }
133 }
134}
135
136#[async_trait]
137impl FederationPlanner for SQLFederationPlanner {
138 async fn plan_federation(
139 &self,
140 node: &FederatedPlanNode,
141 _session_state: &SessionState,
142 ) -> Result<Arc<dyn ExecutionPlan>> {
143 let schema = Arc::new(node.plan().schema().as_arrow().clone());
144 let plan = node.plan().clone();
145 let statistics = self.executor.statistics(&plan).await?;
146 let input = Arc::new(VirtualExecutionPlan::new(
147 plan,
148 Arc::clone(&self.executor),
149 statistics,
150 ));
151 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
152 Ok(Arc::new(schema_cast_exec))
153 }
154}
155
156#[derive(Debug, Clone)]
157struct VirtualExecutionPlan {
158 plan: LogicalPlan,
159 executor: Arc<dyn SQLExecutor>,
160 props: PlanProperties,
161 statistics: Statistics,
162}
163
164impl VirtualExecutionPlan {
165 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>, statistics: Statistics) -> Self {
166 let schema: Schema = plan.schema().as_ref().into();
167 let props = PlanProperties::new(
168 EquivalenceProperties::new(Arc::new(schema)),
169 Partitioning::UnknownPartitioning(1),
170 EmissionType::Incremental,
171 Boundedness::Bounded,
172 );
173 Self {
174 plan,
175 executor,
176 props,
177 statistics,
178 }
179 }
180
181 fn schema(&self) -> SchemaRef {
182 let df_schema = self.plan.schema().as_ref();
183 Arc::new(Schema::from(df_schema))
184 }
185
186 fn final_sql(&self) -> Result<String> {
187 let plan = self.plan.clone();
188 let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
189 let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
190 let plan = apply_logical_optimizers(plan, logical_optimizers)?;
191 let ast = self.plan_to_statement(&plan)?;
192 let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
193 let ast = apply_ast_analyzers(ast, ast_analyzers)?;
194 Ok(ast.to_string())
195 }
196
197 fn rewrite_with_executor_ast_analyzer(
198 &self,
199 ast: Statement,
200 ) -> Result<Statement, datafusion::error::DataFusionError> {
201 if let Some(mut analyzer) = self.executor.ast_analyzer() {
202 Ok(analyzer(ast)?)
203 } else {
204 Ok(ast)
205 }
206 }
207
208 fn plan_to_statement(&self, plan: &LogicalPlan) -> Result<Statement> {
209 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
210 }
211}
212
213fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
214 let mut logical_optimizers = vec![];
215 let mut ast_analyzers = vec![];
216
217 plan.apply(|node| {
218 if let LogicalPlan::TableScan(table) = node {
219 let provider = get_table_source(&table.source)
220 .expect("caller is virtual exec so this is valid")
221 .expect("caller is virtual exec so this is valid");
222 if let Some(source) = provider.as_any().downcast_ref::<SQLTableSource>() {
223 if let Some(analyzer) = source.table.logical_optimizer() {
224 logical_optimizers.push(analyzer);
225 }
226 if let Some(analyzer) = source.table.ast_analyzer() {
227 ast_analyzers.push(analyzer);
228 }
229 }
230 }
231 Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
232 })?;
233
234 Ok((logical_optimizers, ast_analyzers))
235}
236
237fn apply_logical_optimizers(
238 mut plan: LogicalPlan,
239 analyzers: Vec<LogicalOptimizer>,
240) -> Result<LogicalPlan> {
241 for mut analyzer in analyzers {
242 let old_schema = plan.schema().clone();
243 plan = analyzer(plan)?;
244 let new_schema = plan.schema();
245 if &old_schema != new_schema {
246 return Err(DataFusionError::Execution(format!(
247 "Schema altered during logical analysis, expected: {}, found: {}",
248 old_schema, new_schema
249 )));
250 }
251 }
252 Ok(plan)
253}
254
255fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) -> Result<Statement> {
256 for mut analyzer in analyzers {
257 statement = analyzer(statement)?;
258 }
259 Ok(statement)
260}
261
262impl DisplayAs for VirtualExecutionPlan {
263 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
264 write!(f, "VirtualExecutionPlan")?;
265 write!(f, " name={}", self.executor.name())?;
266 if let Some(ctx) = self.executor.compute_context() {
267 write!(f, " compute_context={ctx}")?;
268 };
269 let mut plan = self.plan.clone();
270 if let Ok(statement) = self.plan_to_statement(&plan) {
271 write!(f, " initial_sql={statement}")?;
272 }
273
274 let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
275 Ok(analyzers) => analyzers,
276 Err(_) => return Ok(()),
277 };
278
279 let old_plan = plan.clone();
280
281 plan = match apply_logical_optimizers(plan, logical_optimizers) {
282 Ok(plan) => plan,
283 _ => return Ok(()),
284 };
285
286 let statement = match self.plan_to_statement(&plan) {
287 Ok(statement) => statement,
288 _ => return Ok(()),
289 };
290
291 if plan != old_plan {
292 write!(f, " rewritten_logical_sql={statement}")?;
293 }
294
295 let old_statement = statement.clone();
296 let statement = match self.rewrite_with_executor_ast_analyzer(statement) {
297 Ok(statement) => statement,
298 _ => return Ok(()),
299 };
300 if old_statement != statement {
301 write!(f, " rewritten_executor_sql={statement}")?;
302 }
303
304 let old_statement = statement.clone();
305 let statement = match apply_ast_analyzers(statement, ast_analyzers) {
306 Ok(statement) => statement,
307 _ => return Ok(()),
308 };
309 if old_statement != statement {
310 write!(f, " rewritten_ast_analyzer={statement}")?;
311 }
312
313 Ok(())
314 }
315}
316
317impl ExecutionPlan for VirtualExecutionPlan {
318 fn name(&self) -> &str {
319 "sql_federation_exec"
320 }
321
322 fn as_any(&self) -> &dyn Any {
323 self
324 }
325
326 fn schema(&self) -> SchemaRef {
327 self.schema()
328 }
329
330 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
331 vec![]
332 }
333
334 fn with_new_children(
335 self: Arc<Self>,
336 _: Vec<Arc<dyn ExecutionPlan>>,
337 ) -> Result<Arc<dyn ExecutionPlan>> {
338 Ok(self)
339 }
340
341 fn execute(
342 &self,
343 _partition: usize,
344 _context: Arc<TaskContext>,
345 ) -> Result<SendableRecordBatchStream> {
346 self.executor.execute(&self.final_sql()?, self.schema())
347 }
348
349 fn properties(&self) -> &PlanProperties {
350 &self.props
351 }
352
353 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
354 Ok(self.statistics.clone())
355 }
356}
357
358#[cfg(test)]
359mod tests {
360
361 use std::collections::HashSet;
362 use std::sync::Arc;
363
364 use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
365 use crate::FederatedTableProviderAdaptor;
366 use async_trait::async_trait;
367 use datafusion::arrow::datatypes::{Schema, SchemaRef};
368 use datafusion::common::tree_node::TreeNodeRecursion;
369 use datafusion::execution::SendableRecordBatchStream;
370 use datafusion::sql::unparser::dialect::Dialect;
371 use datafusion::sql::unparser::{self};
372 use datafusion::{
373 arrow::datatypes::{DataType, Field},
374 datasource::TableProvider,
375 execution::context::SessionContext,
376 };
377
378 use super::table::RemoteTable;
379 use super::*;
380
381 #[derive(Debug, Clone)]
382 struct TestExecutor {
383 compute_context: String,
384 }
385
386 #[async_trait]
387 impl SQLExecutor for TestExecutor {
388 fn name(&self) -> &str {
389 "TestExecutor"
390 }
391
392 fn compute_context(&self) -> Option<String> {
393 Some(self.compute_context.clone())
394 }
395
396 fn dialect(&self) -> Arc<dyn Dialect> {
397 Arc::new(unparser::dialect::DefaultDialect {})
398 }
399
400 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
401 unimplemented!()
402 }
403
404 async fn table_names(&self) -> Result<Vec<String>> {
405 unimplemented!()
406 }
407
408 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
409 unimplemented!()
410 }
411 }
412
413 fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc<dyn TableProvider> {
414 let schema = Arc::new(Schema::new(vec![
415 Field::new("a", DataType::Int64, false),
416 Field::new("b", DataType::Utf8, false),
417 Field::new("c", DataType::Date32, false),
418 ]));
419 let table_ref = RemoteTableRef::try_from(name).unwrap();
420 let table = Arc::new(RemoteTable::new(table_ref, schema));
421 let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
422 let table_source = Arc::new(SQLTableSource { provider, table });
423 Arc::new(FederatedTableProviderAdaptor::new(table_source))
424 }
425
426 #[tokio::test]
427 async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
428 let test_executor_a = TestExecutor {
429 compute_context: "a".into(),
430 };
431
432 let test_executor_b = TestExecutor {
433 compute_context: "b".into(),
434 };
435
436 let table_a1_ref = "table_a1".to_string();
437 let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
438
439 let table_a2_ref = "table_a2".to_string();
440 let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a);
441
442 let table_b1_ref = "table_b1(1)".to_string();
443 let table_b1_df_ref = "table_local_b1".to_string();
444
445 let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b);
446
447 let state = crate::default_session_state();
449 let ctx = SessionContext::new_with_state(state);
450
451 ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
452 ctx.register_table(table_a2_ref.clone(), table_a2).unwrap();
453 ctx.register_table(table_b1_df_ref.clone(), table_b1)
454 .unwrap();
455
456 let query = r#"
457 SELECT * FROM table_a1
458 UNION ALL
459 SELECT * FROM table_a2
460 UNION ALL
461 SELECT * FROM table_local_b1;
462 "#;
463
464 let df = ctx.sql(query).await?;
465
466 let logical_plan = df.into_optimized_plan()?;
467
468 let mut table_a1_federated = false;
469 let mut table_a2_federated = false;
470 let mut table_b1_federated = false;
471
472 let _ = logical_plan.apply(|node| {
473 if let LogicalPlan::Extension(node) = node {
474 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
475 let _ = node.plan().apply(|node| {
476 if let LogicalPlan::TableScan(table) = node {
477 if table.table_name.table() == table_a1_ref {
478 table_a1_federated = true;
479 }
480 if table.table_name.table() == table_a2_ref {
481 table_a2_federated = true;
482 }
483 if table.table_name.table() == table_b1_df_ref {
485 table_b1_federated = true;
486 }
487 }
488 Ok(TreeNodeRecursion::Continue)
489 });
490 }
491 }
492 Ok(TreeNodeRecursion::Continue)
493 });
494
495 assert!(table_a1_federated);
496 assert!(table_a2_federated);
497 assert!(table_b1_federated);
498
499 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
500
501 let mut final_queries = vec![];
502
503 let _ = physical_plan.apply(|node| {
504 if node.name() == "sql_federation_exec" {
505 let node = node
506 .as_any()
507 .downcast_ref::<VirtualExecutionPlan>()
508 .unwrap();
509
510 final_queries.push(node.final_sql()?);
511 }
512 Ok(TreeNodeRecursion::Continue)
513 });
514
515 let expected = vec![
516 "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1",
517 "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2",
518 "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1",
519 ];
520
521 assert_eq!(
522 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
523 HashSet::from_iter(expected)
524 );
525
526 Ok(())
527 }
528
529 #[tokio::test]
530 async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
531 let test_executor_a = TestExecutor {
532 compute_context: "test".into(),
533 };
534
535 let lowercase_table_ref = "default.table".to_string();
536 let lowercase_local_table_ref = "dftable".to_string();
537 let lowercase_table =
538 get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone());
539
540 let capitalized_table_ref = "default.Table(1)".to_string();
541 let capitalized_local_table_ref = "dfview".to_string();
542 let capitalized_table =
543 get_test_table_provider(capitalized_table_ref.clone(), test_executor_a);
544
545 let state = crate::default_session_state();
547 let ctx = SessionContext::new_with_state(state);
548
549 ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table)
550 .unwrap();
551 ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table)
552 .unwrap();
553
554 let query = r#"
555 SELECT * FROM dftable
556 UNION ALL
557 SELECT * FROM dfview;
558 "#;
559
560 let df = ctx.sql(query).await?;
561
562 let logical_plan = df.into_optimized_plan()?;
563
564 let mut lowercase_table = false;
565 let mut capitalized_table = false;
566
567 let _ = logical_plan.apply(|node| {
568 if let LogicalPlan::Extension(node) = node {
569 if let Some(node) = node.node.as_any().downcast_ref::<FederatedPlanNode>() {
570 let _ = node.plan().apply(|node| {
571 if let LogicalPlan::TableScan(table) = node {
572 if table.table_name.table() == lowercase_local_table_ref {
573 lowercase_table = true;
574 }
575 if table.table_name.table() == capitalized_local_table_ref {
576 capitalized_table = true;
577 }
578 }
579 Ok(TreeNodeRecursion::Continue)
580 });
581 }
582 }
583 Ok(TreeNodeRecursion::Continue)
584 });
585
586 assert!(lowercase_table);
587 assert!(capitalized_table);
588
589 let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
590
591 let mut final_queries = vec![];
592
593 let _ = physical_plan.apply(|node| {
594 if node.name() == "sql_federation_exec" {
595 let node = node
596 .as_any()
597 .downcast_ref::<VirtualExecutionPlan>()
598 .unwrap();
599
600 final_queries.push(node.final_sql()?);
601 }
602 Ok(TreeNodeRecursion::Continue)
603 });
604
605 let expected = vec![
606 r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1) AS Table"#,
607 ];
608
609 assert_eq!(
610 HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())),
611 HashSet::from_iter(expected)
612 );
613
614 Ok(())
615 }
616}