Skip to main content

sqlparser/ast/
visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35    /// Visit this node with the provided [`Visitor`].
36    ///
37    /// Implementations should call the appropriate visitor hooks to traverse
38    /// child nodes and return a `ControlFlow` value to allow early exit.
39    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
40}
41
42/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
43/// recursively visiting parsed SQL statements.
44///
45/// # Note
46///
47/// This trait should be automatically derived for sqlparser AST nodes
48/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
49///
50/// ```text
51/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
52/// ```
53pub trait VisitMut {
54    /// Mutably visit this node with the provided [`VisitorMut`].
55    ///
56    /// Implementations should call the appropriate mutable visitor hooks to
57    /// traverse and allow in-place mutation of child nodes. Returning a
58    /// `ControlFlow` value permits early termination of the traversal.
59    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
60}
61
62impl<T: Visit> Visit for Option<T> {
63    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64        if let Some(s) = self {
65            s.visit(visitor)?;
66        }
67        ControlFlow::Continue(())
68    }
69}
70
71impl<T: Visit> Visit for Vec<T> {
72    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73        for v in self {
74            v.visit(visitor)?;
75        }
76        ControlFlow::Continue(())
77    }
78}
79
80impl<T: Visit> Visit for Box<T> {
81    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
82        T::visit(self, visitor)
83    }
84}
85
86impl<T: VisitMut> VisitMut for Option<T> {
87    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88        if let Some(s) = self {
89            s.visit(visitor)?;
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95impl<T: VisitMut> VisitMut for Vec<T> {
96    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97        for v in self {
98            v.visit(visitor)?;
99        }
100        ControlFlow::Continue(())
101    }
102}
103
104impl<T: VisitMut> VisitMut for Box<T> {
105    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
106        T::visit(self, visitor)
107    }
108}
109
110macro_rules! visit_noop {
111    ($($t:ty),+) => {
112        $(impl Visit for $t {
113            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
114               ControlFlow::Continue(())
115            }
116        })+
117        $(impl VisitMut for $t {
118            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
119               ControlFlow::Continue(())
120            }
121        })+
122    };
123}
124
125visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
126
127#[cfg(feature = "bigdecimal")]
128visit_noop!(bigdecimal::BigDecimal);
129
130/// A visitor that can be used to walk an AST tree.
131///
132/// `pre_visit_` methods are invoked before visiting all children of the
133/// node and `post_visit_` methods are invoked after visiting all
134/// children of the node.
135///
136/// # See also
137///
138/// These methods provide a more concise way of visiting nodes of a certain type:
139/// * [visit_relations]
140/// * [visit_expressions]
141/// * [visit_statements]
142///
143/// # Example
144/// ```
145/// # use sqlparser::parser::Parser;
146/// # use sqlparser::dialect::GenericDialect;
147/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
148/// # use core::ops::ControlFlow;
149/// // A structure that records statements and relations
150/// #[derive(Default)]
151/// struct V {
152///    visited: Vec<String>,
153/// }
154///
155/// // Visit relations and exprs before children are visited (depth first walk)
156/// // Note you can also visit statements and visit exprs after children have been visited
157/// impl Visitor for V {
158///   type Break = ();
159///
160///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
161///     self.visited.push(format!("PRE: RELATION: {}", relation));
162///     ControlFlow::Continue(())
163///   }
164///
165///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
166///     self.visited.push(format!("PRE: EXPR: {}", expr));
167///     ControlFlow::Continue(())
168///   }
169/// }
170///
171/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
172/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
173///    .unwrap();
174///
175/// // Drive the visitor through the AST
176/// let mut visitor = V::default();
177/// statements.visit(&mut visitor);
178///
179/// // The visitor has visited statements and expressions in pre-traversal order
180/// let expected : Vec<_> = [
181///   "PRE: EXPR: a",
182///   "PRE: RELATION: foo",
183///   "PRE: EXPR: x IN (SELECT y FROM bar)",
184///   "PRE: EXPR: x",
185///   "PRE: EXPR: y",
186///   "PRE: RELATION: bar",
187/// ]
188///   .into_iter().map(|s| s.to_string()).collect();
189///
190/// assert_eq!(visitor.visited, expected);
191/// ```
192pub trait Visitor {
193    /// Type returned when the recursion returns early.
194    ///
195    /// Important note: The `Break` type should be kept as small as possible to prevent
196    /// stack overflow during recursion. If you need to return an error, consider
197    /// boxing it with `Box` to minimize stack usage.
198    type Break;
199
200    /// Invoked for any queries that appear in the AST before visiting children
201    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
202        ControlFlow::Continue(())
203    }
204
205    /// Invoked for any queries that appear in the AST after visiting children
206    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
207        ControlFlow::Continue(())
208    }
209
210    /// Invoked for any [Select] that appear in the AST before visiting children
211    fn pre_visit_select(&mut self, _select: &Select) -> ControlFlow<Self::Break> {
212        ControlFlow::Continue(())
213    }
214
215    /// Invoked for any [Select] that appear in the AST after visiting children
216    fn post_visit_select(&mut self, _select: &Select) -> ControlFlow<Self::Break> {
217        ControlFlow::Continue(())
218    }
219
220    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
221    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
222        ControlFlow::Continue(())
223    }
224
225    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
226    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
227        ControlFlow::Continue(())
228    }
229
230    /// Invoked for any table factors that appear in the AST before visiting children
231    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
232        ControlFlow::Continue(())
233    }
234
235    /// Invoked for any table factors that appear in the AST after visiting children
236    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
237        ControlFlow::Continue(())
238    }
239
240    /// Invoked for any expressions that appear in the AST before visiting children
241    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
242        ControlFlow::Continue(())
243    }
244
245    /// Invoked for any expressions that appear in the AST
246    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
247        ControlFlow::Continue(())
248    }
249
250    /// Invoked for any statements that appear in the AST before visiting children
251    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
252        ControlFlow::Continue(())
253    }
254
255    /// Invoked for any statements that appear in the AST after visiting children
256    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
257        ControlFlow::Continue(())
258    }
259
260    /// Invoked for any Value that appear in the AST before visiting children
261    fn pre_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow<Self::Break> {
262        ControlFlow::Continue(())
263    }
264
265    /// Invoked for any Value that appear in the AST after visiting children
266    fn post_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow<Self::Break> {
267        ControlFlow::Continue(())
268    }
269}
270
271/// A visitor that can be used to mutate an AST tree.
272///
273/// `pre_visit_` methods are invoked before visiting all children of the
274/// node and `post_visit_` methods are invoked after visiting all
275/// children of the node.
276///
277/// # See also
278///
279/// These methods provide a more concise way of visiting nodes of a certain type:
280/// * [visit_relations_mut]
281/// * [visit_expressions_mut]
282/// * [visit_statements_mut]
283///
284/// # Example
285/// ```
286/// # use sqlparser::parser::Parser;
287/// # use sqlparser::dialect::GenericDialect;
288/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
289/// # use core::ops::ControlFlow;
290///
291/// // A visitor that replaces "to_replace" with "replaced" in all expressions
292/// struct Replacer;
293///
294/// // Visit each expression after its children have been visited
295/// impl VisitorMut for Replacer {
296///   type Break = ();
297///
298///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
299///     if let Expr::Identifier(Ident{ value, ..}) = expr {
300///         *value = value.replace("to_replace", "replaced")
301///     }
302///     ControlFlow::Continue(())
303///   }
304/// }
305///
306/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
307/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
308///
309/// // Drive the visitor through the AST
310/// statements.visit(&mut Replacer);
311///
312/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
313/// ```
314pub trait VisitorMut {
315    /// Type returned when the recursion returns early.
316    ///
317    /// Important note: The `Break` type should be kept as small as possible to prevent
318    /// stack overflow during recursion. If you need to return an error, consider
319    /// boxing it with `Box` to minimize stack usage.
320    type Break;
321
322    /// Invoked for any queries that appear in the AST before visiting children
323    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
324        ControlFlow::Continue(())
325    }
326
327    /// Invoked for any queries that appear in the AST after visiting children
328    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
329        ControlFlow::Continue(())
330    }
331
332    /// Invoked for any [Select] that appear in the AST before visiting children
333    fn pre_visit_select(&mut self, _select: &mut Select) -> ControlFlow<Self::Break> {
334        ControlFlow::Continue(())
335    }
336
337    /// Invoked for any [Select] that appear in the AST after visiting children
338    fn post_visit_select(&mut self, _select: &mut Select) -> ControlFlow<Self::Break> {
339        ControlFlow::Continue(())
340    }
341
342    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
343    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
344        ControlFlow::Continue(())
345    }
346
347    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
348    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
349        ControlFlow::Continue(())
350    }
351
352    /// Invoked for any table factors that appear in the AST before visiting children
353    fn pre_visit_table_factor(
354        &mut self,
355        _table_factor: &mut TableFactor,
356    ) -> ControlFlow<Self::Break> {
357        ControlFlow::Continue(())
358    }
359
360    /// Invoked for any table factors that appear in the AST after visiting children
361    fn post_visit_table_factor(
362        &mut self,
363        _table_factor: &mut TableFactor,
364    ) -> ControlFlow<Self::Break> {
365        ControlFlow::Continue(())
366    }
367
368    /// Invoked for any expressions that appear in the AST before visiting children
369    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
370        ControlFlow::Continue(())
371    }
372
373    /// Invoked for any expressions that appear in the AST
374    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
375        ControlFlow::Continue(())
376    }
377
378    /// Invoked for any statements that appear in the AST before visiting children
379    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
380        ControlFlow::Continue(())
381    }
382
383    /// Invoked for any statements that appear in the AST after visiting children
384    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
385        ControlFlow::Continue(())
386    }
387
388    /// Invoked for any value that appear in the AST before visiting children
389    fn pre_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
390        ControlFlow::Continue(())
391    }
392
393    /// Invoked for any statements that appear in the AST after visiting children
394    fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
395        ControlFlow::Continue(())
396    }
397}
398
399struct RelationVisitor<F>(F);
400
401impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
402    type Break = E;
403
404    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
405        self.0(relation)
406    }
407}
408
409impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
410    type Break = E;
411
412    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
413        self.0(relation)
414    }
415}
416
417/// Invokes the provided closure on all relations (e.g. table names) present in `v`
418///
419/// # Example
420/// ```
421/// # use sqlparser::parser::Parser;
422/// # use sqlparser::dialect::GenericDialect;
423/// # use sqlparser::ast::{visit_relations};
424/// # use core::ops::ControlFlow;
425/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
426/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
427///    .unwrap();
428///
429/// // visit statements, capturing relations (table names)
430/// let mut visited = vec![];
431/// visit_relations(&statements, |relation| {
432///   visited.push(format!("RELATION: {}", relation));
433///   ControlFlow::<()>::Continue(())
434/// });
435///
436/// let expected : Vec<_> = [
437///   "RELATION: foo",
438///   "RELATION: bar",
439/// ]
440///   .into_iter().map(|s| s.to_string()).collect();
441///
442/// assert_eq!(visited, expected);
443/// ```
444pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
445where
446    V: Visit,
447    F: FnMut(&ObjectName) -> ControlFlow<E>,
448{
449    let mut visitor = RelationVisitor(f);
450    v.visit(&mut visitor)?;
451    ControlFlow::Continue(())
452}
453
454/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
455/// present in `v`.
456///
457/// When the closure mutates its argument, the new mutated relation will not be visited again.
458///
459/// # Example
460/// ```
461/// # use sqlparser::parser::Parser;
462/// # use sqlparser::dialect::GenericDialect;
463/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
464/// # use core::ops::ControlFlow;
465/// let sql = "SELECT a FROM foo";
466/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
467///    .unwrap();
468///
469/// // visit statements, renaming table foo to bar
470/// visit_relations_mut(&mut statements, |table| {
471///   table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
472///   ControlFlow::<()>::Continue(())
473/// });
474///
475/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
476/// ```
477pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
478where
479    V: VisitMut,
480    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
481{
482    let mut visitor = RelationVisitor(f);
483    v.visit(&mut visitor)?;
484    ControlFlow::Continue(())
485}
486
487struct ExprVisitor<F>(F);
488
489impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
490    type Break = E;
491
492    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
493        self.0(expr)
494    }
495}
496
497impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
498    type Break = E;
499
500    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
501        self.0(expr)
502    }
503}
504
505/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
506///
507/// # Example
508/// ```
509/// # use sqlparser::parser::Parser;
510/// # use sqlparser::dialect::GenericDialect;
511/// # use sqlparser::ast::{visit_expressions};
512/// # use core::ops::ControlFlow;
513/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
514/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
515///    .unwrap();
516///
517/// // visit all expressions
518/// let mut visited = vec![];
519/// visit_expressions(&statements, |expr| {
520///   visited.push(format!("EXPR: {}", expr));
521///   ControlFlow::<()>::Continue(())
522/// });
523///
524/// let expected : Vec<_> = [
525///   "EXPR: a",
526///   "EXPR: x IN (SELECT y FROM bar)",
527///   "EXPR: x",
528///   "EXPR: y",
529/// ]
530///   .into_iter().map(|s| s.to_string()).collect();
531///
532/// assert_eq!(visited, expected);
533/// ```
534pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
535where
536    V: Visit,
537    F: FnMut(&Expr) -> ControlFlow<E>,
538{
539    let mut visitor = ExprVisitor(f);
540    v.visit(&mut visitor)?;
541    ControlFlow::Continue(())
542}
543
544/// Invokes the provided closure iteratively with a mutable reference to all expressions
545/// present in `v`.
546///
547/// This performs a depth-first search, so if the closure mutates the expression
548///
549/// # Example
550///
551/// ## Remove all select limits in sub-queries
552/// ```
553/// # use sqlparser::parser::Parser;
554/// # use sqlparser::dialect::GenericDialect;
555/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
556/// # use core::ops::ControlFlow;
557/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
558/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
559///
560/// // Remove all select limits in sub-queries
561/// visit_expressions_mut(&mut statements, |expr| {
562///   if let Expr::Subquery(q) = expr {
563///      q.limit_clause = None;
564///   }
565///   ControlFlow::<()>::Continue(())
566/// });
567///
568/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
569/// ```
570///
571/// ## Wrap column name in function call
572///
573/// This demonstrates how to effectively replace an expression with another more complicated one
574/// that references the original. This example avoids unnecessary allocations by using the
575/// [`std::mem`] family of functions.
576///
577/// ```
578/// # use sqlparser::parser::Parser;
579/// # use sqlparser::dialect::GenericDialect;
580/// # use sqlparser::ast::*;
581/// # use core::ops::ControlFlow;
582/// let sql = "SELECT x, y FROM t";
583/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
584///
585/// visit_expressions_mut(&mut statements, |expr| {
586///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
587///     let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
588///     *expr = Expr::Function(Function {
589///           name: ObjectName::from(vec![Ident::new("f")]),
590///           uses_odbc_syntax: false,
591///           args: FunctionArguments::List(FunctionArgumentList {
592///               duplicate_treatment: None,
593///               args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
594///               clauses: vec![],
595///           }),
596///           null_treatment: None,
597///           filter: None,
598///           over: None,
599///           parameters: FunctionArguments::None,
600///           within_group: vec![],
601///      });
602///   }
603///   ControlFlow::<()>::Continue(())
604/// });
605///
606/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
607/// ```
608pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
609where
610    V: VisitMut,
611    F: FnMut(&mut Expr) -> ControlFlow<E>,
612{
613    v.visit(&mut ExprVisitor(f))?;
614    ControlFlow::Continue(())
615}
616
617struct StatementVisitor<F>(F);
618
619impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
620    type Break = E;
621
622    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
623        self.0(statement)
624    }
625}
626
627impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
628    type Break = E;
629
630    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
631        self.0(statement)
632    }
633}
634
635/// Invokes the provided closure iteratively with a mutable reference to all statements
636/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
637///
638/// # Example
639/// ```
640/// # use sqlparser::parser::Parser;
641/// # use sqlparser::dialect::GenericDialect;
642/// # use sqlparser::ast::{visit_statements};
643/// # use core::ops::ControlFlow;
644/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
645/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
646///    .unwrap();
647///
648/// // visit all statements
649/// let mut visited = vec![];
650/// visit_statements(&statements, |stmt| {
651///   visited.push(format!("STATEMENT: {}", stmt));
652///   ControlFlow::<()>::Continue(())
653/// });
654///
655/// let expected : Vec<_> = [
656///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
657///   "STATEMENT: CREATE TABLE baz (q INT)"
658/// ]
659///   .into_iter().map(|s| s.to_string()).collect();
660///
661/// assert_eq!(visited, expected);
662/// ```
663pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
664where
665    V: Visit,
666    F: FnMut(&Statement) -> ControlFlow<E>,
667{
668    let mut visitor = StatementVisitor(f);
669    v.visit(&mut visitor)?;
670    ControlFlow::Continue(())
671}
672
673/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
674///
675/// # Example
676/// ```
677/// # use sqlparser::parser::Parser;
678/// # use sqlparser::dialect::GenericDialect;
679/// # use sqlparser::ast::{Statement, visit_statements_mut};
680/// # use core::ops::ControlFlow;
681/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
682/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
683///
684/// // Remove all select limits in outer statements (not in sub-queries)
685/// visit_statements_mut(&mut statements, |stmt| {
686///   if let Statement::Query(q) = stmt {
687///      q.limit_clause = None;
688///   }
689///   ControlFlow::<()>::Continue(())
690/// });
691///
692/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
693/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
694/// ```
695pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
696where
697    V: VisitMut,
698    F: FnMut(&mut Statement) -> ControlFlow<E>,
699{
700    v.visit(&mut StatementVisitor(f))?;
701    ControlFlow::Continue(())
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use crate::ast::Statement;
708    use crate::dialect::GenericDialect;
709    use crate::parser::Parser;
710    use crate::tokenizer::Tokenizer;
711
712    #[derive(Default)]
713    struct TestVisitor {
714        visited: Vec<String>,
715    }
716
717    impl Visitor for TestVisitor {
718        type Break = ();
719
720        /// Invoked for any queries that appear in the AST before visiting children
721        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
722            self.visited.push(format!("PRE: QUERY: {query}"));
723            ControlFlow::Continue(())
724        }
725
726        /// Invoked for any queries that appear in the AST after visiting children
727        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
728            self.visited.push(format!("POST: QUERY: {query}"));
729            ControlFlow::Continue(())
730        }
731
732        fn pre_visit_select(&mut self, select: &Select) -> ControlFlow<Self::Break> {
733            self.visited.push(format!("PRE: SELECT: {select}"));
734            ControlFlow::Continue(())
735        }
736
737        fn post_visit_select(&mut self, select: &Select) -> ControlFlow<Self::Break> {
738            self.visited.push(format!("POST: SELECT: {select}"));
739            ControlFlow::Continue(())
740        }
741
742        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
743            self.visited.push(format!("PRE: RELATION: {relation}"));
744            ControlFlow::Continue(())
745        }
746
747        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
748            self.visited.push(format!("POST: RELATION: {relation}"));
749            ControlFlow::Continue(())
750        }
751
752        fn pre_visit_table_factor(
753            &mut self,
754            table_factor: &TableFactor,
755        ) -> ControlFlow<Self::Break> {
756            self.visited
757                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
758            ControlFlow::Continue(())
759        }
760
761        fn post_visit_table_factor(
762            &mut self,
763            table_factor: &TableFactor,
764        ) -> ControlFlow<Self::Break> {
765            self.visited
766                .push(format!("POST: TABLE FACTOR: {table_factor}"));
767            ControlFlow::Continue(())
768        }
769
770        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
771            self.visited.push(format!("PRE: EXPR: {expr}"));
772            ControlFlow::Continue(())
773        }
774
775        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
776            self.visited.push(format!("POST: EXPR: {expr}"));
777            ControlFlow::Continue(())
778        }
779
780        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
781            self.visited.push(format!("PRE: STATEMENT: {statement}"));
782            ControlFlow::Continue(())
783        }
784
785        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
786            self.visited.push(format!("POST: STATEMENT: {statement}"));
787            ControlFlow::Continue(())
788        }
789    }
790
791    fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
792        let dialect = GenericDialect {};
793        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
794        let s = Parser::new(&dialect)
795            .with_tokens(tokens)
796            .parse_statement()
797            .unwrap();
798
799        let flow = s.visit(visitor);
800        assert_eq!(flow, ControlFlow::Continue(()));
801        s
802    }
803
804    #[test]
805    fn test_sql() {
806        let tests = vec![
807            (
808                "SELECT * from table_name as my_table",
809                vec![
810                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
811                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
812                    "PRE: SELECT: SELECT * FROM table_name AS my_table",
813                    "PRE: TABLE FACTOR: table_name AS my_table",
814                    "PRE: RELATION: table_name",
815                    "POST: RELATION: table_name",
816                    "POST: TABLE FACTOR: table_name AS my_table",
817                    "POST: SELECT: SELECT * FROM table_name AS my_table",
818                    "POST: QUERY: SELECT * FROM table_name AS my_table",
819                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
820                ],
821            ),
822            (
823                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
824                vec![
825                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
826                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
827                    "PRE: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
828                    "PRE: TABLE FACTOR: t1",
829                    "PRE: RELATION: t1",
830                    "POST: RELATION: t1",
831                    "POST: TABLE FACTOR: t1",
832                    "PRE: TABLE FACTOR: t2",
833                    "PRE: RELATION: t2",
834                    "POST: RELATION: t2",
835                    "POST: TABLE FACTOR: t2",
836                    "PRE: EXPR: t1.id = t2.t1_id",
837                    "PRE: EXPR: t1.id",
838                    "POST: EXPR: t1.id",
839                    "PRE: EXPR: t2.t1_id",
840                    "POST: EXPR: t2.t1_id",
841                    "POST: EXPR: t1.id = t2.t1_id",
842                    "POST: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
843                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
844                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
845                ],
846            ),
847            (
848                "SELECT * from t1 where EXISTS(SELECT column from t2)",
849                vec![
850                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
851                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
852                    "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
853                    "PRE: TABLE FACTOR: t1",
854                    "PRE: RELATION: t1",
855                    "POST: RELATION: t1",
856                    "POST: TABLE FACTOR: t1",
857                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
858                    "PRE: QUERY: SELECT column FROM t2",
859                    "PRE: SELECT: SELECT column FROM t2",
860                    "PRE: EXPR: column",
861                    "POST: EXPR: column",
862                    "PRE: TABLE FACTOR: t2",
863                    "PRE: RELATION: t2",
864                    "POST: RELATION: t2",
865                    "POST: TABLE FACTOR: t2",
866                    "POST: SELECT: SELECT column FROM t2",
867                    "POST: QUERY: SELECT column FROM t2",
868                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
869                    "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
870                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
871                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
872                ],
873            ),
874            (
875                "SELECT * from t1 where EXISTS(SELECT column from t2)",
876                vec![
877                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
878                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
879                    "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
880                    "PRE: TABLE FACTOR: t1",
881                    "PRE: RELATION: t1",
882                    "POST: RELATION: t1",
883                    "POST: TABLE FACTOR: t1",
884                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
885                    "PRE: QUERY: SELECT column FROM t2",
886                    "PRE: SELECT: SELECT column FROM t2",
887                    "PRE: EXPR: column",
888                    "POST: EXPR: column",
889                    "PRE: TABLE FACTOR: t2",
890                    "PRE: RELATION: t2",
891                    "POST: RELATION: t2",
892                    "POST: TABLE FACTOR: t2",
893                    "POST: SELECT: SELECT column FROM t2",
894                    "POST: QUERY: SELECT column FROM t2",
895                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
896                    "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
897                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
898                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
899                ],
900            ),
901            (
902                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
903                vec![
904                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
905                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
906                    "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
907                    "PRE: TABLE FACTOR: t1",
908                    "PRE: RELATION: t1",
909                    "POST: RELATION: t1",
910                    "POST: TABLE FACTOR: t1",
911                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
912                    "PRE: QUERY: SELECT column FROM t2",
913                    "PRE: SELECT: SELECT column FROM t2",
914                    "PRE: EXPR: column",
915                    "POST: EXPR: column",
916                    "PRE: TABLE FACTOR: t2",
917                    "PRE: RELATION: t2",
918                    "POST: RELATION: t2",
919                    "POST: TABLE FACTOR: t2",
920                    "POST: SELECT: SELECT column FROM t2",
921                    "POST: QUERY: SELECT column FROM t2",
922                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
923                    "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
924                    "PRE: SELECT: SELECT * FROM t3",
925                    "PRE: TABLE FACTOR: t3",
926                    "PRE: RELATION: t3",
927                    "POST: RELATION: t3",
928                    "POST: TABLE FACTOR: t3",
929                    "POST: SELECT: SELECT * FROM t3",
930                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
931                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
932                ],
933            ),
934            (
935                concat!(
936                    "SELECT * FROM monthly_sales ",
937                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
938                    "ORDER BY EMPID"
939                ),
940                vec![
941                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
942                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
943                    "PRE: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
944                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
945                    "PRE: TABLE FACTOR: monthly_sales",
946                    "PRE: RELATION: monthly_sales",
947                    "POST: RELATION: monthly_sales",
948                    "POST: TABLE FACTOR: monthly_sales",
949                    "PRE: EXPR: SUM(a.amount)",
950                    "PRE: EXPR: a.amount",
951                    "POST: EXPR: a.amount",
952                    "POST: EXPR: SUM(a.amount)",
953                    "PRE: EXPR: a.MONTH",
954                    "POST: EXPR: a.MONTH",
955                    "PRE: EXPR: 'JAN'",
956                    "POST: EXPR: 'JAN'",
957                    "PRE: EXPR: 'FEB'",
958                    "POST: EXPR: 'FEB'",
959                    "PRE: EXPR: 'MAR'",
960                    "POST: EXPR: 'MAR'",
961                    "PRE: EXPR: 'APR'",
962                    "POST: EXPR: 'APR'",
963                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
964                    "POST: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
965                    "PRE: EXPR: EMPID",
966                    "POST: EXPR: EMPID",
967                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
968                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
969                ]
970            ),
971            (
972                "SHOW COLUMNS FROM t1",
973                vec![
974                    "PRE: STATEMENT: SHOW COLUMNS FROM t1",
975                    "PRE: RELATION: t1",
976                    "POST: RELATION: t1",
977                    "POST: STATEMENT: SHOW COLUMNS FROM t1",
978                ],
979            ),
980        ];
981        for (sql, expected) in tests {
982            let mut visitor = TestVisitor::default();
983            let _ = do_visit(sql, &mut visitor);
984            let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
985            assert_eq!(actual, expected)
986        }
987    }
988
989    struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
990
991    impl Visitor for QuickVisitor {
992        type Break = ();
993    }
994
995    #[test]
996    fn overflow() {
997        let cond = (0..1000)
998            .map(|n| format!("X = {n}"))
999            .collect::<Vec<_>>()
1000            .join(" OR ");
1001        let sql = format!("SELECT x where {cond}");
1002
1003        let dialect = GenericDialect {};
1004        let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
1005        let s = Parser::new(&dialect)
1006            .with_tokens(tokens)
1007            .parse_statement()
1008            .unwrap();
1009
1010        let mut visitor = QuickVisitor {};
1011        let flow = s.visit(&mut visitor);
1012        assert_eq!(flow, ControlFlow::Continue(()));
1013    }
1014}
1015
1016#[cfg(test)]
1017mod visit_mut_tests {
1018    use crate::ast::{Statement, Value, ValueWithSpan, VisitMut, VisitorMut};
1019    use crate::dialect::GenericDialect;
1020    use crate::parser::Parser;
1021    use crate::tokenizer::Tokenizer;
1022    use core::ops::ControlFlow;
1023
1024    #[derive(Default)]
1025    struct MutatorVisitor {
1026        index: u64,
1027    }
1028
1029    impl VisitorMut for MutatorVisitor {
1030        type Break = ();
1031
1032        fn pre_visit_value(&mut self, value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
1033            self.index += 1;
1034            value.value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
1035            ControlFlow::Continue(())
1036        }
1037
1038        fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
1039            ControlFlow::Continue(())
1040        }
1041    }
1042
1043    fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
1044        let dialect = GenericDialect {};
1045        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
1046        let mut s = Parser::new(&dialect)
1047            .with_tokens(tokens)
1048            .parse_statement()
1049            .unwrap();
1050
1051        let flow = s.visit(visitor);
1052        assert_eq!(flow, ControlFlow::Continue(()));
1053        s
1054    }
1055
1056    #[test]
1057    fn test_value_redact() {
1058        let tests = vec![
1059            (
1060                concat!(
1061                    "SELECT * FROM monthly_sales ",
1062                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1063                    "ORDER BY EMPID"
1064                ),
1065                concat!(
1066                    "SELECT * FROM monthly_sales ",
1067                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1068                    "ORDER BY EMPID"
1069                ),
1070            ),
1071        ];
1072
1073        for (sql, expected) in tests {
1074            let mut visitor = MutatorVisitor::default();
1075            let mutated = do_visit_mut(sql, &mut visitor);
1076            assert_eq!(mutated.to_string(), expected)
1077        }
1078    }
1079}