elif_orm/
security.rs

1//! Security utilities for SQL injection prevention
2//!
3//! This module provides functions for:
4//! - Escaping SQL identifiers (table names, column names)
5//! - Validating identifier names
6//! - Query pattern validation
7
8use crate::error::ModelError;
9use std::collections::HashSet;
10
11/// Characters allowed in SQL identifiers (alphanumeric, underscore, dollar)
12const ALLOWED_IDENTIFIER_CHARS: &str =
13    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_$";
14
15/// SQL keywords that must be escaped or rejected
16static SQL_KEYWORDS: &[&str] = &[
17    "SELECT",
18    "INSERT",
19    "UPDATE",
20    "DELETE",
21    "FROM",
22    "WHERE",
23    "JOIN",
24    "UNION",
25    "DROP",
26    "CREATE",
27    "ALTER",
28    "GRANT",
29    "REVOKE",
30    "TRUNCATE",
31    "EXEC",
32    "EXECUTE",
33    "DECLARE",
34    "CAST",
35    "CONVERT",
36    "SUBSTRING",
37    "ASCII",
38    "CHAR",
39    "NCHAR",
40    "SYSTEM",
41    "USER",
42    "SESSION_USER",
43    "CURRENT_USER",
44    "SUSER_NAME",
45    "IS_MEMBER",
46];
47
48/// Escape a SQL identifier (table name, column name, etc.)
49///
50/// This function:
51/// 1. Escapes any existing double quotes by doubling them
52/// 2. Wraps the identifier in double quotes for safe SQL usage
53///
54/// This approach prioritizes escaping over validation - any identifier can be escaped safely.
55///
56/// # Arguments
57/// * `identifier` - The identifier to escape
58///
59/// # Returns
60/// * Escaped identifier safe for use in SQL
61///
62/// # Examples
63/// ```
64/// use elif_orm::security::escape_identifier;
65///
66/// assert_eq!(escape_identifier("user_table"), "\"user_table\"");
67/// assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
68/// ```
69pub fn escape_identifier(identifier: &str) -> String {
70    // Escape double quotes by doubling them
71    let escaped = identifier.replace('\"', "\"\"");
72
73    // Wrap in double quotes for PostgreSQL identifier escaping
74    format!("\"{}\"", escaped)
75}
76
77/// Validate that an identifier is safe for use in SQL
78///
79/// # Arguments
80/// * `identifier` - The identifier to validate
81///
82/// # Returns
83/// * Ok(()) if valid, Err(ModelError) if invalid
84pub fn validate_identifier(identifier: &str) -> Result<(), ModelError> {
85    // Check for empty identifier
86    if identifier.is_empty() {
87        return Err(ModelError::Validation(
88            "Identifier cannot be empty".to_string(),
89        ));
90    }
91
92    // Check length (PostgreSQL limit is 63 characters)
93    if identifier.len() > 63 {
94        return Err(ModelError::Validation(format!(
95            "Identifier '{}' is too long (max 63 characters)",
96            identifier
97        )));
98    }
99
100    // Check for allowed characters only
101    for c in identifier.chars() {
102        if !ALLOWED_IDENTIFIER_CHARS.contains(c) {
103            return Err(ModelError::Validation(format!(
104                "Identifier '{}' contains invalid character '{}'",
105                identifier, c
106            )));
107        }
108    }
109
110    // Check that it doesn't start with a number
111    if identifier.chars().next().unwrap().is_ascii_digit() {
112        return Err(ModelError::Validation(format!(
113            "Identifier '{}' cannot start with a number",
114            identifier
115        )));
116    }
117
118    // Check against SQL keywords (case insensitive)
119    let upper_identifier = identifier.to_uppercase();
120    if SQL_KEYWORDS.contains(&upper_identifier.as_str()) {
121        return Err(ModelError::Validation(format!(
122            "Identifier '{}' is a reserved SQL keyword",
123            identifier
124        )));
125    }
126
127    Ok(())
128}
129
130/// Validate query pattern to prevent dangerous SQL constructs
131///
132/// # Arguments
133/// * `sql` - The SQL query to validate
134///
135/// # Returns
136/// * Ok(()) if safe, Err(ModelError) if potentially dangerous
137pub fn validate_query_pattern(sql: &str) -> Result<(), ModelError> {
138    let sql_upper = sql.to_uppercase();
139
140    // Check for multiple statements (semicolon not at the end)
141    let semicolon_positions: Vec<_> = sql.match_indices(';').collect();
142    if semicolon_positions.len() > 1
143        || (semicolon_positions.len() == 1 && semicolon_positions[0].0 != sql.trim().len() - 1)
144    {
145        return Err(ModelError::Validation(
146            "Multiple SQL statements not allowed".to_string(),
147        ));
148    }
149
150    // Check for dangerous patterns
151    let dangerous_patterns = [
152        "EXEC ",
153        "EXECUTE ",
154        "SP_",
155        "XP_",
156        "OPENROWSET",
157        "OPENDATASOURCE",
158        "BULK INSERT",
159        "BCP ",
160        "SQLCMD",
161        "OSQL",
162        "ISQL",
163        "UNION ALL SELECT",
164        "UNION SELECT",
165        "'; --",
166        "'/*",
167        "*/'",
168        "INFORMATION_SCHEMA",
169        "SYS.",
170        "SYSOBJECTS",
171        "SYSCOLUMNS",
172    ];
173
174    for pattern in &dangerous_patterns {
175        if sql_upper.contains(pattern) {
176            return Err(ModelError::Validation(format!(
177                "Query contains potentially dangerous pattern: {}",
178                pattern
179            )));
180        }
181    }
182
183    Ok(())
184}
185
186/// Validate parameter value to prevent injection through parameters
187///
188/// With the escape-focused approach, parameter validation is minimal since
189/// parameters are properly parameterized and escaped by the database driver.
190///
191/// # Arguments
192/// * `value` - The parameter value to validate
193///
194/// # Returns
195/// * Ok(()) if safe, Err(ModelError) if potentially dangerous
196pub fn validate_parameter(value: &str) -> Result<(), ModelError> {
197    // Check for extremely long parameters (potential DoS)
198    if value.len() > 65536 {
199        // 64KB limit
200        return Err(ModelError::Validation(
201            "Parameter value too large (max 64KB)".to_string(),
202        ));
203    }
204
205    // With proper parameterization, most content is safe
206    // Only reject if there are genuine protocol-level risks
207
208    Ok(())
209}
210
211/// Create a whitelist-based identifier validator
212///
213/// This creates a validator that only allows specific identifiers from a predefined list.
214/// Useful for table names and column names that should be strictly controlled.
215pub struct IdentifierWhitelist {
216    allowed: HashSet<String>,
217}
218
219impl IdentifierWhitelist {
220    /// Create a new whitelist validator
221    pub fn new(allowed_identifiers: Vec<&str>) -> Self {
222        let allowed = allowed_identifiers
223            .into_iter()
224            .map(|s| s.to_string())
225            .collect();
226        Self { allowed }
227    }
228
229    /// Validate that an identifier is in the whitelist
230    pub fn validate(&self, identifier: &str) -> Result<(), ModelError> {
231        if self.allowed.contains(identifier) {
232            Ok(())
233        } else {
234            Err(ModelError::Validation(format!(
235                "Identifier '{}' is not in the allowed whitelist",
236                identifier
237            )))
238        }
239    }
240
241    /// Get escaped identifier if it's in the whitelist
242    pub fn escape_if_allowed(&self, identifier: &str) -> Result<String, ModelError> {
243        self.validate(identifier)?;
244        Ok(escape_identifier(identifier))
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_escape_identifier() {
254        assert_eq!(escape_identifier("user_table"), "\"user_table\"");
255        assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
256        assert_eq!(escape_identifier("simple"), "\"simple\"");
257    }
258
259    #[test]
260    fn test_validate_identifier() {
261        assert!(validate_identifier("user_table").is_ok());
262        assert!(validate_identifier("table1").is_ok());
263        assert!(validate_identifier("_private").is_ok());
264
265        assert!(validate_identifier("").is_err());
266        assert!(validate_identifier("1table").is_err());
267        assert!(validate_identifier("table-name").is_err());
268        assert!(validate_identifier("table name").is_err());
269        assert!(validate_identifier("SELECT").is_err());
270        assert!(validate_identifier("select").is_err());
271    }
272
273    #[test]
274    fn test_validate_query_pattern() {
275        assert!(validate_query_pattern("SELECT * FROM users").is_ok());
276        assert!(validate_query_pattern("INSERT INTO users VALUES ($1, $2)").is_ok());
277
278        assert!(validate_query_pattern("SELECT * FROM users; DROP TABLE users").is_err());
279        assert!(validate_query_pattern("SELECT * FROM users UNION SELECT * FROM secrets").is_err());
280        assert!(validate_query_pattern("EXEC sp_executesql 'SELECT * FROM users'").is_err());
281    }
282
283    #[test]
284    fn test_validate_parameter() {
285        assert!(validate_parameter("normal value").is_ok());
286        assert!(validate_parameter("123").is_ok());
287        assert!(validate_parameter("user@example.com").is_ok());
288        // Parameters with SQL-like content are OK since they'll be parameterized
289        assert!(validate_parameter("'; DROP TABLE users; --").is_ok());
290        assert!(validate_parameter("UNION SELECT").is_ok());
291
292        // With escape-focused approach, null bytes are also OK
293        assert!(validate_parameter("value with \0 null byte").is_ok());
294    }
295
296    #[test]
297    fn test_identifier_whitelist() {
298        let whitelist = IdentifierWhitelist::new(vec!["users", "posts", "comments"]);
299
300        assert!(whitelist.validate("users").is_ok());
301        assert!(whitelist.validate("posts").is_ok());
302        assert!(whitelist.validate("admin_table").is_err());
303
304        assert_eq!(whitelist.escape_if_allowed("users").unwrap(), "\"users\"");
305        assert!(whitelist.escape_if_allowed("hacker_table").is_err());
306    }
307}