datafusion_federation/sql/
mod.rs

1mod executor;
2mod schema;
3
4use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec};
5
6use async_trait::async_trait;
7use datafusion::{
8    arrow::datatypes::{Schema, SchemaRef},
9    common::{tree_node::Transformed, Column},
10    error::Result,
11    execution::{context::SessionState, TaskContext},
12    logical_expr::{
13        expr::{
14            AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery,
15            PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions,
16            WindowFunction, WindowFunctionParams,
17        },
18        Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan,
19        Subquery, TryCast,
20    },
21    optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule},
22    physical_expr::EquivalenceProperties,
23    physical_plan::{
24        execution_plan::{Boundedness, EmissionType},
25        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
26        SendableRecordBatchStream,
27    },
28    sql::{
29        sqlparser::ast::Statement,
30        unparser::{plan_to_sql, Unparser},
31        TableReference,
32    },
33};
34
35pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef};
36pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource};
37
38use crate::{
39    get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
40};
41
42// #[macro_use]
43// extern crate derive_builder;
44
45// SQLFederationProvider provides federation to SQL DMBSs.
46#[derive(Debug)]
47pub struct SQLFederationProvider {
48    optimizer: Arc<Optimizer>,
49    executor: Arc<dyn SQLExecutor>,
50}
51
52impl SQLFederationProvider {
53    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
54        Self {
55            optimizer: Arc::new(Optimizer::with_rules(vec![Arc::new(
56                SQLFederationOptimizerRule::new(executor.clone()),
57            )])),
58            executor,
59        }
60    }
61}
62
63impl FederationProvider for SQLFederationProvider {
64    fn name(&self) -> &str {
65        "sql_federation_provider"
66    }
67
68    fn compute_context(&self) -> Option<String> {
69        self.executor.compute_context()
70    }
71
72    fn optimizer(&self) -> Option<Arc<Optimizer>> {
73        Some(self.optimizer.clone())
74    }
75}
76
77#[derive(Debug)]
78struct SQLFederationOptimizerRule {
79    planner: Arc<dyn FederationPlanner>,
80}
81
82impl SQLFederationOptimizerRule {
83    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
84        Self {
85            planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))),
86        }
87    }
88}
89
90impl OptimizerRule for SQLFederationOptimizerRule {
91    /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes`
92    /// if the plan was rewritten and `Transformed::no` if it was not.
93    ///
94    /// Note: this function is only called if [`Self::supports_rewrite`] returns
95    /// true. Otherwise the Optimizer calls  [`Self::try_optimize`]
96    fn rewrite(
97        &self,
98        plan: LogicalPlan,
99        _config: &dyn OptimizerConfig,
100    ) -> Result<Transformed<LogicalPlan>> {
101        if let LogicalPlan::Extension(Extension { ref node }) = plan {
102            if node.name() == "Federated" {
103                // Avoid attempting double federation
104                return Ok(Transformed::no(plan));
105            }
106        }
107        // Simply accept the entire plan for now
108        let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone());
109        let ext_node = Extension {
110            node: Arc::new(fed_plan),
111        };
112        Ok(Transformed::yes(LogicalPlan::Extension(ext_node)))
113    }
114
115    /// A human readable name for this analyzer rule
116    fn name(&self) -> &str {
117        "federate_sql"
118    }
119
120    /// Does this rule support rewriting owned plans (rather than by reference)?
121    fn supports_rewrite(&self) -> bool {
122        true
123    }
124}
125
126/// Rewrite table scans to use the original federated table name.
127fn rewrite_table_scans(
128    plan: &LogicalPlan,
129    known_rewrites: &mut HashMap<TableReference, TableReference>,
130) -> Result<LogicalPlan> {
131    if plan.inputs().is_empty() {
132        if let LogicalPlan::TableScan(table_scan) = plan {
133            let original_table_name = table_scan.table_name.clone();
134            let mut new_table_scan = table_scan.clone();
135
136            let Some(federated_source) = get_table_source(&table_scan.source)? else {
137                // Not a federated source
138                return Ok(plan.clone());
139            };
140
141            match federated_source.as_any().downcast_ref::<SQLTableSource>() {
142                Some(sql_table_source) => {
143                    let remote_table_name = TableReference::from(sql_table_source.table_name());
144                    known_rewrites.insert(original_table_name, remote_table_name.clone());
145
146                    // Rewrite the schema of this node to have the remote table as the qualifier.
147                    let new_schema = (*new_table_scan.projected_schema)
148                        .clone()
149                        .replace_qualifier(remote_table_name.clone());
150                    new_table_scan.projected_schema = Arc::new(new_schema);
151                    new_table_scan.table_name = remote_table_name;
152                }
153                None => {
154                    // Not a SQLTableSource (is this possible?)
155                    return Ok(plan.clone());
156                }
157            }
158
159            return Ok(LogicalPlan::TableScan(new_table_scan));
160        } else {
161            return Ok(plan.clone());
162        }
163    }
164
165    let rewritten_inputs = plan
166        .inputs()
167        .into_iter()
168        .map(|i| rewrite_table_scans(i, known_rewrites))
169        .collect::<Result<Vec<_>>>()?;
170
171    if let LogicalPlan::Limit(limit) = plan {
172        let rewritten_skip = limit
173            .skip
174            .as_ref()
175            .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new))
176            .transpose()?;
177
178        let rewritten_fetch = limit
179            .fetch
180            .as_ref()
181            .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new))
182            .transpose()?;
183
184        // explicitly set fetch and skip
185        let new_plan = LogicalPlan::Limit(Limit {
186            skip: rewritten_skip,
187            fetch: rewritten_fetch,
188            input: Arc::new(rewritten_inputs[0].clone()),
189        });
190
191        return Ok(new_plan);
192    }
193
194    let mut new_expressions = vec![];
195    for expression in plan.expressions() {
196        let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?;
197        new_expressions.push(new_expr);
198    }
199
200    let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?;
201
202    Ok(new_plan)
203}
204
205// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite.
206// The name to rewrite should NOT be a substring of another name.
207// Supports multiple occurrences of table_ref_str in col_name.
208fn rewrite_column_name_in_expr(
209    col_name: &str,
210    table_ref_str: &str,
211    rewrite: &str,
212    start_pos: usize,
213) -> Option<String> {
214    if start_pos >= col_name.len() {
215        return None;
216    }
217
218    // Find the first occurrence of table_ref_str starting from start_pos
219    let idx = col_name[start_pos..].find(table_ref_str)?;
220
221    // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos
222    let idx = start_pos + idx;
223
224    if idx > 0 {
225        // Check if the previous character is alphabetic, numeric, underscore or period, in which case we
226        // should not rewrite as it is a part of another name.
227        if let Some(prev_char) = col_name.chars().nth(idx - 1) {
228            if prev_char.is_alphabetic()
229                || prev_char.is_numeric()
230                || prev_char == '_'
231                || prev_char == '.'
232            {
233                return rewrite_column_name_in_expr(
234                    col_name,
235                    table_ref_str,
236                    rewrite,
237                    idx + table_ref_str.len(),
238                );
239            }
240        }
241    }
242
243    // Check if the next character is alphabetic, numeric or underscore, in which case we
244    // should not rewrite as it is a part of another name.
245    if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) {
246        if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' {
247            return rewrite_column_name_in_expr(
248                col_name,
249                table_ref_str,
250                rewrite,
251                idx + table_ref_str.len(),
252            );
253        }
254    }
255
256    // Found full match, replace table_ref_str occurrence with rewrite
257    let rewritten_name = format!(
258        "{}{}{}",
259        &col_name[..idx],
260        rewrite,
261        &col_name[idx + table_ref_str.len()..]
262    );
263    // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well
264    // This is done by providing the updated start_pos for search
265    match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len())
266    {
267        Some(new_name) => Some(new_name), // more occurrences found
268        None => Some(rewritten_name),     // no more occurrences/changes
269    }
270}
271
272fn rewrite_table_scans_in_expr(
273    expr: Expr,
274    known_rewrites: &mut HashMap<TableReference, TableReference>,
275) -> Result<Expr> {
276    match expr {
277        Expr::ScalarSubquery(subquery) => {
278            let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?;
279            let outer_ref_columns = subquery
280                .outer_ref_columns
281                .into_iter()
282                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
283                .collect::<Result<Vec<Expr>>>()?;
284            Ok(Expr::ScalarSubquery(Subquery {
285                subquery: Arc::new(new_subquery),
286                outer_ref_columns,
287            }))
288        }
289        Expr::BinaryExpr(binary_expr) => {
290            let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?;
291            let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?;
292            Ok(Expr::BinaryExpr(BinaryExpr::new(
293                Box::new(left),
294                binary_expr.op,
295                Box::new(right),
296            )))
297        }
298        Expr::Column(mut col) => {
299            if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
300                Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name)))
301            } else {
302                // This prevent over-eager rewrite and only pass the column into below rewritten
303                // rule like MAX(...)
304                if col.relation.is_some() {
305                    return Ok(Expr::Column(col));
306                }
307
308                // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so.
309                // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)"
310                let (new_name, was_rewritten) = known_rewrites.iter().fold(
311                    (col.name.to_string(), false),
312                    |(col_name, was_rewritten), (table_ref, rewrite)| {
313                        match rewrite_column_name_in_expr(
314                            &col_name,
315                            &table_ref.to_string(),
316                            &rewrite.to_string(),
317                            0,
318                        ) {
319                            Some(new_name) => (new_name, true),
320                            None => (col_name, was_rewritten),
321                        }
322                    },
323                );
324                if was_rewritten {
325                    Ok(Expr::Column(Column::new(col.relation.take(), new_name)))
326                } else {
327                    Ok(Expr::Column(col))
328                }
329            }
330        }
331        Expr::Alias(alias) => {
332            let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?;
333            if let Some(relation) = &alias.relation {
334                if let Some(rewrite) = known_rewrites.get(relation) {
335                    return Ok(Expr::Alias(Alias::new(
336                        expr,
337                        Some(rewrite.clone()),
338                        alias.name,
339                    )));
340                }
341            }
342            Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name)))
343        }
344        Expr::Like(like) => {
345            let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?;
346            let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?;
347            Ok(Expr::Like(Like::new(
348                like.negated,
349                Box::new(expr),
350                Box::new(pattern),
351                like.escape_char,
352                like.case_insensitive,
353            )))
354        }
355        Expr::SimilarTo(similar_to) => {
356            let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?;
357            let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?;
358            Ok(Expr::SimilarTo(Like::new(
359                similar_to.negated,
360                Box::new(expr),
361                Box::new(pattern),
362                similar_to.escape_char,
363                similar_to.case_insensitive,
364            )))
365        }
366        Expr::Not(e) => {
367            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
368            Ok(Expr::Not(Box::new(expr)))
369        }
370        Expr::IsNotNull(e) => {
371            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
372            Ok(Expr::IsNotNull(Box::new(expr)))
373        }
374        Expr::IsNull(e) => {
375            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
376            Ok(Expr::IsNull(Box::new(expr)))
377        }
378        Expr::IsTrue(e) => {
379            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
380            Ok(Expr::IsTrue(Box::new(expr)))
381        }
382        Expr::IsFalse(e) => {
383            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
384            Ok(Expr::IsFalse(Box::new(expr)))
385        }
386        Expr::IsUnknown(e) => {
387            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
388            Ok(Expr::IsUnknown(Box::new(expr)))
389        }
390        Expr::IsNotTrue(e) => {
391            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
392            Ok(Expr::IsNotTrue(Box::new(expr)))
393        }
394        Expr::IsNotFalse(e) => {
395            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
396            Ok(Expr::IsNotFalse(Box::new(expr)))
397        }
398        Expr::IsNotUnknown(e) => {
399            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
400            Ok(Expr::IsNotUnknown(Box::new(expr)))
401        }
402        Expr::Negative(e) => {
403            let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?;
404            Ok(Expr::Negative(Box::new(expr)))
405        }
406        Expr::Between(between) => {
407            let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?;
408            let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?;
409            let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?;
410            Ok(Expr::Between(Between::new(
411                Box::new(expr),
412                between.negated,
413                Box::new(low),
414                Box::new(high),
415            )))
416        }
417        Expr::Case(case) => {
418            let expr = case
419                .expr
420                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
421                .transpose()?
422                .map(Box::new);
423            let else_expr = case
424                .else_expr
425                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
426                .transpose()?
427                .map(Box::new);
428            let when_expr = case
429                .when_then_expr
430                .into_iter()
431                .map(|(when, then)| {
432                    let when = rewrite_table_scans_in_expr(*when, known_rewrites);
433                    let then = rewrite_table_scans_in_expr(*then, known_rewrites);
434
435                    match (when, then) {
436                        (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))),
437                        (Err(e), _) | (_, Err(e)) => Err(e),
438                    }
439                })
440                .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>>>()?;
441            Ok(Expr::Case(Case::new(expr, when_expr, else_expr)))
442        }
443        Expr::Cast(cast) => {
444            let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?;
445            Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type)))
446        }
447        Expr::TryCast(try_cast) => {
448            let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?;
449            Ok(Expr::TryCast(TryCast::new(
450                Box::new(expr),
451                try_cast.data_type,
452            )))
453        }
454        Expr::ScalarFunction(sf) => {
455            let args = sf
456                .args
457                .into_iter()
458                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
459                .collect::<Result<Vec<Expr>>>()?;
460            Ok(Expr::ScalarFunction(ScalarFunction {
461                func: sf.func,
462                args,
463            }))
464        }
465        Expr::AggregateFunction(af) => {
466            let args = af
467                .params
468                .args
469                .into_iter()
470                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
471                .collect::<Result<Vec<Expr>>>()?;
472            let filter = af
473                .params
474                .filter
475                .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites))
476                .transpose()?
477                .map(Box::new);
478            let order_by = af
479                .params
480                .order_by
481                .map(|e| {
482                    e.into_iter()
483                        .map(|sort| {
484                            Ok(Sort {
485                                expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
486                                ..sort
487                            })
488                        })
489                        .collect::<Result<Vec<_>>>()
490                })
491                .transpose()?;
492            let params = AggregateFunctionParams {
493                args,
494                distinct: af.params.distinct,
495                filter,
496                order_by,
497                null_treatment: af.params.null_treatment,
498            };
499            Ok(Expr::AggregateFunction(AggregateFunction {
500                func: af.func,
501                params,
502            }))
503        }
504        Expr::WindowFunction(wf) => {
505            let args = wf
506                .params
507                .args
508                .into_iter()
509                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
510                .collect::<Result<Vec<Expr>>>()?;
511            let partition_by = wf
512                .params
513                .partition_by
514                .into_iter()
515                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
516                .collect::<Result<Vec<Expr>>>()?;
517            let order_by = wf
518                .params
519                .order_by
520                .into_iter()
521                .map(|sort| {
522                    Ok(Sort {
523                        expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?,
524                        ..sort
525                    })
526                })
527                .collect::<Result<Vec<_>>>()?;
528            let params = WindowFunctionParams {
529                args,
530                partition_by,
531                order_by,
532                window_frame: wf.params.window_frame,
533                null_treatment: wf.params.null_treatment,
534            };
535            Ok(Expr::WindowFunction(WindowFunction {
536                fun: wf.fun,
537                params,
538            }))
539        }
540        Expr::InList(il) => {
541            let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?;
542            let list = il
543                .list
544                .into_iter()
545                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
546                .collect::<Result<Vec<Expr>>>()?;
547            Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated)))
548        }
549        Expr::Exists(exists) => {
550            let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?;
551            let outer_ref_columns = exists
552                .subquery
553                .outer_ref_columns
554                .into_iter()
555                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
556                .collect::<Result<Vec<Expr>>>()?;
557            let subquery = Subquery {
558                subquery: Arc::new(subquery_plan),
559                outer_ref_columns,
560            };
561            Ok(Expr::Exists(Exists::new(subquery, exists.negated)))
562        }
563        Expr::InSubquery(is) => {
564            let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?;
565            let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?;
566            let outer_ref_columns = is
567                .subquery
568                .outer_ref_columns
569                .into_iter()
570                .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
571                .collect::<Result<Vec<Expr>>>()?;
572            let subquery = Subquery {
573                subquery: Arc::new(subquery_plan),
574                outer_ref_columns,
575            };
576            Ok(Expr::InSubquery(InSubquery::new(
577                Box::new(expr),
578                subquery,
579                is.negated,
580            )))
581        }
582        // TODO: remove the next line after `Expr::Wildcard` is removed in datafusion
583        #[expect(deprecated)]
584        Expr::Wildcard { qualifier, options } => {
585            let options = WildcardOptions {
586                replace: options
587                    .replace
588                    .map(|replace| -> Result<PlannedReplaceSelectItem> {
589                        Ok(PlannedReplaceSelectItem {
590                            planned_expressions: replace
591                                .planned_expressions
592                                .into_iter()
593                                .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites))
594                                .collect::<Result<Vec<_>>>()?,
595                            ..replace
596                        })
597                    })
598                    .transpose()?,
599                ..*options
600            };
601            if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) {
602                Ok(Expr::Wildcard {
603                    qualifier: Some(rewrite.clone()),
604                    options: Box::new(options),
605                })
606            } else {
607                Ok(Expr::Wildcard {
608                    qualifier,
609                    options: Box::new(options),
610                })
611            }
612        }
613        Expr::GroupingSet(gs) => match gs {
614            GroupingSet::Rollup(exprs) => {
615                let exprs = exprs
616                    .into_iter()
617                    .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
618                    .collect::<Result<Vec<Expr>>>()?;
619                Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs)))
620            }
621            GroupingSet::Cube(exprs) => {
622                let exprs = exprs
623                    .into_iter()
624                    .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
625                    .collect::<Result<Vec<Expr>>>()?;
626                Ok(Expr::GroupingSet(GroupingSet::Cube(exprs)))
627            }
628            GroupingSet::GroupingSets(vec_exprs) => {
629                let vec_exprs = vec_exprs
630                    .into_iter()
631                    .map(|exprs| {
632                        exprs
633                            .into_iter()
634                            .map(|e| rewrite_table_scans_in_expr(e, known_rewrites))
635                            .collect::<Result<Vec<Expr>>>()
636                    })
637                    .collect::<Result<Vec<Vec<Expr>>>>()?;
638                Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs)))
639            }
640        },
641        Expr::OuterReferenceColumn(dt, col) => {
642            if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) {
643                Ok(Expr::OuterReferenceColumn(
644                    dt,
645                    Column::new(Some(rewrite.clone()), &col.name),
646                ))
647            } else {
648                Ok(Expr::OuterReferenceColumn(dt, col))
649            }
650        }
651        Expr::Unnest(unnest) => {
652            let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?;
653            Ok(Expr::Unnest(Unnest::new(expr)))
654        }
655        Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr),
656    }
657}
658
659struct SQLFederationPlanner {
660    executor: Arc<dyn SQLExecutor>,
661}
662
663impl SQLFederationPlanner {
664    pub fn new(executor: Arc<dyn SQLExecutor>) -> Self {
665        Self { executor }
666    }
667}
668
669#[async_trait]
670impl FederationPlanner for SQLFederationPlanner {
671    async fn plan_federation(
672        &self,
673        node: &FederatedPlanNode,
674        _session_state: &SessionState,
675    ) -> Result<Arc<dyn ExecutionPlan>> {
676        let schema = Arc::new(node.plan().schema().as_arrow().clone());
677        let input = Arc::new(VirtualExecutionPlan::new(
678            node.plan().clone(),
679            Arc::clone(&self.executor),
680        ));
681        let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema);
682        Ok(Arc::new(schema_cast_exec))
683    }
684}
685
686#[derive(Debug, Clone)]
687struct VirtualExecutionPlan {
688    plan: LogicalPlan,
689    executor: Arc<dyn SQLExecutor>,
690    props: PlanProperties,
691}
692
693impl VirtualExecutionPlan {
694    pub fn new(plan: LogicalPlan, executor: Arc<dyn SQLExecutor>) -> Self {
695        let schema: Schema = plan.schema().as_ref().into();
696        let props = PlanProperties::new(
697            EquivalenceProperties::new(Arc::new(schema)),
698            Partitioning::UnknownPartitioning(1),
699            EmissionType::Incremental,
700            Boundedness::Bounded,
701        );
702        Self {
703            plan,
704            executor,
705            props,
706        }
707    }
708
709    fn schema(&self) -> SchemaRef {
710        let df_schema = self.plan.schema().as_ref();
711        Arc::new(Schema::from(df_schema))
712    }
713
714    fn sql(&self) -> Result<String> {
715        // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table.
716        let mut known_rewrites = HashMap::new();
717        let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?;
718        let mut ast = self.plan_to_sql(plan)?;
719
720        if let Some(analyzer) = self.executor.ast_analyzer() {
721            ast = analyzer(ast)?;
722        }
723
724        Ok(format!("{ast}"))
725    }
726
727    fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<Statement> {
728        Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan)
729    }
730}
731
732impl DisplayAs for VirtualExecutionPlan {
733    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
734        write!(f, "VirtualExecutionPlan")?;
735        let Ok(ast) = plan_to_sql(&self.plan) else {
736            return Ok(());
737        };
738        write!(f, " name={}", self.executor.name())?;
739        if let Some(ctx) = self.executor.compute_context() {
740            write!(f, " compute_context={ctx}")?;
741        };
742
743        write!(f, " sql={ast}")?;
744        if let Ok(query) = self.sql() {
745            write!(f, " rewritten_sql={query}")?;
746        };
747
748        write!(f, " sql={ast}")
749    }
750}
751
752impl ExecutionPlan for VirtualExecutionPlan {
753    fn name(&self) -> &str {
754        "sql_federation_exec"
755    }
756
757    fn as_any(&self) -> &dyn Any {
758        self
759    }
760
761    fn schema(&self) -> SchemaRef {
762        self.schema()
763    }
764
765    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
766        vec![]
767    }
768
769    fn with_new_children(
770        self: Arc<Self>,
771        _: Vec<Arc<dyn ExecutionPlan>>,
772    ) -> Result<Arc<dyn ExecutionPlan>> {
773        Ok(self)
774    }
775
776    fn execute(
777        &self,
778        _partition: usize,
779        _context: Arc<TaskContext>,
780    ) -> Result<SendableRecordBatchStream> {
781        let query = self.plan_to_sql(&self.plan)?.to_string();
782        self.executor.execute(query.as_str(), self.schema())
783    }
784
785    fn properties(&self) -> &PlanProperties {
786        &self.props
787    }
788}
789
790#[cfg(test)]
791mod tests {
792    use crate::FederatedTableProviderAdaptor;
793    use datafusion::{
794        arrow::datatypes::{DataType, Field},
795        catalog::{MemorySchemaProvider, SchemaProvider},
796        common::Column,
797        datasource::{DefaultTableSource, TableProvider},
798        error::DataFusionError,
799        execution::context::SessionContext,
800        logical_expr::LogicalPlanBuilder,
801        sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect},
802    };
803
804    use super::*;
805
806    struct TestSQLExecutor {}
807
808    #[async_trait]
809    impl SQLExecutor for TestSQLExecutor {
810        fn name(&self) -> &str {
811            "test_sql_table_source"
812        }
813
814        fn compute_context(&self) -> Option<String> {
815            None
816        }
817
818        fn dialect(&self) -> Arc<dyn Dialect> {
819            Arc::new(DefaultDialect {})
820        }
821
822        fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
823            Err(DataFusionError::NotImplemented(
824                "execute not implemented".to_string(),
825            ))
826        }
827
828        async fn table_names(&self) -> Result<Vec<String>> {
829            Err(DataFusionError::NotImplemented(
830                "table inference not implemented".to_string(),
831            ))
832        }
833
834        async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
835            Err(DataFusionError::NotImplemented(
836                "table inference not implemented".to_string(),
837            ))
838        }
839    }
840
841    fn get_test_table_provider() -> Arc<dyn TableProvider> {
842        let sql_federation_provider =
843            Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
844
845        let schema = Arc::new(Schema::new(vec![
846            Field::new("a", DataType::Int64, false),
847            Field::new("b", DataType::Utf8, false),
848            Field::new("c", DataType::Date32, false),
849        ]));
850        let table_source = Arc::new(
851            SQLTableSource::new_with_schema(
852                sql_federation_provider,
853                "remote_table".to_string(),
854                schema,
855            )
856            .expect("to have a valid SQLTableSource"),
857        );
858        Arc::new(FederatedTableProviderAdaptor::new(table_source))
859    }
860
861    fn get_test_table_source() -> Arc<DefaultTableSource> {
862        Arc::new(DefaultTableSource::new(get_test_table_provider()))
863    }
864
865    fn get_test_df_context() -> SessionContext {
866        let ctx = SessionContext::new();
867        let catalog = ctx
868            .catalog("datafusion")
869            .expect("default catalog is datafusion");
870        let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
871        catalog
872            .register_schema("foo", Arc::clone(&foo_schema))
873            .expect("to register schema");
874        foo_schema
875            .register_table("df_table".to_string(), get_test_table_provider())
876            .expect("to register table");
877
878        let public_schema = catalog
879            .schema("public")
880            .expect("public schema should exist");
881        public_schema
882            .register_table("app_table".to_string(), get_test_table_provider())
883            .expect("to register table");
884
885        ctx
886    }
887
888    #[test]
889    fn test_rewrite_table_scans_basic() -> Result<()> {
890        let default_table_source = get_test_table_source();
891        let plan =
892            LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![
893                Expr::Column(Column::from_qualified_name("foo.df_table.a")),
894                Expr::Column(Column::from_qualified_name("foo.df_table.b")),
895                Expr::Column(Column::from_qualified_name("foo.df_table.c")),
896            ])?;
897
898        let mut known_rewrites = HashMap::new();
899        let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?;
900
901        println!("rewritten_plan: \n{:#?}", rewritten_plan);
902
903        let unparsed_sql = plan_to_sql(&rewritten_plan)?;
904
905        println!("unparsed_sql: \n{unparsed_sql}");
906
907        assert_eq!(
908            format!("{unparsed_sql}"),
909            r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"#
910        );
911
912        Ok(())
913    }
914
915    fn init_tracing() {
916        let subscriber = tracing_subscriber::FmtSubscriber::builder()
917            .with_env_filter("debug")
918            .with_ansi(true)
919            .finish();
920        let _ = tracing::subscriber::set_global_default(subscriber);
921    }
922
923    #[tokio::test]
924    async fn test_rewrite_table_scans_agg() -> Result<()> {
925        init_tracing();
926        let ctx = get_test_df_context();
927
928        let agg_tests = vec![
929            (
930                "SELECT MAX(a) FROM foo.df_table",
931                r#"SELECT max(remote_table.a) FROM remote_table"#,
932            ),
933            (
934                "SELECT foo.df_table.a FROM foo.df_table",
935                r#"SELECT remote_table.a FROM remote_table"#,
936            ),
937            (
938                "SELECT MIN(a) FROM foo.df_table",
939                r#"SELECT min(remote_table.a) FROM remote_table"#,
940            ),
941            (
942                "SELECT AVG(a) FROM foo.df_table",
943                r#"SELECT avg(remote_table.a) FROM remote_table"#,
944            ),
945            (
946                "SELECT SUM(a) FROM foo.df_table",
947                r#"SELECT sum(remote_table.a) FROM remote_table"#,
948            ),
949            (
950                "SELECT COUNT(a) FROM foo.df_table",
951                r#"SELECT count(remote_table.a) FROM remote_table"#,
952            ),
953            (
954                "SELECT COUNT(a) as cnt FROM foo.df_table",
955                r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
956            ),
957            (
958                "SELECT COUNT(a) as cnt FROM foo.df_table",
959                r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#,
960            ),
961            (
962                "SELECT app_table from (SELECT a as app_table FROM app_table) b",
963                r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
964            ),
965            (
966                "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b",
967                r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#,
968            ),
969            // multiple occurrences of the same table in single aggregation expression
970            (
971                "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table",
972                r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#,
973            ),
974            // different tables in single aggregation expression
975            (
976                "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft",
977                "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft"
978            ),
979        ];
980
981        for test in agg_tests {
982            test_sql(&ctx, test.0, test.1).await?;
983        }
984
985        Ok(())
986    }
987
988    #[tokio::test]
989    async fn test_rewrite_table_scans_alias() -> Result<()> {
990        init_tracing();
991        let ctx = get_test_df_context();
992
993        let tests = vec![
994            (
995                "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)",
996                r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
997            ),
998            (
999                "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)",
1000                r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#,
1001            ),
1002            (
1003                "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)",
1004                r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#,
1005            ),
1006        ];
1007
1008        for test in tests {
1009            test_sql(&ctx, test.0, test.1).await?;
1010        }
1011
1012        Ok(())
1013    }
1014
1015    async fn test_sql(
1016        ctx: &SessionContext,
1017        sql_query: &str,
1018        expected_sql: &str,
1019    ) -> Result<(), datafusion::error::DataFusionError> {
1020        let data_frame = ctx.sql(sql_query).await?;
1021
1022        println!("before optimization: \n{:#?}", data_frame.logical_plan());
1023
1024        let mut known_rewrites = HashMap::new();
1025        let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?;
1026
1027        println!("rewritten_plan: \n{:#?}", rewritten_plan);
1028
1029        let unparsed_sql = plan_to_sql(&rewritten_plan)?;
1030
1031        println!("unparsed_sql: \n{unparsed_sql}");
1032
1033        assert_eq!(
1034            format!("{unparsed_sql}"),
1035            expected_sql,
1036            "SQL under test: {}",
1037            sql_query
1038        );
1039
1040        Ok(())
1041    }
1042
1043    #[tokio::test]
1044    async fn test_rewrite_table_scans_limit_offset() -> Result<()> {
1045        init_tracing();
1046        let ctx = get_test_df_context();
1047
1048        let tests = vec![
1049            // Basic LIMIT
1050            (
1051                "SELECT a FROM foo.df_table LIMIT 5",
1052                r#"SELECT remote_table.a FROM remote_table LIMIT 5"#,
1053            ),
1054            // Basic OFFSET
1055            (
1056                "SELECT a FROM foo.df_table OFFSET 5",
1057                r#"SELECT remote_table.a FROM remote_table OFFSET 5"#,
1058            ),
1059            // OFFSET after LIMIT
1060            (
1061                "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5",
1062                r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1063            ),
1064            // LIMIT after OFFSET
1065            (
1066                "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10",
1067                r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#,
1068            ),
1069            // Zero OFFSET
1070            (
1071                "SELECT a FROM foo.df_table OFFSET 0",
1072                r#"SELECT remote_table.a FROM remote_table OFFSET 0"#,
1073            ),
1074            // Zero LIMIT
1075            (
1076                "SELECT a FROM foo.df_table LIMIT 0",
1077                r#"SELECT remote_table.a FROM remote_table LIMIT 0"#,
1078            ),
1079            // Zero LIMIT and OFFSET
1080            (
1081                "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0",
1082                r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#,
1083            ),
1084        ];
1085
1086        for test in tests {
1087            test_sql(&ctx, test.0, test.1).await?;
1088        }
1089
1090        Ok(())
1091    }
1092}