use crate::checks::CheckRegistry;
use crate::config::Config;
use crate::error::Result;
use crate::parser::SqlParser;
use crate::violation::Violation;
use camino::{Utf8Path, Utf8PathBuf};
use std::fs;
use walkdir::WalkDir;
pub struct SafetyChecker {
parser: SqlParser,
registry: CheckRegistry,
config: Config,
}
impl SafetyChecker {
pub fn new() -> Self {
let config = Config::load().unwrap_or_else(|e| {
eprintln!("Warning: Failed to load config: {}. Using defaults.", e);
Config::default()
});
Self::with_config(config)
}
pub fn with_config(config: Config) -> Self {
Self {
parser: SqlParser::new(),
registry: CheckRegistry::with_config(&config),
config,
}
}
pub fn check_sql(&self, sql: &str) -> Result<Vec<Violation>> {
let parsed = self.parser.parse_with_metadata(sql)?;
let violations = self.registry.check_statements_with_context(
&parsed.statements,
&parsed.sql,
&parsed.ignore_ranges,
);
Ok(violations)
}
pub fn check_file(&self, path: &Utf8Path) -> Result<Vec<Violation>> {
let sql = fs::read_to_string(path)?;
self.check_sql(&sql)
}
pub fn check_directory(&self, dir: &Utf8Path) -> Result<Vec<(String, Vec<Violation>)>> {
let files_to_check = self.collect_files(dir);
files_to_check
.iter()
.map(|file_path| {
let violations = self.check_file(file_path)?;
if violations.is_empty() {
Ok(None)
} else {
Ok(Some((file_path.to_string(), violations)))
}
})
.collect::<Result<Vec<_>>>()
.map(|results| results.into_iter().flatten().collect())
}
fn collect_files(&self, dir: &Utf8Path) -> Vec<Utf8PathBuf> {
let mut entries: Vec<_> = WalkDir::new(dir)
.max_depth(1)
.min_depth(1)
.into_iter()
.filter_map(|e| e.ok())
.collect();
entries.sort_by(|a, b| a.path().cmp(b.path()));
entries
.into_iter()
.flat_map(|entry| {
let Some(path) = Utf8Path::from_path(entry.path()) else {
return vec![];
};
if entry.file_type().is_dir() {
self.process_migration_directory(path)
} else if path.extension() == Some("sql") {
vec![path.to_owned()]
} else {
vec![]
}
})
.collect()
}
fn process_migration_directory(&self, path: &Utf8Path) -> Vec<Utf8PathBuf> {
let dir_name = match path.file_name() {
Some(name) => name,
None => return vec![],
};
if !self.config.should_check_migration(dir_name) {
return vec![];
}
let mut files = vec![];
let up_sql = path.join("up.sql");
if up_sql.exists() {
files.push(up_sql);
}
if self.config.check_down {
let down_sql = path.join("down.sql");
if down_sql.exists() {
files.push(down_sql);
}
}
files
}
pub fn check_path(&self, path: &Utf8Path) -> Result<Vec<(String, Vec<Violation>)>> {
if path.is_dir() {
self.check_directory(path)
} else {
let violations = self.check_file(path)?;
if violations.is_empty() {
Ok(vec![])
} else {
Ok(vec![(path.to_string(), violations)])
}
}
}
}
impl Default for SafetyChecker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_safe_sql() {
let checker = SafetyChecker::new();
let sql = "ALTER TABLE users ADD COLUMN email VARCHAR(255);";
let violations = checker.check_sql(sql).unwrap();
assert_eq!(violations.len(), 0);
}
#[test]
fn test_check_unsafe_sql() {
let checker = SafetyChecker::new();
let sql = "ALTER TABLE users ADD COLUMN admin BOOLEAN DEFAULT FALSE;";
let violations = checker.check_sql(sql).unwrap();
assert_eq!(violations.len(), 1);
}
#[test]
fn test_with_disabled_checks() {
let config = Config {
disable_checks: vec!["AddColumnCheck".to_string()],
..Default::default()
};
let checker = SafetyChecker::with_config(config);
let sql = "ALTER TABLE users ADD COLUMN admin BOOLEAN DEFAULT FALSE;";
let violations = checker.check_sql(sql).unwrap();
assert_eq!(violations.len(), 0); }
}