1use crate::error::ModelError;
9use std::collections::HashSet;
10
11const ALLOWED_IDENTIFIER_CHARS: &str =
13 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_$";
14
15static SQL_KEYWORDS: &[&str] = &[
17 "SELECT",
18 "INSERT",
19 "UPDATE",
20 "DELETE",
21 "FROM",
22 "WHERE",
23 "JOIN",
24 "UNION",
25 "DROP",
26 "CREATE",
27 "ALTER",
28 "GRANT",
29 "REVOKE",
30 "TRUNCATE",
31 "EXEC",
32 "EXECUTE",
33 "DECLARE",
34 "CAST",
35 "CONVERT",
36 "SUBSTRING",
37 "ASCII",
38 "CHAR",
39 "NCHAR",
40 "SYSTEM",
41 "USER",
42 "SESSION_USER",
43 "CURRENT_USER",
44 "SUSER_NAME",
45 "IS_MEMBER",
46];
47
48pub fn escape_identifier(identifier: &str) -> String {
70 let escaped = identifier.replace('\"', "\"\"");
72
73 format!("\"{}\"", escaped)
75}
76
77pub fn validate_identifier(identifier: &str) -> Result<(), ModelError> {
85 if identifier.is_empty() {
87 return Err(ModelError::Validation(
88 "Identifier cannot be empty".to_string(),
89 ));
90 }
91
92 if identifier.len() > 63 {
94 return Err(ModelError::Validation(format!(
95 "Identifier '{}' is too long (max 63 characters)",
96 identifier
97 )));
98 }
99
100 for c in identifier.chars() {
102 if !ALLOWED_IDENTIFIER_CHARS.contains(c) {
103 return Err(ModelError::Validation(format!(
104 "Identifier '{}' contains invalid character '{}'",
105 identifier, c
106 )));
107 }
108 }
109
110 if identifier.chars().next().unwrap().is_ascii_digit() {
112 return Err(ModelError::Validation(format!(
113 "Identifier '{}' cannot start with a number",
114 identifier
115 )));
116 }
117
118 let upper_identifier = identifier.to_uppercase();
120 if SQL_KEYWORDS.contains(&upper_identifier.as_str()) {
121 return Err(ModelError::Validation(format!(
122 "Identifier '{}' is a reserved SQL keyword",
123 identifier
124 )));
125 }
126
127 Ok(())
128}
129
130pub fn validate_query_pattern(sql: &str) -> Result<(), ModelError> {
138 let sql_upper = sql.to_uppercase();
139
140 let semicolon_positions: Vec<_> = sql.match_indices(';').collect();
142 if semicolon_positions.len() > 1
143 || (semicolon_positions.len() == 1 && semicolon_positions[0].0 != sql.trim().len() - 1)
144 {
145 return Err(ModelError::Validation(
146 "Multiple SQL statements not allowed".to_string(),
147 ));
148 }
149
150 let dangerous_patterns = [
152 "EXEC ",
153 "EXECUTE ",
154 "SP_",
155 "XP_",
156 "OPENROWSET",
157 "OPENDATASOURCE",
158 "BULK INSERT",
159 "BCP ",
160 "SQLCMD",
161 "OSQL",
162 "ISQL",
163 "UNION ALL SELECT",
164 "UNION SELECT",
165 "'; --",
166 "'/*",
167 "*/'",
168 "INFORMATION_SCHEMA",
169 "SYS.",
170 "SYSOBJECTS",
171 "SYSCOLUMNS",
172 ];
173
174 for pattern in &dangerous_patterns {
175 if sql_upper.contains(pattern) {
176 return Err(ModelError::Validation(format!(
177 "Query contains potentially dangerous pattern: {}",
178 pattern
179 )));
180 }
181 }
182
183 Ok(())
184}
185
186pub fn validate_parameter(value: &str) -> Result<(), ModelError> {
197 if value.len() > 65536 {
199 return Err(ModelError::Validation(
201 "Parameter value too large (max 64KB)".to_string(),
202 ));
203 }
204
205 Ok(())
209}
210
211pub struct IdentifierWhitelist {
216 allowed: HashSet<String>,
217}
218
219impl IdentifierWhitelist {
220 pub fn new(allowed_identifiers: Vec<&str>) -> Self {
222 let allowed = allowed_identifiers
223 .into_iter()
224 .map(|s| s.to_string())
225 .collect();
226 Self { allowed }
227 }
228
229 pub fn validate(&self, identifier: &str) -> Result<(), ModelError> {
231 if self.allowed.contains(identifier) {
232 Ok(())
233 } else {
234 Err(ModelError::Validation(format!(
235 "Identifier '{}' is not in the allowed whitelist",
236 identifier
237 )))
238 }
239 }
240
241 pub fn escape_if_allowed(&self, identifier: &str) -> Result<String, ModelError> {
243 self.validate(identifier)?;
244 Ok(escape_identifier(identifier))
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_escape_identifier() {
254 assert_eq!(escape_identifier("user_table"), "\"user_table\"");
255 assert_eq!(escape_identifier("table\"name"), "\"table\"\"name\"");
256 assert_eq!(escape_identifier("simple"), "\"simple\"");
257 }
258
259 #[test]
260 fn test_validate_identifier() {
261 assert!(validate_identifier("user_table").is_ok());
262 assert!(validate_identifier("table1").is_ok());
263 assert!(validate_identifier("_private").is_ok());
264
265 assert!(validate_identifier("").is_err());
266 assert!(validate_identifier("1table").is_err());
267 assert!(validate_identifier("table-name").is_err());
268 assert!(validate_identifier("table name").is_err());
269 assert!(validate_identifier("SELECT").is_err());
270 assert!(validate_identifier("select").is_err());
271 }
272
273 #[test]
274 fn test_validate_query_pattern() {
275 assert!(validate_query_pattern("SELECT * FROM users").is_ok());
276 assert!(validate_query_pattern("INSERT INTO users VALUES ($1, $2)").is_ok());
277
278 assert!(validate_query_pattern("SELECT * FROM users; DROP TABLE users").is_err());
279 assert!(validate_query_pattern("SELECT * FROM users UNION SELECT * FROM secrets").is_err());
280 assert!(validate_query_pattern("EXEC sp_executesql 'SELECT * FROM users'").is_err());
281 }
282
283 #[test]
284 fn test_validate_parameter() {
285 assert!(validate_parameter("normal value").is_ok());
286 assert!(validate_parameter("123").is_ok());
287 assert!(validate_parameter("user@example.com").is_ok());
288 assert!(validate_parameter("'; DROP TABLE users; --").is_ok());
290 assert!(validate_parameter("UNION SELECT").is_ok());
291
292 assert!(validate_parameter("value with \0 null byte").is_ok());
294 }
295
296 #[test]
297 fn test_identifier_whitelist() {
298 let whitelist = IdentifierWhitelist::new(vec!["users", "posts", "comments"]);
299
300 assert!(whitelist.validate("users").is_ok());
301 assert!(whitelist.validate("posts").is_ok());
302 assert!(whitelist.validate("admin_table").is_err());
303
304 assert_eq!(whitelist.escape_if_allowed("users").unwrap(), "\"users\"");
305 assert!(whitelist.escape_if_allowed("hacker_table").is_err());
306 }
307}