datafusion_federation/sql/
mod.rs

1mod executor;
2mod schema;
3
4use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec};
5
6use async_trait::async_trait;
7use datafusion::{
8    arrow::datatypes::{Schema, SchemaRef},
9    common::{tree_node::Transformed, Column},
10    error::Result,
11    execution::{context::SessionState, TaskContext},
12    logical_expr::{
13        expr::{
14            AggregateFunction, Alias, Exists, InList, InSubquery, PlannedReplaceSelectItem,
15            ScalarFunction, Sort, Unnest, WildcardOptions, WindowFunction,
16        },
17        Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan,
18        Subquery, TryCast,
19    },
20    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
21    physical_expr::EquivalenceProperties,
22    physical_plan::{
23        execution_plan::{Boundedness, EmissionType},
24        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
25        SendableRecordBatchStream,
26    },
27    sql::{
28        sqlparser::ast::Statement,
29        unparser::{plan_to_sql, Unparser},
30        TableReference,
31    },
32};
33
34pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef};
35pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource};
36
37use crate::{
38    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
39};
40
41// #[macro_use]
42// extern crate derive_builder;
43
44// SQLFederationProvider provides federation to SQL DMBSs.
45#[derive(Debug)]
46pub struct SQLFederationProvider {
47    optimizer: Arc<Optimizer>,
48    executor: Arc<dyn SQLExecutor>,
49}
50
51impl SQLFederationProvider {
52    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
53        Self {
54            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
55                SQLFederationOptimizerRule::new(executor.clone()),
56            )])),
57            executor,
58        }
59    }
60}
61
62impl FederationProvider for SQLFederationProvider {
63    fn name(&self) -> &str {
64        "sql_federation_provider"
65    }
66
67    fn compute_context(&self) -> Option<String> {
68        self.executor.compute_context()
69    }
70
71    fn optimizer(&self) -> Option<Arc<Optimizer>> {
72        Some(self.optimizer.clone())
73    }
74}
75
76#[derive(Debug)]
77struct SQLFederationOptimizerRule {
78    planner: Arc<dyn FederationPlanner>,
79}
80
81impl SQLFederationOptimizerRule {
82    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
83        Self {
84            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
85        }
86    }
87}
88
89impl OptimizerRule for SQLFederationOptimizerRule {
90    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
91    /// if the plan was rewritten and `Transformed::no` if it was not.
92    ///
93    /// Note: this function is only called if [`Self::supports_rewrite`] returns
94    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
95    fn rewrite(
96        &self,
97        plan: LogicalPlan,
98        _config: &dyn OptimizerConfig,
99    ) -> Result<Transformed<LogicalPlan>> {
100        if let LogicalPlan::Extension(Extension { ref node }) = plan {
101            if node.name() == "Federated" {
102                // Avoid attempting double federation
103                return Ok(Transformed::no(plan));
104            }
105        }
106        // Simply accept the entire plan for now
107        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
108        let ext_node = Extension {
109            node: Arc::new(fed_plan),
110        };
111        Ok(Transformed::yes(LogicalPlan::Extension(ext_node)))
112    }
113
114    /// A human readable name for this analyzer rule
115    fn name(&self) -> &str {
116        "federate_sql"
117    }
118
119    /// Does this rule support rewriting owned plans (rather than by reference)?
120    fn supports_rewrite(&self) -> bool {
121        true
122    }
123}
124
125/// Rewrite table scans to use the original federated table name.
126fn rewrite_table_scans(
127    plan: &LogicalPlan,
128    known_rewrites: &mut HashMap<TableReference, TableReference>,
129) -> Result<LogicalPlan> {
130    if plan.inputs().is_empty() {
131        if let LogicalPlan::TableScan(table_scan) = plan {
132            let original_table_name = table_scan.table_name.clone();
133            let mut new_table_scan = table_scan.clone();
134
135            let Some(federated_source) = get_table_source(&table_scan.source)? else {
136                // Not a federated source
137                return Ok(plan.clone());
138            };
139
140            match federated_source.as_any().downcast_ref::<SQLTableSource>() {
141                Some(sql_table_source) => {
142                    let remote_table_name = TableReference::from(sql_table_source.table_name());
143                    known_rewrites.insert(original_table_name, remote_table_name.clone());
144
145                    // Rewrite the schema of this node to have the remote table as the qualifier.
146                    let new_schema = (*new_table_scan.projected_schema)
147                        .clone()
148                        .replace_qualifier(remote_table_name.clone());
149                    new_table_scan.projected_schema = Arc::new(new_schema);
150                    new_table_scan.table_name = remote_table_name;
151                }
152                None => {
153                    // Not a SQLTableSource (is this possible?)
154                    return Ok(plan.clone());
155                }
156            }
157
158            return Ok(LogicalPlan::TableScan(new_table_scan));
159        } else {
160            return Ok(plan.clone());
161        }
162    }
163
164    let rewritten_inputs = plan
165        .inputs()
166        .into_iter()
167        .map(|i| rewrite_table_scans(i, known_rewrites))
168        .collect::<Result<Vec<_>>>()?;
169
170    if let LogicalPlan::Limit(limit) = plan {
171        let rewritten_skip = limit
172            .skip
173            .as_ref()
174            .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new))
175            .transpose()?;
176
177        let rewritten_fetch = limit
178            .fetch
179            .as_ref()
180            .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new))
181            .transpose()?;
182
183        // explicitly set fetch and skip
184        let new_plan = LogicalPlan::Limit(Limit {
185            skip: rewritten_skip,
186            fetch: rewritten_fetch,
187            input: Arc::new(rewritten_inputs[0].clone()),
188        });
189
190        return Ok(new_plan);
191    }
192
193    let mut new_expressions = vec![];
194    for expression in plan.expressions() {
195        let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?;
196        new_expressions.push(new_expr);
197    }
198
199    let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?;
200
201    Ok(new_plan)
202}
203
204// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite.
205// The name to rewrite should NOT be a substring of another name.
206// Supports multiple occurrences of table_ref_str in col_name.
207fn rewrite_column_name_in_expr(
208    col_name: &str,
209    table_ref_str: &str,
210    rewrite: &str,
211    start_pos: usize,
212) -> Option<String> {
213    if start_pos >= col_name.len() {
214        return None;
215    }
216
217    // Find the first occurrence of table_ref_str starting from start_pos
218    let idx = col_name[start_pos..].find(table_ref_str)?;
219
220    // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos
221    let idx = start_pos + idx;
222
223    if idx > 0 {
224        // Check if the previous character is alphabetic, numeric, underscore or period, in which case we
225        // should not rewrite as it is a part of another name.
226        if let Some(prev_char) = col_name.chars().nth(idx - 1) {
227            if prev_char.is_alphabetic()
228                || prev_char.is_numeric()
229                || prev_char == '_'
230                || prev_char == '.'
231            {
232                return rewrite_column_name_in_expr(
233                    col_name,
234                    table_ref_str,
235                    rewrite,
236                    idx + table_ref_str.len(),
237                );
238            }
239        }
240    }
241
242    // Check if the next character is alphabetic, numeric or underscore, in which case we
243    // should not rewrite as it is a part of another name.
244    if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) {
245        if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' {
246            return rewrite_column_name_in_expr(
247                col_name,
248                table_ref_str,
249                rewrite,
250                idx + table_ref_str.len(),
251            );
252        }
253    }
254
255    // Found full match, replace table_ref_str occurrence with rewrite
256    let rewritten_name = format!(
257        "{}{}{}",
258        &col_name[..idx],
259        rewrite,
260        &col_name[idx + table_ref_str.len()..]
261    );
262    // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well
263    // This is done by providing the updated start_pos for search
264    match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len())
265    {
266        Some(new_name) => Some(new_name), // more occurrences found
267        None => Some(rewritten_name),     // no more occurrences/changes
268    }
269}
270
271fn rewrite_table_scans_in_expr(
272    expr: Expr,
273    known_rewrites: &mut HashMap<TableReference, TableReference>,
274) -> Result<Expr> {
275    match expr {
276        Expr::ScalarSubquery(subquery) => {
277            let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?;
278            let outer_ref_columns = subquery
279                .outer_ref_columns
280                .into_iter()
281                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
282                .collect::<Result<Vec<Expr>>>()?;
283            Ok(Expr::ScalarSubquery(Subquery {
284                subquery: Arc::new(new_subquery),
285                outer_ref_columns,
286            }))
287        }
288        Expr::BinaryExpr(binary_expr) => {
289            let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?;
290            let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?;
291            Ok(Expr::BinaryExpr(BinaryExpr::new(
292                Box::new(left),
293                binary_expr.op,
294                Box::new(right),
295            )))
296        }
297        Expr::Column(mut col) => {
298            if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
299                Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name)))
300            } else {
301                // This prevent over-eager rewrite and only pass the column into below rewritten
302                // rule like MAX(...)
303                if col.relation.is_some() {
304                    return Ok(Expr::Column(col));
305                }
306
307                // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so.
308                // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)"
309                let (new_name, was_rewritten) = known_rewrites.iter().fold(
310                    (col.name.to_string(), false),
311                    |(col_name, was_rewritten), (table_ref, rewrite)| {
312                        match rewrite_column_name_in_expr(
313                            &col_name,
314                            &table_ref.to_string(),
315                            &rewrite.to_string(),
316                            0,
317                        ) {
318                            Some(new_name) => (new_name, true),
319                            None => (col_name, was_rewritten),
320                        }
321                    },
322                );
323                if was_rewritten {
324                    Ok(Expr::Column(Column::new(col.relation.take(), new_name)))
325                } else {
326                    Ok(Expr::Column(col))
327                }
328            }
329        }
330        Expr::Alias(alias) => {
331            let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?;
332            if let Some(relation) = &alias.relation {
333                if let Some(rewrite) = known_rewrites.get(relation) {
334                    return Ok(Expr::Alias(Alias::new(
335                        expr,
336                        Some(rewrite.clone()),
337                        alias.name,
338                    )));
339                }
340            }
341            Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name)))
342        }
343        Expr::Like(like) => {
344            let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?;
345            let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?;
346            Ok(Expr::Like(Like::new(
347                like.negated,
348                Box::new(expr),
349                Box::new(pattern),
350                like.escape_char,
351                like.case_insensitive,
352            )))
353        }
354        Expr::SimilarTo(similar_to) => {
355            let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?;
356            let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?;
357            Ok(Expr::SimilarTo(Like::new(
358                similar_to.negated,
359                Box::new(expr),
360                Box::new(pattern),
361                similar_to.escape_char,
362                similar_to.case_insensitive,
363            )))
364        }
365        Expr::Not(e) => {
366            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
367            Ok(Expr::Not(Box::new(expr)))
368        }
369        Expr::IsNotNull(e) => {
370            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
371            Ok(Expr::IsNotNull(Box::new(expr)))
372        }
373        Expr::IsNull(e) => {
374            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
375            Ok(Expr::IsNull(Box::new(expr)))
376        }
377        Expr::IsTrue(e) => {
378            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
379            Ok(Expr::IsTrue(Box::new(expr)))
380        }
381        Expr::IsFalse(e) => {
382            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
383            Ok(Expr::IsFalse(Box::new(expr)))
384        }
385        Expr::IsUnknown(e) => {
386            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
387            Ok(Expr::IsUnknown(Box::new(expr)))
388        }
389        Expr::IsNotTrue(e) => {
390            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
391            Ok(Expr::IsNotTrue(Box::new(expr)))
392        }
393        Expr::IsNotFalse(e) => {
394            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
395            Ok(Expr::IsNotFalse(Box::new(expr)))
396        }
397        Expr::IsNotUnknown(e) => {
398            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
399            Ok(Expr::IsNotUnknown(Box::new(expr)))
400        }
401        Expr::Negative(e) => {
402            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
403            Ok(Expr::Negative(Box::new(expr)))
404        }
405        Expr::Between(between) => {
406            let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?;
407            let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?;
408            let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?;
409            Ok(Expr::Between(Between::new(
410                Box::new(expr),
411                between.negated,
412                Box::new(low),
413                Box::new(high),
414            )))
415        }
416        Expr::Case(case) => {
417            let expr = case
418                .expr
419                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
420                .transpose()?
421                .map(Box::new);
422            let else_expr = case
423                .else_expr
424                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
425                .transpose()?
426                .map(Box::new);
427            let when_expr = case
428                .when_then_expr
429                .into_iter()
430                .map(|(when, then)| {
431                    let when = rewrite_table_scans_in_expr(*when, known_rewrites);
432                    let then = rewrite_table_scans_in_expr(*then, known_rewrites);
433
434                    match (when, then) {
435                        (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))),
436                        (Err(e), _) | (_, Err(e)) => Err(e),
437                    }
438                })
439                .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>>>()?;
440            Ok(Expr::Case(Case::new(expr, when_expr, else_expr)))
441        }
442        Expr::Cast(cast) => {
443            let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?;
444            Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type)))
445        }
446        Expr::TryCast(try_cast) => {
447            let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?;
448            Ok(Expr::TryCast(TryCast::new(
449                Box::new(expr),
450                try_cast.data_type,
451            )))
452        }
453        Expr::ScalarFunction(sf) => {
454            let args = sf
455                .args
456                .into_iter()
457                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
458                .collect::<Result<Vec<Expr>>>()?;
459            Ok(Expr::ScalarFunction(ScalarFunction {
460                func: sf.func,
461                args,
462            }))
463        }
464        Expr::AggregateFunction(af) => {
465            let args = af
466                .args
467                .into_iter()
468                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
469                .collect::<Result<Vec<Expr>>>()?;
470            let filter = af
471                .filter
472                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
473                .transpose()?
474                .map(Box::new);
475            let order_by = af
476                .order_by
477                .map(|e| {
478                    e.into_iter()
479                        .map(|sort| {
480                            Ok(Sort {
481                                expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
482                                ..sort
483                            })
484                        })
485                        .collect::<Result<Vec<_>>>()
486                })
487                .transpose()?;
488            Ok(Expr::AggregateFunction(AggregateFunction {
489                func: af.func,
490                args,
491                distinct: af.distinct,
492                filter,
493                order_by,
494                null_treatment: af.null_treatment,
495            }))
496        }
497        Expr::WindowFunction(wf) => {
498            let args = wf
499                .args
500                .into_iter()
501                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
502                .collect::<Result<Vec<Expr>>>()?;
503            let partition_by = wf
504                .partition_by
505                .into_iter()
506                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
507                .collect::<Result<Vec<Expr>>>()?;
508            let order_by = wf
509                .order_by
510                .into_iter()
511                .map(|sort| {
512                    Ok(Sort {
513                        expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
514                        ..sort
515                    })
516                })
517                .collect::<Result<Vec<_>>>()?;
518            Ok(Expr::WindowFunction(WindowFunction {
519                fun: wf.fun,
520                args,
521                partition_by,
522                order_by,
523                window_frame: wf.window_frame,
524                null_treatment: wf.null_treatment,
525            }))
526        }
527        Expr::InList(il) => {
528            let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?;
529            let list = il
530                .list
531                .into_iter()
532                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
533                .collect::<Result<Vec<Expr>>>()?;
534            Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated)))
535        }
536        Expr::Exists(exists) => {
537            let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?;
538            let outer_ref_columns = exists
539                .subquery
540                .outer_ref_columns
541                .into_iter()
542                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
543                .collect::<Result<Vec<Expr>>>()?;
544            let subquery = Subquery {
545                subquery: Arc::new(subquery_plan),
546                outer_ref_columns,
547            };
548            Ok(Expr::Exists(Exists::new(subquery, exists.negated)))
549        }
550        Expr::InSubquery(is) => {
551            let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?;
552            let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?;
553            let outer_ref_columns = is
554                .subquery
555                .outer_ref_columns
556                .into_iter()
557                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
558                .collect::<Result<Vec<Expr>>>()?;
559            let subquery = Subquery {
560                subquery: Arc::new(subquery_plan),
561                outer_ref_columns,
562            };
563            Ok(Expr::InSubquery(InSubquery::new(
564                Box::new(expr),
565                subquery,
566                is.negated,
567            )))
568        }
569        Expr::Wildcard { qualifier, options } => {
570            let options = WildcardOptions {
571                replace: options
572                    .replace
573                    .map(|replace| -> Result<PlannedReplaceSelectItem> {
574                        Ok(PlannedReplaceSelectItem {
575                            planned_expressions: replace
576                                .planned_expressions
577                                .into_iter()
578                                .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites))
579                                .collect::<Result<Vec<_>>>()?,
580                            ..replace
581                        })
582                    })
583                    .transpose()?,
584                ..*options
585            };
586            if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) {
587                Ok(Expr::Wildcard {
588                    qualifier: Some(rewrite.clone()),
589                    options: Box::new(options),
590                })
591            } else {
592                Ok(Expr::Wildcard {
593                    qualifier,
594                    options: Box::new(options),
595                })
596            }
597        }
598        Expr::GroupingSet(gs) => match gs {
599            GroupingSet::Rollup(exprs) => {
600                let exprs = exprs
601                    .into_iter()
602                    .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
603                    .collect::<Result<Vec<Expr>>>()?;
604                Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs)))
605            }
606            GroupingSet::Cube(exprs) => {
607                let exprs = exprs
608                    .into_iter()
609                    .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
610                    .collect::<Result<Vec<Expr>>>()?;
611                Ok(Expr::GroupingSet(GroupingSet::Cube(exprs)))
612            }
613            GroupingSet::GroupingSets(vec_exprs) => {
614                let vec_exprs = vec_exprs
615                    .into_iter()
616                    .map(|exprs| {
617                        exprs
618                            .into_iter()
619                            .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
620                            .collect::<Result<Vec<Expr>>>()
621                    })
622                    .collect::<Result<Vec<Vec<Expr>>>>()?;
623                Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs)))
624            }
625        },
626        Expr::OuterReferenceColumn(dt, col) => {
627            if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
628                Ok(Expr::OuterReferenceColumn(
629                    dt,
630                    Column::new(Some(rewrite.clone()), &col.name),
631                ))
632            } else {
633                Ok(Expr::OuterReferenceColumn(dt, col))
634            }
635        }
636        Expr::Unnest(unnest) => {
637            let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?;
638            Ok(Expr::Unnest(Unnest::new(expr)))
639        }
640        Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr),
641    }
642}
643
644struct SQLFederationPlanner {
645    executor: Arc<dyn SQLExecutor>,
646}
647
648impl SQLFederationPlanner {
649    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
650        Self { executor }
651    }
652}
653
654#[async_trait]
655impl FederationPlanner for SQLFederationPlanner {
656    async fn plan_federation(
657        &self,
658        node: &FederatedPlanNode,
659        _session_state: &SessionState,
660    ) -> Result<Arc<dyn ExecutionPlan>> {
661        let schema = Arc::new(node.plan().schema().as_arrow().clone());
662        let input = Arc::new(VirtualExecutionPlan::new(
663            node.plan().clone(),
664            Arc::clone(&self.executor),
665        ));
666        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
667        Ok(Arc::new(schema_cast_exec))
668    }
669}
670
671#[derive(Debug, Clone)]
672struct VirtualExecutionPlan {
673    plan: LogicalPlan,
674    executor: Arc<dyn SQLExecutor>,
675    props: PlanProperties,
676}
677
678impl VirtualExecutionPlan {
679    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
680        let schema: Schema = plan.schema().as_ref().into();
681        let props = PlanProperties::new(
682            EquivalenceProperties::new(Arc::new(schema)),
683            Partitioning::UnknownPartitioning(1),
684            EmissionType::Incremental,
685            Boundedness::Bounded,
686        );
687        Self {
688            plan,
689            executor,
690            props,
691        }
692    }
693
694    fn schema(&self) -> SchemaRef {
695        let df_schema = self.plan.schema().as_ref();
696        Arc::new(Schema::from(df_schema))
697    }
698
699    fn sql(&self) -> Result<String> {
700        // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table.
701        let mut known_rewrites = HashMap::new();
702        let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?;
703        let mut ast = self.plan_to_sql(plan)?;
704
705        if let Some(analyzer) = self.executor.ast_analyzer() {
706            ast = analyzer(ast)?;
707        }
708
709        Ok(format!("{ast}"))
710    }
711
712    fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<Statement> {
713        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
714    }
715}
716
717impl DisplayAs for VirtualExecutionPlan {
718    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
719        write!(f, "VirtualExecutionPlan")?;
720        let Ok(ast) = plan_to_sql(&self.plan) else {
721            return Ok(());
722        };
723        write!(f, " name={}", self.executor.name())?;
724        if let Some(ctx) = self.executor.compute_context() {
725            write!(f, " compute_context={ctx}")?;
726        };
727
728        write!(f, " sql={ast}")?;
729        if let Ok(query) = self.sql() {
730            write!(f, " rewritten_sql={query}")?;
731        };
732
733        write!(f, " sql={ast}")
734    }
735}
736
737impl ExecutionPlan for VirtualExecutionPlan {
738    fn name(&self) -> &str {
739        "sql_federation_exec"
740    }
741
742    fn as_any(&self) -> &dyn Any {
743        self
744    }
745
746    fn schema(&self) -> SchemaRef {
747        self.schema()
748    }
749
750    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
751        vec![]
752    }
753
754    fn with_new_children(
755        self: Arc<Self>,
756        _: Vec<Arc<dyn ExecutionPlan>>,
757    ) -> Result<Arc<dyn ExecutionPlan>> {
758        Ok(self)
759    }
760
761    fn execute(
762        &self,
763        _partition: usize,
764        _context: Arc<TaskContext>,
765    ) -> Result<SendableRecordBatchStream> {
766        let query = self.plan_to_sql(&self.plan)?.to_string();
767        self.executor.execute(query.as_str(), self.schema())
768    }
769
770    fn properties(&self) -> &PlanProperties {
771        &self.props
772    }
773}
774
775#[cfg(test)]
776mod tests {
777    use crate::FederatedTableProviderAdaptor;
778    use datafusion::{
779        arrow::datatypes::{DataType, Field},
780        catalog::SchemaProvider,
781        catalog_common::MemorySchemaProvider,
782        common::Column,
783        datasource::{DefaultTableSource, TableProvider},
784        error::DataFusionError,
785        execution::context::SessionContext,
786        logical_expr::LogicalPlanBuilder,
787        sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect},
788    };
789
790    use super::*;
791
792    struct TestSQLExecutor {}
793
794    #[async_trait]
795    impl SQLExecutor for TestSQLExecutor {
796        fn name(&self) -> &str {
797            "test_sql_table_source"
798        }
799
800        fn compute_context(&self) -> Option<String> {
801            None
802        }
803
804        fn dialect(&self) -> Arc<dyn Dialect> {
805            Arc::new(DefaultDialect {})
806        }
807
808        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
809            Err(DataFusionError::NotImplemented(
810                "execute not implemented".to_string(),
811            ))
812        }
813
814        async fn table_names(&self) -> Result<Vec<String>> {
815            Err(DataFusionError::NotImplemented(
816                "table inference not implemented".to_string(),
817            ))
818        }
819
820        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
821            Err(DataFusionError::NotImplemented(
822                "table inference not implemented".to_string(),
823            ))
824        }
825    }
826
827    fn get_test_table_provider() -> Arc<dyn TableProvider> {
828        let sql_federation_provider =
829            Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
830
831        let schema = Arc::new(Schema::new(vec![
832            Field::new("a", DataType::Int64, false),
833            Field::new("b", DataType::Utf8, false),
834            Field::new("c", DataType::Date32, false),
835        ]));
836        let table_source = Arc::new(
837            SQLTableSource::new_with_schema(
838                sql_federation_provider,
839                "remote_table".to_string(),
840                schema,
841            )
842            .expect("to have a valid SQLTableSource"),
843        );
844        Arc::new(FederatedTableProviderAdaptor::new(table_source))
845    }
846
847    fn get_test_table_source() -> Arc<DefaultTableSource> {
848        Arc::new(DefaultTableSource::new(get_test_table_provider()))
849    }
850
851    fn get_test_df_context() -> SessionContext {
852        let ctx = SessionContext::new();
853        let catalog = ctx
854            .catalog("datafusion")
855            .expect("default catalog is datafusion");
856        let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
857        catalog
858            .register_schema("foo", Arc::clone(&foo_schema))
859            .expect("to register schema");
860        foo_schema
861            .register_table("df_table".to_string(), get_test_table_provider())
862            .expect("to register table");
863
864        let public_schema = catalog
865            .schema("public")
866            .expect("public schema should exist");
867        public_schema
868            .register_table("app_table".to_string(), get_test_table_provider())
869            .expect("to register table");
870
871        ctx
872    }
873
874    #[test]
875    fn test_rewrite_table_scans_basic() -> Result<()> {
876        let default_table_source = get_test_table_source();
877        let plan =
878            LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![
879                Expr::Column(Column::from_qualified_name("foo.df_table.a")),
880                Expr::Column(Column::from_qualified_name("foo.df_table.b")),
881                Expr::Column(Column::from_qualified_name("foo.df_table.c")),
882            ])?;
883
884        let mut known_rewrites = HashMap::new();
885        let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?;
886
887        println!("rewritten_plan: \n{:#?}", rewritten_plan);
888
889        let unparsed_sql = plan_to_sql(&rewritten_plan)?;
890
891        println!("unparsed_sql: \n{unparsed_sql}");
892
893        assert_eq!(
894            format!("{unparsed_sql}"),
895            r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"#
896        );
897
898        Ok(())
899    }
900
901    fn init_tracing() {
902        let subscriber = tracing_subscriber::FmtSubscriber::builder()
903            .with_env_filter("debug")
904            .with_ansi(true)
905            .finish();
906        let _ = tracing::subscriber::set_global_default(subscriber);
907    }
908
909    #[tokio::test]
910    async fn test_rewrite_table_scans_agg() -> Result<()> {
911        init_tracing();
912        let ctx = get_test_df_context();
913
914        let agg_tests = vec![
915            (
916                "SELECT MAX(a) FROM foo.df_table",
917                r#"SELECT max(remote_table.a) FROM remote_table"#,
918            ),
919            (
920                "SELECT foo.df_table.a FROM foo.df_table",
921                r#"SELECT remote_table.a FROM remote_table"#,
922            ),
923            (
924                "SELECT MIN(a) FROM foo.df_table",
925                r#"SELECT min(remote_table.a) FROM remote_table"#,
926            ),
927            (
928                "SELECT AVG(a) FROM foo.df_table",
929                r#"SELECT avg(remote_table.a) FROM remote_table"#,
930            ),
931            (
932                "SELECT SUM(a) FROM foo.df_table",
933                r#"SELECT sum(remote_table.a) FROM remote_table"#,
934            ),
935            (
936                "SELECT COUNT(a) FROM foo.df_table",
937                r#"SELECT count(remote_table.a) FROM remote_table"#,
938            ),
939            (
940                "SELECT COUNT(a) as cnt FROM foo.df_table",
941                r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
942            ),
943            (
944                "SELECT COUNT(a) as cnt FROM foo.df_table",
945                r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
946            ),
947            (
948                "SELECT app_table from (SELECT a as app_table FROM app_table) b",
949                r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
950            ),
951            (
952                "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b",
953                r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
954            ),
955            // multiple occurrences of the same table in single aggregation expression
956            (
957                "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table",
958                r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#,
959            ),
960            // different tables in single aggregation expression
961            (
962                "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft",
963                "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft"
964            ),
965        ];
966
967        for test in agg_tests {
968            test_sql(&ctx, test.0, test.1).await?;
969        }
970
971        Ok(())
972    }
973
974    #[tokio::test]
975    async fn test_rewrite_table_scans_alias() -> Result<()> {
976        init_tracing();
977        let ctx = get_test_df_context();
978
979        let tests = vec![
980            (
981                "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)",
982                r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
983            ),
984            (
985                "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)",
986                r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
987            ),
988            (
989                "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)",
990                r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#,
991            ),
992        ];
993
994        for test in tests {
995            test_sql(&ctx, test.0, test.1).await?;
996        }
997
998        Ok(())
999    }
1000
1001    async fn test_sql(
1002        ctx: &SessionContext,
1003        sql_query: &str,
1004        expected_sql: &str,
1005    ) -> Result<(), datafusion::error::DataFusionError> {
1006        let data_frame = ctx.sql(sql_query).await?;
1007
1008        println!("before optimization: \n{:#?}", data_frame.logical_plan());
1009
1010        let mut known_rewrites = HashMap::new();
1011        let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?;
1012
1013        println!("rewritten_plan: \n{:#?}", rewritten_plan);
1014
1015        let unparsed_sql = plan_to_sql(&rewritten_plan)?;
1016
1017        println!("unparsed_sql: \n{unparsed_sql}");
1018
1019        assert_eq!(
1020            format!("{unparsed_sql}"),
1021            expected_sql,
1022            "SQL under test: {}",
1023            sql_query
1024        );
1025
1026        Ok(())
1027    }
1028
1029    #[tokio::test]
1030    async fn test_rewrite_table_scans_limit_offset() -> Result<()> {
1031        init_tracing();
1032        let ctx = get_test_df_context();
1033
1034        let tests = vec![
1035            // Basic LIMIT
1036            (
1037                "SELECT a FROM foo.df_table LIMIT 5",
1038                r#"SELECT remote_table.a FROM remote_table LIMIT 5"#,
1039            ),
1040            // Basic OFFSET
1041            (
1042                "SELECT a FROM foo.df_table OFFSET 5",
1043                r#"SELECT remote_table.a FROM remote_table OFFSET 5"#,
1044            ),
1045            // OFFSET after LIMIT
1046            (
1047                "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5",
1048                r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1049            ),
1050            // LIMIT after OFFSET
1051            (
1052                "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10",
1053                r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1054            ),
1055            // Zero OFFSET
1056            (
1057                "SELECT a FROM foo.df_table OFFSET 0",
1058                r#"SELECT remote_table.a FROM remote_table OFFSET 0"#,
1059            ),
1060            // Zero LIMIT
1061            (
1062                "SELECT a FROM foo.df_table LIMIT 0",
1063                r#"SELECT remote_table.a FROM remote_table LIMIT 0"#,
1064            ),
1065            // Zero LIMIT and OFFSET
1066            (
1067                "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0",
1068                r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#,
1069            ),
1070        ];
1071
1072        for test in tests {
1073            test_sql(&ctx, test.0, test.1).await?;
1074        }
1075
1076        Ok(())
1077    }
1078}