sqlparser/
test_utils.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13/// This module contains internal utilities used for testing the library.
14/// While technically public, the library's users are not supposed to rely
15/// on this module, as it will change without notice.
16//
17// Integration tests (i.e. everything under `tests/`) import this
18// via `tests/test_utils/helpers`.
19
20#[cfg(not(feature = "std"))]
21use alloc::{
22    boxed::Box,
23    string::{String, ToString},
24    vec,
25    vec::Vec,
26};
27use core::fmt::Debug;
28
29use crate::dialect::*;
30use crate::parser::{Parser, ParserError};
31use crate::tokenizer::Tokenizer;
32use crate::{ast::*, parser::ParserOptions};
33
34#[cfg(test)]
35use pretty_assertions::assert_eq;
36
37/// Tests use the methods on this struct to invoke the parser on one or
38/// multiple dialects.
39pub struct TestedDialects {
40    pub dialects: Vec<Box<dyn Dialect>>,
41    pub options: Option<ParserOptions>,
42}
43
44impl TestedDialects {
45    fn new_parser<'a>(&self, dialect: &'a dyn Dialect) -> Parser<'a> {
46        let parser = Parser::new(dialect);
47        if let Some(options) = &self.options {
48            parser.with_options(options.clone())
49        } else {
50            parser
51        }
52    }
53
54    /// Run the given function for all of `self.dialects`, assert that they
55    /// return the same result, and return that result.
56    pub fn one_of_identical_results<F, T: Debug + PartialEq>(&self, f: F) -> T
57    where
58        F: Fn(&dyn Dialect) -> T,
59    {
60        let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect)));
61        parse_results
62            .fold(None, |s, (dialect, parsed)| {
63                if let Some((prev_dialect, prev_parsed)) = s {
64                    assert_eq!(
65                        prev_parsed, parsed,
66                        "Parse results with {prev_dialect:?} are different from {dialect:?}"
67                    );
68                }
69                Some((dialect, parsed))
70            })
71            .unwrap()
72            .1
73    }
74
75    pub fn run_parser_method<F, T: Debug + PartialEq>(&self, sql: &str, f: F) -> T
76    where
77        F: Fn(&mut Parser) -> T,
78    {
79        self.one_of_identical_results(|dialect| {
80            let mut parser = self.new_parser(dialect).try_with_sql(sql).unwrap();
81            f(&mut parser)
82        })
83    }
84
85    /// Parses a single SQL string into multiple statements, ensuring
86    /// the result is the same for all tested dialects.
87    pub fn parse_sql_statements(&self, sql: &str) -> Result<Vec<Statement>, ParserError> {
88        self.one_of_identical_results(|dialect| {
89            let mut tokenizer = Tokenizer::new(dialect, sql);
90            if let Some(options) = &self.options {
91                tokenizer = tokenizer.with_unescape(options.unescape);
92            }
93            let tokens = tokenizer.tokenize()?;
94            self.new_parser(dialect)
95                .with_tokens(tokens)
96                .parse_statements()
97        })
98        // To fail the `ensure_multiple_dialects_are_tested` test:
99        // Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
100    }
101
102    /// Ensures that `sql` parses as a single [Statement] for all tested
103    /// dialects.
104    ///
105    /// If `canonical` is non empty,this function additionally asserts
106    /// that:
107    ///
108    /// 1. parsing `sql` results in the same [`Statement`] as parsing
109    /// `canonical`.
110    ///
111    /// 2. re-serializing the result of parsing `sql` produces the same
112    /// `canonical` sql string
113    pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
114        let mut statements = self.parse_sql_statements(sql).expect(sql);
115        assert_eq!(statements.len(), 1);
116
117        if !canonical.is_empty() && sql != canonical {
118            assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
119        }
120
121        let only_statement = statements.pop().unwrap();
122        if !canonical.is_empty() {
123            assert_eq!(canonical, only_statement.to_string())
124        }
125        only_statement
126    }
127
128    /// Ensures that `sql` parses as an [`Expr`], and that
129    /// re-serializing the parse result produces canonical
130    pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
131        let ast = self
132            .run_parser_method(sql, |parser| parser.parse_expr())
133            .unwrap();
134        assert_eq!(canonical, &ast.to_string());
135        ast
136    }
137
138    /// Ensures that `sql` parses as a single [Statement], and that
139    /// re-serializing the parse result produces the same `sql`
140    /// string (is not modified after a serialization round-trip).
141    pub fn verified_stmt(&self, sql: &str) -> Statement {
142        self.one_statement_parses_to(sql, sql)
143    }
144
145    /// Ensures that `sql` parses as a single [Query], and that
146    /// re-serializing the parse result produces the same `sql`
147    /// string (is not modified after a serialization round-trip).
148    pub fn verified_query(&self, sql: &str) -> Query {
149        match self.verified_stmt(sql) {
150            Statement::Query(query) => *query,
151            _ => panic!("Expected Query"),
152        }
153    }
154
155    /// Ensures that `sql` parses as a single [Select], and that
156    /// re-serializing the parse result produces the same `sql`
157    /// string (is not modified after a serialization round-trip).
158    pub fn verified_only_select(&self, query: &str) -> Select {
159        match *self.verified_query(query).body {
160            SetExpr::Select(s) => *s,
161            _ => panic!("Expected SetExpr::Select"),
162        }
163    }
164
165    /// Ensures that `sql` parses as a single [`Select`], and that additionally:
166    ///
167    /// 1. parsing `sql` results in the same [`Statement`] as parsing
168    /// `canonical`.
169    ///
170    /// 2. re-serializing the result of parsing `sql` produces the same
171    /// `canonical` sql string
172    pub fn verified_only_select_with_canonical(&self, query: &str, canonical: &str) -> Select {
173        let q = match self.one_statement_parses_to(query, canonical) {
174            Statement::Query(query) => *query,
175            _ => panic!("Expected Query"),
176        };
177        match *q.body {
178            SetExpr::Select(s) => *s,
179            _ => panic!("Expected SetExpr::Select"),
180        }
181    }
182
183    /// Ensures that `sql` parses as an [`Expr`], and that
184    /// re-serializing the parse result produces the same `sql`
185    /// string (is not modified after a serialization round-trip).
186    pub fn verified_expr(&self, sql: &str) -> Expr {
187        self.expr_parses_to(sql, sql)
188    }
189}
190
191pub fn all_dialects() -> TestedDialects {
192    TestedDialects {
193        dialects: vec![
194            Box::new(GenericDialect {}),
195            Box::new(PostgreSqlDialect {}),
196            Box::new(MsSqlDialect {}),
197            Box::new(AnsiDialect {}),
198            Box::new(SnowflakeDialect {}),
199            Box::new(HiveDialect {}),
200            Box::new(RedshiftSqlDialect {}),
201            Box::new(MySqlDialect {}),
202            Box::new(BigQueryDialect {}),
203            Box::new(SQLiteDialect {}),
204            Box::new(DuckDbDialect {}),
205        ],
206        options: None,
207    }
208}
209
210pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
211    assert_eq!(
212        expected,
213        actual.iter().map(ToString::to_string).collect::<Vec<_>>()
214    );
215}
216
217pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
218    let mut iter = v.into_iter();
219    if let (Some(item), None) = (iter.next(), iter.next()) {
220        item
221    } else {
222        panic!("only called on collection without exactly one item")
223    }
224}
225
226pub fn expr_from_projection(item: &SelectItem) -> &Expr {
227    match item {
228        SelectItem::UnnamedExpr(expr) => expr,
229        _ => panic!("Expected UnnamedExpr"),
230    }
231}
232
233pub fn alter_table_op_with_name(stmt: Statement, expected_name: &str) -> AlterTableOperation {
234    match stmt {
235        Statement::AlterTable {
236            name,
237            if_exists,
238            only: is_only,
239            operations,
240        } => {
241            assert_eq!(name.to_string(), expected_name);
242            assert!(!if_exists);
243            assert!(!is_only);
244            only(operations)
245        }
246        _ => panic!("Expected ALTER TABLE statement"),
247    }
248}
249pub fn alter_table_op(stmt: Statement) -> AlterTableOperation {
250    alter_table_op_with_name(stmt, "tab")
251}
252
253/// Creates a `Value::Number`, panic'ing if n is not a number
254pub fn number(n: &str) -> Value {
255    Value::Number(n.parse().unwrap(), false)
256}
257
258pub fn table_alias(name: impl Into<String>) -> Option<TableAlias> {
259    Some(TableAlias {
260        name: Ident::new(name),
261        columns: vec![],
262    })
263}
264
265pub fn table(name: impl Into<String>) -> TableFactor {
266    TableFactor::Table {
267        name: ObjectName(vec![Ident::new(name.into())]),
268        alias: None,
269        args: None,
270        with_hints: vec![],
271        version: None,
272        partitions: vec![],
273    }
274}
275
276pub fn join(relation: TableFactor) -> Join {
277    Join {
278        relation,
279        join_operator: JoinOperator::Inner(JoinConstraint::Natural),
280    }
281}