sqlparser/ast/visitor.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqlparser AST nodes
44/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55 if let Some(s) = self {
56 s.visit(visitor)?;
57 }
58 ControlFlow::Continue(())
59 }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64 for v in self {
65 v.visit(visitor)?;
66 }
67 ControlFlow::Continue(())
68 }
69}
70
71impl<T: Visit> Visit for Box<T> {
72 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73 T::visit(self, visitor)
74 }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79 if let Some(s) = self {
80 s.visit(visitor)?;
81 }
82 ControlFlow::Continue(())
83 }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88 for v in self {
89 v.visit(visitor)?;
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97 T::visit(self, visitor)
98 }
99}
100
101macro_rules! visit_noop {
102 ($($t:ty),+) => {
103 $(impl Visit for $t {
104 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105 ControlFlow::Continue(())
106 }
107 })+
108 $(impl VisitMut for $t {
109 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110 ControlFlow::Continue(())
111 }
112 })+
113 };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqlparser::parser::Parser;
137/// # use sqlparser::dialect::GenericDialect;
138/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143/// visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149/// type Break = ();
150///
151/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152/// self.visited.push(format!("PRE: RELATION: {}", relation));
153/// ControlFlow::Continue(())
154/// }
155///
156/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157/// self.visited.push(format!("PRE: EXPR: {}", expr));
158/// ControlFlow::Continue(())
159/// }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164/// .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172/// "PRE: EXPR: a",
173/// "PRE: RELATION: foo",
174/// "PRE: EXPR: x IN (SELECT y FROM bar)",
175/// "PRE: EXPR: x",
176/// "PRE: EXPR: y",
177/// "PRE: RELATION: bar",
178/// ]
179/// .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184 /// Type returned when the recursion returns early.
185 ///
186 /// Important note: The `Break` type should be kept as small as possible to prevent
187 /// stack overflow during recursion. If you need to return an error, consider
188 /// boxing it with `Box` to minimize stack usage.
189 type Break;
190
191 /// Invoked for any queries that appear in the AST before visiting children
192 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
193 ControlFlow::Continue(())
194 }
195
196 /// Invoked for any queries that appear in the AST after visiting children
197 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
198 ControlFlow::Continue(())
199 }
200
201 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
202 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
203 ControlFlow::Continue(())
204 }
205
206 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
207 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
208 ControlFlow::Continue(())
209 }
210
211 /// Invoked for any table factors that appear in the AST before visiting children
212 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
213 ControlFlow::Continue(())
214 }
215
216 /// Invoked for any table factors that appear in the AST after visiting children
217 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
218 ControlFlow::Continue(())
219 }
220
221 /// Invoked for any expressions that appear in the AST before visiting children
222 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
223 ControlFlow::Continue(())
224 }
225
226 /// Invoked for any expressions that appear in the AST
227 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
228 ControlFlow::Continue(())
229 }
230
231 /// Invoked for any statements that appear in the AST before visiting children
232 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
233 ControlFlow::Continue(())
234 }
235
236 /// Invoked for any statements that appear in the AST after visiting children
237 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
238 ControlFlow::Continue(())
239 }
240
241 /// Invoked for any Value that appear in the AST before visiting children
242 fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
243 ControlFlow::Continue(())
244 }
245
246 /// Invoked for any Value that appear in the AST after visiting children
247 fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
248 ControlFlow::Continue(())
249 }
250}
251
252/// A visitor that can be used to mutate an AST tree.
253///
254/// `pre_visit_` methods are invoked before visiting all children of the
255/// node and `post_visit_` methods are invoked after visiting all
256/// children of the node.
257///
258/// # See also
259///
260/// These methods provide a more concise way of visiting nodes of a certain type:
261/// * [visit_relations_mut]
262/// * [visit_expressions_mut]
263/// * [visit_statements_mut]
264///
265/// # Example
266/// ```
267/// # use sqlparser::parser::Parser;
268/// # use sqlparser::dialect::GenericDialect;
269/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
270/// # use core::ops::ControlFlow;
271///
272/// // A visitor that replaces "to_replace" with "replaced" in all expressions
273/// struct Replacer;
274///
275/// // Visit each expression after its children have been visited
276/// impl VisitorMut for Replacer {
277/// type Break = ();
278///
279/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
280/// if let Expr::Identifier(Ident{ value, ..}) = expr {
281/// *value = value.replace("to_replace", "replaced")
282/// }
283/// ControlFlow::Continue(())
284/// }
285/// }
286///
287/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
288/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
289///
290/// // Drive the visitor through the AST
291/// statements.visit(&mut Replacer);
292///
293/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
294/// ```
295pub trait VisitorMut {
296 /// Type returned when the recursion returns early.
297 ///
298 /// Important note: The `Break` type should be kept as small as possible to prevent
299 /// stack overflow during recursion. If you need to return an error, consider
300 /// boxing it with `Box` to minimize stack usage.
301 type Break;
302
303 /// Invoked for any queries that appear in the AST before visiting children
304 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
305 ControlFlow::Continue(())
306 }
307
308 /// Invoked for any queries that appear in the AST after visiting children
309 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
310 ControlFlow::Continue(())
311 }
312
313 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
314 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
315 ControlFlow::Continue(())
316 }
317
318 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
319 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
320 ControlFlow::Continue(())
321 }
322
323 /// Invoked for any table factors that appear in the AST before visiting children
324 fn pre_visit_table_factor(
325 &mut self,
326 _table_factor: &mut TableFactor,
327 ) -> ControlFlow<Self::Break> {
328 ControlFlow::Continue(())
329 }
330
331 /// Invoked for any table factors that appear in the AST after visiting children
332 fn post_visit_table_factor(
333 &mut self,
334 _table_factor: &mut TableFactor,
335 ) -> ControlFlow<Self::Break> {
336 ControlFlow::Continue(())
337 }
338
339 /// Invoked for any expressions that appear in the AST before visiting children
340 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
341 ControlFlow::Continue(())
342 }
343
344 /// Invoked for any expressions that appear in the AST
345 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
346 ControlFlow::Continue(())
347 }
348
349 /// Invoked for any statements that appear in the AST before visiting children
350 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
351 ControlFlow::Continue(())
352 }
353
354 /// Invoked for any statements that appear in the AST after visiting children
355 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
356 ControlFlow::Continue(())
357 }
358
359 /// Invoked for any value that appear in the AST before visiting children
360 fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
361 ControlFlow::Continue(())
362 }
363
364 /// Invoked for any statements that appear in the AST after visiting children
365 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
366 ControlFlow::Continue(())
367 }
368}
369
370struct RelationVisitor<F>(F);
371
372impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
373 type Break = E;
374
375 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
376 self.0(relation)
377 }
378}
379
380impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
381 type Break = E;
382
383 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
384 self.0(relation)
385 }
386}
387
388/// Invokes the provided closure on all relations (e.g. table names) present in `v`
389///
390/// # Example
391/// ```
392/// # use sqlparser::parser::Parser;
393/// # use sqlparser::dialect::GenericDialect;
394/// # use sqlparser::ast::{visit_relations};
395/// # use core::ops::ControlFlow;
396/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
397/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
398/// .unwrap();
399///
400/// // visit statements, capturing relations (table names)
401/// let mut visited = vec![];
402/// visit_relations(&statements, |relation| {
403/// visited.push(format!("RELATION: {}", relation));
404/// ControlFlow::<()>::Continue(())
405/// });
406///
407/// let expected : Vec<_> = [
408/// "RELATION: foo",
409/// "RELATION: bar",
410/// ]
411/// .into_iter().map(|s| s.to_string()).collect();
412///
413/// assert_eq!(visited, expected);
414/// ```
415pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
416where
417 V: Visit,
418 F: FnMut(&ObjectName) -> ControlFlow<E>,
419{
420 let mut visitor = RelationVisitor(f);
421 v.visit(&mut visitor)?;
422 ControlFlow::Continue(())
423}
424
425/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
426/// present in `v`.
427///
428/// When the closure mutates its argument, the new mutated relation will not be visited again.
429///
430/// # Example
431/// ```
432/// # use sqlparser::parser::Parser;
433/// # use sqlparser::dialect::GenericDialect;
434/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
435/// # use core::ops::ControlFlow;
436/// let sql = "SELECT a FROM foo";
437/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
438/// .unwrap();
439///
440/// // visit statements, renaming table foo to bar
441/// visit_relations_mut(&mut statements, |table| {
442/// table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
443/// ControlFlow::<()>::Continue(())
444/// });
445///
446/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
447/// ```
448pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
449where
450 V: VisitMut,
451 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
452{
453 let mut visitor = RelationVisitor(f);
454 v.visit(&mut visitor)?;
455 ControlFlow::Continue(())
456}
457
458struct ExprVisitor<F>(F);
459
460impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
461 type Break = E;
462
463 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
464 self.0(expr)
465 }
466}
467
468impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
469 type Break = E;
470
471 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
472 self.0(expr)
473 }
474}
475
476/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
477///
478/// # Example
479/// ```
480/// # use sqlparser::parser::Parser;
481/// # use sqlparser::dialect::GenericDialect;
482/// # use sqlparser::ast::{visit_expressions};
483/// # use core::ops::ControlFlow;
484/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
485/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
486/// .unwrap();
487///
488/// // visit all expressions
489/// let mut visited = vec![];
490/// visit_expressions(&statements, |expr| {
491/// visited.push(format!("EXPR: {}", expr));
492/// ControlFlow::<()>::Continue(())
493/// });
494///
495/// let expected : Vec<_> = [
496/// "EXPR: a",
497/// "EXPR: x IN (SELECT y FROM bar)",
498/// "EXPR: x",
499/// "EXPR: y",
500/// ]
501/// .into_iter().map(|s| s.to_string()).collect();
502///
503/// assert_eq!(visited, expected);
504/// ```
505pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
506where
507 V: Visit,
508 F: FnMut(&Expr) -> ControlFlow<E>,
509{
510 let mut visitor = ExprVisitor(f);
511 v.visit(&mut visitor)?;
512 ControlFlow::Continue(())
513}
514
515/// Invokes the provided closure iteratively with a mutable reference to all expressions
516/// present in `v`.
517///
518/// This performs a depth-first search, so if the closure mutates the expression
519///
520/// # Example
521///
522/// ## Remove all select limits in sub-queries
523/// ```
524/// # use sqlparser::parser::Parser;
525/// # use sqlparser::dialect::GenericDialect;
526/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
527/// # use core::ops::ControlFlow;
528/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
529/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
530///
531/// // Remove all select limits in sub-queries
532/// visit_expressions_mut(&mut statements, |expr| {
533/// if let Expr::Subquery(q) = expr {
534/// q.limit_clause = None;
535/// }
536/// ControlFlow::<()>::Continue(())
537/// });
538///
539/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
540/// ```
541///
542/// ## Wrap column name in function call
543///
544/// This demonstrates how to effectively replace an expression with another more complicated one
545/// that references the original. This example avoids unnecessary allocations by using the
546/// [`std::mem`] family of functions.
547///
548/// ```
549/// # use sqlparser::parser::Parser;
550/// # use sqlparser::dialect::GenericDialect;
551/// # use sqlparser::ast::*;
552/// # use core::ops::ControlFlow;
553/// let sql = "SELECT x, y FROM t";
554/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
555///
556/// visit_expressions_mut(&mut statements, |expr| {
557/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
558/// let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
559/// *expr = Expr::Function(Function {
560/// name: ObjectName::from(vec![Ident::new("f")]),
561/// uses_odbc_syntax: false,
562/// args: FunctionArguments::List(FunctionArgumentList {
563/// duplicate_treatment: None,
564/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
565/// clauses: vec![],
566/// }),
567/// null_treatment: None,
568/// filter: None,
569/// over: None,
570/// parameters: FunctionArguments::None,
571/// within_group: vec![],
572/// });
573/// }
574/// ControlFlow::<()>::Continue(())
575/// });
576///
577/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
578/// ```
579pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
580where
581 V: VisitMut,
582 F: FnMut(&mut Expr) -> ControlFlow<E>,
583{
584 v.visit(&mut ExprVisitor(f))?;
585 ControlFlow::Continue(())
586}
587
588struct StatementVisitor<F>(F);
589
590impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
591 type Break = E;
592
593 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
594 self.0(statement)
595 }
596}
597
598impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
599 type Break = E;
600
601 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
602 self.0(statement)
603 }
604}
605
606/// Invokes the provided closure iteratively with a mutable reference to all statements
607/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
608///
609/// # Example
610/// ```
611/// # use sqlparser::parser::Parser;
612/// # use sqlparser::dialect::GenericDialect;
613/// # use sqlparser::ast::{visit_statements};
614/// # use core::ops::ControlFlow;
615/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
616/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
617/// .unwrap();
618///
619/// // visit all statements
620/// let mut visited = vec![];
621/// visit_statements(&statements, |stmt| {
622/// visited.push(format!("STATEMENT: {}", stmt));
623/// ControlFlow::<()>::Continue(())
624/// });
625///
626/// let expected : Vec<_> = [
627/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
628/// "STATEMENT: CREATE TABLE baz (q INT)"
629/// ]
630/// .into_iter().map(|s| s.to_string()).collect();
631///
632/// assert_eq!(visited, expected);
633/// ```
634pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
635where
636 V: Visit,
637 F: FnMut(&Statement) -> ControlFlow<E>,
638{
639 let mut visitor = StatementVisitor(f);
640 v.visit(&mut visitor)?;
641 ControlFlow::Continue(())
642}
643
644/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
645///
646/// # Example
647/// ```
648/// # use sqlparser::parser::Parser;
649/// # use sqlparser::dialect::GenericDialect;
650/// # use sqlparser::ast::{Statement, visit_statements_mut};
651/// # use core::ops::ControlFlow;
652/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
653/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
654///
655/// // Remove all select limits in outer statements (not in sub-queries)
656/// visit_statements_mut(&mut statements, |stmt| {
657/// if let Statement::Query(q) = stmt {
658/// q.limit_clause = None;
659/// }
660/// ControlFlow::<()>::Continue(())
661/// });
662///
663/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
664/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
665/// ```
666pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
667where
668 V: VisitMut,
669 F: FnMut(&mut Statement) -> ControlFlow<E>,
670{
671 v.visit(&mut StatementVisitor(f))?;
672 ControlFlow::Continue(())
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::ast::Statement;
679 use crate::dialect::GenericDialect;
680 use crate::parser::Parser;
681 use crate::tokenizer::Tokenizer;
682
683 #[derive(Default)]
684 struct TestVisitor {
685 visited: Vec<String>,
686 }
687
688 impl Visitor for TestVisitor {
689 type Break = ();
690
691 /// Invoked for any queries that appear in the AST before visiting children
692 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
693 self.visited.push(format!("PRE: QUERY: {query}"));
694 ControlFlow::Continue(())
695 }
696
697 /// Invoked for any queries that appear in the AST after visiting children
698 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
699 self.visited.push(format!("POST: QUERY: {query}"));
700 ControlFlow::Continue(())
701 }
702
703 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
704 self.visited.push(format!("PRE: RELATION: {relation}"));
705 ControlFlow::Continue(())
706 }
707
708 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
709 self.visited.push(format!("POST: RELATION: {relation}"));
710 ControlFlow::Continue(())
711 }
712
713 fn pre_visit_table_factor(
714 &mut self,
715 table_factor: &TableFactor,
716 ) -> ControlFlow<Self::Break> {
717 self.visited
718 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
719 ControlFlow::Continue(())
720 }
721
722 fn post_visit_table_factor(
723 &mut self,
724 table_factor: &TableFactor,
725 ) -> ControlFlow<Self::Break> {
726 self.visited
727 .push(format!("POST: TABLE FACTOR: {table_factor}"));
728 ControlFlow::Continue(())
729 }
730
731 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
732 self.visited.push(format!("PRE: EXPR: {expr}"));
733 ControlFlow::Continue(())
734 }
735
736 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
737 self.visited.push(format!("POST: EXPR: {expr}"));
738 ControlFlow::Continue(())
739 }
740
741 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
742 self.visited.push(format!("PRE: STATEMENT: {statement}"));
743 ControlFlow::Continue(())
744 }
745
746 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
747 self.visited.push(format!("POST: STATEMENT: {statement}"));
748 ControlFlow::Continue(())
749 }
750 }
751
752 fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
753 let dialect = GenericDialect {};
754 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
755 let s = Parser::new(&dialect)
756 .with_tokens(tokens)
757 .parse_statement()
758 .unwrap();
759
760 let flow = s.visit(visitor);
761 assert_eq!(flow, ControlFlow::Continue(()));
762 s
763 }
764
765 #[test]
766 fn test_sql() {
767 let tests = vec![
768 (
769 "SELECT * from table_name as my_table",
770 vec![
771 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
772 "PRE: QUERY: SELECT * FROM table_name AS my_table",
773 "PRE: TABLE FACTOR: table_name AS my_table",
774 "PRE: RELATION: table_name",
775 "POST: RELATION: table_name",
776 "POST: TABLE FACTOR: table_name AS my_table",
777 "POST: QUERY: SELECT * FROM table_name AS my_table",
778 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
779 ],
780 ),
781 (
782 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
783 vec![
784 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
785 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
786 "PRE: TABLE FACTOR: t1",
787 "PRE: RELATION: t1",
788 "POST: RELATION: t1",
789 "POST: TABLE FACTOR: t1",
790 "PRE: TABLE FACTOR: t2",
791 "PRE: RELATION: t2",
792 "POST: RELATION: t2",
793 "POST: TABLE FACTOR: t2",
794 "PRE: EXPR: t1.id = t2.t1_id",
795 "PRE: EXPR: t1.id",
796 "POST: EXPR: t1.id",
797 "PRE: EXPR: t2.t1_id",
798 "POST: EXPR: t2.t1_id",
799 "POST: EXPR: t1.id = t2.t1_id",
800 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
801 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
802 ],
803 ),
804 (
805 "SELECT * from t1 where EXISTS(SELECT column from t2)",
806 vec![
807 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
808 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
809 "PRE: TABLE FACTOR: t1",
810 "PRE: RELATION: t1",
811 "POST: RELATION: t1",
812 "POST: TABLE FACTOR: t1",
813 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
814 "PRE: QUERY: SELECT column FROM t2",
815 "PRE: EXPR: column",
816 "POST: EXPR: column",
817 "PRE: TABLE FACTOR: t2",
818 "PRE: RELATION: t2",
819 "POST: RELATION: t2",
820 "POST: TABLE FACTOR: t2",
821 "POST: QUERY: SELECT column FROM t2",
822 "POST: EXPR: EXISTS (SELECT column FROM t2)",
823 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
824 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
825 ],
826 ),
827 (
828 "SELECT * from t1 where EXISTS(SELECT column from t2)",
829 vec![
830 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
831 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
832 "PRE: TABLE FACTOR: t1",
833 "PRE: RELATION: t1",
834 "POST: RELATION: t1",
835 "POST: TABLE FACTOR: t1",
836 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
837 "PRE: QUERY: SELECT column FROM t2",
838 "PRE: EXPR: column",
839 "POST: EXPR: column",
840 "PRE: TABLE FACTOR: t2",
841 "PRE: RELATION: t2",
842 "POST: RELATION: t2",
843 "POST: TABLE FACTOR: t2",
844 "POST: QUERY: SELECT column FROM t2",
845 "POST: EXPR: EXISTS (SELECT column FROM t2)",
846 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
847 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
848 ],
849 ),
850 (
851 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
852 vec![
853 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
854 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
855 "PRE: TABLE FACTOR: t1",
856 "PRE: RELATION: t1",
857 "POST: RELATION: t1",
858 "POST: TABLE FACTOR: t1",
859 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
860 "PRE: QUERY: SELECT column FROM t2",
861 "PRE: EXPR: column",
862 "POST: EXPR: column",
863 "PRE: TABLE FACTOR: t2",
864 "PRE: RELATION: t2",
865 "POST: RELATION: t2",
866 "POST: TABLE FACTOR: t2",
867 "POST: QUERY: SELECT column FROM t2",
868 "POST: EXPR: EXISTS (SELECT column FROM t2)",
869 "PRE: TABLE FACTOR: t3",
870 "PRE: RELATION: t3",
871 "POST: RELATION: t3",
872 "POST: TABLE FACTOR: t3",
873 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
874 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
875 ],
876 ),
877 (
878 concat!(
879 "SELECT * FROM monthly_sales ",
880 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
881 "ORDER BY EMPID"
882 ),
883 vec![
884 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
885 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
886 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
887 "PRE: TABLE FACTOR: monthly_sales",
888 "PRE: RELATION: monthly_sales",
889 "POST: RELATION: monthly_sales",
890 "POST: TABLE FACTOR: monthly_sales",
891 "PRE: EXPR: SUM(a.amount)",
892 "PRE: EXPR: a.amount",
893 "POST: EXPR: a.amount",
894 "POST: EXPR: SUM(a.amount)",
895 "PRE: EXPR: a.MONTH",
896 "POST: EXPR: a.MONTH",
897 "PRE: EXPR: 'JAN'",
898 "POST: EXPR: 'JAN'",
899 "PRE: EXPR: 'FEB'",
900 "POST: EXPR: 'FEB'",
901 "PRE: EXPR: 'MAR'",
902 "POST: EXPR: 'MAR'",
903 "PRE: EXPR: 'APR'",
904 "POST: EXPR: 'APR'",
905 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
906 "PRE: EXPR: EMPID",
907 "POST: EXPR: EMPID",
908 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
909 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
910 ]
911 ),
912 (
913 "SHOW COLUMNS FROM t1",
914 vec![
915 "PRE: STATEMENT: SHOW COLUMNS FROM t1",
916 "PRE: RELATION: t1",
917 "POST: RELATION: t1",
918 "POST: STATEMENT: SHOW COLUMNS FROM t1",
919 ],
920 ),
921 ];
922 for (sql, expected) in tests {
923 let mut visitor = TestVisitor::default();
924 let _ = do_visit(sql, &mut visitor);
925 let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
926 assert_eq!(actual, expected)
927 }
928 }
929
930 struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
931
932 impl Visitor for QuickVisitor {
933 type Break = ();
934 }
935
936 #[test]
937 fn overflow() {
938 let cond = (0..1000)
939 .map(|n| format!("X = {n}"))
940 .collect::<Vec<_>>()
941 .join(" OR ");
942 let sql = format!("SELECT x where {cond}");
943
944 let dialect = GenericDialect {};
945 let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
946 let s = Parser::new(&dialect)
947 .with_tokens(tokens)
948 .parse_statement()
949 .unwrap();
950
951 let mut visitor = QuickVisitor {};
952 let flow = s.visit(&mut visitor);
953 assert_eq!(flow, ControlFlow::Continue(()));
954 }
955}
956
957#[cfg(test)]
958mod visit_mut_tests {
959 use crate::ast::{Statement, Value, VisitMut, VisitorMut};
960 use crate::dialect::GenericDialect;
961 use crate::parser::Parser;
962 use crate::tokenizer::Tokenizer;
963 use core::ops::ControlFlow;
964
965 #[derive(Default)]
966 struct MutatorVisitor {
967 index: u64,
968 }
969
970 impl VisitorMut for MutatorVisitor {
971 type Break = ();
972
973 fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
974 self.index += 1;
975 *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
976 ControlFlow::Continue(())
977 }
978
979 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
980 ControlFlow::Continue(())
981 }
982 }
983
984 fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
985 let dialect = GenericDialect {};
986 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
987 let mut s = Parser::new(&dialect)
988 .with_tokens(tokens)
989 .parse_statement()
990 .unwrap();
991
992 let flow = s.visit(visitor);
993 assert_eq!(flow, ControlFlow::Continue(()));
994 s
995 }
996
997 #[test]
998 fn test_value_redact() {
999 let tests = vec![
1000 (
1001 concat!(
1002 "SELECT * FROM monthly_sales ",
1003 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1004 "ORDER BY EMPID"
1005 ),
1006 concat!(
1007 "SELECT * FROM monthly_sales ",
1008 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1009 "ORDER BY EMPID"
1010 ),
1011 ),
1012 ];
1013
1014 for (sql, expected) in tests {
1015 let mut visitor = MutatorVisitor::default();
1016 let mutated = do_visit_mut(sql, &mut visitor);
1017 assert_eq!(mutated.to_string(), expected)
1018 }
1019 }
1020}