athena_rs 2.12.1

Database gateway API
Documentation
//! Validation for COUNT-only SQL accepted by [`super::sql_count_query`].

use crate::drivers::postgresql::raw_sql::normalize_sql_query;

/// Returns an error message if `sql` is not a single read-only `COUNT` query suitable for row counting.
pub fn validate_count_sql(sql: &str) -> Result<(), String> {
    let normalized: String = normalize_sql_query(sql);
    if normalized.is_empty() {
        return Err("Query cannot be empty".to_string());
    }
    if normalized.contains(';') {
        return Err("Multiple statements are not allowed".to_string());
    }
    if normalized.contains("--") || normalized.contains("/*") {
        return Err("SQL comments are not allowed in count queries".to_string());
    }

    let lower: String = normalized.to_ascii_lowercase();
    let compact: String = lower.chars().filter(|c| !c.is_whitespace()).collect();

    if !compact.starts_with("select") {
        return Err("Count query must start with SELECT".to_string());
    }
    if !compact.contains("count(") {
        return Err("Query must include COUNT(...)".to_string());
    }
    if !lower.contains(" from ") {
        return Err("Query must include a FROM clause".to_string());
    }

    let forbidden_phrases: [&str; 16] = [
        " insert ",
        " update ",
        " delete ",
        " drop ",
        " alter ",
        " create ",
        " truncate ",
        " execute ",
        " grant ",
        " revoke ",
        " copy ",
        " into ",
        " union ",
        "group by",
        "having ",
        "window ",
    ];

    let padded: String = format!(" {lower} ");
    for phrase in forbidden_phrases {
        if padded.contains(phrase) {
            return Err(format!(
                "Query contains forbidden construct: {}",
                phrase.trim()
            ));
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::validate_count_sql;

    #[test]
    fn accepts_simple_count() {
        assert!(
            validate_count_sql("SELECT COUNT(*) AS count FROM public.api_key_auth_log").is_ok()
        );
    }

    #[test]
    fn rejects_semicolon_separated() {
        assert!(validate_count_sql("SELECT COUNT(*) FROM t; SELECT 1").is_err());
    }

    #[test]
    fn rejects_insert() {
        assert!(validate_count_sql("INSERT INTO t SELECT COUNT(*) FROM u").is_err());
    }

    #[test]
    fn rejects_group_by() {
        assert!(validate_count_sql("SELECT COUNT(*) FROM t GROUP BY id").is_err());
    }

    #[test]
    fn rejects_union() {
        assert!(validate_count_sql("SELECT COUNT(*) FROM t UNION SELECT COUNT(*) FROM u").is_err());
    }
}