Skip to main content

dataprof_db/security/
sql_validation.rs

1//! SQL validation utilities to prevent injection attacks
2
3use crate::DataProfilerError;
4use std::collections::HashSet;
5
6/// Validate SQL identifiers (table names, column names) to prevent injection
7pub fn validate_sql_identifier(identifier: &str) -> Result<(), DataProfilerError> {
8    if identifier.trim().is_empty() {
9        return Err(DataProfilerError::sql_validation(
10            "SQL identifier cannot be empty",
11        ));
12    }
13
14    if identifier.len() > 128 {
15        return Err(DataProfilerError::sql_validation(
16            "SQL identifier too long (max 128 chars)",
17        ));
18    }
19
20    let allowed_chars: HashSet<char> =
21        "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_."
22            .chars()
23            .collect();
24
25    if (identifier.starts_with('"') && identifier.ends_with('"'))
26        || (identifier.starts_with('`') && identifier.ends_with('`'))
27        || (identifier.starts_with('[') && identifier.ends_with(']'))
28    {
29        let inner = &identifier[1..identifier.len() - 1];
30        if inner.is_empty() {
31            return Err(DataProfilerError::sql_validation(
32                "Quoted identifier cannot be empty",
33            ));
34        }
35        let quote_char = identifier
36            .chars()
37            .next()
38            .ok_or_else(|| DataProfilerError::sql_validation("Invalid identifier format"))?;
39        if inner.contains(quote_char)
40            || inner.contains(';')
41            || inner.contains("--")
42            || inner.contains("/*")
43            || inner.contains("*/")
44        {
45            return Err(DataProfilerError::sql_validation(
46                "Invalid characters in quoted identifier",
47            ));
48        }
49    } else {
50        if !identifier.chars().all(|c| allowed_chars.contains(&c)) {
51            return Err(DataProfilerError::sql_validation(&format!(
52                "Invalid SQL identifier '{}': only alphanumeric, underscore, and dot allowed",
53                identifier
54            )));
55        }
56
57        if let Some(first_char) = identifier.chars().next()
58            && !first_char.is_alphabetic()
59            && first_char != '_'
60        {
61            return Err(DataProfilerError::sql_validation(
62                "SQL identifier must start with letter or underscore",
63            ));
64        }
65    }
66
67    let identifier_upper = identifier.to_uppercase();
68    let dangerous_keywords = [
69        "DROP",
70        "DELETE",
71        "INSERT",
72        "UPDATE",
73        "TRUNCATE",
74        "ALTER",
75        "CREATE",
76        "GRANT",
77        "REVOKE",
78        "EXEC",
79        "EXECUTE",
80        "UNION",
81        "--",
82        "/*",
83        "*/",
84        ";",
85        "INFORMATION_SCHEMA",
86        "SYS",
87        "MASTER",
88        "PG_",
89        "MYSQL",
90    ];
91
92    for keyword in &dangerous_keywords {
93        if identifier_upper.contains(keyword) {
94            return Err(DataProfilerError::sql_validation(&format!(
95                "SQL identifier contains dangerous keyword or pattern: {}",
96                keyword
97            )));
98        }
99    }
100
101    Ok(())
102}
103
104/// Validate and sanitize a basic SQL query to ensure it's a SELECT statement
105pub fn validate_base_query(query: &str) -> Result<String, DataProfilerError> {
106    let trimmed = query.trim();
107
108    if trimmed.is_empty() {
109        return Err(DataProfilerError::sql_validation("Query cannot be empty"));
110    }
111
112    if trimmed.len() > 10000 {
113        return Err(DataProfilerError::sql_validation(
114            "Query too long (max 10000 chars)",
115        ));
116    }
117
118    let query_upper = trimmed.to_uppercase();
119    if !query_upper.starts_with("SELECT") {
120        return Err(DataProfilerError::sql_validation(
121            "Only SELECT queries are allowed for sampling",
122        ));
123    }
124
125    let dangerous_patterns = [
126        "DROP",
127        "DELETE",
128        "INSERT",
129        "UPDATE",
130        "TRUNCATE",
131        "ALTER",
132        "CREATE",
133        "GRANT",
134        "REVOKE",
135        "EXEC",
136        "EXECUTE",
137        "UNION",
138        "--",
139        "/*",
140        "INFORMATION_SCHEMA",
141        "SYS",
142        "MASTER",
143        "PG_",
144        "MYSQL",
145        "WAITFOR",
146        "SLEEP",
147        "EXTRACTVALUE",
148        "LOAD_FILE",
149        "COPY",
150        "ATTACH",
151        "PROGRAM",
152        "XP_CMDSHELL",
153    ];
154
155    for pattern in &dangerous_patterns {
156        if query_upper.contains(pattern) {
157            return Err(DataProfilerError::sql_validation(&format!(
158                "Query contains dangerous SQL pattern: {}",
159                pattern
160            )));
161        }
162    }
163
164    Ok(trimmed.to_string())
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_sql_identifier_validation() {
173        assert!(validate_sql_identifier("users").is_ok());
174        assert!(validate_sql_identifier("user_table").is_ok());
175        assert!(validate_sql_identifier("schema.table").is_ok());
176        assert!(validate_sql_identifier("\"quoted table\"").is_ok());
177        assert!(validate_sql_identifier("`quoted_table`").is_ok());
178
179        assert!(validate_sql_identifier("").is_err());
180        assert!(validate_sql_identifier("DROP TABLE").is_err());
181        assert!(validate_sql_identifier("users; DROP TABLE users; --").is_err());
182        assert!(validate_sql_identifier("table/* comment */").is_err());
183        assert!(validate_sql_identifier("123invalid").is_err());
184    }
185
186    #[test]
187    fn test_base_query_validation() {
188        assert!(validate_base_query("SELECT * FROM users").is_ok());
189        assert!(validate_base_query("  SELECT id, name FROM products  ").is_ok());
190
191        assert!(validate_base_query("").is_err());
192        assert!(validate_base_query("DROP TABLE users").is_err());
193        assert!(validate_base_query("SELECT * FROM users; DROP TABLE users").is_err());
194        assert!(validate_base_query("SELECT * FROM users UNION SELECT * FROM admin").is_err());
195    }
196}