datafusion_table_providers/util/
table_arg_replace.rs

1use std::ops::ControlFlow;
2
3use datafusion::sql::{
4    sqlparser::ast::{
5        FunctionArg, Ident, ObjectName, TableAlias, TableFactor, TableFunctionArgs, VisitorMut,
6    },
7    TableReference,
8};
9
10#[derive(Debug, Clone, PartialEq, Eq, Default)]
11pub struct TableArgReplace {
12    pub tables: Vec<(TableReference, TableFunctionArgs)>,
13}
14
15impl TableArgReplace {
16    /// Constructs a new `TableArgReplace` instance.
17    pub fn new(tables: Vec<(TableReference, Vec<FunctionArg>)>) -> Self {
18        Self {
19            tables: tables
20                .into_iter()
21                .map(|(table, args)| {
22                    (
23                        table,
24                        TableFunctionArgs {
25                            args,
26                            settings: None,
27                        },
28                    )
29                })
30                .collect(),
31        }
32    }
33
34    /// Adds a new table argument replacement.
35    pub fn with(mut self, table: TableReference, args: Vec<FunctionArg>) -> Self {
36        self.tables.push((
37            table,
38            TableFunctionArgs {
39                args,
40                settings: None,
41            },
42        ));
43        self
44    }
45
46    #[cfg(feature = "federation")]
47    /// Converts the `TableArgReplace` instance into an `AstAnalyzer`.
48    pub fn into_analyzer(self) -> datafusion_federation::sql::AstAnalyzer {
49        let mut visitor = self;
50        let x = move |mut statement: datafusion::sql::sqlparser::ast::Statement| {
51            let _ = datafusion::sql::sqlparser::ast::VisitMut::visit(&mut statement, &mut visitor);
52            Ok(statement)
53        };
54        Box::new(x)
55    }
56}
57
58impl VisitorMut for TableArgReplace {
59    type Break = ();
60    fn pre_visit_table_factor(
61        &mut self,
62        table_factor: &mut TableFactor,
63    ) -> ControlFlow<Self::Break> {
64        if let TableFactor::Table {
65            name, args, alias, ..
66        } = table_factor
67        {
68            let name_as_tableref = name_to_table_reference(name);
69            if let Some((table, arg)) = self
70                .tables
71                .iter()
72                .find(|(t, _)| t.resolved_eq(&name_as_tableref))
73            {
74                *args = Some(arg.clone());
75                if alias.is_none() {
76                    *alias = Some(TableAlias {
77                        name: Ident::new(table.table()),
78                        columns: vec![],
79                    })
80                }
81            }
82        }
83        ControlFlow::Continue(())
84    }
85}
86
87fn name_to_table_reference(name: &ObjectName) -> TableReference {
88    let first = name
89        .0
90        .first()
91        .map(|n| n.as_ident().expect("expected Ident").value.to_string());
92    let second = name
93        .0
94        .get(1)
95        .map(|n| n.as_ident().expect("expected Ident").value.to_string());
96    let third = name
97        .0
98        .get(2)
99        .map(|n| n.as_ident().expect("expected Ident").value.to_string());
100
101    match (first, second, third) {
102        (Some(first), Some(second), Some(third)) => TableReference::full(first, second, third),
103        (Some(first), Some(second), None) => TableReference::partial(first, second),
104        (Some(first), None, None) => TableReference::bare(first),
105        _ => panic!("Invalid table name"),
106    }
107}