1use crate::error::ModelError;
9use std::collections::HashSet;
10
11const ALLOWED_IDENTIFIER_CHARS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_$";
13
14static SQL_KEYWORDS: &[&str] = &[
16 "SELECT", "INSERT", "UPDATE", "DELETE", "FROM", "WHERE", "JOIN", "UNION",
17 "DROP", "CREATE", "ALTER", "GRANT", "REVOKE", "TRUNCATE", "EXEC", "EXECUTE",
18 "DECLARE", "CAST", "CONVERT", "SUBSTRING", "ASCII", "CHAR", "NCHAR",
19 "SYSTEM", "USER", "SESSION_USER", "CURRENT_USER", "SUSER_NAME", "IS_MEMBER"
20];
21
22pub fn escape_identifier(identifier: &str) -> String {
44 let escaped = identifier.replace('\"', "\"\"");
46
47 format!("\"{}\"", escaped)
49}
50
51pub fn validate_identifier(identifier: &str) -> Result<(), ModelError> {
59 if identifier.is_empty() {
61 return Err(ModelError::Validation(
62 "Identifier cannot be empty".to_string()
63 ));
64 }
65
66 if identifier.len() > 63 {
68 return Err(ModelError::Validation(
69 format!("Identifier '{}' is too long (max 63 characters)", identifier)
70 ));
71 }
72
73 for c in identifier.chars() {
75 if !ALLOWED_IDENTIFIER_CHARS.contains(c) {
76 return Err(ModelError::Validation(
77 format!("Identifier '{}' contains invalid character '{}'", identifier, c)
78 ));
79 }
80 }
81
82 if identifier.chars().next().unwrap().is_ascii_digit() {
84 return Err(ModelError::Validation(
85 format!("Identifier '{}' cannot start with a number", identifier)
86 ));
87 }
88
89 let upper_identifier = identifier.to_uppercase();
91 if SQL_KEYWORDS.contains(&upper_identifier.as_str()) {
92 return Err(ModelError::Validation(
93 format!("Identifier '{}' is a reserved SQL keyword", identifier)
94 ));
95 }
96
97 Ok(())
98}
99
100pub fn validate_query_pattern(sql: &str) -> Result<(), ModelError> {
108 let sql_upper = sql.to_uppercase();
109
110 let semicolon_positions: Vec<_> = sql.match_indices(';').collect();
112 if semicolon_positions.len() > 1 ||
113 (semicolon_positions.len() == 1 && semicolon_positions[0].0 != sql.trim().len() - 1) {
114 return Err(ModelError::Validation(
115 "Multiple SQL statements not allowed".to_string()
116 ));
117 }
118
119 let dangerous_patterns = [
121 "EXEC ", "EXECUTE ", "SP_", "XP_", "OPENROWSET", "OPENDATASOURCE",
122 "BULK INSERT", "BCP ", "SQLCMD", "OSQL", "ISQL",
123 "UNION ALL SELECT", "UNION SELECT", "'; --", "'/*", "*/'",
124 "INFORMATION_SCHEMA", "SYS.", "SYSOBJECTS", "SYSCOLUMNS"
125 ];
126
127 for pattern in &dangerous_patterns {
128 if sql_upper.contains(pattern) {
129 return Err(ModelError::Validation(
130 format!("Query contains potentially dangerous pattern: {}", pattern)
131 ));
132 }
133 }
134
135 Ok(())
136}
137
138pub fn validate_parameter(value: &str) -> Result<(), ModelError> {
149 if value.len() > 65536 { return Err(ModelError::Validation(
152 "Parameter value too large (max 64KB)".to_string()
153 ));
154 }
155
156 Ok(())
160}
161
162pub struct IdentifierWhitelist {
167 allowed: HashSet<String>,
168}
169
170impl IdentifierWhitelist {
171 pub fn new(allowed_identifiers: Vec<&str>) -> Self {
173 let allowed = allowed_identifiers.into_iter()
174 .map(|s| s.to_string())
175 .collect();
176 Self { allowed }
177 }
178
179 pub fn validate(&self, identifier: &str) -> Result<(), ModelError> {
181 if self.allowed.contains(identifier) {
182 Ok(())
183 } else {
184 Err(ModelError::Validation(
185 format!("Identifier '{}' is not in the allowed whitelist", identifier)
186 ))
187 }
188 }
189
190 pub fn escape_if_allowed(&self, identifier: &str) -> Result<String, ModelError> {
192 self.validate(identifier)?;
193 Ok(escape_identifier(identifier))
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_escape_identifier() {
203 assert_eq!(escape_identifier("user_table"), "\"user_table\"");
204 assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
205 assert_eq!(escape_identifier("simple"), "\"simple\"");
206 }
207
208 #[test]
209 fn test_validate_identifier() {
210 assert!(validate_identifier("user_table").is_ok());
211 assert!(validate_identifier("table1").is_ok());
212 assert!(validate_identifier("_private").is_ok());
213
214 assert!(validate_identifier("").is_err());
215 assert!(validate_identifier("1table").is_err());
216 assert!(validate_identifier("table-name").is_err());
217 assert!(validate_identifier("table name").is_err());
218 assert!(validate_identifier("SELECT").is_err());
219 assert!(validate_identifier("select").is_err());
220 }
221
222 #[test]
223 fn test_validate_query_pattern() {
224 assert!(validate_query_pattern("SELECT * FROM users").is_ok());
225 assert!(validate_query_pattern("INSERT INTO users VALUES ($1, $2)").is_ok());
226
227 assert!(validate_query_pattern("SELECT * FROM users; DROP TABLE users").is_err());
228 assert!(validate_query_pattern("SELECT * FROM users UNION SELECT * FROM secrets").is_err());
229 assert!(validate_query_pattern("EXEC sp_executesql 'SELECT * FROM users'").is_err());
230 }
231
232 #[test]
233 fn test_validate_parameter() {
234 assert!(validate_parameter("normal value").is_ok());
235 assert!(validate_parameter("123").is_ok());
236 assert!(validate_parameter("user@example.com").is_ok());
237 assert!(validate_parameter("'; DROP TABLE users; --").is_ok());
239 assert!(validate_parameter("UNION SELECT").is_ok());
240
241 assert!(validate_parameter("value with \0 null byte").is_ok());
243 }
244
245 #[test]
246 fn test_identifier_whitelist() {
247 let whitelist = IdentifierWhitelist::new(vec!["users", "posts", "comments"]);
248
249 assert!(whitelist.validate("users").is_ok());
250 assert!(whitelist.validate("posts").is_ok());
251 assert!(whitelist.validate("admin_table").is_err());
252
253 assert_eq!(whitelist.escape_if_allowed("users").unwrap(), "\"users\"");
254 assert!(whitelist.escape_if_allowed("hacker_table").is_err());
255 }
256}