use std::fs;
use std::path::{Path, PathBuf};
pub struct RegressCase {
pub name: String,
pub sql_path: PathBuf,
pub expected_path: PathBuf,
}
#[derive(Debug, Default)]
pub struct RegressReport {
pub passed: Vec<String>,
pub failed: Vec<RegressFailure>,
}
#[derive(Debug)]
pub struct RegressFailure {
pub name: String,
pub diff: String,
pub expected: String,
pub actual: String,
}
impl RegressReport {
pub fn pass_count(&self) -> usize {
self.passed.len()
}
pub fn fail_count(&self) -> usize {
self.failed.len()
}
pub fn is_green(&self) -> bool {
self.failed.is_empty()
}
}
pub fn discover_cases(sql_dir: &Path, expected_dir: &Path) -> std::io::Result<Vec<RegressCase>> {
let mut cases = Vec::new();
for entry in fs::read_dir(sql_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
continue;
}
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("");
if stem.is_empty() {
continue;
}
let expected_path = expected_dir.join(format!("{stem}.out"));
cases.push(RegressCase {
name: stem.to_string(),
sql_path: path,
expected_path,
});
}
cases.sort_by(|a, b| a.name.cmp(&b.name));
Ok(cases)
}
pub fn run_case<F>(
case: &RegressCase,
executor: F,
) -> Result<Option<RegressFailure>, std::io::Error>
where
F: FnMut(&str) -> String,
{
let sql = fs::read_to_string(&case.sql_path)?;
let actual = run_sql_to_canonical(&sql, executor);
let expected = fs::read_to_string(&case.expected_path).unwrap_or_default();
if actual == expected {
Ok(None)
} else {
Ok(Some(RegressFailure {
name: case.name.clone(),
diff: render_diff(&expected, &actual),
expected,
actual,
}))
}
}
pub fn run_suite<F>(cases: &[RegressCase], mut executor: F) -> Result<RegressReport, std::io::Error>
where
F: FnMut(&str) -> String,
{
let mut report = RegressReport::default();
for case in cases {
match run_case(case, &mut executor)? {
None => report.passed.push(case.name.clone()),
Some(failure) => report.failed.push(failure),
}
}
Ok(report)
}
fn run_sql_to_canonical<F>(sql: &str, mut executor: F) -> String
where
F: FnMut(&str) -> String,
{
let mut out = String::new();
for stmt in split_statements(sql) {
let stmt_trimmed = stmt.trim();
if stmt_trimmed.is_empty() {
continue;
}
out.push_str("-- ");
out.push_str(stmt_trimmed);
out.push('\n');
let result = executor(stmt_trimmed);
out.push_str(&result);
if !result.ends_with('\n') {
out.push('\n');
}
out.push('\n');
}
out
}
fn split_statements(sql: &str) -> Vec<String> {
let mut out = Vec::new();
let mut current = String::new();
let mut in_single = false;
let mut in_double = false;
let mut prev = '\0';
for ch in sql.chars() {
match ch {
'\'' if !in_double && prev != '\\' => {
in_single = !in_single;
current.push(ch);
}
'"' if !in_single && prev != '\\' => {
in_double = !in_double;
current.push(ch);
}
';' if !in_single && !in_double => {
if !current.trim().is_empty() {
out.push(std::mem::take(&mut current));
}
}
_ => current.push(ch),
}
prev = ch;
}
if !current.trim().is_empty() {
out.push(current);
}
out
}
fn render_diff(expected: &str, actual: &str) -> String {
let exp_lines: Vec<&str> = expected.lines().collect();
let act_lines: Vec<&str> = actual.lines().collect();
let mut out = String::new();
let max = exp_lines.len().max(act_lines.len());
for i in 0..max {
match (exp_lines.get(i), act_lines.get(i)) {
(Some(e), Some(a)) if e == a => {
out.push_str(" ");
out.push_str(e);
out.push('\n');
}
(Some(e), Some(a)) => {
out.push_str("- ");
out.push_str(e);
out.push('\n');
out.push_str("+ ");
out.push_str(a);
out.push('\n');
}
(Some(e), None) => {
out.push_str("- ");
out.push_str(e);
out.push('\n');
}
(None, Some(a)) => {
out.push_str("+ ");
out.push_str(a);
out.push('\n');
}
(None, None) => {}
}
}
out
}
pub fn format_result(columns: &[String], rows: &[Vec<String>]) -> String {
if columns.is_empty() {
return String::new();
}
let mut widths: Vec<usize> = columns.iter().map(|c| c.len()).collect();
for row in rows {
for (i, cell) in row.iter().enumerate() {
if let Some(w) = widths.get_mut(i) {
*w = (*w).max(cell.len());
}
}
}
let mut out = String::new();
for (i, col) in columns.iter().enumerate() {
if i > 0 {
out.push_str(" | ");
}
out.push_str(&pad_right(col, widths[i]));
}
out.push('\n');
for (i, w) in widths.iter().enumerate() {
if i > 0 {
out.push_str("-+-");
}
out.push_str(&"-".repeat(*w));
}
out.push('\n');
for row in rows {
for (i, cell) in row.iter().enumerate() {
if i > 0 {
out.push_str(" | ");
}
let w = widths.get(i).copied().unwrap_or(cell.len());
out.push_str(&pad_right(cell, w));
}
out.push('\n');
}
out.push_str(&format!(
"({} row{})\n",
rows.len(),
if rows.len() == 1 { "" } else { "s" }
));
out
}
fn pad_right(s: &str, width: usize) -> String {
if s.len() >= width {
s.to_string()
} else {
let mut out = String::with_capacity(width);
out.push_str(s);
for _ in s.len()..width {
out.push(' ');
}
out
}
}