elif_orm/
security.rs

1//! Security utilities for SQL injection prevention
2//! 
3//! This module provides functions for:
4//! - Escaping SQL identifiers (table names, column names)
5//! - Validating identifier names
6//! - Query pattern validation
7
8use crate::error::ModelError;
9use std::collections::HashSet;
10
11/// Characters allowed in SQL identifiers (alphanumeric, underscore, dollar)
12const ALLOWED_IDENTIFIER_CHARS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_$";
13
14/// SQL keywords that must be escaped or rejected
15static SQL_KEYWORDS: &[&str] = &[
16    "SELECT", "INSERT", "UPDATE", "DELETE", "FROM", "WHERE", "JOIN", "UNION", 
17    "DROP", "CREATE", "ALTER", "GRANT", "REVOKE", "TRUNCATE", "EXEC", "EXECUTE",
18    "DECLARE", "CAST", "CONVERT", "SUBSTRING", "ASCII", "CHAR", "NCHAR",
19    "SYSTEM", "USER", "SESSION_USER", "CURRENT_USER", "SUSER_NAME", "IS_MEMBER"
20];
21
22/// Escape a SQL identifier (table name, column name, etc.)
23/// 
24/// This function:
25/// 1. Escapes any existing double quotes by doubling them
26/// 2. Wraps the identifier in double quotes for safe SQL usage
27/// 
28/// This approach prioritizes escaping over validation - any identifier can be escaped safely.
29/// 
30/// # Arguments
31/// * `identifier` - The identifier to escape
32/// 
33/// # Returns
34/// * Escaped identifier safe for use in SQL
35/// 
36/// # Examples
37/// ```
38/// use elif_orm::security::escape_identifier;
39/// 
40/// assert_eq!(escape_identifier("user_table"), "\"user_table\"");
41/// assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
42/// ```
43pub fn escape_identifier(identifier: &str) -> String {
44    // Escape double quotes by doubling them
45    let escaped = identifier.replace('\"', "\"\"");
46    
47    // Wrap in double quotes for PostgreSQL identifier escaping
48    format!("\"{}\"", escaped)
49}
50
51/// Validate that an identifier is safe for use in SQL
52/// 
53/// # Arguments
54/// * `identifier` - The identifier to validate
55/// 
56/// # Returns
57/// * Ok(()) if valid, Err(ModelError) if invalid
58pub fn validate_identifier(identifier: &str) -> Result<(), ModelError> {
59    // Check for empty identifier
60    if identifier.is_empty() {
61        return Err(ModelError::Validation(
62            "Identifier cannot be empty".to_string()
63        ));
64    }
65
66    // Check length (PostgreSQL limit is 63 characters)
67    if identifier.len() > 63 {
68        return Err(ModelError::Validation(
69            format!("Identifier '{}' is too long (max 63 characters)", identifier)
70        ));
71    }
72
73    // Check for allowed characters only
74    for c in identifier.chars() {
75        if !ALLOWED_IDENTIFIER_CHARS.contains(c) {
76            return Err(ModelError::Validation(
77                format!("Identifier '{}' contains invalid character '{}'", identifier, c)
78            ));
79        }
80    }
81
82    // Check that it doesn't start with a number
83    if identifier.chars().next().unwrap().is_ascii_digit() {
84        return Err(ModelError::Validation(
85            format!("Identifier '{}' cannot start with a number", identifier)
86        ));
87    }
88
89    // Check against SQL keywords (case insensitive)
90    let upper_identifier = identifier.to_uppercase();
91    if SQL_KEYWORDS.contains(&upper_identifier.as_str()) {
92        return Err(ModelError::Validation(
93            format!("Identifier '{}' is a reserved SQL keyword", identifier)
94        ));
95    }
96
97    Ok(())
98}
99
100/// Validate query pattern to prevent dangerous SQL constructs
101/// 
102/// # Arguments
103/// * `sql` - The SQL query to validate
104/// 
105/// # Returns
106/// * Ok(()) if safe, Err(ModelError) if potentially dangerous
107pub fn validate_query_pattern(sql: &str) -> Result<(), ModelError> {
108    let sql_upper = sql.to_uppercase();
109    
110    // Check for multiple statements (semicolon not at the end)
111    let semicolon_positions: Vec<_> = sql.match_indices(';').collect();
112    if semicolon_positions.len() > 1 || 
113       (semicolon_positions.len() == 1 && semicolon_positions[0].0 != sql.trim().len() - 1) {
114        return Err(ModelError::Validation(
115            "Multiple SQL statements not allowed".to_string()
116        ));
117    }
118
119    // Check for dangerous patterns
120    let dangerous_patterns = [
121        "EXEC ", "EXECUTE ", "SP_", "XP_", "OPENROWSET", "OPENDATASOURCE",
122        "BULK INSERT", "BCP ", "SQLCMD", "OSQL", "ISQL",
123        "UNION ALL SELECT", "UNION SELECT", "'; --", "'/*", "*/'",
124        "INFORMATION_SCHEMA", "SYS.", "SYSOBJECTS", "SYSCOLUMNS"
125    ];
126
127    for pattern in &dangerous_patterns {
128        if sql_upper.contains(pattern) {
129            return Err(ModelError::Validation(
130                format!("Query contains potentially dangerous pattern: {}", pattern)
131            ));
132        }
133    }
134
135    Ok(())
136}
137
138/// Validate parameter value to prevent injection through parameters
139/// 
140/// With the escape-focused approach, parameter validation is minimal since
141/// parameters are properly parameterized and escaped by the database driver.
142/// 
143/// # Arguments
144/// * `value` - The parameter value to validate
145/// 
146/// # Returns
147/// * Ok(()) if safe, Err(ModelError) if potentially dangerous
148pub fn validate_parameter(value: &str) -> Result<(), ModelError> {
149    // Check for extremely long parameters (potential DoS)
150    if value.len() > 65536 { // 64KB limit
151        return Err(ModelError::Validation(
152            "Parameter value too large (max 64KB)".to_string()
153        ));
154    }
155
156    // With proper parameterization, most content is safe
157    // Only reject if there are genuine protocol-level risks
158    
159    Ok(())
160}
161
162/// Create a whitelist-based identifier validator
163/// 
164/// This creates a validator that only allows specific identifiers from a predefined list.
165/// Useful for table names and column names that should be strictly controlled.
166pub struct IdentifierWhitelist {
167    allowed: HashSet<String>,
168}
169
170impl IdentifierWhitelist {
171    /// Create a new whitelist validator
172    pub fn new(allowed_identifiers: Vec<&str>) -> Self {
173        let allowed = allowed_identifiers.into_iter()
174            .map(|s| s.to_string())
175            .collect();
176        Self { allowed }
177    }
178
179    /// Validate that an identifier is in the whitelist
180    pub fn validate(&self, identifier: &str) -> Result<(), ModelError> {
181        if self.allowed.contains(identifier) {
182            Ok(())
183        } else {
184            Err(ModelError::Validation(
185                format!("Identifier '{}' is not in the allowed whitelist", identifier)
186            ))
187        }
188    }
189
190    /// Get escaped identifier if it's in the whitelist
191    pub fn escape_if_allowed(&self, identifier: &str) -> Result<String, ModelError> {
192        self.validate(identifier)?;
193        Ok(escape_identifier(identifier))
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_escape_identifier() {
203        assert_eq!(escape_identifier("user_table"), "\"user_table\"");
204        assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
205        assert_eq!(escape_identifier("simple"), "\"simple\"");
206    }
207
208    #[test]
209    fn test_validate_identifier() {
210        assert!(validate_identifier("user_table").is_ok());
211        assert!(validate_identifier("table1").is_ok());
212        assert!(validate_identifier("_private").is_ok());
213        
214        assert!(validate_identifier("").is_err());
215        assert!(validate_identifier("1table").is_err());
216        assert!(validate_identifier("table-name").is_err());
217        assert!(validate_identifier("table name").is_err());
218        assert!(validate_identifier("SELECT").is_err());
219        assert!(validate_identifier("select").is_err());
220    }
221
222    #[test]
223    fn test_validate_query_pattern() {
224        assert!(validate_query_pattern("SELECT * FROM users").is_ok());
225        assert!(validate_query_pattern("INSERT INTO users VALUES ($1, $2)").is_ok());
226        
227        assert!(validate_query_pattern("SELECT * FROM users; DROP TABLE users").is_err());
228        assert!(validate_query_pattern("SELECT * FROM users UNION SELECT * FROM secrets").is_err());
229        assert!(validate_query_pattern("EXEC sp_executesql 'SELECT * FROM users'").is_err());
230    }
231
232    #[test]
233    fn test_validate_parameter() {
234        assert!(validate_parameter("normal value").is_ok());
235        assert!(validate_parameter("123").is_ok());
236        assert!(validate_parameter("user@example.com").is_ok());
237        // Parameters with SQL-like content are OK since they'll be parameterized
238        assert!(validate_parameter("'; DROP TABLE users; --").is_ok());
239        assert!(validate_parameter("UNION SELECT").is_ok());
240        
241        // With escape-focused approach, null bytes are also OK
242        assert!(validate_parameter("value with \0 null byte").is_ok());
243    }
244
245    #[test]
246    fn test_identifier_whitelist() {
247        let whitelist = IdentifierWhitelist::new(vec!["users", "posts", "comments"]);
248        
249        assert!(whitelist.validate("users").is_ok());
250        assert!(whitelist.validate("posts").is_ok());
251        assert!(whitelist.validate("admin_table").is_err());
252        
253        assert_eq!(whitelist.escape_if_allowed("users").unwrap(), "\"users\"");
254        assert!(whitelist.escape_if_allowed("hacker_table").is_err());
255    }
256}