athena_rs 3.6.1

Hyper performant polyglot Database driver
Documentation
//! SQL sanitization utilities for removing sensitive query information from error messages.
//!
//! This module provides functions to strip SQL queries and sensitive data from error messages
//! while preserving useful metadata like constraint names, column names, and error codes.

use regex::Regex;
use std::sync::OnceLock;

/// Regular expressions for SQL pattern matching.
static SQL_PATTERNS: OnceLock<SqlPatterns> = OnceLock::new();

struct SqlPatterns {
    /// Matches SELECT statements
    select: Regex,
    /// Matches INSERT statements
    insert: Regex,
    /// Matches UPDATE statements
    update: Regex,
    /// Matches DELETE statements
    delete: Regex,
    /// Matches constraint names in quotes
    constraint: Regex,
    /// Matches column names after "column"
    column_name: Regex,
    /// Matches table names after "relation" or "table"
    table_name: Regex,
}

fn get_patterns() -> &'static SqlPatterns {
    SQL_PATTERNS.get_or_init(|| SqlPatterns {
        select: Regex::new(r"(?i)SELECT\s+[\s\S]*?FROM\s+[\s\S]*?(?:WHERE|LIMIT|ORDER|GROUP|$)")
            .unwrap(),
        insert: Regex::new(r"(?i)INSERT\s+INTO\s+[\s\S]*?(?:VALUES|SELECT|RETURNING|$)").unwrap(),
        update: Regex::new(r"(?i)UPDATE\s+[\s\S]*?SET\s+[\s\S]*?(?:WHERE|RETURNING|$)").unwrap(),
        delete: Regex::new(r"(?i)DELETE\s+FROM\s+[\s\S]*?(?:WHERE|RETURNING|$)").unwrap(),
        constraint: Regex::new(r#"constraint\s+"([^"]+)""#).unwrap(),
        // Match `column "<name>"` or `column <name>`, but NOT a qualified alias
        // like `column "t"."cache_hit_ratio"` (the regex would otherwise capture
        // `t` as the column). We require at least 2 characters AND that the
        // captured identifier is not immediately followed by a `.` (alias dot).
        column_name: Regex::new(r#"column\s+"?(?P<name>[a-zA-Z_][a-zA-Z0-9_]+)"?(?P<tail>\.?)"#)
            .unwrap(),
        table_name: Regex::new(r#"(?:relation|table)\s+"?([a-zA-Z_][a-zA-Z0-9_]*)"?"#).unwrap(),
    })
}

/// Information extracted from a SQL error message.
#[derive(Debug, Clone, Default)]
pub struct ExtractedInfo {
    /// Constraint name if present
    pub constraint_name: Option<String>,
    /// Column name if present
    pub column_name: Option<String>,
    /// Table name if present
    pub table_name: Option<String>,
}

/// Removes SQL queries from an error message while preserving important metadata.
///
/// This function:
/// - Strips out full SQL statements (SELECT, INSERT, UPDATE, DELETE)
/// - Preserves constraint names, column names, and table names
/// - Removes parameter values
/// - Returns a sanitized message safe for client responses
///
/// # Examples
///
/// ```
/// use athena_rs::error::sql_sanitizer::sanitize_error_message;
///
/// let raw_error = "failed to execute select query: SELECT * FROM users WHERE email = $1: column 'email' does not exist";
/// let sanitized = sanitize_error_message(raw_error);
/// assert!(!sanitized.contains("SELECT"));
/// assert!(sanitized.contains("column"));
/// ```
pub fn sanitize_error_message(error_msg: &str) -> String {
    let patterns: &SqlPatterns = get_patterns();
    let mut sanitized: String = error_msg.to_string();

    // Remove SELECT statements
    sanitized = patterns
        .select
        .replace_all(&sanitized, "[SQL query removed]")
        .to_string();

    // Remove INSERT statements
    sanitized = patterns
        .insert
        .replace_all(&sanitized, "[SQL query removed]")
        .to_string();

    // Remove UPDATE statements
    sanitized = patterns
        .update
        .replace_all(&sanitized, "[SQL query removed]")
        .to_string();

    // Remove DELETE statements
    sanitized = patterns
        .delete
        .replace_all(&sanitized, "[SQL query removed]")
        .to_string();

    // Clean up multiple consecutive "[SQL query removed]" markers
    while sanitized.contains("[SQL query removed][SQL query removed]") {
        sanitized = sanitized.replace(
            "[SQL query removed][SQL query removed]",
            "[SQL query removed]",
        );
    }

    // Clean up common patterns left after SQL removal
    sanitized = sanitized
        .replace(": [SQL query removed]:", ":")
        .replace(": [SQL query removed]", "")
        .replace("[SQL query removed]: ", "")
        .trim()
        .to_string();

    // If the entire message was just a SQL query, provide a generic message
    if sanitized.is_empty() || sanitized == "[SQL query removed]" {
        sanitized = "Database query execution failed".to_string();
    }

    sanitized
}

/// Extracts useful metadata from an error message without exposing the full query.
///
/// Returns constraint names, column names, and table names that can be safely
/// included in error responses for debugging.
pub fn extract_metadata(error_msg: &str) -> ExtractedInfo {
    let patterns: &SqlPatterns = get_patterns();
    let mut info: ExtractedInfo = ExtractedInfo::default();

    // Extract constraint name
    if let Some(captures) = patterns.constraint.captures(error_msg)
        && let Some(constraint) = captures.get(1)
    {
        info.constraint_name = Some(constraint.as_str().to_string());
    }

    // Extract column name. Walk every match so we can skip alias captures where
    // the identifier is followed by a `.` (e.g. `column "t"."cache_hit_ratio"`
    // would otherwise emit `t`). We keep the first match whose tail is NOT a
    // dot, which aligns with what Postgres actually reports as the column
    // name. Also skip captures that exactly match common internal aliases
    // (`t`, `s`) even if the regex's minimum-length guard already filters
    // single-letter names.
    for captures in patterns.column_name.captures_iter(error_msg) {
        let Some(name) = captures.name("name") else {
            continue;
        };
        let tail: &str = captures
            .name("tail")
            .map(|m| m.as_str())
            .unwrap_or_default();
        if tail == "." {
            continue;
        }
        let value: &str = name.as_str();
        if matches!(value, "t" | "s") {
            continue;
        }
        info.column_name = Some(value.to_string());
        break;
    }

    // Extract table name
    if let Some(captures) = patterns.table_name.captures(error_msg)
        && let Some(table) = captures.get(1)
    {
        info.table_name = Some(table.as_str().to_string());
    }

    info
}

/// Sanitizes an error message and extracts metadata in one call.
///
/// This is a convenience function that combines `sanitize_error_message`
/// and `extract_metadata`.
pub fn sanitize_and_extract(error_msg: &str) -> (String, ExtractedInfo) {
    let sanitized: String = sanitize_error_message(error_msg);
    let metadata: ExtractedInfo = extract_metadata(error_msg);
    (sanitized, metadata)
}