1mod executor;
2mod schema;
3
4use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec};
5
6use async_trait::async_trait;
7use datafusion::{
8 arrow::datatypes::{Schema, SchemaRef},
9 common::{tree_node::Transformed, Column},
10 error::Result,
11 execution::{context::SessionState, TaskContext},
12 logical_expr::{
13 expr::{
14 AggregateFunction, Alias, Exists, InList, InSubquery, PlannedReplaceSelectItem,
15 ScalarFunction, Sort, Unnest, WildcardOptions, WindowFunction,
16 },
17 Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan,
18 Subquery, TryCast,
19 },
20 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
21 physical_expr::EquivalenceProperties,
22 physical_plan::{
23 execution_plan::{Boundedness, EmissionType},
24 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
25 SendableRecordBatchStream,
26 },
27 sql::{
28 sqlparser::ast::Statement,
29 unparser::{plan_to_sql, Unparser},
30 TableReference,
31 },
32};
33
34pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef};
35pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource};
36
37use crate::{
38 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
39};
40
41#[derive(Debug)]
46pub struct SQLFederationProvider {
47 optimizer: Arc<Optimizer>,
48 executor: Arc<dyn SQLExecutor>,
49}
50
51impl SQLFederationProvider {
52 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
53 Self {
54 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
55 SQLFederationOptimizerRule::new(executor.clone()),
56 )])),
57 executor,
58 }
59 }
60}
61
62impl FederationProvider for SQLFederationProvider {
63 fn name(&self) -> &str {
64 "sql_federation_provider"
65 }
66
67 fn compute_context(&self) -> Option<String> {
68 self.executor.compute_context()
69 }
70
71 fn optimizer(&self) -> Option<Arc<Optimizer>> {
72 Some(self.optimizer.clone())
73 }
74}
75
76#[derive(Debug)]
77struct SQLFederationOptimizerRule {
78 planner: Arc<dyn FederationPlanner>,
79}
80
81impl SQLFederationOptimizerRule {
82 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
83 Self {
84 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
85 }
86 }
87}
88
89impl OptimizerRule for SQLFederationOptimizerRule {
90 fn rewrite(
96 &self,
97 plan: LogicalPlan,
98 _config: &dyn OptimizerConfig,
99 ) -> Result<Transformed<LogicalPlan>> {
100 if let LogicalPlan::Extension(Extension { ref node }) = plan {
101 if node.name() == "Federated" {
102 return Ok(Transformed::no(plan));
104 }
105 }
106 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
108 let ext_node = Extension {
109 node: Arc::new(fed_plan),
110 };
111 Ok(Transformed::yes(LogicalPlan::Extension(ext_node)))
112 }
113
114 fn name(&self) -> &str {
116 "federate_sql"
117 }
118
119 fn supports_rewrite(&self) -> bool {
121 true
122 }
123}
124
125fn rewrite_table_scans(
127 plan: &LogicalPlan,
128 known_rewrites: &mut HashMap<TableReference, TableReference>,
129) -> Result<LogicalPlan> {
130 if plan.inputs().is_empty() {
131 if let LogicalPlan::TableScan(table_scan) = plan {
132 let original_table_name = table_scan.table_name.clone();
133 let mut new_table_scan = table_scan.clone();
134
135 let Some(federated_source) = get_table_source(&table_scan.source)? else {
136 return Ok(plan.clone());
138 };
139
140 match federated_source.as_any().downcast_ref::<SQLTableSource>() {
141 Some(sql_table_source) => {
142 let remote_table_name = TableReference::from(sql_table_source.table_name());
143 known_rewrites.insert(original_table_name, remote_table_name.clone());
144
145 let new_schema = (*new_table_scan.projected_schema)
147 .clone()
148 .replace_qualifier(remote_table_name.clone());
149 new_table_scan.projected_schema = Arc::new(new_schema);
150 new_table_scan.table_name = remote_table_name;
151 }
152 None => {
153 return Ok(plan.clone());
155 }
156 }
157
158 return Ok(LogicalPlan::TableScan(new_table_scan));
159 } else {
160 return Ok(plan.clone());
161 }
162 }
163
164 let rewritten_inputs = plan
165 .inputs()
166 .into_iter()
167 .map(|i| rewrite_table_scans(i, known_rewrites))
168 .collect::<Result<Vec<_>>>()?;
169
170 if let LogicalPlan::Limit(limit) = plan {
171 let rewritten_skip = limit
172 .skip
173 .as_ref()
174 .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new))
175 .transpose()?;
176
177 let rewritten_fetch = limit
178 .fetch
179 .as_ref()
180 .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new))
181 .transpose()?;
182
183 let new_plan = LogicalPlan::Limit(Limit {
185 skip: rewritten_skip,
186 fetch: rewritten_fetch,
187 input: Arc::new(rewritten_inputs[0].clone()),
188 });
189
190 return Ok(new_plan);
191 }
192
193 let mut new_expressions = vec![];
194 for expression in plan.expressions() {
195 let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?;
196 new_expressions.push(new_expr);
197 }
198
199 let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?;
200
201 Ok(new_plan)
202}
203
204fn rewrite_column_name_in_expr(
208 col_name: &str,
209 table_ref_str: &str,
210 rewrite: &str,
211 start_pos: usize,
212) -> Option<String> {
213 if start_pos >= col_name.len() {
214 return None;
215 }
216
217 let idx = col_name[start_pos..].find(table_ref_str)?;
219
220 let idx = start_pos + idx;
222
223 if idx > 0 {
224 if let Some(prev_char) = col_name.chars().nth(idx - 1) {
227 if prev_char.is_alphabetic()
228 || prev_char.is_numeric()
229 || prev_char == '_'
230 || prev_char == '.'
231 {
232 return rewrite_column_name_in_expr(
233 col_name,
234 table_ref_str,
235 rewrite,
236 idx + table_ref_str.len(),
237 );
238 }
239 }
240 }
241
242 if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) {
245 if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' {
246 return rewrite_column_name_in_expr(
247 col_name,
248 table_ref_str,
249 rewrite,
250 idx + table_ref_str.len(),
251 );
252 }
253 }
254
255 let rewritten_name = format!(
257 "{}{}{}",
258 &col_name[..idx],
259 rewrite,
260 &col_name[idx + table_ref_str.len()..]
261 );
262 match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len())
265 {
266 Some(new_name) => Some(new_name), None => Some(rewritten_name), }
269}
270
271fn rewrite_table_scans_in_expr(
272 expr: Expr,
273 known_rewrites: &mut HashMap<TableReference, TableReference>,
274) -> Result<Expr> {
275 match expr {
276 Expr::ScalarSubquery(subquery) => {
277 let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?;
278 let outer_ref_columns = subquery
279 .outer_ref_columns
280 .into_iter()
281 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
282 .collect::<Result<Vec<Expr>>>()?;
283 Ok(Expr::ScalarSubquery(Subquery {
284 subquery: Arc::new(new_subquery),
285 outer_ref_columns,
286 }))
287 }
288 Expr::BinaryExpr(binary_expr) => {
289 let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?;
290 let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?;
291 Ok(Expr::BinaryExpr(BinaryExpr::new(
292 Box::new(left),
293 binary_expr.op,
294 Box::new(right),
295 )))
296 }
297 Expr::Column(mut col) => {
298 if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
299 Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name)))
300 } else {
301 if col.relation.is_some() {
304 return Ok(Expr::Column(col));
305 }
306
307 let (new_name, was_rewritten) = known_rewrites.iter().fold(
310 (col.name.to_string(), false),
311 |(col_name, was_rewritten), (table_ref, rewrite)| {
312 match rewrite_column_name_in_expr(
313 &col_name,
314 &table_ref.to_string(),
315 &rewrite.to_string(),
316 0,
317 ) {
318 Some(new_name) => (new_name, true),
319 None => (col_name, was_rewritten),
320 }
321 },
322 );
323 if was_rewritten {
324 Ok(Expr::Column(Column::new(col.relation.take(), new_name)))
325 } else {
326 Ok(Expr::Column(col))
327 }
328 }
329 }
330 Expr::Alias(alias) => {
331 let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?;
332 if let Some(relation) = &alias.relation {
333 if let Some(rewrite) = known_rewrites.get(relation) {
334 return Ok(Expr::Alias(Alias::new(
335 expr,
336 Some(rewrite.clone()),
337 alias.name,
338 )));
339 }
340 }
341 Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name)))
342 }
343 Expr::Like(like) => {
344 let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?;
345 let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?;
346 Ok(Expr::Like(Like::new(
347 like.negated,
348 Box::new(expr),
349 Box::new(pattern),
350 like.escape_char,
351 like.case_insensitive,
352 )))
353 }
354 Expr::SimilarTo(similar_to) => {
355 let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?;
356 let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?;
357 Ok(Expr::SimilarTo(Like::new(
358 similar_to.negated,
359 Box::new(expr),
360 Box::new(pattern),
361 similar_to.escape_char,
362 similar_to.case_insensitive,
363 )))
364 }
365 Expr::Not(e) => {
366 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
367 Ok(Expr::Not(Box::new(expr)))
368 }
369 Expr::IsNotNull(e) => {
370 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
371 Ok(Expr::IsNotNull(Box::new(expr)))
372 }
373 Expr::IsNull(e) => {
374 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
375 Ok(Expr::IsNull(Box::new(expr)))
376 }
377 Expr::IsTrue(e) => {
378 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
379 Ok(Expr::IsTrue(Box::new(expr)))
380 }
381 Expr::IsFalse(e) => {
382 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
383 Ok(Expr::IsFalse(Box::new(expr)))
384 }
385 Expr::IsUnknown(e) => {
386 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
387 Ok(Expr::IsUnknown(Box::new(expr)))
388 }
389 Expr::IsNotTrue(e) => {
390 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
391 Ok(Expr::IsNotTrue(Box::new(expr)))
392 }
393 Expr::IsNotFalse(e) => {
394 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
395 Ok(Expr::IsNotFalse(Box::new(expr)))
396 }
397 Expr::IsNotUnknown(e) => {
398 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
399 Ok(Expr::IsNotUnknown(Box::new(expr)))
400 }
401 Expr::Negative(e) => {
402 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
403 Ok(Expr::Negative(Box::new(expr)))
404 }
405 Expr::Between(between) => {
406 let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?;
407 let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?;
408 let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?;
409 Ok(Expr::Between(Between::new(
410 Box::new(expr),
411 between.negated,
412 Box::new(low),
413 Box::new(high),
414 )))
415 }
416 Expr::Case(case) => {
417 let expr = case
418 .expr
419 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
420 .transpose()?
421 .map(Box::new);
422 let else_expr = case
423 .else_expr
424 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
425 .transpose()?
426 .map(Box::new);
427 let when_expr = case
428 .when_then_expr
429 .into_iter()
430 .map(|(when, then)| {
431 let when = rewrite_table_scans_in_expr(*when, known_rewrites);
432 let then = rewrite_table_scans_in_expr(*then, known_rewrites);
433
434 match (when, then) {
435 (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))),
436 (Err(e), _) | (_, Err(e)) => Err(e),
437 }
438 })
439 .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>>>()?;
440 Ok(Expr::Case(Case::new(expr, when_expr, else_expr)))
441 }
442 Expr::Cast(cast) => {
443 let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?;
444 Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type)))
445 }
446 Expr::TryCast(try_cast) => {
447 let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?;
448 Ok(Expr::TryCast(TryCast::new(
449 Box::new(expr),
450 try_cast.data_type,
451 )))
452 }
453 Expr::ScalarFunction(sf) => {
454 let args = sf
455 .args
456 .into_iter()
457 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
458 .collect::<Result<Vec<Expr>>>()?;
459 Ok(Expr::ScalarFunction(ScalarFunction {
460 func: sf.func,
461 args,
462 }))
463 }
464 Expr::AggregateFunction(af) => {
465 let args = af
466 .args
467 .into_iter()
468 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
469 .collect::<Result<Vec<Expr>>>()?;
470 let filter = af
471 .filter
472 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
473 .transpose()?
474 .map(Box::new);
475 let order_by = af
476 .order_by
477 .map(|e| {
478 e.into_iter()
479 .map(|sort| {
480 Ok(Sort {
481 expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
482 ..sort
483 })
484 })
485 .collect::<Result<Vec<_>>>()
486 })
487 .transpose()?;
488 Ok(Expr::AggregateFunction(AggregateFunction {
489 func: af.func,
490 args,
491 distinct: af.distinct,
492 filter,
493 order_by,
494 null_treatment: af.null_treatment,
495 }))
496 }
497 Expr::WindowFunction(wf) => {
498 let args = wf
499 .args
500 .into_iter()
501 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
502 .collect::<Result<Vec<Expr>>>()?;
503 let partition_by = wf
504 .partition_by
505 .into_iter()
506 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
507 .collect::<Result<Vec<Expr>>>()?;
508 let order_by = wf
509 .order_by
510 .into_iter()
511 .map(|sort| {
512 Ok(Sort {
513 expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
514 ..sort
515 })
516 })
517 .collect::<Result<Vec<_>>>()?;
518 Ok(Expr::WindowFunction(WindowFunction {
519 fun: wf.fun,
520 args,
521 partition_by,
522 order_by,
523 window_frame: wf.window_frame,
524 null_treatment: wf.null_treatment,
525 }))
526 }
527 Expr::InList(il) => {
528 let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?;
529 let list = il
530 .list
531 .into_iter()
532 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
533 .collect::<Result<Vec<Expr>>>()?;
534 Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated)))
535 }
536 Expr::Exists(exists) => {
537 let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?;
538 let outer_ref_columns = exists
539 .subquery
540 .outer_ref_columns
541 .into_iter()
542 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
543 .collect::<Result<Vec<Expr>>>()?;
544 let subquery = Subquery {
545 subquery: Arc::new(subquery_plan),
546 outer_ref_columns,
547 };
548 Ok(Expr::Exists(Exists::new(subquery, exists.negated)))
549 }
550 Expr::InSubquery(is) => {
551 let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?;
552 let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?;
553 let outer_ref_columns = is
554 .subquery
555 .outer_ref_columns
556 .into_iter()
557 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
558 .collect::<Result<Vec<Expr>>>()?;
559 let subquery = Subquery {
560 subquery: Arc::new(subquery_plan),
561 outer_ref_columns,
562 };
563 Ok(Expr::InSubquery(InSubquery::new(
564 Box::new(expr),
565 subquery,
566 is.negated,
567 )))
568 }
569 Expr::Wildcard { qualifier, options } => {
570 let options = WildcardOptions {
571 replace: options
572 .replace
573 .map(|replace| -> Result<PlannedReplaceSelectItem> {
574 Ok(PlannedReplaceSelectItem {
575 planned_expressions: replace
576 .planned_expressions
577 .into_iter()
578 .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites))
579 .collect::<Result<Vec<_>>>()?,
580 ..replace
581 })
582 })
583 .transpose()?,
584 ..*options
585 };
586 if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) {
587 Ok(Expr::Wildcard {
588 qualifier: Some(rewrite.clone()),
589 options: Box::new(options),
590 })
591 } else {
592 Ok(Expr::Wildcard {
593 qualifier,
594 options: Box::new(options),
595 })
596 }
597 }
598 Expr::GroupingSet(gs) => match gs {
599 GroupingSet::Rollup(exprs) => {
600 let exprs = exprs
601 .into_iter()
602 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
603 .collect::<Result<Vec<Expr>>>()?;
604 Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs)))
605 }
606 GroupingSet::Cube(exprs) => {
607 let exprs = exprs
608 .into_iter()
609 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
610 .collect::<Result<Vec<Expr>>>()?;
611 Ok(Expr::GroupingSet(GroupingSet::Cube(exprs)))
612 }
613 GroupingSet::GroupingSets(vec_exprs) => {
614 let vec_exprs = vec_exprs
615 .into_iter()
616 .map(|exprs| {
617 exprs
618 .into_iter()
619 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
620 .collect::<Result<Vec<Expr>>>()
621 })
622 .collect::<Result<Vec<Vec<Expr>>>>()?;
623 Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs)))
624 }
625 },
626 Expr::OuterReferenceColumn(dt, col) => {
627 if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
628 Ok(Expr::OuterReferenceColumn(
629 dt,
630 Column::new(Some(rewrite.clone()), &col.name),
631 ))
632 } else {
633 Ok(Expr::OuterReferenceColumn(dt, col))
634 }
635 }
636 Expr::Unnest(unnest) => {
637 let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?;
638 Ok(Expr::Unnest(Unnest::new(expr)))
639 }
640 Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr),
641 }
642}
643
644struct SQLFederationPlanner {
645 executor: Arc<dyn SQLExecutor>,
646}
647
648impl SQLFederationPlanner {
649 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
650 Self { executor }
651 }
652}
653
654#[async_trait]
655impl FederationPlanner for SQLFederationPlanner {
656 async fn plan_federation(
657 &self,
658 node: &FederatedPlanNode,
659 _session_state: &SessionState,
660 ) -> Result<Arc<dyn ExecutionPlan>> {
661 let schema = Arc::new(node.plan().schema().as_arrow().clone());
662 let input = Arc::new(VirtualExecutionPlan::new(
663 node.plan().clone(),
664 Arc::clone(&self.executor),
665 ));
666 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
667 Ok(Arc::new(schema_cast_exec))
668 }
669}
670
671#[derive(Debug, Clone)]
672struct VirtualExecutionPlan {
673 plan: LogicalPlan,
674 executor: Arc<dyn SQLExecutor>,
675 props: PlanProperties,
676}
677
678impl VirtualExecutionPlan {
679 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
680 let schema: Schema = plan.schema().as_ref().into();
681 let props = PlanProperties::new(
682 EquivalenceProperties::new(Arc::new(schema)),
683 Partitioning::UnknownPartitioning(1),
684 EmissionType::Incremental,
685 Boundedness::Bounded,
686 );
687 Self {
688 plan,
689 executor,
690 props,
691 }
692 }
693
694 fn schema(&self) -> SchemaRef {
695 let df_schema = self.plan.schema().as_ref();
696 Arc::new(Schema::from(df_schema))
697 }
698
699 fn sql(&self) -> Result<String> {
700 let mut known_rewrites = HashMap::new();
702 let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?;
703 let mut ast = self.plan_to_sql(plan)?;
704
705 if let Some(analyzer) = self.executor.ast_analyzer() {
706 ast = analyzer(ast)?;
707 }
708
709 Ok(format!("{ast}"))
710 }
711
712 fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<Statement> {
713 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
714 }
715}
716
717impl DisplayAs for VirtualExecutionPlan {
718 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
719 write!(f, "VirtualExecutionPlan")?;
720 let Ok(ast) = plan_to_sql(&self.plan) else {
721 return Ok(());
722 };
723 write!(f, " name={}", self.executor.name())?;
724 if let Some(ctx) = self.executor.compute_context() {
725 write!(f, " compute_context={ctx}")?;
726 };
727
728 write!(f, " sql={ast}")?;
729 if let Ok(query) = self.sql() {
730 write!(f, " rewritten_sql={query}")?;
731 };
732
733 write!(f, " sql={ast}")
734 }
735}
736
737impl ExecutionPlan for VirtualExecutionPlan {
738 fn name(&self) -> &str {
739 "sql_federation_exec"
740 }
741
742 fn as_any(&self) -> &dyn Any {
743 self
744 }
745
746 fn schema(&self) -> SchemaRef {
747 self.schema()
748 }
749
750 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
751 vec![]
752 }
753
754 fn with_new_children(
755 self: Arc<Self>,
756 _: Vec<Arc<dyn ExecutionPlan>>,
757 ) -> Result<Arc<dyn ExecutionPlan>> {
758 Ok(self)
759 }
760
761 fn execute(
762 &self,
763 _partition: usize,
764 _context: Arc<TaskContext>,
765 ) -> Result<SendableRecordBatchStream> {
766 let query = self.plan_to_sql(&self.plan)?.to_string();
767 self.executor.execute(query.as_str(), self.schema())
768 }
769
770 fn properties(&self) -> &PlanProperties {
771 &self.props
772 }
773}
774
775#[cfg(test)]
776mod tests {
777 use crate::FederatedTableProviderAdaptor;
778 use datafusion::{
779 arrow::datatypes::{DataType, Field},
780 catalog::SchemaProvider,
781 catalog_common::MemorySchemaProvider,
782 common::Column,
783 datasource::{DefaultTableSource, TableProvider},
784 error::DataFusionError,
785 execution::context::SessionContext,
786 logical_expr::LogicalPlanBuilder,
787 sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect},
788 };
789
790 use super::*;
791
792 struct TestSQLExecutor {}
793
794 #[async_trait]
795 impl SQLExecutor for TestSQLExecutor {
796 fn name(&self) -> &str {
797 "test_sql_table_source"
798 }
799
800 fn compute_context(&self) -> Option<String> {
801 None
802 }
803
804 fn dialect(&self) -> Arc<dyn Dialect> {
805 Arc::new(DefaultDialect {})
806 }
807
808 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
809 Err(DataFusionError::NotImplemented(
810 "execute not implemented".to_string(),
811 ))
812 }
813
814 async fn table_names(&self) -> Result<Vec<String>> {
815 Err(DataFusionError::NotImplemented(
816 "table inference not implemented".to_string(),
817 ))
818 }
819
820 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
821 Err(DataFusionError::NotImplemented(
822 "table inference not implemented".to_string(),
823 ))
824 }
825 }
826
827 fn get_test_table_provider() -> Arc<dyn TableProvider> {
828 let sql_federation_provider =
829 Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
830
831 let schema = Arc::new(Schema::new(vec![
832 Field::new("a", DataType::Int64, false),
833 Field::new("b", DataType::Utf8, false),
834 Field::new("c", DataType::Date32, false),
835 ]));
836 let table_source = Arc::new(
837 SQLTableSource::new_with_schema(
838 sql_federation_provider,
839 "remote_table".to_string(),
840 schema,
841 )
842 .expect("to have a valid SQLTableSource"),
843 );
844 Arc::new(FederatedTableProviderAdaptor::new(table_source))
845 }
846
847 fn get_test_table_source() -> Arc<DefaultTableSource> {
848 Arc::new(DefaultTableSource::new(get_test_table_provider()))
849 }
850
851 fn get_test_df_context() -> SessionContext {
852 let ctx = SessionContext::new();
853 let catalog = ctx
854 .catalog("datafusion")
855 .expect("default catalog is datafusion");
856 let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
857 catalog
858 .register_schema("foo", Arc::clone(&foo_schema))
859 .expect("to register schema");
860 foo_schema
861 .register_table("df_table".to_string(), get_test_table_provider())
862 .expect("to register table");
863
864 let public_schema = catalog
865 .schema("public")
866 .expect("public schema should exist");
867 public_schema
868 .register_table("app_table".to_string(), get_test_table_provider())
869 .expect("to register table");
870
871 ctx
872 }
873
874 #[test]
875 fn test_rewrite_table_scans_basic() -> Result<()> {
876 let default_table_source = get_test_table_source();
877 let plan =
878 LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![
879 Expr::Column(Column::from_qualified_name("foo.df_table.a")),
880 Expr::Column(Column::from_qualified_name("foo.df_table.b")),
881 Expr::Column(Column::from_qualified_name("foo.df_table.c")),
882 ])?;
883
884 let mut known_rewrites = HashMap::new();
885 let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?;
886
887 println!("rewritten_plan: \n{:#?}", rewritten_plan);
888
889 let unparsed_sql = plan_to_sql(&rewritten_plan)?;
890
891 println!("unparsed_sql: \n{unparsed_sql}");
892
893 assert_eq!(
894 format!("{unparsed_sql}"),
895 r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"#
896 );
897
898 Ok(())
899 }
900
901 fn init_tracing() {
902 let subscriber = tracing_subscriber::FmtSubscriber::builder()
903 .with_env_filter("debug")
904 .with_ansi(true)
905 .finish();
906 let _ = tracing::subscriber::set_global_default(subscriber);
907 }
908
909 #[tokio::test]
910 async fn test_rewrite_table_scans_agg() -> Result<()> {
911 init_tracing();
912 let ctx = get_test_df_context();
913
914 let agg_tests = vec![
915 (
916 "SELECT MAX(a) FROM foo.df_table",
917 r#"SELECT max(remote_table.a) FROM remote_table"#,
918 ),
919 (
920 "SELECT foo.df_table.a FROM foo.df_table",
921 r#"SELECT remote_table.a FROM remote_table"#,
922 ),
923 (
924 "SELECT MIN(a) FROM foo.df_table",
925 r#"SELECT min(remote_table.a) FROM remote_table"#,
926 ),
927 (
928 "SELECT AVG(a) FROM foo.df_table",
929 r#"SELECT avg(remote_table.a) FROM remote_table"#,
930 ),
931 (
932 "SELECT SUM(a) FROM foo.df_table",
933 r#"SELECT sum(remote_table.a) FROM remote_table"#,
934 ),
935 (
936 "SELECT COUNT(a) FROM foo.df_table",
937 r#"SELECT count(remote_table.a) FROM remote_table"#,
938 ),
939 (
940 "SELECT COUNT(a) as cnt FROM foo.df_table",
941 r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
942 ),
943 (
944 "SELECT COUNT(a) as cnt FROM foo.df_table",
945 r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
946 ),
947 (
948 "SELECT app_table from (SELECT a as app_table FROM app_table) b",
949 r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
950 ),
951 (
952 "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b",
953 r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
954 ),
955 (
957 "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table",
958 r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#,
959 ),
960 (
962 "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft",
963 "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft"
964 ),
965 ];
966
967 for test in agg_tests {
968 test_sql(&ctx, test.0, test.1).await?;
969 }
970
971 Ok(())
972 }
973
974 #[tokio::test]
975 async fn test_rewrite_table_scans_alias() -> Result<()> {
976 init_tracing();
977 let ctx = get_test_df_context();
978
979 let tests = vec![
980 (
981 "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)",
982 r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
983 ),
984 (
985 "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)",
986 r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
987 ),
988 (
989 "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)",
990 r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#,
991 ),
992 ];
993
994 for test in tests {
995 test_sql(&ctx, test.0, test.1).await?;
996 }
997
998 Ok(())
999 }
1000
1001 async fn test_sql(
1002 ctx: &SessionContext,
1003 sql_query: &str,
1004 expected_sql: &str,
1005 ) -> Result<(), datafusion::error::DataFusionError> {
1006 let data_frame = ctx.sql(sql_query).await?;
1007
1008 println!("before optimization: \n{:#?}", data_frame.logical_plan());
1009
1010 let mut known_rewrites = HashMap::new();
1011 let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?;
1012
1013 println!("rewritten_plan: \n{:#?}", rewritten_plan);
1014
1015 let unparsed_sql = plan_to_sql(&rewritten_plan)?;
1016
1017 println!("unparsed_sql: \n{unparsed_sql}");
1018
1019 assert_eq!(
1020 format!("{unparsed_sql}"),
1021 expected_sql,
1022 "SQL under test: {}",
1023 sql_query
1024 );
1025
1026 Ok(())
1027 }
1028
1029 #[tokio::test]
1030 async fn test_rewrite_table_scans_limit_offset() -> Result<()> {
1031 init_tracing();
1032 let ctx = get_test_df_context();
1033
1034 let tests = vec![
1035 (
1037 "SELECT a FROM foo.df_table LIMIT 5",
1038 r#"SELECT remote_table.a FROM remote_table LIMIT 5"#,
1039 ),
1040 (
1042 "SELECT a FROM foo.df_table OFFSET 5",
1043 r#"SELECT remote_table.a FROM remote_table OFFSET 5"#,
1044 ),
1045 (
1047 "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5",
1048 r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1049 ),
1050 (
1052 "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10",
1053 r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1054 ),
1055 (
1057 "SELECT a FROM foo.df_table OFFSET 0",
1058 r#"SELECT remote_table.a FROM remote_table OFFSET 0"#,
1059 ),
1060 (
1062 "SELECT a FROM foo.df_table LIMIT 0",
1063 r#"SELECT remote_table.a FROM remote_table LIMIT 0"#,
1064 ),
1065 (
1067 "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0",
1068 r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#,
1069 ),
1070 ];
1071
1072 for test in tests {
1073 test_sql(&ctx, test.0, test.1).await?;
1074 }
1075
1076 Ok(())
1077 }
1078}