#[cfg(not(feature = "std"))]
use alloc::{
boxed::Box,
string::{String, ToString},
vec,
vec::Vec,
};
use core::fmt::Debug;
use crate::dialect::*;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::{Token, Tokenizer};
use crate::{ast::*, parser::ParserOptions};
#[cfg(test)]
use pretty_assertions::assert_eq;
pub struct TestedDialects {
pub dialects: Vec<Box<dyn Dialect>>,
pub options: Option<ParserOptions>,
pub recursion_limit: Option<usize>,
}
impl TestedDialects {
pub fn new(dialects: Vec<Box<dyn Dialect>>) -> Self {
Self {
dialects,
options: None,
recursion_limit: None,
}
}
pub fn new_with_options(dialects: Vec<Box<dyn Dialect>>, options: ParserOptions) -> Self {
Self {
dialects,
options: Some(options),
recursion_limit: None,
}
}
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_limit = Some(recursion_limit);
self
}
fn new_parser<'a>(&self, dialect: &'a dyn Dialect) -> Parser<'a> {
let parser = Parser::new(dialect);
let parser = if let Some(options) = &self.options {
parser.with_options(options.clone())
} else {
parser
};
let parser = if let Some(recursion_limit) = &self.recursion_limit {
parser.with_recursion_limit(*recursion_limit)
} else {
parser
};
parser
}
pub fn one_of_identical_results<F, T: Debug + PartialEq>(&self, f: F) -> T
where
F: Fn(&dyn Dialect) -> T,
{
let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect)));
parse_results
.fold(None, |s, (dialect, parsed)| {
if let Some((prev_dialect, prev_parsed)) = s {
assert_eq!(
prev_parsed, parsed,
"Parse results with {prev_dialect:?} are different from {dialect:?}"
);
}
Some((dialect, parsed))
})
.expect("tested dialects cannot be empty")
.1
}
pub fn run_parser_method<F, T: Debug + PartialEq>(&self, sql: &str, f: F) -> T
where
F: Fn(&mut Parser) -> T,
{
self.one_of_identical_results(|dialect| {
let mut parser = self.new_parser(dialect).try_with_sql(sql).unwrap();
f(&mut parser)
})
}
pub fn parse_sql_statements(&self, sql: &str) -> Result<Vec<Statement>, ParserError> {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
if let Some(options) = &self.options {
tokenizer = tokenizer.with_unescape(options.unescape);
}
let tokens = tokenizer.tokenize()?;
self.new_parser(dialect)
.with_tokens(tokens)
.parse_statements()
})
}
pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
assert_eq!(statements.len(), 1);
if !canonical.is_empty() && sql != canonical {
assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
}
let only_statement = statements.pop().unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}
pub fn statements_parse_to(&self, sql: &str, canonical: &str) -> Vec<Statement> {
let statements = self.parse_sql_statements(sql).expect(sql);
if !canonical.is_empty() && sql != canonical {
assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
} else {
assert_eq!(
sql,
statements
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join("; ")
);
}
statements
}
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
let ast = self
.run_parser_method(sql, |parser| parser.parse_expr())
.unwrap();
assert_eq!(canonical, &ast.to_string());
ast
}
pub fn verified_stmt(&self, sql: &str) -> Statement {
self.one_statement_parses_to(sql, sql)
}
pub fn verified_query(&self, sql: &str) -> Query {
match self.verified_stmt(sql) {
Statement::Query(query) => *query,
_ => panic!("Expected Query"),
}
}
pub fn verified_query_with_canonical(&self, query: &str, canonical: &str) -> Query {
match self.one_statement_parses_to(query, canonical) {
Statement::Query(query) => *query,
_ => panic!("Expected Query"),
}
}
pub fn verified_only_select(&self, query: &str) -> Select {
match *self.verified_query(query).body {
SetExpr::Select(s) => *s,
_ => panic!("Expected SetExpr::Select"),
}
}
pub fn verified_only_select_with_canonical(&self, query: &str, canonical: &str) -> Select {
let q = match self.one_statement_parses_to(query, canonical) {
Statement::Query(query) => *query,
_ => panic!("Expected Query"),
};
match *q.body {
SetExpr::Select(s) => *s,
_ => panic!("Expected SetExpr::Select"),
}
}
pub fn verified_expr(&self, sql: &str) -> Expr {
self.expr_parses_to(sql, sql)
}
pub fn tokenizes_to(&self, sql: &str, expected: Vec<Token>) {
if self.dialects.is_empty() {
panic!("No dialects to test");
}
self.dialects.iter().for_each(|dialect| {
let mut tokenizer = Tokenizer::new(&**dialect, sql);
if let Some(options) = &self.options {
tokenizer = tokenizer.with_unescape(options.unescape);
}
let tokens = tokenizer.tokenize().unwrap();
assert_eq!(expected, tokens, "Tokenized differently for {dialect:?}");
});
}
}
pub fn all_dialects() -> TestedDialects {
TestedDialects::new(vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(MsSqlDialect {}),
Box::new(AnsiDialect {}),
Box::new(SnowflakeDialect {}),
Box::new(HiveDialect {}),
Box::new(RedshiftSqlDialect {}),
Box::new(MySqlDialect {}),
Box::new(BigQueryDialect {}),
Box::new(SQLiteDialect {}),
Box::new(DuckDbDialect {}),
Box::new(DatabricksDialect {}),
Box::new(ClickHouseDialect {}),
Box::new(OracleDialect {}),
])
}
pub fn all_dialects_with_options(options: ParserOptions) -> TestedDialects {
TestedDialects::new_with_options(all_dialects().dialects, options)
}
pub fn all_dialects_where<F>(predicate: F) -> TestedDialects
where
F: Fn(&dyn Dialect) -> bool,
{
let mut dialects = all_dialects();
dialects.dialects.retain(|d| predicate(&**d));
dialects
}
pub fn all_dialects_except<F>(except: F) -> TestedDialects
where
F: Fn(&dyn Dialect) -> bool,
{
all_dialects_where(|d| !except(d))
}
pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
assert_eq!(
expected,
actual.iter().map(ToString::to_string).collect::<Vec<_>>()
);
}
pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
let mut iter = v.into_iter();
if let (Some(item), None) = (iter.next(), iter.next()) {
item
} else {
panic!("only called on collection without exactly one item")
}
}
pub fn expr_from_projection(item: &SelectItem) -> &Expr {
match item {
SelectItem::UnnamedExpr(expr) => expr,
_ => panic!("Expected UnnamedExpr"),
}
}
pub fn alter_table_op_with_name(stmt: Statement, expected_name: &str) -> AlterTableOperation {
match stmt {
Statement::AlterTable(alter_table) => {
assert_eq!(alter_table.name.to_string(), expected_name);
assert!(!alter_table.if_exists);
assert!(!alter_table.only);
assert_eq!(alter_table.table_type, None);
only(alter_table.operations)
}
_ => panic!("Expected ALTER TABLE statement"),
}
}
pub fn alter_table_op(stmt: Statement) -> AlterTableOperation {
alter_table_op_with_name(stmt, "tab")
}
pub fn number(n: &str) -> Value {
Value::Number(n.parse().unwrap(), false)
}
pub fn single_quoted_string(s: impl Into<String>) -> Value {
Value::SingleQuotedString(s.into())
}
pub fn table_alias(explicit: bool, name: impl Into<String>) -> Option<TableAlias> {
Some(TableAlias {
explicit,
name: Ident::new(name),
columns: vec![],
})
}
pub fn table(name: impl Into<String>) -> TableFactor {
TableFactor::Table {
name: ObjectName::from(vec![Ident::new(name.into())]),
alias: None,
args: None,
with_hints: vec![],
version: None,
partitions: vec![],
with_ordinality: false,
json_path: None,
sample: None,
index_hints: vec![],
}
}
pub fn table_from_name(name: ObjectName) -> TableFactor {
TableFactor::Table {
name,
alias: None,
args: None,
with_hints: vec![],
version: None,
partitions: vec![],
with_ordinality: false,
json_path: None,
sample: None,
index_hints: vec![],
}
}
pub fn table_with_alias(
name: impl Into<String>,
with_as_keyword: bool,
alias: impl Into<String>,
) -> TableFactor {
TableFactor::Table {
name: ObjectName::from(vec![Ident::new(name)]),
alias: table_alias(with_as_keyword, alias),
args: None,
with_hints: vec![],
version: None,
partitions: vec![],
with_ordinality: false,
json_path: None,
sample: None,
index_hints: vec![],
}
}
pub fn join(relation: TableFactor) -> Join {
Join {
relation,
global: false,
join_operator: JoinOperator::Join(JoinConstraint::Natural),
}
}
pub fn call(function: &str, args: impl IntoIterator<Item = Expr>) -> Expr {
Expr::Function(Function {
name: ObjectName::from(vec![Ident::new(function)]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
args: args
.into_iter()
.map(|arg| FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)))
.collect(),
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
})
}
pub fn index_column(stmt: Statement) -> Expr {
match stmt {
Statement::CreateIndex(CreateIndex { columns, .. }) => {
columns.first().unwrap().column.expr.clone()
}
Statement::CreateTable(CreateTable { constraints, .. }) => {
match constraints.first().unwrap() {
TableConstraint::Index(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::Unique(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::PrimaryKey(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::FulltextOrSpatial(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
_ => panic!("Expected an index, unique, primary, full text, or spatial constraint (foreign key does not support general key part expressions)"),
}
}
Statement::AlterTable(alter_table) => match alter_table.operations.first().unwrap() {
AlterTableOperation::AddConstraint { constraint, .. } => {
match constraint {
TableConstraint::Index(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::Unique(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::PrimaryKey(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
TableConstraint::FulltextOrSpatial(constraint) => {
constraint.columns.first().unwrap().column.expr.clone()
}
_ => panic!("Expected an index, unique, primary, full text, or spatial constraint (foreign key does not support general key part expressions)"),
}
}
_ => panic!("Expected a constraint"),
},
_ => panic!("Expected CREATE INDEX, ALTER TABLE, or CREATE TABLE, got: {stmt:?}"),
}
}