#![allow(dead_code)]
use polyglot_sql::dialects::{Dialect, DialectType};
use polyglot_sql::generator::{Generator, GeneratorConfig};
use polyglot_sql::parser::Parser;
fn has_formatting_newline(sql: &str) -> bool {
let mut in_string = false;
let mut in_line_comment = false;
let mut in_block_comment = false;
let chars: Vec<char> = sql.chars().collect();
let mut i = 0;
while i < chars.len() {
let c = chars[i];
let next = chars.get(i + 1).copied();
if in_line_comment {
if c == '\n' {
in_line_comment = false;
}
i += 1;
continue;
}
if in_block_comment {
if c == '*' && next == Some('/') {
in_block_comment = false;
i += 2;
continue;
}
i += 1;
continue;
}
if !in_string && c == '-' && next == Some('-') {
in_line_comment = true;
i += 2;
continue;
}
if !in_string && c == '/' && next == Some('*') {
in_block_comment = true;
i += 2;
continue;
}
if c == '\'' && !in_string {
in_string = true;
} else if c == '\'' && in_string {
if i + 1 < chars.len() && chars[i + 1] == '\'' {
i += 1; } else {
in_string = false;
}
} else if c == '\n' && !in_string {
return true;
}
i += 1;
}
false
}
pub fn identity_test(sql: &str) -> Result<(), String> {
let ast = Parser::parse_sql(sql).map_err(|e| format!("Parse error: {}", e))?;
if ast.is_empty() {
return Err("No statements parsed".to_string());
}
let output = Generator::sql(&ast[0]).map_err(|e| format!("Generate error: {}", e))?;
if output != sql {
return Err(format!(
"Mismatch:\n input: {}\n output: {}",
sql, output
));
}
Ok(())
}
pub fn identity_test_with_expected(sql: &str, expected: Option<&str>) -> Result<(), String> {
let ast = Parser::parse_sql(sql).map_err(|e| format!("Parse error: {}", e))?;
if ast.is_empty() {
return Err("No statements parsed".to_string());
}
let output = Generator::sql(&ast[0]).map_err(|e| format!("Generate error: {}", e))?;
let expected_output = expected.unwrap_or(sql);
if output != expected_output {
return Err(format!(
"Mismatch:\n input: {}\n expected: {}\n output: {}",
sql, expected_output, output
));
}
Ok(())
}
pub fn dialect_identity_test(
sql: &str,
expected: Option<&str>,
dialect: DialectType,
) -> Result<(), String> {
let d = Dialect::get(dialect);
let ast = d.parse(sql).map_err(|e| format!("Parse error: {}", e))?;
if ast.is_empty() {
return Err("No statements parsed".to_string());
}
let expected_output = expected.unwrap_or(sql);
let use_pretty = expected_output.contains('\n');
let use_identify = if let Some(exp) = expected {
let skip_identify = matches!(
dialect,
DialectType::StarRocks
| DialectType::Exasol
| DialectType::TSQL
| DialectType::Fabric
| DialectType::BigQuery
| DialectType::Snowflake
| DialectType::ClickHouse
| DialectType::Databricks
| DialectType::Spark
| DialectType::Hive
);
if skip_identify {
false
} else {
let input_quotes = sql.matches('"').count() + sql.matches('`').count();
let expected_quotes = exp.matches('"').count() + exp.matches('`').count();
expected_quotes > input_quotes
}
} else {
false
};
let use_lowercase = matches!(dialect, DialectType::ClickHouse)
&& expected_output
.trim_start()
.chars()
.next()
.map_or(false, |c| c.is_ascii_lowercase());
let mut outputs = Vec::new();
for stmt in &ast {
let transformed = d
.transform(stmt.clone())
.map_err(|e| format!("Transform error: {}", e))?;
let out = d
.generate_with_overrides(&transformed, |config| {
if use_pretty {
config.pretty = true;
}
if use_identify {
config.always_quote_identifiers = true;
}
if use_lowercase {
config.uppercase_keywords = false;
}
})
.map_err(|e| format!("Generate error: {}", e))?;
outputs.push(out);
}
let output = outputs.join("; ");
if output != expected_output {
return Err(format!(
"Mismatch:\n input: {}\n expected: {}\n output: {}",
sql, expected_output, output
));
}
Ok(())
}
pub fn transpile_test(
sql: &str,
source: DialectType,
target: DialectType,
expected: &str,
) -> Result<(), String> {
let source_dialect = Dialect::get(source);
let use_pretty = has_formatting_newline(expected);
let results = if use_pretty {
source_dialect
.transpile_with(sql, target, polyglot_sql::TranspileOptions::pretty())
.map_err(|e| format!("Transpile error: {}", e))?
} else {
source_dialect
.transpile(sql, target)
.map_err(|e| format!("Transpile error: {}", e))?
};
if results.is_empty() {
return Err("No statements transpiled".to_string());
}
if results[0] != expected {
return Err(format!(
"Mismatch:\n input: {} ({:?} -> {:?})\n expected: {}\n actual: {}",
sql, source, target, expected, results[0]
));
}
Ok(())
}
pub fn parse_dialect(name: &str) -> Option<DialectType> {
match name.to_lowercase().as_str() {
"generic" | "" => Some(DialectType::Generic),
"postgres" | "postgresql" => Some(DialectType::PostgreSQL),
"mysql" => Some(DialectType::MySQL),
"bigquery" => Some(DialectType::BigQuery),
"snowflake" => Some(DialectType::Snowflake),
"duckdb" => Some(DialectType::DuckDB),
"sqlite" => Some(DialectType::SQLite),
"hive" => Some(DialectType::Hive),
"spark" => Some(DialectType::Spark),
"trino" => Some(DialectType::Trino),
"presto" => Some(DialectType::Presto),
"redshift" => Some(DialectType::Redshift),
"tsql" | "mssql" | "sqlserver" => Some(DialectType::TSQL),
"oracle" => Some(DialectType::Oracle),
"clickhouse" => Some(DialectType::ClickHouse),
"databricks" => Some(DialectType::Databricks),
"athena" => Some(DialectType::Athena),
"teradata" => Some(DialectType::Teradata),
"doris" => Some(DialectType::Doris),
"starrocks" => Some(DialectType::StarRocks),
"materialize" => Some(DialectType::Materialize),
"risingwave" => Some(DialectType::RisingWave),
"singlestore" | "memsql" => Some(DialectType::SingleStore),
"cockroachdb" | "cockroach" => Some(DialectType::CockroachDB),
"tidb" => Some(DialectType::TiDB),
"dremio" => Some(DialectType::Dremio),
"drill" => Some(DialectType::Drill),
"druid" => Some(DialectType::Druid),
"dune" => Some(DialectType::Dune),
"exasol" => Some(DialectType::Exasol),
"fabric" => Some(DialectType::Fabric),
"solr" => Some(DialectType::Solr),
"datafusion" | "arrow-datafusion" | "arrow_datafusion" => Some(DialectType::DataFusion),
_ => None,
}
}
pub fn normalization_test(sql: &str, expected: &str) -> Result<(), String> {
let ast = Parser::parse_sql(sql).map_err(|e| format!("Parse error: {}", e))?;
if ast.is_empty() {
return Err("No statements parsed".to_string());
}
let output = if has_formatting_newline(expected) {
let config = GeneratorConfig {
pretty: true,
..Default::default()
};
let mut gen = Generator::with_config(config);
gen.generate(&ast[0])
.map_err(|e| format!("Generate error: {}", e))?
} else {
Generator::sql(&ast[0]).map_err(|e| format!("Generate error: {}", e))?
};
if output != expected {
return Err(format!(
"Mismatch:\n input: {}\n expected: {}\n output: {}",
sql, expected, output
));
}
Ok(())
}
pub fn parser_error_test(sql: &str, dialect: Option<DialectType>) -> Result<(), String> {
let result = if let Some(d) = dialect {
let dial = Dialect::get(d);
dial.parse(sql)
} else {
Parser::parse_sql(sql)
};
match result {
Err(_) => Ok(()),
Ok(ast) => {
let generated = if !ast.is_empty() {
Generator::sql(&ast[0]).unwrap_or_default()
} else {
String::new()
};
Err(format!(
"Expected parse error for SQL: {}\n but got: {}",
sql, generated
))
}
}
}
pub fn pretty_test(input: &str, expected: &str) -> Result<(), String> {
let ast = Parser::parse_sql(input).map_err(|e| format!("Parse error: {}", e))?;
if ast.is_empty() {
return Err("No statements parsed".to_string());
}
let output = Generator::pretty_sql(&ast[0]).map_err(|e| format!("Generate error: {}", e))?;
let output_normalized = output.trim();
let expected_normalized = expected.trim();
if output_normalized != expected_normalized {
return Err(format!(
"Mismatch:\n input:\n{}\n expected:\n{}\n output:\n{}",
input, expected_normalized, output_normalized
));
}
Ok(())
}