use sqlparser::ast::{AlterTableOperation, Statement};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
pub struct SqlReverser {
dialect: GenericDialect,
}
impl SqlReverser {
pub fn new() -> Self {
Self {
dialect: GenericDialect {},
}
}
pub fn reverse(&self, up_sql: &str) -> Result<String, String> {
let statements = Parser::parse_sql(&self.dialect, up_sql).map_err(|e| format!("Failed to parse SQL: {}", e))?;
let mut down_statements = Vec::new();
for stmt in statements {
let down_sql = self.reverse_statement(&stmt)?;
down_statements.push(down_sql);
}
Ok(down_statements.join(";\n"))
}
fn reverse_statement(&self, stmt: &Statement) -> Result<String, String> {
match stmt {
Statement::CreateTable { name, .. } => {
let table_name = name.to_string();
Ok(format!("DROP TABLE IF EXISTS {}", table_name))
}
Statement::AlterTable { name, operations, .. } => {
let table_name = name.to_string();
self.reverse_alter_table(&table_name, operations)
}
_ => Err(format!("Unsupported statement type: {:?}", stmt)),
}
}
fn reverse_alter_table(&self, table_name: &str, operations: &[AlterTableOperation]) -> Result<String, String> {
let mut down_operations = Vec::new();
for op in operations {
match op {
AlterTableOperation::AddColumn { column_def, .. } => {
let column_name = &column_def.name.value;
down_operations.push(format!("ALTER TABLE {} DROP COLUMN {}", table_name, column_name));
}
AlterTableOperation::RenameColumn {
old_column_name,
new_column_name,
} => {
down_operations.push(format!(
"ALTER TABLE {} RENAME COLUMN {} TO {}",
table_name, new_column_name, old_column_name
));
}
_ => return Err(format!("Unsupported alter operation: {:?}", op)),
}
}
Ok(down_operations.join(";\n"))
}
}
impl Default for SqlReverser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reverse_create_table() {
let up_sql = "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255));";
let reverser = SqlReverser::new();
let down_sql = reverser.reverse(up_sql).unwrap();
assert!(down_sql.contains("DROP TABLE"));
assert!(down_sql.contains("users"));
}
#[test]
fn test_reverse_add_column() {
let up_sql = "ALTER TABLE users ADD COLUMN email VARCHAR(255);";
let reverser = SqlReverser::new();
let down_sql = reverser.reverse(up_sql).unwrap();
assert!(down_sql.contains("DROP COLUMN"));
assert!(down_sql.contains("email"));
}
}