datafusion_federation/sql/
ast_analyzer.rs1use std::ops::ControlFlow;
2
3use datafusion::sql::{
4 sqlparser::ast::{
5 FunctionArg, Ident, ObjectName, Statement, TableAlias, TableFactor, TableFunctionArgs,
6 VisitMut, VisitorMut,
7 },
8 TableReference,
9};
10
11use super::AstAnalyzer;
12
13pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer {
14 let x = move |mut statement: Statement| {
15 let _ = VisitMut::visit(&mut statement, &mut visitor);
16 Ok(statement)
17 };
18 Box::new(x)
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Default)]
40pub struct TableArgReplace {
41 pub tables: Vec<(TableReference, TableFunctionArgs)>,
42}
43
44impl TableArgReplace {
45 pub fn new(tables: Vec<(TableReference, Vec<FunctionArg>)>) -> Self {
47 Self {
48 tables: tables
49 .into_iter()
50 .map(|(table, args)| {
51 (
52 table,
53 TableFunctionArgs {
54 args,
55 settings: None,
56 },
57 )
58 })
59 .collect(),
60 }
61 }
62
63 pub fn with(mut self, table: TableReference, args: Vec<FunctionArg>) -> Self {
65 self.tables.push((
66 table,
67 TableFunctionArgs {
68 args,
69 settings: None,
70 },
71 ));
72 self
73 }
74
75 pub fn into_analyzer(self) -> AstAnalyzer {
77 replace_table_args_analyzer(self)
78 }
79}
80
81impl VisitorMut for TableArgReplace {
82 type Break = ();
83 fn pre_visit_table_factor(
84 &mut self,
85 table_factor: &mut TableFactor,
86 ) -> ControlFlow<Self::Break> {
87 if let TableFactor::Table {
88 name, args, alias, ..
89 } = table_factor
90 {
91 let name_as_tableref = name_to_table_reference(name);
92 if let Some((table, arg)) = self
93 .tables
94 .iter()
95 .find(|(t, _)| t.resolved_eq(&name_as_tableref))
96 {
97 *args = Some(arg.clone());
98 if alias.is_none() {
99 *alias = Some(TableAlias {
100 name: Ident::new(table.table()),
101 columns: vec![],
102 })
103 }
104 }
105 }
106 ControlFlow::Continue(())
107 }
108}
109
110fn name_to_table_reference(name: &ObjectName) -> TableReference {
111 let first = name
112 .0
113 .first()
114 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
115 let second = name
116 .0
117 .get(1)
118 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
119 let third = name
120 .0
121 .get(2)
122 .map(|n| n.as_ident().expect("expected Ident").value.to_string());
123
124 match (first, second, third) {
125 (Some(first), Some(second), Some(third)) => TableReference::full(first, second, third),
126 (Some(first), Some(second), None) => TableReference::partial(first, second),
127 (Some(first), None, None) => TableReference::bare(first),
128 _ => panic!("Invalid table name"),
129 }
130}