sqlparser/ast/visitor.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35 /// Visit this node with the provided [`Visitor`].
36 ///
37 /// Implementations should call the appropriate visitor hooks to traverse
38 /// child nodes and return a `ControlFlow` value to allow early exit.
39 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
40}
41
42/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
43/// recursively visiting parsed SQL statements.
44///
45/// # Note
46///
47/// This trait should be automatically derived for sqlparser AST nodes
48/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
49///
50/// ```text
51/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
52/// ```
53pub trait VisitMut {
54 /// Mutably visit this node with the provided [`VisitorMut`].
55 ///
56 /// Implementations should call the appropriate mutable visitor hooks to
57 /// traverse and allow in-place mutation of child nodes. Returning a
58 /// `ControlFlow` value permits early termination of the traversal.
59 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
60}
61
62impl<T: Visit> Visit for Option<T> {
63 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64 if let Some(s) = self {
65 s.visit(visitor)?;
66 }
67 ControlFlow::Continue(())
68 }
69}
70
71impl<T: Visit> Visit for Vec<T> {
72 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73 for v in self {
74 v.visit(visitor)?;
75 }
76 ControlFlow::Continue(())
77 }
78}
79
80impl<T: Visit> Visit for Box<T> {
81 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
82 T::visit(self, visitor)
83 }
84}
85
86impl<T: VisitMut> VisitMut for Option<T> {
87 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88 if let Some(s) = self {
89 s.visit(visitor)?;
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95impl<T: VisitMut> VisitMut for Vec<T> {
96 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97 for v in self {
98 v.visit(visitor)?;
99 }
100 ControlFlow::Continue(())
101 }
102}
103
104impl<T: VisitMut> VisitMut for Box<T> {
105 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
106 T::visit(self, visitor)
107 }
108}
109
110macro_rules! visit_noop {
111 ($($t:ty),+) => {
112 $(impl Visit for $t {
113 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
114 ControlFlow::Continue(())
115 }
116 })+
117 $(impl VisitMut for $t {
118 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
119 ControlFlow::Continue(())
120 }
121 })+
122 };
123}
124
125visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
126
127#[cfg(feature = "bigdecimal")]
128visit_noop!(bigdecimal::BigDecimal);
129
130/// A visitor that can be used to walk an AST tree.
131///
132/// `pre_visit_` methods are invoked before visiting all children of the
133/// node and `post_visit_` methods are invoked after visiting all
134/// children of the node.
135///
136/// # See also
137///
138/// These methods provide a more concise way of visiting nodes of a certain type:
139/// * [visit_relations]
140/// * [visit_expressions]
141/// * [visit_statements]
142///
143/// # Example
144/// ```
145/// # use sqlparser::parser::Parser;
146/// # use sqlparser::dialect::GenericDialect;
147/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
148/// # use core::ops::ControlFlow;
149/// // A structure that records statements and relations
150/// #[derive(Default)]
151/// struct V {
152/// visited: Vec<String>,
153/// }
154///
155/// // Visit relations and exprs before children are visited (depth first walk)
156/// // Note you can also visit statements and visit exprs after children have been visited
157/// impl Visitor for V {
158/// type Break = ();
159///
160/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
161/// self.visited.push(format!("PRE: RELATION: {}", relation));
162/// ControlFlow::Continue(())
163/// }
164///
165/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
166/// self.visited.push(format!("PRE: EXPR: {}", expr));
167/// ControlFlow::Continue(())
168/// }
169/// }
170///
171/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
172/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
173/// .unwrap();
174///
175/// // Drive the visitor through the AST
176/// let mut visitor = V::default();
177/// statements.visit(&mut visitor);
178///
179/// // The visitor has visited statements and expressions in pre-traversal order
180/// let expected : Vec<_> = [
181/// "PRE: EXPR: a",
182/// "PRE: RELATION: foo",
183/// "PRE: EXPR: x IN (SELECT y FROM bar)",
184/// "PRE: EXPR: x",
185/// "PRE: EXPR: y",
186/// "PRE: RELATION: bar",
187/// ]
188/// .into_iter().map(|s| s.to_string()).collect();
189///
190/// assert_eq!(visitor.visited, expected);
191/// ```
192pub trait Visitor {
193 /// Type returned when the recursion returns early.
194 ///
195 /// Important note: The `Break` type should be kept as small as possible to prevent
196 /// stack overflow during recursion. If you need to return an error, consider
197 /// boxing it with `Box` to minimize stack usage.
198 type Break;
199
200 /// Invoked for any queries that appear in the AST before visiting children
201 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
202 ControlFlow::Continue(())
203 }
204
205 /// Invoked for any queries that appear in the AST after visiting children
206 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
207 ControlFlow::Continue(())
208 }
209
210 /// Invoked for any [Select] that appear in the AST before visiting children
211 fn pre_visit_select(&mut self, _select: &Select) -> ControlFlow<Self::Break> {
212 ControlFlow::Continue(())
213 }
214
215 /// Invoked for any [Select] that appear in the AST after visiting children
216 fn post_visit_select(&mut self, _select: &Select) -> ControlFlow<Self::Break> {
217 ControlFlow::Continue(())
218 }
219
220 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
221 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
222 ControlFlow::Continue(())
223 }
224
225 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
226 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
227 ControlFlow::Continue(())
228 }
229
230 /// Invoked for any table factors that appear in the AST before visiting children
231 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
232 ControlFlow::Continue(())
233 }
234
235 /// Invoked for any table factors that appear in the AST after visiting children
236 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
237 ControlFlow::Continue(())
238 }
239
240 /// Invoked for any expressions that appear in the AST before visiting children
241 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
242 ControlFlow::Continue(())
243 }
244
245 /// Invoked for any expressions that appear in the AST
246 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
247 ControlFlow::Continue(())
248 }
249
250 /// Invoked for any statements that appear in the AST before visiting children
251 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
252 ControlFlow::Continue(())
253 }
254
255 /// Invoked for any statements that appear in the AST after visiting children
256 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
257 ControlFlow::Continue(())
258 }
259
260 /// Invoked for any Value that appear in the AST before visiting children
261 fn pre_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow<Self::Break> {
262 ControlFlow::Continue(())
263 }
264
265 /// Invoked for any Value that appear in the AST after visiting children
266 fn post_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow<Self::Break> {
267 ControlFlow::Continue(())
268 }
269}
270
271/// A visitor that can be used to mutate an AST tree.
272///
273/// `pre_visit_` methods are invoked before visiting all children of the
274/// node and `post_visit_` methods are invoked after visiting all
275/// children of the node.
276///
277/// # See also
278///
279/// These methods provide a more concise way of visiting nodes of a certain type:
280/// * [visit_relations_mut]
281/// * [visit_expressions_mut]
282/// * [visit_statements_mut]
283///
284/// # Example
285/// ```
286/// # use sqlparser::parser::Parser;
287/// # use sqlparser::dialect::GenericDialect;
288/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
289/// # use core::ops::ControlFlow;
290///
291/// // A visitor that replaces "to_replace" with "replaced" in all expressions
292/// struct Replacer;
293///
294/// // Visit each expression after its children have been visited
295/// impl VisitorMut for Replacer {
296/// type Break = ();
297///
298/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
299/// if let Expr::Identifier(Ident{ value, ..}) = expr {
300/// *value = value.replace("to_replace", "replaced")
301/// }
302/// ControlFlow::Continue(())
303/// }
304/// }
305///
306/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
307/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
308///
309/// // Drive the visitor through the AST
310/// statements.visit(&mut Replacer);
311///
312/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
313/// ```
314pub trait VisitorMut {
315 /// Type returned when the recursion returns early.
316 ///
317 /// Important note: The `Break` type should be kept as small as possible to prevent
318 /// stack overflow during recursion. If you need to return an error, consider
319 /// boxing it with `Box` to minimize stack usage.
320 type Break;
321
322 /// Invoked for any queries that appear in the AST before visiting children
323 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
324 ControlFlow::Continue(())
325 }
326
327 /// Invoked for any queries that appear in the AST after visiting children
328 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
329 ControlFlow::Continue(())
330 }
331
332 /// Invoked for any [Select] that appear in the AST before visiting children
333 fn pre_visit_select(&mut self, _select: &mut Select) -> ControlFlow<Self::Break> {
334 ControlFlow::Continue(())
335 }
336
337 /// Invoked for any [Select] that appear in the AST after visiting children
338 fn post_visit_select(&mut self, _select: &mut Select) -> ControlFlow<Self::Break> {
339 ControlFlow::Continue(())
340 }
341
342 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
343 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
344 ControlFlow::Continue(())
345 }
346
347 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
348 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
349 ControlFlow::Continue(())
350 }
351
352 /// Invoked for any table factors that appear in the AST before visiting children
353 fn pre_visit_table_factor(
354 &mut self,
355 _table_factor: &mut TableFactor,
356 ) -> ControlFlow<Self::Break> {
357 ControlFlow::Continue(())
358 }
359
360 /// Invoked for any table factors that appear in the AST after visiting children
361 fn post_visit_table_factor(
362 &mut self,
363 _table_factor: &mut TableFactor,
364 ) -> ControlFlow<Self::Break> {
365 ControlFlow::Continue(())
366 }
367
368 /// Invoked for any expressions that appear in the AST before visiting children
369 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
370 ControlFlow::Continue(())
371 }
372
373 /// Invoked for any expressions that appear in the AST
374 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
375 ControlFlow::Continue(())
376 }
377
378 /// Invoked for any statements that appear in the AST before visiting children
379 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
380 ControlFlow::Continue(())
381 }
382
383 /// Invoked for any statements that appear in the AST after visiting children
384 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
385 ControlFlow::Continue(())
386 }
387
388 /// Invoked for any value that appear in the AST before visiting children
389 fn pre_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
390 ControlFlow::Continue(())
391 }
392
393 /// Invoked for any statements that appear in the AST after visiting children
394 fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
395 ControlFlow::Continue(())
396 }
397}
398
399struct RelationVisitor<F>(F);
400
401impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
402 type Break = E;
403
404 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
405 self.0(relation)
406 }
407}
408
409impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
410 type Break = E;
411
412 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
413 self.0(relation)
414 }
415}
416
417/// Invokes the provided closure on all relations (e.g. table names) present in `v`
418///
419/// # Example
420/// ```
421/// # use sqlparser::parser::Parser;
422/// # use sqlparser::dialect::GenericDialect;
423/// # use sqlparser::ast::{visit_relations};
424/// # use core::ops::ControlFlow;
425/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
426/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
427/// .unwrap();
428///
429/// // visit statements, capturing relations (table names)
430/// let mut visited = vec![];
431/// visit_relations(&statements, |relation| {
432/// visited.push(format!("RELATION: {}", relation));
433/// ControlFlow::<()>::Continue(())
434/// });
435///
436/// let expected : Vec<_> = [
437/// "RELATION: foo",
438/// "RELATION: bar",
439/// ]
440/// .into_iter().map(|s| s.to_string()).collect();
441///
442/// assert_eq!(visited, expected);
443/// ```
444pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
445where
446 V: Visit,
447 F: FnMut(&ObjectName) -> ControlFlow<E>,
448{
449 let mut visitor = RelationVisitor(f);
450 v.visit(&mut visitor)?;
451 ControlFlow::Continue(())
452}
453
454/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
455/// present in `v`.
456///
457/// When the closure mutates its argument, the new mutated relation will not be visited again.
458///
459/// # Example
460/// ```
461/// # use sqlparser::parser::Parser;
462/// # use sqlparser::dialect::GenericDialect;
463/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
464/// # use core::ops::ControlFlow;
465/// let sql = "SELECT a FROM foo";
466/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
467/// .unwrap();
468///
469/// // visit statements, renaming table foo to bar
470/// visit_relations_mut(&mut statements, |table| {
471/// table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
472/// ControlFlow::<()>::Continue(())
473/// });
474///
475/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
476/// ```
477pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
478where
479 V: VisitMut,
480 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
481{
482 let mut visitor = RelationVisitor(f);
483 v.visit(&mut visitor)?;
484 ControlFlow::Continue(())
485}
486
487struct ExprVisitor<F>(F);
488
489impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
490 type Break = E;
491
492 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
493 self.0(expr)
494 }
495}
496
497impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
498 type Break = E;
499
500 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
501 self.0(expr)
502 }
503}
504
505/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
506///
507/// # Example
508/// ```
509/// # use sqlparser::parser::Parser;
510/// # use sqlparser::dialect::GenericDialect;
511/// # use sqlparser::ast::{visit_expressions};
512/// # use core::ops::ControlFlow;
513/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
514/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
515/// .unwrap();
516///
517/// // visit all expressions
518/// let mut visited = vec![];
519/// visit_expressions(&statements, |expr| {
520/// visited.push(format!("EXPR: {}", expr));
521/// ControlFlow::<()>::Continue(())
522/// });
523///
524/// let expected : Vec<_> = [
525/// "EXPR: a",
526/// "EXPR: x IN (SELECT y FROM bar)",
527/// "EXPR: x",
528/// "EXPR: y",
529/// ]
530/// .into_iter().map(|s| s.to_string()).collect();
531///
532/// assert_eq!(visited, expected);
533/// ```
534pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
535where
536 V: Visit,
537 F: FnMut(&Expr) -> ControlFlow<E>,
538{
539 let mut visitor = ExprVisitor(f);
540 v.visit(&mut visitor)?;
541 ControlFlow::Continue(())
542}
543
544/// Invokes the provided closure iteratively with a mutable reference to all expressions
545/// present in `v`.
546///
547/// This performs a depth-first search, so if the closure mutates the expression
548///
549/// # Example
550///
551/// ## Remove all select limits in sub-queries
552/// ```
553/// # use sqlparser::parser::Parser;
554/// # use sqlparser::dialect::GenericDialect;
555/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
556/// # use core::ops::ControlFlow;
557/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
558/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
559///
560/// // Remove all select limits in sub-queries
561/// visit_expressions_mut(&mut statements, |expr| {
562/// if let Expr::Subquery(q) = expr {
563/// q.limit_clause = None;
564/// }
565/// ControlFlow::<()>::Continue(())
566/// });
567///
568/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
569/// ```
570///
571/// ## Wrap column name in function call
572///
573/// This demonstrates how to effectively replace an expression with another more complicated one
574/// that references the original. This example avoids unnecessary allocations by using the
575/// [`std::mem`] family of functions.
576///
577/// ```
578/// # use sqlparser::parser::Parser;
579/// # use sqlparser::dialect::GenericDialect;
580/// # use sqlparser::ast::*;
581/// # use core::ops::ControlFlow;
582/// let sql = "SELECT x, y FROM t";
583/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
584///
585/// visit_expressions_mut(&mut statements, |expr| {
586/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
587/// let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
588/// *expr = Expr::Function(Function {
589/// name: ObjectName::from(vec![Ident::new("f")]),
590/// uses_odbc_syntax: false,
591/// args: FunctionArguments::List(FunctionArgumentList {
592/// duplicate_treatment: None,
593/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
594/// clauses: vec![],
595/// }),
596/// null_treatment: None,
597/// filter: None,
598/// over: None,
599/// parameters: FunctionArguments::None,
600/// within_group: vec![],
601/// });
602/// }
603/// ControlFlow::<()>::Continue(())
604/// });
605///
606/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
607/// ```
608pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
609where
610 V: VisitMut,
611 F: FnMut(&mut Expr) -> ControlFlow<E>,
612{
613 v.visit(&mut ExprVisitor(f))?;
614 ControlFlow::Continue(())
615}
616
617struct StatementVisitor<F>(F);
618
619impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
620 type Break = E;
621
622 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
623 self.0(statement)
624 }
625}
626
627impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
628 type Break = E;
629
630 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
631 self.0(statement)
632 }
633}
634
635/// Invokes the provided closure iteratively with a mutable reference to all statements
636/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
637///
638/// # Example
639/// ```
640/// # use sqlparser::parser::Parser;
641/// # use sqlparser::dialect::GenericDialect;
642/// # use sqlparser::ast::{visit_statements};
643/// # use core::ops::ControlFlow;
644/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
645/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
646/// .unwrap();
647///
648/// // visit all statements
649/// let mut visited = vec![];
650/// visit_statements(&statements, |stmt| {
651/// visited.push(format!("STATEMENT: {}", stmt));
652/// ControlFlow::<()>::Continue(())
653/// });
654///
655/// let expected : Vec<_> = [
656/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
657/// "STATEMENT: CREATE TABLE baz (q INT)"
658/// ]
659/// .into_iter().map(|s| s.to_string()).collect();
660///
661/// assert_eq!(visited, expected);
662/// ```
663pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
664where
665 V: Visit,
666 F: FnMut(&Statement) -> ControlFlow<E>,
667{
668 let mut visitor = StatementVisitor(f);
669 v.visit(&mut visitor)?;
670 ControlFlow::Continue(())
671}
672
673/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
674///
675/// # Example
676/// ```
677/// # use sqlparser::parser::Parser;
678/// # use sqlparser::dialect::GenericDialect;
679/// # use sqlparser::ast::{Statement, visit_statements_mut};
680/// # use core::ops::ControlFlow;
681/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
682/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
683///
684/// // Remove all select limits in outer statements (not in sub-queries)
685/// visit_statements_mut(&mut statements, |stmt| {
686/// if let Statement::Query(q) = stmt {
687/// q.limit_clause = None;
688/// }
689/// ControlFlow::<()>::Continue(())
690/// });
691///
692/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
693/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
694/// ```
695pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
696where
697 V: VisitMut,
698 F: FnMut(&mut Statement) -> ControlFlow<E>,
699{
700 v.visit(&mut StatementVisitor(f))?;
701 ControlFlow::Continue(())
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use crate::ast::Statement;
708 use crate::dialect::GenericDialect;
709 use crate::parser::Parser;
710 use crate::tokenizer::Tokenizer;
711
712 #[derive(Default)]
713 struct TestVisitor {
714 visited: Vec<String>,
715 }
716
717 impl Visitor for TestVisitor {
718 type Break = ();
719
720 /// Invoked for any queries that appear in the AST before visiting children
721 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
722 self.visited.push(format!("PRE: QUERY: {query}"));
723 ControlFlow::Continue(())
724 }
725
726 /// Invoked for any queries that appear in the AST after visiting children
727 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
728 self.visited.push(format!("POST: QUERY: {query}"));
729 ControlFlow::Continue(())
730 }
731
732 fn pre_visit_select(&mut self, select: &Select) -> ControlFlow<Self::Break> {
733 self.visited.push(format!("PRE: SELECT: {select}"));
734 ControlFlow::Continue(())
735 }
736
737 fn post_visit_select(&mut self, select: &Select) -> ControlFlow<Self::Break> {
738 self.visited.push(format!("POST: SELECT: {select}"));
739 ControlFlow::Continue(())
740 }
741
742 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
743 self.visited.push(format!("PRE: RELATION: {relation}"));
744 ControlFlow::Continue(())
745 }
746
747 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
748 self.visited.push(format!("POST: RELATION: {relation}"));
749 ControlFlow::Continue(())
750 }
751
752 fn pre_visit_table_factor(
753 &mut self,
754 table_factor: &TableFactor,
755 ) -> ControlFlow<Self::Break> {
756 self.visited
757 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
758 ControlFlow::Continue(())
759 }
760
761 fn post_visit_table_factor(
762 &mut self,
763 table_factor: &TableFactor,
764 ) -> ControlFlow<Self::Break> {
765 self.visited
766 .push(format!("POST: TABLE FACTOR: {table_factor}"));
767 ControlFlow::Continue(())
768 }
769
770 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
771 self.visited.push(format!("PRE: EXPR: {expr}"));
772 ControlFlow::Continue(())
773 }
774
775 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
776 self.visited.push(format!("POST: EXPR: {expr}"));
777 ControlFlow::Continue(())
778 }
779
780 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
781 self.visited.push(format!("PRE: STATEMENT: {statement}"));
782 ControlFlow::Continue(())
783 }
784
785 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
786 self.visited.push(format!("POST: STATEMENT: {statement}"));
787 ControlFlow::Continue(())
788 }
789 }
790
791 fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
792 let dialect = GenericDialect {};
793 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
794 let s = Parser::new(&dialect)
795 .with_tokens(tokens)
796 .parse_statement()
797 .unwrap();
798
799 let flow = s.visit(visitor);
800 assert_eq!(flow, ControlFlow::Continue(()));
801 s
802 }
803
804 #[test]
805 fn test_sql() {
806 let tests = vec![
807 (
808 "SELECT * from table_name as my_table",
809 vec![
810 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
811 "PRE: QUERY: SELECT * FROM table_name AS my_table",
812 "PRE: SELECT: SELECT * FROM table_name AS my_table",
813 "PRE: TABLE FACTOR: table_name AS my_table",
814 "PRE: RELATION: table_name",
815 "POST: RELATION: table_name",
816 "POST: TABLE FACTOR: table_name AS my_table",
817 "POST: SELECT: SELECT * FROM table_name AS my_table",
818 "POST: QUERY: SELECT * FROM table_name AS my_table",
819 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
820 ],
821 ),
822 (
823 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
824 vec![
825 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
826 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
827 "PRE: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
828 "PRE: TABLE FACTOR: t1",
829 "PRE: RELATION: t1",
830 "POST: RELATION: t1",
831 "POST: TABLE FACTOR: t1",
832 "PRE: TABLE FACTOR: t2",
833 "PRE: RELATION: t2",
834 "POST: RELATION: t2",
835 "POST: TABLE FACTOR: t2",
836 "PRE: EXPR: t1.id = t2.t1_id",
837 "PRE: EXPR: t1.id",
838 "POST: EXPR: t1.id",
839 "PRE: EXPR: t2.t1_id",
840 "POST: EXPR: t2.t1_id",
841 "POST: EXPR: t1.id = t2.t1_id",
842 "POST: SELECT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
843 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
844 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
845 ],
846 ),
847 (
848 "SELECT * from t1 where EXISTS(SELECT column from t2)",
849 vec![
850 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
851 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
852 "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
853 "PRE: TABLE FACTOR: t1",
854 "PRE: RELATION: t1",
855 "POST: RELATION: t1",
856 "POST: TABLE FACTOR: t1",
857 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
858 "PRE: QUERY: SELECT column FROM t2",
859 "PRE: SELECT: SELECT column FROM t2",
860 "PRE: EXPR: column",
861 "POST: EXPR: column",
862 "PRE: TABLE FACTOR: t2",
863 "PRE: RELATION: t2",
864 "POST: RELATION: t2",
865 "POST: TABLE FACTOR: t2",
866 "POST: SELECT: SELECT column FROM t2",
867 "POST: QUERY: SELECT column FROM t2",
868 "POST: EXPR: EXISTS (SELECT column FROM t2)",
869 "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
870 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
871 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
872 ],
873 ),
874 (
875 "SELECT * from t1 where EXISTS(SELECT column from t2)",
876 vec![
877 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
878 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
879 "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
880 "PRE: TABLE FACTOR: t1",
881 "PRE: RELATION: t1",
882 "POST: RELATION: t1",
883 "POST: TABLE FACTOR: t1",
884 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
885 "PRE: QUERY: SELECT column FROM t2",
886 "PRE: SELECT: SELECT column FROM t2",
887 "PRE: EXPR: column",
888 "POST: EXPR: column",
889 "PRE: TABLE FACTOR: t2",
890 "PRE: RELATION: t2",
891 "POST: RELATION: t2",
892 "POST: TABLE FACTOR: t2",
893 "POST: SELECT: SELECT column FROM t2",
894 "POST: QUERY: SELECT column FROM t2",
895 "POST: EXPR: EXISTS (SELECT column FROM t2)",
896 "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
897 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
898 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
899 ],
900 ),
901 (
902 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
903 vec![
904 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
905 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
906 "PRE: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
907 "PRE: TABLE FACTOR: t1",
908 "PRE: RELATION: t1",
909 "POST: RELATION: t1",
910 "POST: TABLE FACTOR: t1",
911 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
912 "PRE: QUERY: SELECT column FROM t2",
913 "PRE: SELECT: SELECT column FROM t2",
914 "PRE: EXPR: column",
915 "POST: EXPR: column",
916 "PRE: TABLE FACTOR: t2",
917 "PRE: RELATION: t2",
918 "POST: RELATION: t2",
919 "POST: TABLE FACTOR: t2",
920 "POST: SELECT: SELECT column FROM t2",
921 "POST: QUERY: SELECT column FROM t2",
922 "POST: EXPR: EXISTS (SELECT column FROM t2)",
923 "POST: SELECT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
924 "PRE: SELECT: SELECT * FROM t3",
925 "PRE: TABLE FACTOR: t3",
926 "PRE: RELATION: t3",
927 "POST: RELATION: t3",
928 "POST: TABLE FACTOR: t3",
929 "POST: SELECT: SELECT * FROM t3",
930 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
931 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
932 ],
933 ),
934 (
935 concat!(
936 "SELECT * FROM monthly_sales ",
937 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
938 "ORDER BY EMPID"
939 ),
940 vec![
941 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
942 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
943 "PRE: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
944 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
945 "PRE: TABLE FACTOR: monthly_sales",
946 "PRE: RELATION: monthly_sales",
947 "POST: RELATION: monthly_sales",
948 "POST: TABLE FACTOR: monthly_sales",
949 "PRE: EXPR: SUM(a.amount)",
950 "PRE: EXPR: a.amount",
951 "POST: EXPR: a.amount",
952 "POST: EXPR: SUM(a.amount)",
953 "PRE: EXPR: a.MONTH",
954 "POST: EXPR: a.MONTH",
955 "PRE: EXPR: 'JAN'",
956 "POST: EXPR: 'JAN'",
957 "PRE: EXPR: 'FEB'",
958 "POST: EXPR: 'FEB'",
959 "PRE: EXPR: 'MAR'",
960 "POST: EXPR: 'MAR'",
961 "PRE: EXPR: 'APR'",
962 "POST: EXPR: 'APR'",
963 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
964 "POST: SELECT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
965 "PRE: EXPR: EMPID",
966 "POST: EXPR: EMPID",
967 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
968 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
969 ]
970 ),
971 (
972 "SHOW COLUMNS FROM t1",
973 vec![
974 "PRE: STATEMENT: SHOW COLUMNS FROM t1",
975 "PRE: RELATION: t1",
976 "POST: RELATION: t1",
977 "POST: STATEMENT: SHOW COLUMNS FROM t1",
978 ],
979 ),
980 ];
981 for (sql, expected) in tests {
982 let mut visitor = TestVisitor::default();
983 let _ = do_visit(sql, &mut visitor);
984 let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
985 assert_eq!(actual, expected)
986 }
987 }
988
989 struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
990
991 impl Visitor for QuickVisitor {
992 type Break = ();
993 }
994
995 #[test]
996 fn overflow() {
997 let cond = (0..1000)
998 .map(|n| format!("X = {n}"))
999 .collect::<Vec<_>>()
1000 .join(" OR ");
1001 let sql = format!("SELECT x where {cond}");
1002
1003 let dialect = GenericDialect {};
1004 let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
1005 let s = Parser::new(&dialect)
1006 .with_tokens(tokens)
1007 .parse_statement()
1008 .unwrap();
1009
1010 let mut visitor = QuickVisitor {};
1011 let flow = s.visit(&mut visitor);
1012 assert_eq!(flow, ControlFlow::Continue(()));
1013 }
1014}
1015
1016#[cfg(test)]
1017mod visit_mut_tests {
1018 use crate::ast::{Statement, Value, ValueWithSpan, VisitMut, VisitorMut};
1019 use crate::dialect::GenericDialect;
1020 use crate::parser::Parser;
1021 use crate::tokenizer::Tokenizer;
1022 use core::ops::ControlFlow;
1023
1024 #[derive(Default)]
1025 struct MutatorVisitor {
1026 index: u64,
1027 }
1028
1029 impl VisitorMut for MutatorVisitor {
1030 type Break = ();
1031
1032 fn pre_visit_value(&mut self, value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
1033 self.index += 1;
1034 value.value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
1035 ControlFlow::Continue(())
1036 }
1037
1038 fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
1039 ControlFlow::Continue(())
1040 }
1041 }
1042
1043 fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
1044 let dialect = GenericDialect {};
1045 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
1046 let mut s = Parser::new(&dialect)
1047 .with_tokens(tokens)
1048 .parse_statement()
1049 .unwrap();
1050
1051 let flow = s.visit(visitor);
1052 assert_eq!(flow, ControlFlow::Continue(()));
1053 s
1054 }
1055
1056 #[test]
1057 fn test_value_redact() {
1058 let tests = vec![
1059 (
1060 concat!(
1061 "SELECT * FROM monthly_sales ",
1062 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1063 "ORDER BY EMPID"
1064 ),
1065 concat!(
1066 "SELECT * FROM monthly_sales ",
1067 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1068 "ORDER BY EMPID"
1069 ),
1070 ),
1071 ];
1072
1073 for (sql, expected) in tests {
1074 let mut visitor = MutatorVisitor::default();
1075 let mutated = do_visit_mut(sql, &mut visitor);
1076 assert_eq!(mutated.to_string(), expected)
1077 }
1078 }
1079}