use crate::error::Error;
pub fn validate_sql_identifier(name: &str) -> crate::Result<()> {
if name.is_empty() {
return Err(Error::config("SQL identifier cannot be empty"));
}
if name.len() > 255 {
return Err(Error::config(format!(
"SQL identifier too long: {} chars (max 255)",
name.len()
)));
}
let mut chars = name.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
_ => {
return Err(Error::config(format!(
"Invalid SQL identifier '{}': must start with a letter or underscore",
name
)));
}
}
for c in chars {
if !c.is_ascii_alphanumeric() && c != '_' {
return Err(Error::config(format!(
"Invalid SQL identifier '{}': contains invalid character '{}'",
name, c
)));
}
}
Ok(())
}
pub fn escape_string_literal(value: &str) -> String {
if !value.contains('\'') {
return value.to_string();
}
value.replace('\'', "''")
}
pub fn validate_sql_type_name(type_name: &str) -> crate::Result<()> {
if type_name.is_empty() {
return Err(Error::config("SQL type name cannot be empty"));
}
if type_name.len() > 255 {
return Err(Error::config(format!(
"SQL type name too long: {} chars (max 255)",
type_name.len()
)));
}
for c in type_name.chars() {
if !(c.is_ascii_alphanumeric()
|| c == '_'
|| c == '('
|| c == ')'
|| c == ','
|| c == ' '
|| c == '\''
|| c == '.')
{
return Err(Error::config(format!(
"Invalid SQL type name '{}': contains invalid character '{}'",
type_name, c
)));
}
}
Ok(())
}
pub fn validate_where_clause(clause: &str) -> crate::Result<()> {
if clause.is_empty() {
return Err(Error::config("WHERE clause cannot be empty"));
}
if clause.len() > 4096 {
return Err(Error::config(format!(
"WHERE clause too long: {} chars (max 4096)",
clause.len()
)));
}
if clause.contains('\0') {
return Err(Error::config(format!(
"WHERE clause contains prohibited null byte: {}",
clause
)));
}
if clause.contains('\n') || clause.contains('\r') {
return Err(Error::config(format!(
"WHERE clause contains prohibited newline: {}",
clause
)));
}
if clause.contains(';') {
return Err(Error::config(format!(
"WHERE clause contains prohibited character ';': {}",
clause
)));
}
if clause.contains("--") {
return Err(Error::config(format!(
"WHERE clause contains prohibited pattern '--': {}",
clause
)));
}
if clause.contains("/*") || clause.contains("*/") {
return Err(Error::config(format!(
"WHERE clause contains prohibited comment syntax: {}",
clause
)));
}
if clause.contains('\\') {
return Err(Error::config(format!(
"WHERE clause contains prohibited backslash escape: {}",
clause
)));
}
let upper = clause.to_uppercase();
for keyword in &[
"UNION",
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"EXEC",
"EXECUTE",
"DECLARE",
"CALL",
"GRANT",
"REVOKE",
"WAITFOR",
"BENCHMARK",
"SLEEP",
"PG_SLEEP",
"PG_READ_FILE",
"PG_LS_DIR",
"PG_READ_BINARY_FILE",
] {
if contains_word(&upper, keyword) {
return Err(Error::config(format!(
"WHERE clause contains prohibited keyword '{}': {}",
keyword, clause
)));
}
}
if contains_word_prefix(&upper, "XP_") {
return Err(Error::config(format!(
"WHERE clause contains prohibited pattern 'xp_': {}",
clause
)));
}
for keyword in &["INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE"] {
if upper.contains(keyword) {
return Err(Error::config(format!(
"WHERE clause contains prohibited keyword '{}': {}",
keyword, clause
)));
}
}
Ok(())
}
#[inline]
fn is_word_boundary(c: Option<char>) -> bool {
match c {
None => true,
Some(ch) => !ch.is_ascii_alphanumeric() && ch != '_',
}
}
fn contains_word(haystack: &str, word: &str) -> bool {
let h = haystack.as_bytes();
let w = word.as_bytes();
if w.is_empty() || w.len() > h.len() {
return false;
}
for start in 0..=(h.len() - w.len()) {
if &h[start..start + w.len()] == w {
let before = if start == 0 {
None
} else {
Some(h[start - 1] as char)
};
let after = if start + w.len() >= h.len() {
None
} else {
Some(h[start + w.len()] as char)
};
if is_word_boundary(before) && is_word_boundary(after) {
return true;
}
}
}
false
}
fn contains_word_prefix(haystack: &str, prefix: &str) -> bool {
let h = haystack.as_bytes();
let p = prefix.as_bytes();
if p.len() > h.len() {
return false;
}
for start in 0..=(h.len() - p.len()) {
if &h[start..start + p.len()] == p {
let before = if start == 0 {
None
} else {
Some(h[start - 1] as char)
};
if is_word_boundary(before) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_identifiers() {
assert!(validate_sql_identifier("users").is_ok());
assert!(validate_sql_identifier("my_table").is_ok());
assert!(validate_sql_identifier("_private").is_ok());
assert!(validate_sql_identifier("a").is_ok());
assert!(validate_sql_identifier("TABLE_123").is_ok());
assert!(validate_sql_identifier("sp1").is_ok());
}
#[test]
fn test_empty_identifier() {
assert!(validate_sql_identifier("").is_err());
}
#[test]
fn test_too_long_identifier() {
let long = "a".repeat(256);
assert!(validate_sql_identifier(&long).is_err());
let max = "a".repeat(255);
assert!(validate_sql_identifier(&max).is_ok());
}
#[test]
fn test_starts_with_digit() {
assert!(validate_sql_identifier("123abc").is_err());
assert!(validate_sql_identifier("0").is_err());
}
#[test]
fn test_injection_attempts() {
assert!(validate_sql_identifier("x; DROP TABLE users--").is_err());
assert!(validate_sql_identifier("x' OR '1'='1").is_err());
assert!(validate_sql_identifier("x--").is_err());
assert!(validate_sql_identifier("x()").is_err());
assert!(validate_sql_identifier("tabl\u{0435}").is_err()); assert!(validate_sql_identifier("user name").is_err());
assert!(validate_sql_identifier("x\nDROP TABLE").is_err());
assert!(validate_sql_identifier("x\0").is_err());
assert!(validate_sql_identifier("schema.table").is_err());
}
#[test]
fn test_special_chars_rejected() {
for ch in &[
'.', '-', '@', '#', '$', '!', '%', '&', '*', '[', ']', '"', '`',
] {
let name = format!("a{}", ch);
assert!(
validate_sql_identifier(&name).is_err(),
"Should reject '{}'",
name
);
}
}
#[test]
fn test_escape_no_quotes() {
assert_eq!(escape_string_literal("users"), "users");
assert_eq!(escape_string_literal("my_table"), "my_table");
}
#[test]
fn test_escape_single_quotes() {
assert_eq!(escape_string_literal("don't"), "don''t");
assert_eq!(escape_string_literal("'hello'"), "''hello''");
}
#[test]
fn test_escape_injection_attempt() {
assert_eq!(
escape_string_literal("x'; DROP TABLE users--"),
"x''; DROP TABLE users--"
);
assert_eq!(escape_string_literal("' OR '1'='1"), "'' OR ''1''=''1");
}
#[test]
fn test_escape_empty_string() {
assert_eq!(escape_string_literal(""), "");
}
#[test]
fn test_valid_type_names() {
assert!(validate_sql_type_name("INT").is_ok());
assert!(validate_sql_type_name("BIGINT").is_ok());
assert!(validate_sql_type_name("VARCHAR(255)").is_ok());
assert!(validate_sql_type_name("DECIMAL(10,2)").is_ok());
assert!(validate_sql_type_name("INT UNSIGNED").is_ok());
assert!(validate_sql_type_name("DOUBLE PRECISION").is_ok());
assert!(validate_sql_type_name("ENUM('a','b','c')").is_ok());
assert!(validate_sql_type_name("SET('x','y')").is_ok());
assert!(validate_sql_type_name("NUMERIC(10.2)").is_ok());
assert!(validate_sql_type_name("timestamp").is_ok());
assert!(validate_sql_type_name("TINYINT(1)").is_ok());
}
#[test]
fn test_empty_type_name() {
assert!(validate_sql_type_name("").is_err());
}
#[test]
fn test_type_name_injection_attempts() {
assert!(validate_sql_type_name("INT; DROP TABLE users--").is_err());
assert!(validate_sql_type_name("INT`; DROP TABLE").is_err());
assert!(validate_sql_type_name("INT\nDROP TABLE").is_err());
assert!(validate_sql_type_name("INT\0").is_err());
assert!(validate_sql_type_name("INT--comment").is_err());
}
#[test]
fn test_type_name_too_long() {
let long = "A".repeat(256);
assert!(validate_sql_type_name(&long).is_err());
}
#[test]
fn test_valid_where_clauses() {
assert!(validate_where_clause("status = 'active'").is_ok());
assert!(validate_where_clause("id > 0 AND deleted = false").is_ok());
assert!(validate_where_clause("age BETWEEN 18 AND 65").is_ok());
assert!(validate_where_clause("name LIKE '%test%'").is_ok());
assert!(validate_where_clause("id IN (1, 2, 3)").is_ok());
}
#[test]
fn test_where_clause_injection_attacks() {
assert!(validate_where_clause("1=1; DROP TABLE users").is_err());
assert!(validate_where_clause("1=1 -- bypass").is_err());
assert!(validate_where_clause("1=1 /* comment */").is_err());
assert!(validate_where_clause("name = '\\' OR 1=1").is_err());
assert!(validate_where_clause("1=1 UNION SELECT * FROM passwords").is_err());
assert!(validate_where_clause("1=1 union select * from passwords").is_err());
assert!(validate_where_clause("id = (SELECT MAX(id) FROM users)").is_err());
assert!(validate_where_clause("1=1; EXEC sp_help").is_err());
assert!(validate_where_clause("EXECUTE xp_cmdshell 'dir'").is_err());
assert!(validate_where_clause("xp_cmdshell('dir')").is_err());
assert!(validate_where_clause("1=1 AND SLEEP(5)").is_err());
assert!(validate_where_clause("1=1 AND BENCHMARK(1000000, SHA1('test'))").is_err());
assert!(validate_where_clause("1=1; WAITFOR DELAY '0:0:5'").is_err());
assert!(validate_where_clause("1=1 AND PG_SLEEP(5)").is_err());
assert!(validate_where_clause("1=1 INTO OUTFILE '/tmp/data'").is_err());
assert!(validate_where_clause("1=1 INTO DUMPFILE '/tmp/data'").is_err());
assert!(validate_where_clause("LOAD_FILE('/etc/passwd')").is_err());
assert!(validate_where_clause("name = 'test\0").is_err());
assert!(validate_where_clause("1=1\nDROP TABLE users").is_err());
}
#[test]
fn test_where_clause_empty() {
assert!(validate_where_clause("").is_err());
}
#[test]
fn test_where_clause_too_long() {
let long = "a".repeat(4097);
assert!(validate_where_clause(&long).is_err());
}
#[test]
fn test_where_clause_word_boundary_no_false_positives() {
assert!(validate_where_clause("executor_status = 'running'").is_ok());
assert!(validate_where_clause("execution_count > 0").is_ok());
assert!(validate_where_clause("selectivity > 0.5").is_ok());
assert!(validate_where_clause("selected = true").is_ok());
assert!(validate_where_clause("preselected = true").is_ok());
assert!(validate_where_clause("reunionist = 'alice'").is_ok());
assert!(validate_where_clause("asleep = false").is_ok());
}
#[test]
fn test_where_clause_word_boundary_still_catches_keywords() {
assert!(validate_where_clause("1=1 UNION ALL").is_err());
assert!(validate_where_clause("(SELECT 1)").is_err());
assert!(validate_where_clause("EXEC sp_help").is_err());
assert!(validate_where_clause("id=1 AND SLEEP(5)").is_err());
assert!(validate_where_clause("EXECUTE('cmd')").is_err());
assert!(validate_where_clause("1=1; DROP TABLE users").is_err()); assert!(validate_where_clause("ALTER TABLE users").is_err());
assert!(validate_where_clause("CREATE TABLE evil").is_err());
assert!(validate_where_clause("TRUNCATE TABLE users").is_err());
assert!(validate_where_clause("INSERT INTO users").is_err());
assert!(validate_where_clause("UPDATE users SET x=1").is_err());
assert!(validate_where_clause("DELETE FROM users").is_err());
assert!(validate_where_clause("GRANT ALL ON users").is_err());
assert!(validate_where_clause("REVOKE ALL ON users").is_err());
}
#[test]
fn test_where_clause_hash_comment_allowed() {
assert!(validate_where_clause("flags # 4 > 0").is_ok());
assert!(validate_where_clause("data #>> '{key}'").is_ok());
}
#[test]
fn test_where_clause_ddl_word_boundary_no_false_positives() {
assert!(validate_where_clause("droplet_count > 0").is_ok());
assert!(validate_where_clause("created_at > '2024-01-01'").is_ok());
assert!(validate_where_clause("alteration = 'none'").is_ok());
assert!(validate_where_clause("undeleted = true").is_ok());
assert!(validate_where_clause("inserted = false").is_ok());
assert!(validate_where_clause("updated_at IS NOT NULL").is_ok());
assert!(validate_where_clause("revoked = false").is_ok());
assert!(validate_where_clause("granted = true").is_ok());
}
#[test]
fn test_contains_word_boundaries() {
assert!(contains_word("HELLO WORLD", "HELLO"));
assert!(contains_word("HELLO WORLD", "WORLD"));
assert!(contains_word("(EXEC)", "EXEC"));
assert!(!contains_word("EXECUTOR", "EXEC"));
assert!(!contains_word("PRESELECT", "SELECT"));
assert!(!contains_word("SELECTIVITY", "SELECT"));
assert!(contains_word("SELECT", "SELECT"));
assert!(contains_word(" SELECT ", "SELECT"));
assert!(!contains_word("A_EXEC_B", "EXEC"));
assert!(!contains_word("DROP_TABLE", "DROP"));
assert!(!contains_word("", "EXEC"));
assert!(!contains_word("EX", "EXEC")); assert!(contains_word("EXEC\tSOMETHING", "EXEC"));
assert!(contains_word("X=EXEC", "EXEC")); }
}