1mod executor;
2mod schema;
3
4use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec};
5
6use async_trait::async_trait;
7use datafusion::{
8 arrow::datatypes::{Schema, SchemaRef},
9 common::{tree_node::Transformed, Column},
10 error::Result,
11 execution::{context::SessionState, TaskContext},
12 logical_expr::{
13 expr::{
14 AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery,
15 PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions,
16 WindowFunction, WindowFunctionParams,
17 },
18 Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan,
19 Subquery, TryCast,
20 },
21 optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
22 physical_expr::EquivalenceProperties,
23 physical_plan::{
24 execution_plan::{Boundedness, EmissionType},
25 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
26 SendableRecordBatchStream,
27 },
28 sql::{
29 sqlparser::ast::Statement,
30 unparser::{plan_to_sql, Unparser},
31 TableReference,
32 },
33};
34
35pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef};
36pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource};
37
38use crate::{
39 get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
40};
41
42#[derive(Debug)]
47pub struct SQLFederationProvider {
48 optimizer: Arc<Optimizer>,
49 executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54 Self {
55 optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56 SQLFederationOptimizerRule::new(executor.clone()),
57 )])),
58 executor,
59 }
60 }
61}
62
63impl FederationProvider for SQLFederationProvider {
64 fn name(&self) -> &str {
65 "sql_federation_provider"
66 }
67
68 fn compute_context(&self) -> Option<String> {
69 self.executor.compute_context()
70 }
71
72 fn optimizer(&self) -> Option<Arc<Optimizer>> {
73 Some(self.optimizer.clone())
74 }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79 planner: Arc<dyn FederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84 Self {
85 planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86 }
87 }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91 fn rewrite(
97 &self,
98 plan: LogicalPlan,
99 _config: &dyn OptimizerConfig,
100 ) -> Result<Transformed<LogicalPlan>> {
101 if let LogicalPlan::Extension(Extension { ref node }) = plan {
102 if node.name() == "Federated" {
103 return Ok(Transformed::no(plan));
105 }
106 }
107 let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109 let ext_node = Extension {
110 node: Arc::new(fed_plan),
111 };
112 Ok(Transformed::yes(LogicalPlan::Extension(ext_node)))
113 }
114
115 fn name(&self) -> &str {
117 "federate_sql"
118 }
119
120 fn supports_rewrite(&self) -> bool {
122 true
123 }
124}
125
126fn rewrite_table_scans(
128 plan: &LogicalPlan,
129 known_rewrites: &mut HashMap<TableReference, TableReference>,
130) -> Result<LogicalPlan> {
131 if plan.inputs().is_empty() {
132 if let LogicalPlan::TableScan(table_scan) = plan {
133 let original_table_name = table_scan.table_name.clone();
134 let mut new_table_scan = table_scan.clone();
135
136 let Some(federated_source) = get_table_source(&table_scan.source)? else {
137 return Ok(plan.clone());
139 };
140
141 match federated_source.as_any().downcast_ref::<SQLTableSource>() {
142 Some(sql_table_source) => {
143 let remote_table_name = TableReference::from(sql_table_source.table_name());
144 known_rewrites.insert(original_table_name, remote_table_name.clone());
145
146 let new_schema = (*new_table_scan.projected_schema)
148 .clone()
149 .replace_qualifier(remote_table_name.clone());
150 new_table_scan.projected_schema = Arc::new(new_schema);
151 new_table_scan.table_name = remote_table_name;
152 }
153 None => {
154 return Ok(plan.clone());
156 }
157 }
158
159 return Ok(LogicalPlan::TableScan(new_table_scan));
160 } else {
161 return Ok(plan.clone());
162 }
163 }
164
165 let rewritten_inputs = plan
166 .inputs()
167 .into_iter()
168 .map(|i| rewrite_table_scans(i, known_rewrites))
169 .collect::<Result<Vec<_>>>()?;
170
171 if let LogicalPlan::Limit(limit) = plan {
172 let rewritten_skip = limit
173 .skip
174 .as_ref()
175 .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new))
176 .transpose()?;
177
178 let rewritten_fetch = limit
179 .fetch
180 .as_ref()
181 .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new))
182 .transpose()?;
183
184 let new_plan = LogicalPlan::Limit(Limit {
186 skip: rewritten_skip,
187 fetch: rewritten_fetch,
188 input: Arc::new(rewritten_inputs[0].clone()),
189 });
190
191 return Ok(new_plan);
192 }
193
194 let mut new_expressions = vec![];
195 for expression in plan.expressions() {
196 let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?;
197 new_expressions.push(new_expr);
198 }
199
200 let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?;
201
202 Ok(new_plan)
203}
204
205fn rewrite_column_name_in_expr(
209 col_name: &str,
210 table_ref_str: &str,
211 rewrite: &str,
212 start_pos: usize,
213) -> Option<String> {
214 if start_pos >= col_name.len() {
215 return None;
216 }
217
218 let idx = col_name[start_pos..].find(table_ref_str)?;
220
221 let idx = start_pos + idx;
223
224 if idx > 0 {
225 if let Some(prev_char) = col_name.chars().nth(idx - 1) {
228 if prev_char.is_alphabetic()
229 || prev_char.is_numeric()
230 || prev_char == '_'
231 || prev_char == '.'
232 {
233 return rewrite_column_name_in_expr(
234 col_name,
235 table_ref_str,
236 rewrite,
237 idx + table_ref_str.len(),
238 );
239 }
240 }
241 }
242
243 if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) {
246 if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' {
247 return rewrite_column_name_in_expr(
248 col_name,
249 table_ref_str,
250 rewrite,
251 idx + table_ref_str.len(),
252 );
253 }
254 }
255
256 let rewritten_name = format!(
258 "{}{}{}",
259 &col_name[..idx],
260 rewrite,
261 &col_name[idx + table_ref_str.len()..]
262 );
263 match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len())
266 {
267 Some(new_name) => Some(new_name), None => Some(rewritten_name), }
270}
271
272fn rewrite_table_scans_in_expr(
273 expr: Expr,
274 known_rewrites: &mut HashMap<TableReference, TableReference>,
275) -> Result<Expr> {
276 match expr {
277 Expr::ScalarSubquery(subquery) => {
278 let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?;
279 let outer_ref_columns = subquery
280 .outer_ref_columns
281 .into_iter()
282 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
283 .collect::<Result<Vec<Expr>>>()?;
284 Ok(Expr::ScalarSubquery(Subquery {
285 subquery: Arc::new(new_subquery),
286 outer_ref_columns,
287 }))
288 }
289 Expr::BinaryExpr(binary_expr) => {
290 let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?;
291 let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?;
292 Ok(Expr::BinaryExpr(BinaryExpr::new(
293 Box::new(left),
294 binary_expr.op,
295 Box::new(right),
296 )))
297 }
298 Expr::Column(mut col) => {
299 if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
300 Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name)))
301 } else {
302 if col.relation.is_some() {
305 return Ok(Expr::Column(col));
306 }
307
308 let (new_name, was_rewritten) = known_rewrites.iter().fold(
311 (col.name.to_string(), false),
312 |(col_name, was_rewritten), (table_ref, rewrite)| {
313 match rewrite_column_name_in_expr(
314 &col_name,
315 &table_ref.to_string(),
316 &rewrite.to_string(),
317 0,
318 ) {
319 Some(new_name) => (new_name, true),
320 None => (col_name, was_rewritten),
321 }
322 },
323 );
324 if was_rewritten {
325 Ok(Expr::Column(Column::new(col.relation.take(), new_name)))
326 } else {
327 Ok(Expr::Column(col))
328 }
329 }
330 }
331 Expr::Alias(alias) => {
332 let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?;
333 if let Some(relation) = &alias.relation {
334 if let Some(rewrite) = known_rewrites.get(relation) {
335 return Ok(Expr::Alias(Alias::new(
336 expr,
337 Some(rewrite.clone()),
338 alias.name,
339 )));
340 }
341 }
342 Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name)))
343 }
344 Expr::Like(like) => {
345 let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?;
346 let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?;
347 Ok(Expr::Like(Like::new(
348 like.negated,
349 Box::new(expr),
350 Box::new(pattern),
351 like.escape_char,
352 like.case_insensitive,
353 )))
354 }
355 Expr::SimilarTo(similar_to) => {
356 let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?;
357 let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?;
358 Ok(Expr::SimilarTo(Like::new(
359 similar_to.negated,
360 Box::new(expr),
361 Box::new(pattern),
362 similar_to.escape_char,
363 similar_to.case_insensitive,
364 )))
365 }
366 Expr::Not(e) => {
367 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
368 Ok(Expr::Not(Box::new(expr)))
369 }
370 Expr::IsNotNull(e) => {
371 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
372 Ok(Expr::IsNotNull(Box::new(expr)))
373 }
374 Expr::IsNull(e) => {
375 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
376 Ok(Expr::IsNull(Box::new(expr)))
377 }
378 Expr::IsTrue(e) => {
379 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
380 Ok(Expr::IsTrue(Box::new(expr)))
381 }
382 Expr::IsFalse(e) => {
383 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
384 Ok(Expr::IsFalse(Box::new(expr)))
385 }
386 Expr::IsUnknown(e) => {
387 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
388 Ok(Expr::IsUnknown(Box::new(expr)))
389 }
390 Expr::IsNotTrue(e) => {
391 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
392 Ok(Expr::IsNotTrue(Box::new(expr)))
393 }
394 Expr::IsNotFalse(e) => {
395 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
396 Ok(Expr::IsNotFalse(Box::new(expr)))
397 }
398 Expr::IsNotUnknown(e) => {
399 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
400 Ok(Expr::IsNotUnknown(Box::new(expr)))
401 }
402 Expr::Negative(e) => {
403 let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
404 Ok(Expr::Negative(Box::new(expr)))
405 }
406 Expr::Between(between) => {
407 let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?;
408 let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?;
409 let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?;
410 Ok(Expr::Between(Between::new(
411 Box::new(expr),
412 between.negated,
413 Box::new(low),
414 Box::new(high),
415 )))
416 }
417 Expr::Case(case) => {
418 let expr = case
419 .expr
420 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
421 .transpose()?
422 .map(Box::new);
423 let else_expr = case
424 .else_expr
425 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
426 .transpose()?
427 .map(Box::new);
428 let when_expr = case
429 .when_then_expr
430 .into_iter()
431 .map(|(when, then)| {
432 let when = rewrite_table_scans_in_expr(*when, known_rewrites);
433 let then = rewrite_table_scans_in_expr(*then, known_rewrites);
434
435 match (when, then) {
436 (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))),
437 (Err(e), _) | (_, Err(e)) => Err(e),
438 }
439 })
440 .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>>>()?;
441 Ok(Expr::Case(Case::new(expr, when_expr, else_expr)))
442 }
443 Expr::Cast(cast) => {
444 let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?;
445 Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type)))
446 }
447 Expr::TryCast(try_cast) => {
448 let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?;
449 Ok(Expr::TryCast(TryCast::new(
450 Box::new(expr),
451 try_cast.data_type,
452 )))
453 }
454 Expr::ScalarFunction(sf) => {
455 let args = sf
456 .args
457 .into_iter()
458 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
459 .collect::<Result<Vec<Expr>>>()?;
460 Ok(Expr::ScalarFunction(ScalarFunction {
461 func: sf.func,
462 args,
463 }))
464 }
465 Expr::AggregateFunction(af) => {
466 let args = af
467 .params
468 .args
469 .into_iter()
470 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
471 .collect::<Result<Vec<Expr>>>()?;
472 let filter = af
473 .params
474 .filter
475 .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
476 .transpose()?
477 .map(Box::new);
478 let order_by = af
479 .params
480 .order_by
481 .map(|e| {
482 e.into_iter()
483 .map(|sort| {
484 Ok(Sort {
485 expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
486 ..sort
487 })
488 })
489 .collect::<Result<Vec<_>>>()
490 })
491 .transpose()?;
492 let params = AggregateFunctionParams {
493 args,
494 distinct: af.params.distinct,
495 filter,
496 order_by,
497 null_treatment: af.params.null_treatment,
498 };
499 Ok(Expr::AggregateFunction(AggregateFunction {
500 func: af.func,
501 params,
502 }))
503 }
504 Expr::WindowFunction(wf) => {
505 let args = wf
506 .params
507 .args
508 .into_iter()
509 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
510 .collect::<Result<Vec<Expr>>>()?;
511 let partition_by = wf
512 .params
513 .partition_by
514 .into_iter()
515 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
516 .collect::<Result<Vec<Expr>>>()?;
517 let order_by = wf
518 .params
519 .order_by
520 .into_iter()
521 .map(|sort| {
522 Ok(Sort {
523 expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
524 ..sort
525 })
526 })
527 .collect::<Result<Vec<_>>>()?;
528 let params = WindowFunctionParams {
529 args,
530 partition_by,
531 order_by,
532 window_frame: wf.params.window_frame,
533 null_treatment: wf.params.null_treatment,
534 };
535 Ok(Expr::WindowFunction(WindowFunction {
536 fun: wf.fun,
537 params,
538 }))
539 }
540 Expr::InList(il) => {
541 let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?;
542 let list = il
543 .list
544 .into_iter()
545 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
546 .collect::<Result<Vec<Expr>>>()?;
547 Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated)))
548 }
549 Expr::Exists(exists) => {
550 let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?;
551 let outer_ref_columns = exists
552 .subquery
553 .outer_ref_columns
554 .into_iter()
555 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
556 .collect::<Result<Vec<Expr>>>()?;
557 let subquery = Subquery {
558 subquery: Arc::new(subquery_plan),
559 outer_ref_columns,
560 };
561 Ok(Expr::Exists(Exists::new(subquery, exists.negated)))
562 }
563 Expr::InSubquery(is) => {
564 let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?;
565 let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?;
566 let outer_ref_columns = is
567 .subquery
568 .outer_ref_columns
569 .into_iter()
570 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
571 .collect::<Result<Vec<Expr>>>()?;
572 let subquery = Subquery {
573 subquery: Arc::new(subquery_plan),
574 outer_ref_columns,
575 };
576 Ok(Expr::InSubquery(InSubquery::new(
577 Box::new(expr),
578 subquery,
579 is.negated,
580 )))
581 }
582 #[expect(deprecated)]
584 Expr::Wildcard { qualifier, options } => {
585 let options = WildcardOptions {
586 replace: options
587 .replace
588 .map(|replace| -> Result<PlannedReplaceSelectItem> {
589 Ok(PlannedReplaceSelectItem {
590 planned_expressions: replace
591 .planned_expressions
592 .into_iter()
593 .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites))
594 .collect::<Result<Vec<_>>>()?,
595 ..replace
596 })
597 })
598 .transpose()?,
599 ..*options
600 };
601 if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) {
602 Ok(Expr::Wildcard {
603 qualifier: Some(rewrite.clone()),
604 options: Box::new(options),
605 })
606 } else {
607 Ok(Expr::Wildcard {
608 qualifier,
609 options: Box::new(options),
610 })
611 }
612 }
613 Expr::GroupingSet(gs) => match gs {
614 GroupingSet::Rollup(exprs) => {
615 let exprs = exprs
616 .into_iter()
617 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
618 .collect::<Result<Vec<Expr>>>()?;
619 Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs)))
620 }
621 GroupingSet::Cube(exprs) => {
622 let exprs = exprs
623 .into_iter()
624 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
625 .collect::<Result<Vec<Expr>>>()?;
626 Ok(Expr::GroupingSet(GroupingSet::Cube(exprs)))
627 }
628 GroupingSet::GroupingSets(vec_exprs) => {
629 let vec_exprs = vec_exprs
630 .into_iter()
631 .map(|exprs| {
632 exprs
633 .into_iter()
634 .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
635 .collect::<Result<Vec<Expr>>>()
636 })
637 .collect::<Result<Vec<Vec<Expr>>>>()?;
638 Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs)))
639 }
640 },
641 Expr::OuterReferenceColumn(dt, col) => {
642 if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
643 Ok(Expr::OuterReferenceColumn(
644 dt,
645 Column::new(Some(rewrite.clone()), &col.name),
646 ))
647 } else {
648 Ok(Expr::OuterReferenceColumn(dt, col))
649 }
650 }
651 Expr::Unnest(unnest) => {
652 let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?;
653 Ok(Expr::Unnest(Unnest::new(expr)))
654 }
655 Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr),
656 }
657}
658
659struct SQLFederationPlanner {
660 executor: Arc<dyn SQLExecutor>,
661}
662
663impl SQLFederationPlanner {
664 pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
665 Self { executor }
666 }
667}
668
669#[async_trait]
670impl FederationPlanner for SQLFederationPlanner {
671 async fn plan_federation(
672 &self,
673 node: &FederatedPlanNode,
674 _session_state: &SessionState,
675 ) -> Result<Arc<dyn ExecutionPlan>> {
676 let schema = Arc::new(node.plan().schema().as_arrow().clone());
677 let input = Arc::new(VirtualExecutionPlan::new(
678 node.plan().clone(),
679 Arc::clone(&self.executor),
680 ));
681 let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
682 Ok(Arc::new(schema_cast_exec))
683 }
684}
685
686#[derive(Debug, Clone)]
687struct VirtualExecutionPlan {
688 plan: LogicalPlan,
689 executor: Arc<dyn SQLExecutor>,
690 props: PlanProperties,
691}
692
693impl VirtualExecutionPlan {
694 pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
695 let schema: Schema = plan.schema().as_ref().into();
696 let props = PlanProperties::new(
697 EquivalenceProperties::new(Arc::new(schema)),
698 Partitioning::UnknownPartitioning(1),
699 EmissionType::Incremental,
700 Boundedness::Bounded,
701 );
702 Self {
703 plan,
704 executor,
705 props,
706 }
707 }
708
709 fn schema(&self) -> SchemaRef {
710 let df_schema = self.plan.schema().as_ref();
711 Arc::new(Schema::from(df_schema))
712 }
713
714 fn sql(&self) -> Result<String> {
715 let mut known_rewrites = HashMap::new();
717 let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?;
718 let mut ast = self.plan_to_sql(plan)?;
719
720 if let Some(analyzer) = self.executor.ast_analyzer() {
721 ast = analyzer(ast)?;
722 }
723
724 Ok(format!("{ast}"))
725 }
726
727 fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<Statement> {
728 Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
729 }
730}
731
732impl DisplayAs for VirtualExecutionPlan {
733 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
734 write!(f, "VirtualExecutionPlan")?;
735 let Ok(ast) = plan_to_sql(&self.plan) else {
736 return Ok(());
737 };
738 write!(f, " name={}", self.executor.name())?;
739 if let Some(ctx) = self.executor.compute_context() {
740 write!(f, " compute_context={ctx}")?;
741 };
742
743 write!(f, " sql={ast}")?;
744 if let Ok(query) = self.sql() {
745 write!(f, " rewritten_sql={query}")?;
746 };
747
748 write!(f, " sql={ast}")
749 }
750}
751
752impl ExecutionPlan for VirtualExecutionPlan {
753 fn name(&self) -> &str {
754 "sql_federation_exec"
755 }
756
757 fn as_any(&self) -> &dyn Any {
758 self
759 }
760
761 fn schema(&self) -> SchemaRef {
762 self.schema()
763 }
764
765 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
766 vec![]
767 }
768
769 fn with_new_children(
770 self: Arc<Self>,
771 _: Vec<Arc<dyn ExecutionPlan>>,
772 ) -> Result<Arc<dyn ExecutionPlan>> {
773 Ok(self)
774 }
775
776 fn execute(
777 &self,
778 _partition: usize,
779 _context: Arc<TaskContext>,
780 ) -> Result<SendableRecordBatchStream> {
781 let query = self.plan_to_sql(&self.plan)?.to_string();
782 self.executor.execute(query.as_str(), self.schema())
783 }
784
785 fn properties(&self) -> &PlanProperties {
786 &self.props
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use crate::FederatedTableProviderAdaptor;
793 use datafusion::{
794 arrow::datatypes::{DataType, Field},
795 catalog::{MemorySchemaProvider, SchemaProvider},
796 common::Column,
797 datasource::{DefaultTableSource, TableProvider},
798 error::DataFusionError,
799 execution::context::SessionContext,
800 logical_expr::LogicalPlanBuilder,
801 sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect},
802 };
803
804 use super::*;
805
806 struct TestSQLExecutor {}
807
808 #[async_trait]
809 impl SQLExecutor for TestSQLExecutor {
810 fn name(&self) -> &str {
811 "test_sql_table_source"
812 }
813
814 fn compute_context(&self) -> Option<String> {
815 None
816 }
817
818 fn dialect(&self) -> Arc<dyn Dialect> {
819 Arc::new(DefaultDialect {})
820 }
821
822 fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
823 Err(DataFusionError::NotImplemented(
824 "execute not implemented".to_string(),
825 ))
826 }
827
828 async fn table_names(&self) -> Result<Vec<String>> {
829 Err(DataFusionError::NotImplemented(
830 "table inference not implemented".to_string(),
831 ))
832 }
833
834 async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
835 Err(DataFusionError::NotImplemented(
836 "table inference not implemented".to_string(),
837 ))
838 }
839 }
840
841 fn get_test_table_provider() -> Arc<dyn TableProvider> {
842 let sql_federation_provider =
843 Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
844
845 let schema = Arc::new(Schema::new(vec![
846 Field::new("a", DataType::Int64, false),
847 Field::new("b", DataType::Utf8, false),
848 Field::new("c", DataType::Date32, false),
849 ]));
850 let table_source = Arc::new(
851 SQLTableSource::new_with_schema(
852 sql_federation_provider,
853 "remote_table".to_string(),
854 schema,
855 )
856 .expect("to have a valid SQLTableSource"),
857 );
858 Arc::new(FederatedTableProviderAdaptor::new(table_source))
859 }
860
861 fn get_test_table_source() -> Arc<DefaultTableSource> {
862 Arc::new(DefaultTableSource::new(get_test_table_provider()))
863 }
864
865 fn get_test_df_context() -> SessionContext {
866 let ctx = SessionContext::new();
867 let catalog = ctx
868 .catalog("datafusion")
869 .expect("default catalog is datafusion");
870 let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
871 catalog
872 .register_schema("foo", Arc::clone(&foo_schema))
873 .expect("to register schema");
874 foo_schema
875 .register_table("df_table".to_string(), get_test_table_provider())
876 .expect("to register table");
877
878 let public_schema = catalog
879 .schema("public")
880 .expect("public schema should exist");
881 public_schema
882 .register_table("app_table".to_string(), get_test_table_provider())
883 .expect("to register table");
884
885 ctx
886 }
887
888 #[test]
889 fn test_rewrite_table_scans_basic() -> Result<()> {
890 let default_table_source = get_test_table_source();
891 let plan =
892 LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![
893 Expr::Column(Column::from_qualified_name("foo.df_table.a")),
894 Expr::Column(Column::from_qualified_name("foo.df_table.b")),
895 Expr::Column(Column::from_qualified_name("foo.df_table.c")),
896 ])?;
897
898 let mut known_rewrites = HashMap::new();
899 let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?;
900
901 println!("rewritten_plan: \n{:#?}", rewritten_plan);
902
903 let unparsed_sql = plan_to_sql(&rewritten_plan)?;
904
905 println!("unparsed_sql: \n{unparsed_sql}");
906
907 assert_eq!(
908 format!("{unparsed_sql}"),
909 r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"#
910 );
911
912 Ok(())
913 }
914
915 fn init_tracing() {
916 let subscriber = tracing_subscriber::FmtSubscriber::builder()
917 .with_env_filter("debug")
918 .with_ansi(true)
919 .finish();
920 let _ = tracing::subscriber::set_global_default(subscriber);
921 }
922
923 #[tokio::test]
924 async fn test_rewrite_table_scans_agg() -> Result<()> {
925 init_tracing();
926 let ctx = get_test_df_context();
927
928 let agg_tests = vec![
929 (
930 "SELECT MAX(a) FROM foo.df_table",
931 r#"SELECT max(remote_table.a) FROM remote_table"#,
932 ),
933 (
934 "SELECT foo.df_table.a FROM foo.df_table",
935 r#"SELECT remote_table.a FROM remote_table"#,
936 ),
937 (
938 "SELECT MIN(a) FROM foo.df_table",
939 r#"SELECT min(remote_table.a) FROM remote_table"#,
940 ),
941 (
942 "SELECT AVG(a) FROM foo.df_table",
943 r#"SELECT avg(remote_table.a) FROM remote_table"#,
944 ),
945 (
946 "SELECT SUM(a) FROM foo.df_table",
947 r#"SELECT sum(remote_table.a) FROM remote_table"#,
948 ),
949 (
950 "SELECT COUNT(a) FROM foo.df_table",
951 r#"SELECT count(remote_table.a) FROM remote_table"#,
952 ),
953 (
954 "SELECT COUNT(a) as cnt FROM foo.df_table",
955 r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
956 ),
957 (
958 "SELECT COUNT(a) as cnt FROM foo.df_table",
959 r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
960 ),
961 (
962 "SELECT app_table from (SELECT a as app_table FROM app_table) b",
963 r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
964 ),
965 (
966 "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b",
967 r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
968 ),
969 (
971 "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table",
972 r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#,
973 ),
974 (
976 "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft",
977 "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft"
978 ),
979 ];
980
981 for test in agg_tests {
982 test_sql(&ctx, test.0, test.1).await?;
983 }
984
985 Ok(())
986 }
987
988 #[tokio::test]
989 async fn test_rewrite_table_scans_alias() -> Result<()> {
990 init_tracing();
991 let ctx = get_test_df_context();
992
993 let tests = vec![
994 (
995 "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)",
996 r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
997 ),
998 (
999 "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)",
1000 r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
1001 ),
1002 (
1003 "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)",
1004 r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#,
1005 ),
1006 ];
1007
1008 for test in tests {
1009 test_sql(&ctx, test.0, test.1).await?;
1010 }
1011
1012 Ok(())
1013 }
1014
1015 async fn test_sql(
1016 ctx: &SessionContext,
1017 sql_query: &str,
1018 expected_sql: &str,
1019 ) -> Result<(), datafusion::error::DataFusionError> {
1020 let data_frame = ctx.sql(sql_query).await?;
1021
1022 println!("before optimization: \n{:#?}", data_frame.logical_plan());
1023
1024 let mut known_rewrites = HashMap::new();
1025 let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?;
1026
1027 println!("rewritten_plan: \n{:#?}", rewritten_plan);
1028
1029 let unparsed_sql = plan_to_sql(&rewritten_plan)?;
1030
1031 println!("unparsed_sql: \n{unparsed_sql}");
1032
1033 assert_eq!(
1034 format!("{unparsed_sql}"),
1035 expected_sql,
1036 "SQL under test: {}",
1037 sql_query
1038 );
1039
1040 Ok(())
1041 }
1042
1043 #[tokio::test]
1044 async fn test_rewrite_table_scans_limit_offset() -> Result<()> {
1045 init_tracing();
1046 let ctx = get_test_df_context();
1047
1048 let tests = vec![
1049 (
1051 "SELECT a FROM foo.df_table LIMIT 5",
1052 r#"SELECT remote_table.a FROM remote_table LIMIT 5"#,
1053 ),
1054 (
1056 "SELECT a FROM foo.df_table OFFSET 5",
1057 r#"SELECT remote_table.a FROM remote_table OFFSET 5"#,
1058 ),
1059 (
1061 "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5",
1062 r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1063 ),
1064 (
1066 "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10",
1067 r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1068 ),
1069 (
1071 "SELECT a FROM foo.df_table OFFSET 0",
1072 r#"SELECT remote_table.a FROM remote_table OFFSET 0"#,
1073 ),
1074 (
1076 "SELECT a FROM foo.df_table LIMIT 0",
1077 r#"SELECT remote_table.a FROM remote_table LIMIT 0"#,
1078 ),
1079 (
1081 "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0",
1082 r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#,
1083 ),
1084 ];
1085
1086 for test in tests {
1087 test_sql(&ctx, test.0, test.1).await?;
1088 }
1089
1090 Ok(())
1091 }
1092}