use crate::drivers::postgresql::raw_sql::normalize_sql_query;
pub fn validate_count_sql(sql: &str) -> Result<(), String> {
let normalized: String = normalize_sql_query(sql);
if normalized.is_empty() {
return Err("Query cannot be empty".to_string());
}
if normalized.contains(';') {
return Err("Multiple statements are not allowed".to_string());
}
if normalized.contains("--") || normalized.contains("/*") {
return Err("SQL comments are not allowed in count queries".to_string());
}
let lower: String = normalized.to_ascii_lowercase();
let compact: String = lower.chars().filter(|c| !c.is_whitespace()).collect();
if !compact.starts_with("select") {
return Err("Count query must start with SELECT".to_string());
}
if !compact.contains("count(") {
return Err("Query must include COUNT(...)".to_string());
}
if !lower.contains(" from ") {
return Err("Query must include a FROM clause".to_string());
}
let forbidden_phrases: [&str; 16] = [
" insert ",
" update ",
" delete ",
" drop ",
" alter ",
" create ",
" truncate ",
" execute ",
" grant ",
" revoke ",
" copy ",
" into ",
" union ",
"group by",
"having ",
"window ",
];
let padded: String = format!(" {lower} ");
for phrase in forbidden_phrases {
if padded.contains(phrase) {
return Err(format!(
"Query contains forbidden construct: {}",
phrase.trim()
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::validate_count_sql;
#[test]
fn accepts_simple_count() {
assert!(
validate_count_sql("SELECT COUNT(*) AS count FROM public.api_key_auth_log").is_ok()
);
}
#[test]
fn rejects_semicolon_separated() {
assert!(validate_count_sql("SELECT COUNT(*) FROM t; SELECT 1").is_err());
}
#[test]
fn rejects_insert() {
assert!(validate_count_sql("INSERT INTO t SELECT COUNT(*) FROM u").is_err());
}
#[test]
fn rejects_group_by() {
assert!(validate_count_sql("SELECT COUNT(*) FROM t GROUP BY id").is_err());
}
#[test]
fn rejects_union() {
assert!(validate_count_sql("SELECT COUNT(*) FROM t UNION SELECT COUNT(*) FROM u").is_err());
}
}