datafusion_federation/sql/
ast_analyzer.rs1use std::ops::ControlFlow;
2
3use datafusion::sql::{
4 sqlparser::ast::{
5 FunctionArg, Ident, ObjectName, Statement, TableAlias, TableFactor, TableFunctionArgs,
6 VisitMut, VisitorMut,
7 },
8 TableReference,
9};
10
11use super::AstAnalyzer;
12
13pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer {
14 let x = move |mut statement: Statement| {
15 let _ = VisitMut::visit(&mut statement, &mut visitor);
16 Ok(statement)
17 };
18 Box::new(x)
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Default)]
40pub struct TableArgReplace {
41 pub tables: Vec<(TableReference, TableFunctionArgs)>,
42}
43
44impl TableArgReplace {
45 pub fn new(tables: Vec<(TableReference, Vec<FunctionArg>)>) -> Self {
47 Self {
48 tables: tables
49 .into_iter()
50 .map(|(table, args)| {
51 (
52 table,
53 TableFunctionArgs {
54 args,
55 settings: None,
56 },
57 )
58 })
59 .collect(),
60 }
61 }
62
63 pub fn with(mut self, table: TableReference, args: Vec<FunctionArg>) -> Self {
65 self.tables.push((
66 table,
67 TableFunctionArgs {
68 args,
69 settings: None,
70 },
71 ));
72 self
73 }
74
75 pub fn into_analyzer(self) -> AstAnalyzer {
77 replace_table_args_analyzer(self)
78 }
79}
80
81impl VisitorMut for TableArgReplace {
82 type Break = ();
83 fn pre_visit_table_factor(
84 &mut self,
85 table_factor: &mut TableFactor,
86 ) -> ControlFlow<Self::Break> {
87 if let TableFactor::Table {
88 name, args, alias, ..
89 } = table_factor
90 {
91 let name_as_tableref = name_to_table_reference(name);
92 if let Some((table, arg)) = self
93 .tables
94 .iter()
95 .find(|(t, _)| t.resolved_eq(&name_as_tableref))
96 {
97 *args = Some(arg.clone());
98 if alias.is_none() {
99 *alias = Some(TableAlias {
100 explicit: true,
101 name: Ident::new(table.table()),
102 columns: vec![],
103 })
104 }
105 }
106 }
107 ControlFlow::Continue(())
108 }
109}
110
111fn name_to_table_reference(name: &ObjectName) -> TableReference {
112 let first = name
113 .0
114 .first()
115 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
116 let second = name
117 .0
118 .get(1)
119 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
120 let third = name
121 .0
122 .get(2)
123 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
124
125 match (first, second, third) {
126 (Some(first), Some(second), Some(third)) => TableReference::full(first, second, third),
127 (Some(first), Some(second), None) => TableReference::partial(first, second),
128 (Some(first), None, None) => TableReference::bare(first),
129 _ => panic!("Invalid table name"),
130 }
131}