use crate::checks::{Check, MigrationContext};
use crate::config::Config;
use crate::violation::Violation;
use camino::Utf8Path;
use pg_query::protobuf::node::Node as NodeEnum;
use rhai::{AST, Dynamic, Engine};
use std::sync::Arc;
#[derive(thiserror::Error, Debug)]
#[error("{file}: {message}")]
pub struct ScriptError {
pub file: String,
pub message: String,
}
pub struct CustomCheck {
name: &'static str,
engine: Arc<Engine>,
ast: AST,
}
impl CustomCheck {
fn internal_error(&self, err: &dyn std::fmt::Display) -> Vec<Violation> {
vec![Violation::new(
format!("SCRIPT ERROR: {}", self.name),
format!("Error in custom check '{}': {err}", self.name),
"This is likely a diesel-guard bug. Please report it.",
)]
}
}
impl Check for CustomCheck {
fn name(&self) -> &'static str {
self.name
}
fn check(&self, node: &NodeEnum, config: &Config, ctx: &MigrationContext) -> Vec<Violation> {
let dynamic_node = match rhai::serde::to_dynamic(node) {
Ok(d) => d,
Err(e) => return self.internal_error(&e),
};
let dynamic_config = match rhai::serde::to_dynamic(config) {
Ok(d) => d,
Err(e) => return self.internal_error(&e),
};
let dynamic_ctx = rhai::serde::to_dynamic(ctx).unwrap();
let mut scope = rhai::Scope::new();
scope.push("node", dynamic_node);
scope.push("config", dynamic_config);
scope.push("ctx", dynamic_ctx);
match self
.engine
.eval_ast_with_scope::<Dynamic>(&mut scope, &self.ast)
{
Ok(result) => parse_script_result(self.name, result),
Err(e) => {
vec![Violation::new(
format!("SCRIPT ERROR: {}", self.name),
format!("Runtime error in custom check '{}': {e}", self.name),
"Fix the custom check script to eliminate the runtime error.",
)]
}
}
}
}
fn parse_script_result(check_name: &str, result: Dynamic) -> Vec<Violation> {
if result.is_unit() {
return vec![];
}
if result.is_map() {
return match map_to_violation(check_name, result) {
Some(v) => vec![v],
None => vec![],
};
}
if result.is_array() {
return result
.into_array()
.unwrap_or_default()
.into_iter()
.filter_map(|v| map_to_violation(check_name, v))
.collect();
}
vec![Violation::new(
format!("SCRIPT ERROR: {check_name}"),
format!(
"Custom check returned {}, expected (), map, or array",
result.type_name()
),
"Fix the custom check script to return a valid type.",
)]
}
fn map_to_violation(check_name: &str, value: Dynamic) -> Option<Violation> {
let map = value.try_cast::<rhai::Map>()?;
let operation = map
.get("operation")
.and_then(|v| v.clone().into_string().ok());
let problem = map
.get("problem")
.and_then(|v| v.clone().into_string().ok());
let safe_alternative = map
.get("safe_alternative")
.and_then(|v| v.clone().into_string().ok());
if let (Some(op), Some(prob), Some(alt)) = (operation, problem, safe_alternative) {
Some(Violation::new(op, prob, alt))
} else {
let mut issues = Vec::new();
for key in &["operation", "problem", "safe_alternative"] {
match map.get(*key) {
None => issues.push(format!("'{key}' is missing")),
Some(v) if v.clone().into_string().is_err() => {
issues.push(format!("'{key}' must be a string (got {})", v.type_name()));
}
_ => {}
}
}
Some(Violation::new(
format!("SCRIPT ERROR: {check_name}"),
format!(
"Custom check returned an invalid map: {}",
issues.join(", ")
),
"Fix the custom check script to return all three required string keys.",
))
}
}
fn create_pg_constants_module() -> rhai::Module {
use pg_query::protobuf::{AlterTableType, ConstrType, DropBehavior, ObjectType};
let mut m = rhai::Module::new();
m.set_var("OBJECT_INDEX", ObjectType::ObjectIndex as i64);
m.set_var("OBJECT_TABLE", ObjectType::ObjectTable as i64);
m.set_var("OBJECT_COLUMN", ObjectType::ObjectColumn as i64);
m.set_var("OBJECT_DATABASE", ObjectType::ObjectDatabase as i64);
m.set_var("OBJECT_SCHEMA", ObjectType::ObjectSchema as i64);
m.set_var("OBJECT_SEQUENCE", ObjectType::ObjectSequence as i64);
m.set_var("OBJECT_VIEW", ObjectType::ObjectView as i64);
m.set_var("OBJECT_FUNCTION", ObjectType::ObjectFunction as i64);
m.set_var("OBJECT_EXTENSION", ObjectType::ObjectExtension as i64);
m.set_var("OBJECT_TRIGGER", ObjectType::ObjectTrigger as i64);
m.set_var("OBJECT_TYPE", ObjectType::ObjectType as i64);
m.set_var("AT_ADD_COLUMN", AlterTableType::AtAddColumn as i64);
m.set_var("AT_COLUMN_DEFAULT", AlterTableType::AtColumnDefault as i64);
m.set_var("AT_DROP_NOT_NULL", AlterTableType::AtDropNotNull as i64);
m.set_var("AT_SET_NOT_NULL", AlterTableType::AtSetNotNull as i64);
m.set_var("AT_DROP_COLUMN", AlterTableType::AtDropColumn as i64);
m.set_var(
"AT_ALTER_COLUMN_TYPE",
AlterTableType::AtAlterColumnType as i64,
);
m.set_var("AT_ADD_CONSTRAINT", AlterTableType::AtAddConstraint as i64);
m.set_var(
"AT_DROP_CONSTRAINT",
AlterTableType::AtDropConstraint as i64,
);
m.set_var(
"AT_VALIDATE_CONSTRAINT",
AlterTableType::AtValidateConstraint as i64,
);
m.set_var("CONSTR_NOTNULL", ConstrType::ConstrNotnull as i64);
m.set_var("CONSTR_DEFAULT", ConstrType::ConstrDefault as i64);
m.set_var("CONSTR_IDENTITY", ConstrType::ConstrIdentity as i64);
m.set_var("CONSTR_GENERATED", ConstrType::ConstrGenerated as i64);
m.set_var("CONSTR_CHECK", ConstrType::ConstrCheck as i64);
m.set_var("CONSTR_PRIMARY", ConstrType::ConstrPrimary as i64);
m.set_var("CONSTR_UNIQUE", ConstrType::ConstrUnique as i64);
m.set_var("CONSTR_EXCLUSION", ConstrType::ConstrExclusion as i64);
m.set_var("CONSTR_FOREIGN", ConstrType::ConstrForeign as i64);
m.set_var("DROP_RESTRICT", DropBehavior::DropRestrict as i64);
m.set_var("DROP_CASCADE", DropBehavior::DropCascade as i64);
m
}
fn create_engine() -> Engine {
let mut engine = Engine::new();
engine.set_max_operations(100_000);
engine.set_max_string_size(10_000);
engine.set_max_array_size(1_000);
engine.set_max_map_size(1_000);
engine.register_static_module("pg", create_pg_constants_module().into());
engine
}
pub fn load_custom_checks(
dir: &Utf8Path,
config: &crate::config::Config,
) -> (Vec<Box<dyn Check>>, Vec<ScriptError>) {
let mut checks: Vec<Box<dyn Check>> = Vec::new();
let mut errors: Vec<ScriptError> = Vec::new();
let engine = Arc::new(create_engine());
let read_dir = match std::fs::read_dir(dir) {
Ok(rd) => rd,
Err(e) => {
errors.push(ScriptError {
file: dir.to_string(),
message: format!("Failed to read directory: {e}"),
});
return (checks, errors);
}
};
let mut entries: Vec<_> = read_dir
.filter_map(std::result::Result::ok)
.filter(|entry| entry.path().extension().is_some_and(|ext| ext == "rhai"))
.collect();
entries.sort_by_key(std::fs::DirEntry::file_name);
for entry in entries {
let path = entry.path();
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
if !config.is_check_enabled(stem) {
continue;
}
let source = match std::fs::read_to_string(&path) {
Ok(s) => s,
Err(e) => {
errors.push(ScriptError {
file: path.display().to_string(),
message: format!("Failed to read: {e}"),
});
continue;
}
};
match engine.compile(&source) {
Ok(ast) => {
let name: &'static str = Box::leak(stem.to_string().into_boxed_str());
checks.push(Box::new(CustomCheck {
name,
engine: Arc::clone(&engine),
ast,
}));
}
Err(e) => {
errors.push(ScriptError {
file: path.display().to_string(),
message: format!("Compilation error: {e}"),
});
}
}
}
(checks, errors)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checks::pg_helpers::extract_node;
use std::fs;
use tempfile::TempDir;
fn run_script(script: &str, sql: &str) -> Vec<Violation> {
run_script_with_config(script, sql, &crate::config::Config::default())
}
fn run_script_with_config(
script: &str,
sql: &str,
config: &crate::config::Config,
) -> Vec<Violation> {
run_script_with_ctx(
script,
sql,
config,
&crate::checks::MigrationContext::default(),
)
}
fn run_script_with_ctx(
script: &str,
sql: &str,
config: &crate::config::Config,
ctx: &crate::checks::MigrationContext,
) -> Vec<Violation> {
let engine = Arc::new(create_engine());
let ast = engine.compile(script).expect("script should compile");
let name: &'static str = Box::leak("test_check".to_string().into_boxed_str());
let check = CustomCheck { name, engine, ast };
let stmts = crate::parser::parse(sql).expect("SQL should parse");
let mut all_violations = Vec::new();
for raw_stmt in &stmts {
if let Some(node) = extract_node(raw_stmt) {
all_violations.extend(check.check(node, config, ctx));
}
}
all_violations
}
#[test]
fn test_script_returns_unit_no_violations() {
let violations = run_script(
r"
// Script that always returns unit (no violation)
let stmt = node.CreateStmt;
if stmt == () { return; }
",
"CREATE INDEX idx ON t(id);",
);
assert!(violations.is_empty());
}
#[test]
fn test_script_returns_map_one_violation() {
let violations = run_script(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if !stmt.concurrent {
#{
operation: "INDEX without CONCURRENTLY",
problem: "locks table",
safe_alternative: "use CONCURRENTLY"
}
}
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "INDEX without CONCURRENTLY");
assert_eq!(violations[0].problem, "locks table");
}
#[test]
fn test_script_returns_array_multiple_violations() {
let violations = run_script(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
[
#{ operation: "violation 1", problem: "p1", safe_alternative: "s1" },
#{ operation: "violation 2", problem: "p2", safe_alternative: "s2" }
]
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 2);
assert_eq!(violations[0].operation, "violation 1");
assert_eq!(violations[1].operation, "violation 2");
}
#[test]
fn test_script_invalid_return_type_no_crash() {
let violations = run_script(
r#"
"not a valid return type"
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
}
#[test]
fn test_script_infinite_loop_hits_max_operations() {
let violations = run_script(
r"
loop { }
",
"CREATE INDEX idx ON users(email);",
);
assert_eq!(
violations.len(),
1,
"expected 1 SCRIPT ERROR, got: {violations:?}"
);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
}
#[test]
fn test_script_wrong_node_type_returns_unit() {
let violations = run_script(
r#"
let stmt = node.CreateStmt;
if stmt == () { return; }
#{ operation: "found", problem: "p", safe_alternative: "s" }
"#,
"CREATE INDEX idx ON users(email);",
);
assert!(violations.is_empty());
}
#[test]
fn test_compilation_error_reported() {
let engine = Arc::new(create_engine());
let result = engine.compile("this is not valid rhai {{{");
assert!(result.is_err());
}
#[test]
fn test_load_custom_checks_from_directory() {
let dir = TempDir::new().unwrap();
let dir_path = Utf8Path::from_path(dir.path()).unwrap();
fs::write(
dir.path().join("require_concurrent.rhai"),
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if !stmt.concurrent {
#{ operation: "custom", problem: "no concurrently", safe_alternative: "use it" }
}
"#,
)
.unwrap();
fs::write(dir.path().join("broken.rhai"), "this is not valid {{{").unwrap();
fs::write(dir.path().join("notes.txt"), "not a script").unwrap();
let config = crate::config::Config::default();
let (checks, errors) = load_custom_checks(dir_path, &config);
assert_eq!(checks.len(), 1);
assert_eq!(checks[0].name(), "require_concurrent");
assert_eq!(errors.len(), 1);
assert!(errors[0].file.contains("broken.rhai"));
}
#[test]
fn test_empty_script_no_violations() {
let violations = run_script("", "CREATE INDEX idx ON users(email);");
assert!(violations.is_empty());
}
#[test]
fn test_map_with_missing_keys_produces_error_violation() {
let violations = run_script(
r#"
#{ operation: "op", problem: "p" }
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
assert_eq!(
violations[0].problem,
"Custom check returned an invalid map: 'safe_alternative' is missing"
);
}
#[test]
fn test_map_with_misspelled_key_produces_error_violation() {
let violations = run_script(
r#"
#{ operation: "op", problem: "p", safe_alterative: "s" }
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
assert_eq!(
violations[0].problem,
"Custom check returned an invalid map: 'safe_alternative' is missing"
);
}
#[test]
fn test_pg_constants_accessible_in_scripts() {
let violations = run_script(
r#"
let stmt = node.DropStmt;
if stmt == () { return; }
if stmt.remove_type == pg::OBJECT_INDEX {
#{ operation: "DROP INDEX", problem: "not concurrent", safe_alternative: "use CONCURRENTLY" }
}
"#,
"DROP INDEX idx_users_email;",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "DROP INDEX");
}
#[test]
fn test_config_postgres_version_accessible_in_scripts() {
let config = crate::config::Config {
postgres_version: Some(14),
..Default::default()
};
let violations = run_script_with_config(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if config.postgres_version != () && config.postgres_version >= 14 { return; }
#{ operation: "INDEX without CONCURRENTLY", problem: "locks table", safe_alternative: "use CONCURRENTLY" }
"#,
"CREATE INDEX idx ON users(email);",
&config,
);
assert!(violations.is_empty());
let config_old = crate::config::Config {
postgres_version: Some(10),
..Default::default()
};
let violations = run_script_with_config(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if config.postgres_version != () && config.postgres_version >= 14 { return; }
#{ operation: "INDEX without CONCURRENTLY", problem: "locks table", safe_alternative: "use CONCURRENTLY" }
"#,
"CREATE INDEX idx ON users(email);",
&config_old,
);
assert_eq!(violations.len(), 1);
}
#[test]
fn test_pg_constants_no_match() {
let violations = run_script(
r#"
let stmt = node.DropStmt;
if stmt == () { return; }
if stmt.remove_type == pg::OBJECT_TABLE {
#{ operation: "DROP TABLE", problem: "dangerous", safe_alternative: "be careful" }
}
"#,
"DROP INDEX idx_users_email;",
);
assert!(violations.is_empty());
}
#[test]
fn test_load_custom_checks_respects_disable() {
let dir = TempDir::new().unwrap();
let dir_path = Utf8Path::from_path(dir.path()).unwrap();
fs::write(dir.path().join("my_check.rhai"), r"return;").unwrap();
let config = crate::config::Config {
disable_checks: vec!["my_check".to_string()],
..Default::default()
};
let (checks, errors) = load_custom_checks(dir_path, &config);
assert_eq!(checks.len(), 0);
assert_eq!(errors.len(), 0);
}
#[test]
fn test_load_custom_checks_nonexistent_directory() {
let dir = TempDir::new().unwrap();
let missing = dir.path().join("does_not_exist");
let dir_path = Utf8Path::from_path(&missing).unwrap();
let config = crate::config::Config::default();
let (checks, errors) = load_custom_checks(dir_path, &config);
assert_eq!(checks.len(), 0);
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("Failed to read directory"));
}
#[test]
fn test_ctx_run_in_transaction_false_no_violation() {
let ctx = crate::checks::MigrationContext {
run_in_transaction: false,
no_transaction_hint: "",
};
let violations = run_script_with_ctx(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if stmt.concurrent && ctx.run_in_transaction {
#{ operation: "CONCURRENTLY in transaction", problem: "will fail", safe_alternative: ctx.no_transaction_hint }
}
"#,
"CREATE INDEX CONCURRENTLY idx ON users(email);",
&crate::config::Config::default(),
&ctx,
);
assert!(violations.is_empty());
}
#[test]
fn test_ctx_run_in_transaction_true_produces_violation() {
let ctx = crate::checks::MigrationContext {
run_in_transaction: true,
no_transaction_hint: "Add -- diesel:no-transaction to the migration file.",
};
let violations = run_script_with_ctx(
r#"
let stmt = node.IndexStmt;
if stmt == () { return; }
if stmt.concurrent && ctx.run_in_transaction {
#{
operation: "CONCURRENTLY in transaction",
problem: "will fail",
safe_alternative: ctx.no_transaction_hint
}
}
"#,
"CREATE INDEX CONCURRENTLY idx ON users(email);",
&crate::config::Config::default(),
&ctx,
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "CONCURRENTLY in transaction");
assert!(
violations[0]
.safe_alternative
.contains("diesel:no-transaction")
);
}
#[test]
fn test_load_custom_checks_unreadable_file() {
let dir = TempDir::new().unwrap();
let dir_path = Utf8Path::from_path(dir.path()).unwrap();
let script_path = dir.path().join("unreadable.rhai");
fs::create_dir(&script_path).unwrap();
let config = crate::config::Config::default();
let (checks, errors) = load_custom_checks(dir_path, &config);
assert_eq!(checks.len(), 0);
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("Failed to read"));
}
#[test]
fn test_map_with_non_string_operation_field() {
let violations = run_script(
r#"
#{ operation: 42, problem: "p", safe_alternative: "s" }
"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
assert_eq!(
violations[0].problem,
"Custom check returned an invalid map: 'operation' must be a string (got i64)"
);
}
#[test]
fn test_map_with_non_string_problem_field() {
let violations = run_script(
r#"#{ operation: "op", problem: 42, safe_alternative: "s" }"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
assert!(
violations[0].problem.contains("'problem' must be a string"),
"got: {}",
violations[0].problem
);
}
#[test]
fn test_map_with_non_string_safe_alternative_field() {
let violations = run_script(
r#"#{ operation: "op", problem: "p", safe_alternative: false }"#,
"CREATE INDEX idx ON users(email);",
);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].operation, "SCRIPT ERROR: test_check");
assert!(
violations[0]
.problem
.contains("'safe_alternative' must be a string"),
"got: {}",
violations[0].problem
);
}
fn make_test_check() -> CustomCheck {
let engine = Arc::new(create_engine());
let ast = engine.compile("()").expect("script should compile");
let name: &'static str = Box::leak("test_check".to_string().into_boxed_str());
CustomCheck { name, engine, ast }
}
#[test]
fn test_internal_error_yields_script_error_violation() {
let check = make_test_check();
let violations = check.internal_error(&"boom");
assert_eq!(violations.len(), 1);
let v = &violations[0];
assert_eq!(v.operation, "SCRIPT ERROR: test_check");
assert_eq!(v.problem, "Error in custom check 'test_check': boom");
assert_eq!(
v.safe_alternative,
"This is likely a diesel-guard bug. Please report it."
);
}
#[test]
fn test_script_runtime_error_yields_script_error_violation() {
let violations = run_script("1 / 0", "CREATE INDEX idx ON users(email);");
assert_eq!(
violations.len(),
1,
"expected 1 SCRIPT ERROR violation, got: {violations:?}"
);
let v = &violations[0];
assert_eq!(v.operation, "SCRIPT ERROR: test_check");
assert_eq!(
v.problem,
"Runtime error in custom check 'test_check': Division by zero: 1 / 0"
);
assert_eq!(
v.safe_alternative,
"Fix the custom check script to eliminate the runtime error."
);
}
#[test]
fn test_pg_alter_table_constraint_and_drop_constants_accessible() {
let violations = run_script(
r#"
let at = pg::AT_ADD_COLUMN;
let ct = pg::CONSTR_PRIMARY;
let db = pg::DROP_CASCADE;
if at == () || ct == () || db == () {
return #{ operation: "MISSING CONSTANT", problem: "a pg constant was ()", safe_alternative: "" };
}
"#,
"SELECT 1;",
);
assert!(
violations.is_empty(),
"All pg constants should be accessible, got: {violations:?}"
);
}
}