use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use core::ops::ControlFlow;
pub trait Visit {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
}
pub trait VisitMut {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
}
impl<T: Visit> Visit for Option<T> {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
if let Some(s) = self {
s.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: Visit> Visit for Vec<T> {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
for v in self {
v.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: Visit> Visit for Box<T> {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
T::visit(self, visitor)
}
}
impl<T: VisitMut> VisitMut for Option<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
if let Some(s) = self {
s.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: VisitMut> VisitMut for Vec<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
for v in self {
v.visit(visitor)?;
}
ControlFlow::Continue(())
}
}
impl<T: VisitMut> VisitMut for Box<T> {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
T::visit(self, visitor)
}
}
macro_rules! visit_noop {
($($t:ty),+) => {
$(impl Visit for $t {
fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
ControlFlow::Continue(())
}
})+
$(impl VisitMut for $t {
fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
ControlFlow::Continue(())
}
})+
};
}
visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
#[cfg(feature = "bigdecimal")]
visit_noop!(bigdecimal::BigDecimal);
pub trait Visitor {
type Break;
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
pub trait VisitorMut {
type Break;
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_table_factor(
&mut self,
_table_factor: &mut TableFactor,
) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_table_factor(
&mut self,
_table_factor: &mut TableFactor,
) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}
struct RelationVisitor<F>(F);
impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
type Break = E;
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.0(relation)
}
}
impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
type Break = E;
fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
self.0(relation)
}
}
pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
F: FnMut(&ObjectName) -> ControlFlow<E>,
{
let mut visitor = RelationVisitor(f);
v.visit(&mut visitor)?;
ControlFlow::Continue(())
}
pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut ObjectName) -> ControlFlow<E>,
{
let mut visitor = RelationVisitor(f);
v.visit(&mut visitor)?;
ControlFlow::Continue(())
}
struct ExprVisitor<F>(F);
impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
type Break = E;
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
self.0(expr)
}
}
impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
type Break = E;
fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
self.0(expr)
}
}
pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
F: FnMut(&Expr) -> ControlFlow<E>,
{
let mut visitor = ExprVisitor(f);
v.visit(&mut visitor)?;
ControlFlow::Continue(())
}
pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut Expr) -> ControlFlow<E>,
{
v.visit(&mut ExprVisitor(f))?;
ControlFlow::Continue(())
}
struct StatementVisitor<F>(F);
impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
type Break = E;
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
self.0(statement)
}
}
impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
type Break = E;
fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
self.0(statement)
}
}
pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
F: FnMut(&Statement) -> ControlFlow<E>,
{
let mut visitor = StatementVisitor(f);
v.visit(&mut visitor)?;
ControlFlow::Continue(())
}
pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
where
V: VisitMut,
F: FnMut(&mut Statement) -> ControlFlow<E>,
{
v.visit(&mut StatementVisitor(f))?;
ControlFlow::Continue(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
#[derive(Default)]
struct TestVisitor {
visited: Vec<String>,
}
impl Visitor for TestVisitor {
type Break = ();
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: QUERY: {query}"));
ControlFlow::Continue(())
}
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: QUERY: {query}"));
ControlFlow::Continue(())
}
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: RELATION: {relation}"));
ControlFlow::Continue(())
}
fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: RELATION: {relation}"));
ControlFlow::Continue(())
}
fn pre_visit_table_factor(
&mut self,
table_factor: &TableFactor,
) -> ControlFlow<Self::Break> {
self.visited
.push(format!("PRE: TABLE FACTOR: {table_factor}"));
ControlFlow::Continue(())
}
fn post_visit_table_factor(
&mut self,
table_factor: &TableFactor,
) -> ControlFlow<Self::Break> {
self.visited
.push(format!("POST: TABLE FACTOR: {table_factor}"));
ControlFlow::Continue(())
}
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: EXPR: {expr}"));
ControlFlow::Continue(())
}
fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: EXPR: {expr}"));
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: STATEMENT: {statement}"));
ControlFlow::Continue(())
}
fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: STATEMENT: {statement}"));
ControlFlow::Continue(())
}
}
fn do_visit(sql: &str) -> Vec<String> {
let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
let s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();
let mut visitor = TestVisitor::default();
s.visit(&mut visitor);
visitor.visited
}
#[test]
fn test_sql() {
let tests = vec![
(
"SELECT * from table_name as my_table",
vec![
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
"PRE: QUERY: SELECT * FROM table_name AS my_table",
"PRE: TABLE FACTOR: table_name AS my_table",
"PRE: RELATION: table_name",
"POST: RELATION: table_name",
"POST: TABLE FACTOR: table_name AS my_table",
"POST: QUERY: SELECT * FROM table_name AS my_table",
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
],
),
(
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
vec![
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"PRE: EXPR: t1.id = t2.t1_id",
"PRE: EXPR: t1.id",
"POST: EXPR: t1.id",
"PRE: EXPR: t2.t1_id",
"POST: EXPR: t2.t1_id",
"POST: EXPR: t1.id = t2.t1_id",
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t3",
"PRE: RELATION: t3",
"POST: RELATION: t3",
"POST: TABLE FACTOR: t3",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
],
),
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
vec![
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: TABLE FACTOR: monthly_sales",
"PRE: RELATION: monthly_sales",
"POST: RELATION: monthly_sales",
"POST: TABLE FACTOR: monthly_sales",
"PRE: EXPR: SUM(a.amount)",
"PRE: EXPR: a.amount",
"POST: EXPR: a.amount",
"POST: EXPR: SUM(a.amount)",
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: EXPR: EMPID",
"POST: EXPR: EMPID",
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
]
)
];
for (sql, expected) in tests {
let actual = do_visit(sql);
let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
assert_eq!(actual, expected)
}
}
}