sqlparser/ast/
visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqlparser AST nodes
44/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55        if let Some(s) = self {
56            s.visit(visitor)?;
57        }
58        ControlFlow::Continue(())
59    }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64        for v in self {
65            v.visit(visitor)?;
66        }
67        ControlFlow::Continue(())
68    }
69}
70
71impl<T: Visit> Visit for Box<T> {
72    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73        T::visit(self, visitor)
74    }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79        if let Some(s) = self {
80            s.visit(visitor)?;
81        }
82        ControlFlow::Continue(())
83    }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88        for v in self {
89            v.visit(visitor)?;
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97        T::visit(self, visitor)
98    }
99}
100
101macro_rules! visit_noop {
102    ($($t:ty),+) => {
103        $(impl Visit for $t {
104            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105               ControlFlow::Continue(())
106            }
107        })+
108        $(impl VisitMut for $t {
109            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110               ControlFlow::Continue(())
111            }
112        })+
113    };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqlparser::parser::Parser;
137/// # use sqlparser::dialect::GenericDialect;
138/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143///    visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149///   type Break = ();
150///
151///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152///     self.visited.push(format!("PRE: RELATION: {}", relation));
153///     ControlFlow::Continue(())
154///   }
155///
156///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157///     self.visited.push(format!("PRE: EXPR: {}", expr));
158///     ControlFlow::Continue(())
159///   }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164///    .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172///   "PRE: EXPR: a",
173///   "PRE: RELATION: foo",
174///   "PRE: EXPR: x IN (SELECT y FROM bar)",
175///   "PRE: EXPR: x",
176///   "PRE: EXPR: y",
177///   "PRE: RELATION: bar",
178/// ]
179///   .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184    /// Type returned when the recursion returns early.
185    ///
186    /// Important note: The `Break` type should be kept as small as possible to prevent
187    /// stack overflow during recursion. If you need to return an error, consider
188    /// boxing it with `Box` to minimize stack usage.
189    type Break;
190
191    /// Invoked for any queries that appear in the AST before visiting children
192    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
193        ControlFlow::Continue(())
194    }
195
196    /// Invoked for any queries that appear in the AST after visiting children
197    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
198        ControlFlow::Continue(())
199    }
200
201    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
202    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
203        ControlFlow::Continue(())
204    }
205
206    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
207    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
208        ControlFlow::Continue(())
209    }
210
211    /// Invoked for any table factors that appear in the AST before visiting children
212    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
213        ControlFlow::Continue(())
214    }
215
216    /// Invoked for any table factors that appear in the AST after visiting children
217    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
218        ControlFlow::Continue(())
219    }
220
221    /// Invoked for any expressions that appear in the AST before visiting children
222    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
223        ControlFlow::Continue(())
224    }
225
226    /// Invoked for any expressions that appear in the AST
227    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
228        ControlFlow::Continue(())
229    }
230
231    /// Invoked for any statements that appear in the AST before visiting children
232    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
233        ControlFlow::Continue(())
234    }
235
236    /// Invoked for any statements that appear in the AST after visiting children
237    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
238        ControlFlow::Continue(())
239    }
240
241    /// Invoked for any Value that appear in the AST before visiting children
242    fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
243        ControlFlow::Continue(())
244    }
245
246    /// Invoked for any Value that appear in the AST after visiting children
247    fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
248        ControlFlow::Continue(())
249    }
250}
251
252/// A visitor that can be used to mutate an AST tree.
253///
254/// `pre_visit_` methods are invoked before visiting all children of the
255/// node and `post_visit_` methods are invoked after visiting all
256/// children of the node.
257///
258/// # See also
259///
260/// These methods provide a more concise way of visiting nodes of a certain type:
261/// * [visit_relations_mut]
262/// * [visit_expressions_mut]
263/// * [visit_statements_mut]
264///
265/// # Example
266/// ```
267/// # use sqlparser::parser::Parser;
268/// # use sqlparser::dialect::GenericDialect;
269/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
270/// # use core::ops::ControlFlow;
271///
272/// // A visitor that replaces "to_replace" with "replaced" in all expressions
273/// struct Replacer;
274///
275/// // Visit each expression after its children have been visited
276/// impl VisitorMut for Replacer {
277///   type Break = ();
278///
279///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
280///     if let Expr::Identifier(Ident{ value, ..}) = expr {
281///         *value = value.replace("to_replace", "replaced")
282///     }
283///     ControlFlow::Continue(())
284///   }
285/// }
286///
287/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
288/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
289///
290/// // Drive the visitor through the AST
291/// statements.visit(&mut Replacer);
292///
293/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
294/// ```
295pub trait VisitorMut {
296    /// Type returned when the recursion returns early.
297    ///
298    /// Important note: The `Break` type should be kept as small as possible to prevent
299    /// stack overflow during recursion. If you need to return an error, consider
300    /// boxing it with `Box` to minimize stack usage.
301    type Break;
302
303    /// Invoked for any queries that appear in the AST before visiting children
304    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
305        ControlFlow::Continue(())
306    }
307
308    /// Invoked for any queries that appear in the AST after visiting children
309    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
310        ControlFlow::Continue(())
311    }
312
313    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
314    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
315        ControlFlow::Continue(())
316    }
317
318    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
319    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
320        ControlFlow::Continue(())
321    }
322
323    /// Invoked for any table factors that appear in the AST before visiting children
324    fn pre_visit_table_factor(
325        &mut self,
326        _table_factor: &mut TableFactor,
327    ) -> ControlFlow<Self::Break> {
328        ControlFlow::Continue(())
329    }
330
331    /// Invoked for any table factors that appear in the AST after visiting children
332    fn post_visit_table_factor(
333        &mut self,
334        _table_factor: &mut TableFactor,
335    ) -> ControlFlow<Self::Break> {
336        ControlFlow::Continue(())
337    }
338
339    /// Invoked for any expressions that appear in the AST before visiting children
340    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
341        ControlFlow::Continue(())
342    }
343
344    /// Invoked for any expressions that appear in the AST
345    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
346        ControlFlow::Continue(())
347    }
348
349    /// Invoked for any statements that appear in the AST before visiting children
350    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
351        ControlFlow::Continue(())
352    }
353
354    /// Invoked for any statements that appear in the AST after visiting children
355    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
356        ControlFlow::Continue(())
357    }
358
359    /// Invoked for any value that appear in the AST before visiting children
360    fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
361        ControlFlow::Continue(())
362    }
363
364    /// Invoked for any statements that appear in the AST after visiting children
365    fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
366        ControlFlow::Continue(())
367    }
368}
369
370struct RelationVisitor<F>(F);
371
372impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
373    type Break = E;
374
375    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
376        self.0(relation)
377    }
378}
379
380impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
381    type Break = E;
382
383    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
384        self.0(relation)
385    }
386}
387
388/// Invokes the provided closure on all relations (e.g. table names) present in `v`
389///
390/// # Example
391/// ```
392/// # use sqlparser::parser::Parser;
393/// # use sqlparser::dialect::GenericDialect;
394/// # use sqlparser::ast::{visit_relations};
395/// # use core::ops::ControlFlow;
396/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
397/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
398///    .unwrap();
399///
400/// // visit statements, capturing relations (table names)
401/// let mut visited = vec![];
402/// visit_relations(&statements, |relation| {
403///   visited.push(format!("RELATION: {}", relation));
404///   ControlFlow::<()>::Continue(())
405/// });
406///
407/// let expected : Vec<_> = [
408///   "RELATION: foo",
409///   "RELATION: bar",
410/// ]
411///   .into_iter().map(|s| s.to_string()).collect();
412///
413/// assert_eq!(visited, expected);
414/// ```
415pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
416where
417    V: Visit,
418    F: FnMut(&ObjectName) -> ControlFlow<E>,
419{
420    let mut visitor = RelationVisitor(f);
421    v.visit(&mut visitor)?;
422    ControlFlow::Continue(())
423}
424
425/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
426/// present in `v`.
427///
428/// When the closure mutates its argument, the new mutated relation will not be visited again.
429///
430/// # Example
431/// ```
432/// # use sqlparser::parser::Parser;
433/// # use sqlparser::dialect::GenericDialect;
434/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
435/// # use core::ops::ControlFlow;
436/// let sql = "SELECT a FROM foo";
437/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
438///    .unwrap();
439///
440/// // visit statements, renaming table foo to bar
441/// visit_relations_mut(&mut statements, |table| {
442///   table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
443///   ControlFlow::<()>::Continue(())
444/// });
445///
446/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
447/// ```
448pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
449where
450    V: VisitMut,
451    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
452{
453    let mut visitor = RelationVisitor(f);
454    v.visit(&mut visitor)?;
455    ControlFlow::Continue(())
456}
457
458struct ExprVisitor<F>(F);
459
460impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
461    type Break = E;
462
463    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
464        self.0(expr)
465    }
466}
467
468impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
469    type Break = E;
470
471    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
472        self.0(expr)
473    }
474}
475
476/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
477///
478/// # Example
479/// ```
480/// # use sqlparser::parser::Parser;
481/// # use sqlparser::dialect::GenericDialect;
482/// # use sqlparser::ast::{visit_expressions};
483/// # use core::ops::ControlFlow;
484/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
485/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
486///    .unwrap();
487///
488/// // visit all expressions
489/// let mut visited = vec![];
490/// visit_expressions(&statements, |expr| {
491///   visited.push(format!("EXPR: {}", expr));
492///   ControlFlow::<()>::Continue(())
493/// });
494///
495/// let expected : Vec<_> = [
496///   "EXPR: a",
497///   "EXPR: x IN (SELECT y FROM bar)",
498///   "EXPR: x",
499///   "EXPR: y",
500/// ]
501///   .into_iter().map(|s| s.to_string()).collect();
502///
503/// assert_eq!(visited, expected);
504/// ```
505pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
506where
507    V: Visit,
508    F: FnMut(&Expr) -> ControlFlow<E>,
509{
510    let mut visitor = ExprVisitor(f);
511    v.visit(&mut visitor)?;
512    ControlFlow::Continue(())
513}
514
515/// Invokes the provided closure iteratively with a mutable reference to all expressions
516/// present in `v`.
517///
518/// This performs a depth-first search, so if the closure mutates the expression
519///
520/// # Example
521///
522/// ## Remove all select limits in sub-queries
523/// ```
524/// # use sqlparser::parser::Parser;
525/// # use sqlparser::dialect::GenericDialect;
526/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
527/// # use core::ops::ControlFlow;
528/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
529/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
530///
531/// // Remove all select limits in sub-queries
532/// visit_expressions_mut(&mut statements, |expr| {
533///   if let Expr::Subquery(q) = expr {
534///      q.limit_clause = None;
535///   }
536///   ControlFlow::<()>::Continue(())
537/// });
538///
539/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
540/// ```
541///
542/// ## Wrap column name in function call
543///
544/// This demonstrates how to effectively replace an expression with another more complicated one
545/// that references the original. This example avoids unnecessary allocations by using the
546/// [`std::mem`] family of functions.
547///
548/// ```
549/// # use sqlparser::parser::Parser;
550/// # use sqlparser::dialect::GenericDialect;
551/// # use sqlparser::ast::*;
552/// # use core::ops::ControlFlow;
553/// let sql = "SELECT x, y FROM t";
554/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
555///
556/// visit_expressions_mut(&mut statements, |expr| {
557///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
558///     let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
559///     *expr = Expr::Function(Function {
560///           name: ObjectName::from(vec![Ident::new("f")]),
561///           uses_odbc_syntax: false,
562///           args: FunctionArguments::List(FunctionArgumentList {
563///               duplicate_treatment: None,
564///               args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
565///               clauses: vec![],
566///           }),
567///           null_treatment: None,
568///           filter: None,
569///           over: None,
570///           parameters: FunctionArguments::None,
571///           within_group: vec![],
572///      });
573///   }
574///   ControlFlow::<()>::Continue(())
575/// });
576///
577/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
578/// ```
579pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
580where
581    V: VisitMut,
582    F: FnMut(&mut Expr) -> ControlFlow<E>,
583{
584    v.visit(&mut ExprVisitor(f))?;
585    ControlFlow::Continue(())
586}
587
588struct StatementVisitor<F>(F);
589
590impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
591    type Break = E;
592
593    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
594        self.0(statement)
595    }
596}
597
598impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
599    type Break = E;
600
601    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
602        self.0(statement)
603    }
604}
605
606/// Invokes the provided closure iteratively with a mutable reference to all statements
607/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
608///
609/// # Example
610/// ```
611/// # use sqlparser::parser::Parser;
612/// # use sqlparser::dialect::GenericDialect;
613/// # use sqlparser::ast::{visit_statements};
614/// # use core::ops::ControlFlow;
615/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
616/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
617///    .unwrap();
618///
619/// // visit all statements
620/// let mut visited = vec![];
621/// visit_statements(&statements, |stmt| {
622///   visited.push(format!("STATEMENT: {}", stmt));
623///   ControlFlow::<()>::Continue(())
624/// });
625///
626/// let expected : Vec<_> = [
627///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
628///   "STATEMENT: CREATE TABLE baz (q INT)"
629/// ]
630///   .into_iter().map(|s| s.to_string()).collect();
631///
632/// assert_eq!(visited, expected);
633/// ```
634pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
635where
636    V: Visit,
637    F: FnMut(&Statement) -> ControlFlow<E>,
638{
639    let mut visitor = StatementVisitor(f);
640    v.visit(&mut visitor)?;
641    ControlFlow::Continue(())
642}
643
644/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
645///
646/// # Example
647/// ```
648/// # use sqlparser::parser::Parser;
649/// # use sqlparser::dialect::GenericDialect;
650/// # use sqlparser::ast::{Statement, visit_statements_mut};
651/// # use core::ops::ControlFlow;
652/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
653/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
654///
655/// // Remove all select limits in outer statements (not in sub-queries)
656/// visit_statements_mut(&mut statements, |stmt| {
657///   if let Statement::Query(q) = stmt {
658///      q.limit_clause = None;
659///   }
660///   ControlFlow::<()>::Continue(())
661/// });
662///
663/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
664/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
665/// ```
666pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
667where
668    V: VisitMut,
669    F: FnMut(&mut Statement) -> ControlFlow<E>,
670{
671    v.visit(&mut StatementVisitor(f))?;
672    ControlFlow::Continue(())
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use crate::ast::Statement;
679    use crate::dialect::GenericDialect;
680    use crate::parser::Parser;
681    use crate::tokenizer::Tokenizer;
682
683    #[derive(Default)]
684    struct TestVisitor {
685        visited: Vec<String>,
686    }
687
688    impl Visitor for TestVisitor {
689        type Break = ();
690
691        /// Invoked for any queries that appear in the AST before visiting children
692        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
693            self.visited.push(format!("PRE: QUERY: {query}"));
694            ControlFlow::Continue(())
695        }
696
697        /// Invoked for any queries that appear in the AST after visiting children
698        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
699            self.visited.push(format!("POST: QUERY: {query}"));
700            ControlFlow::Continue(())
701        }
702
703        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
704            self.visited.push(format!("PRE: RELATION: {relation}"));
705            ControlFlow::Continue(())
706        }
707
708        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
709            self.visited.push(format!("POST: RELATION: {relation}"));
710            ControlFlow::Continue(())
711        }
712
713        fn pre_visit_table_factor(
714            &mut self,
715            table_factor: &TableFactor,
716        ) -> ControlFlow<Self::Break> {
717            self.visited
718                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
719            ControlFlow::Continue(())
720        }
721
722        fn post_visit_table_factor(
723            &mut self,
724            table_factor: &TableFactor,
725        ) -> ControlFlow<Self::Break> {
726            self.visited
727                .push(format!("POST: TABLE FACTOR: {table_factor}"));
728            ControlFlow::Continue(())
729        }
730
731        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
732            self.visited.push(format!("PRE: EXPR: {expr}"));
733            ControlFlow::Continue(())
734        }
735
736        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
737            self.visited.push(format!("POST: EXPR: {expr}"));
738            ControlFlow::Continue(())
739        }
740
741        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
742            self.visited.push(format!("PRE: STATEMENT: {statement}"));
743            ControlFlow::Continue(())
744        }
745
746        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
747            self.visited.push(format!("POST: STATEMENT: {statement}"));
748            ControlFlow::Continue(())
749        }
750    }
751
752    fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
753        let dialect = GenericDialect {};
754        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
755        let s = Parser::new(&dialect)
756            .with_tokens(tokens)
757            .parse_statement()
758            .unwrap();
759
760        let flow = s.visit(visitor);
761        assert_eq!(flow, ControlFlow::Continue(()));
762        s
763    }
764
765    #[test]
766    fn test_sql() {
767        let tests = vec![
768            (
769                "SELECT * from table_name as my_table",
770                vec![
771                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
772                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
773                    "PRE: TABLE FACTOR: table_name AS my_table",
774                    "PRE: RELATION: table_name",
775                    "POST: RELATION: table_name",
776                    "POST: TABLE FACTOR: table_name AS my_table",
777                    "POST: QUERY: SELECT * FROM table_name AS my_table",
778                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
779                ],
780            ),
781            (
782                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
783                vec![
784                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
785                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
786                    "PRE: TABLE FACTOR: t1",
787                    "PRE: RELATION: t1",
788                    "POST: RELATION: t1",
789                    "POST: TABLE FACTOR: t1",
790                    "PRE: TABLE FACTOR: t2",
791                    "PRE: RELATION: t2",
792                    "POST: RELATION: t2",
793                    "POST: TABLE FACTOR: t2",
794                    "PRE: EXPR: t1.id = t2.t1_id",
795                    "PRE: EXPR: t1.id",
796                    "POST: EXPR: t1.id",
797                    "PRE: EXPR: t2.t1_id",
798                    "POST: EXPR: t2.t1_id",
799                    "POST: EXPR: t1.id = t2.t1_id",
800                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
801                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
802                ],
803            ),
804            (
805                "SELECT * from t1 where EXISTS(SELECT column from t2)",
806                vec![
807                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
808                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
809                    "PRE: TABLE FACTOR: t1",
810                    "PRE: RELATION: t1",
811                    "POST: RELATION: t1",
812                    "POST: TABLE FACTOR: t1",
813                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
814                    "PRE: QUERY: SELECT column FROM t2",
815                    "PRE: EXPR: column",
816                    "POST: EXPR: column",
817                    "PRE: TABLE FACTOR: t2",
818                    "PRE: RELATION: t2",
819                    "POST: RELATION: t2",
820                    "POST: TABLE FACTOR: t2",
821                    "POST: QUERY: SELECT column FROM t2",
822                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
823                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
824                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
825                ],
826            ),
827            (
828                "SELECT * from t1 where EXISTS(SELECT column from t2)",
829                vec![
830                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
831                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
832                    "PRE: TABLE FACTOR: t1",
833                    "PRE: RELATION: t1",
834                    "POST: RELATION: t1",
835                    "POST: TABLE FACTOR: t1",
836                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
837                    "PRE: QUERY: SELECT column FROM t2",
838                    "PRE: EXPR: column",
839                    "POST: EXPR: column",
840                    "PRE: TABLE FACTOR: t2",
841                    "PRE: RELATION: t2",
842                    "POST: RELATION: t2",
843                    "POST: TABLE FACTOR: t2",
844                    "POST: QUERY: SELECT column FROM t2",
845                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
846                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
847                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
848                ],
849            ),
850            (
851                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
852                vec![
853                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
854                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
855                    "PRE: TABLE FACTOR: t1",
856                    "PRE: RELATION: t1",
857                    "POST: RELATION: t1",
858                    "POST: TABLE FACTOR: t1",
859                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
860                    "PRE: QUERY: SELECT column FROM t2",
861                    "PRE: EXPR: column",
862                    "POST: EXPR: column",
863                    "PRE: TABLE FACTOR: t2",
864                    "PRE: RELATION: t2",
865                    "POST: RELATION: t2",
866                    "POST: TABLE FACTOR: t2",
867                    "POST: QUERY: SELECT column FROM t2",
868                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
869                    "PRE: TABLE FACTOR: t3",
870                    "PRE: RELATION: t3",
871                    "POST: RELATION: t3",
872                    "POST: TABLE FACTOR: t3",
873                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
874                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
875                ],
876            ),
877            (
878                concat!(
879                    "SELECT * FROM monthly_sales ",
880                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
881                    "ORDER BY EMPID"
882                ),
883                vec![
884                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
885                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
886                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
887                    "PRE: TABLE FACTOR: monthly_sales",
888                    "PRE: RELATION: monthly_sales",
889                    "POST: RELATION: monthly_sales",
890                    "POST: TABLE FACTOR: monthly_sales",
891                    "PRE: EXPR: SUM(a.amount)",
892                    "PRE: EXPR: a.amount",
893                    "POST: EXPR: a.amount",
894                    "POST: EXPR: SUM(a.amount)",
895                    "PRE: EXPR: a.MONTH",
896                    "POST: EXPR: a.MONTH",
897                    "PRE: EXPR: 'JAN'",
898                    "POST: EXPR: 'JAN'",
899                    "PRE: EXPR: 'FEB'",
900                    "POST: EXPR: 'FEB'",
901                    "PRE: EXPR: 'MAR'",
902                    "POST: EXPR: 'MAR'",
903                    "PRE: EXPR: 'APR'",
904                    "POST: EXPR: 'APR'",
905                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
906                    "PRE: EXPR: EMPID",
907                    "POST: EXPR: EMPID",
908                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
909                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
910                ]
911            ),
912            (
913                "SHOW COLUMNS FROM t1",
914                vec![
915                    "PRE: STATEMENT: SHOW COLUMNS FROM t1",
916                    "PRE: RELATION: t1",
917                    "POST: RELATION: t1",
918                    "POST: STATEMENT: SHOW COLUMNS FROM t1",
919                ],
920            ),
921        ];
922        for (sql, expected) in tests {
923            let mut visitor = TestVisitor::default();
924            let _ = do_visit(sql, &mut visitor);
925            let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
926            assert_eq!(actual, expected)
927        }
928    }
929
930    struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
931
932    impl Visitor for QuickVisitor {
933        type Break = ();
934    }
935
936    #[test]
937    fn overflow() {
938        let cond = (0..1000)
939            .map(|n| format!("X = {n}"))
940            .collect::<Vec<_>>()
941            .join(" OR ");
942        let sql = format!("SELECT x where {cond}");
943
944        let dialect = GenericDialect {};
945        let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
946        let s = Parser::new(&dialect)
947            .with_tokens(tokens)
948            .parse_statement()
949            .unwrap();
950
951        let mut visitor = QuickVisitor {};
952        let flow = s.visit(&mut visitor);
953        assert_eq!(flow, ControlFlow::Continue(()));
954    }
955}
956
957#[cfg(test)]
958mod visit_mut_tests {
959    use crate::ast::{Statement, Value, VisitMut, VisitorMut};
960    use crate::dialect::GenericDialect;
961    use crate::parser::Parser;
962    use crate::tokenizer::Tokenizer;
963    use core::ops::ControlFlow;
964
965    #[derive(Default)]
966    struct MutatorVisitor {
967        index: u64,
968    }
969
970    impl VisitorMut for MutatorVisitor {
971        type Break = ();
972
973        fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
974            self.index += 1;
975            *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
976            ControlFlow::Continue(())
977        }
978
979        fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
980            ControlFlow::Continue(())
981        }
982    }
983
984    fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
985        let dialect = GenericDialect {};
986        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
987        let mut s = Parser::new(&dialect)
988            .with_tokens(tokens)
989            .parse_statement()
990            .unwrap();
991
992        let flow = s.visit(visitor);
993        assert_eq!(flow, ControlFlow::Continue(()));
994        s
995    }
996
997    #[test]
998    fn test_value_redact() {
999        let tests = vec![
1000            (
1001                concat!(
1002                    "SELECT * FROM monthly_sales ",
1003                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1004                    "ORDER BY EMPID"
1005                ),
1006                concat!(
1007                    "SELECT * FROM monthly_sales ",
1008                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1009                    "ORDER BY EMPID"
1010                ),
1011            ),
1012        ];
1013
1014        for (sql, expected) in tests {
1015            let mut visitor = MutatorVisitor::default();
1016            let mutated = do_visit_mut(sql, &mut visitor);
1017            assert_eq!(mutated.to_string(), expected)
1018        }
1019    }
1020}