use crate::linter::{Diagnostic, LintResult, Severity, Span};
const DB_COMMANDS: &[&str] = &["mysql", "psql", "sqlite3", "mariadb", "mongodb"];
const SQL_KEYWORDS: &[&str] = &[
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "WHERE",
];
pub fn check(source: &str) -> LintResult {
let mut result = LintResult::new();
for (line_num, line) in source.lines().enumerate() {
for db_cmd in DB_COMMANDS {
if let Some(cmd_col) = find_command(line, db_cmd) {
if contains_sql_with_variable(line) {
let span = Span::new(line_num + 1, cmd_col + 1, line_num + 1, line.len());
let diag = Diagnostic::new(
"SEC009",
Severity::Error,
format!(
"SQL injection risk in {} command - use parameterized queries",
db_cmd
),
span,
);
result.add(diag);
}
}
}
}
result
}
fn find_command(line: &str, cmd: &str) -> Option<usize> {
if let Some(pos) = line.find(cmd) {
let before_ok = if pos == 0 {
true
} else {
let char_before = line.chars().nth(pos - 1);
matches!(char_before, Some(' ' | '\t' | ';' | '&' | '|' | '('))
};
let after_idx = pos + cmd.len();
let after_ok = if after_idx >= line.len() {
true
} else {
let char_after = line.chars().nth(after_idx);
matches!(char_after, Some(' ' | '\t' | ';' | '&' | '|' | ')'))
};
if before_ok && after_ok {
return Some(pos);
}
}
None
}
fn contains_sql_with_variable(line: &str) -> bool {
let has_sql_keyword = SQL_KEYWORDS.iter().any(|kw| {
line.to_uppercase()
.split_whitespace()
.any(|word| word.contains(kw))
});
if !has_sql_keyword {
return false;
}
let has_variable_in_query = line.contains('"') || line.contains('\'');
let has_unquoted_var = line.contains('$') && !line.contains("\\$");
has_variable_in_query && has_unquoted_var
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_SEC009_detects_mysql_injection() {
let script = r#"mysql -e "SELECT * FROM users WHERE id=$USER_ID""#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 1);
let diag = &result.diagnostics[0];
assert_eq!(diag.code, "SEC009");
assert_eq!(diag.severity, Severity::Error);
assert!(diag.message.contains("SQL injection"));
}
#[test]
fn test_SEC009_detects_psql_injection() {
let script = r#"psql -c "DELETE FROM logs WHERE user='$USERNAME'""#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 1);
assert_eq!(result.diagnostics[0].code, "SEC009");
}
#[test]
fn test_SEC009_detects_sqlite3_injection() {
let script = r#"sqlite3 db.sqlite "INSERT INTO data VALUES ('$INPUT')""#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 1);
assert_eq!(result.diagnostics[0].code, "SEC009");
}
#[test]
fn test_SEC009_no_false_positive_no_variable() {
let script = r#"mysql -e "SELECT * FROM users""#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 0);
}
#[test]
fn test_SEC009_no_false_positive_escaped_variable() {
let script = r#"mysql -e "SELECT * FROM users WHERE id=\$SAFE_ID""#;
let result = check(script);
}
#[test]
fn test_SEC009_no_false_positive_comment() {
let script = "# mysql -e \"SELECT * FROM users WHERE id=$USER_ID\"";
let result = check(script);
}
#[test]
fn test_SEC009_multiple_injections() {
let script = r#"
mysql -e "SELECT * FROM users WHERE id=$USER_ID"
psql -c "DELETE FROM logs WHERE id=$LOG_ID"
"#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 2);
}
#[test]
fn test_SEC009_no_auto_fix() {
let script = r#"mysql -e "SELECT * FROM users WHERE id=$USER_ID""#;
let result = check(script);
assert_eq!(result.diagnostics.len(), 1);
let diag = &result.diagnostics[0];
assert!(diag.fix.is_none(), "SEC009 should not provide auto-fix");
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(10))]
#[test]
fn prop_sec009_never_panics(s in ".*") {
let _ = check(&s);
}
#[test]
fn prop_sec009_safe_mysql_no_vars(
table in "[a-z]{1,10}",
) {
let query = format!(r#"mysql -e "SELECT * FROM {}""#, table);
let result = check(&query);
prop_assert_eq!(result.diagnostics.len(), 0);
}
#[test]
fn prop_sec009_detects_all_db_commands(
db_idx in 0..5usize,
var_name in "[A-Z_]{1,10}",
) {
let db_cmd = match db_idx {
0 => "mysql",
1 => "psql",
2 => "sqlite3",
3 => "mariadb",
_ => "mongodb",
};
let query = format!(r#"{} -e "SELECT * FROM users WHERE id=${}_ID""#, db_cmd, var_name);
let result = check(&query);
prop_assert!(!result.diagnostics.is_empty());
prop_assert_eq!(result.diagnostics[0].code.as_str(), "SEC009");
}
}
}