1use thiserror::Error;
16use uuid::Uuid;
17
18pub const MAX_TABLE_NAME_LENGTH: usize = 255;
20
21pub const MAX_COLUMN_NAME_LENGTH: usize = 255;
23
24pub const MAX_IDENTIFIER_LENGTH: usize = 255;
26
27pub const MAX_DESCRIPTION_LENGTH: usize = 10000;
29
30#[derive(Debug, Clone, Error)]
32pub enum ValidationError {
33 #[error("{0} cannot be empty")]
35 Empty(&'static str),
36
37 #[error("{field} exceeds maximum length (max: {max}, got: {actual})")]
39 TooLong {
40 field: &'static str,
41 max: usize,
42 actual: usize,
43 },
44
45 #[error("{field} contains invalid characters: {reason}")]
47 InvalidCharacters { field: &'static str, reason: String },
48
49 #[error("{0}: {1}")]
51 InvalidFormat(&'static str, String),
52
53 #[error("{field} cannot be a reserved word: {word}")]
55 ReservedWord { field: &'static str, word: String },
56}
57
58pub type ValidationResult<T> = Result<T, ValidationError>;
60
61pub fn validate_table_name(name: &str) -> ValidationResult<()> {
82 if name.is_empty() {
83 return Err(ValidationError::Empty("table name"));
84 }
85
86 if name.len() > MAX_TABLE_NAME_LENGTH {
87 return Err(ValidationError::TooLong {
88 field: "table name",
89 max: MAX_TABLE_NAME_LENGTH,
90 actual: name.len(),
91 });
92 }
93
94 let first_char = name.chars().next().unwrap();
96 if !first_char.is_alphabetic() && first_char != '_' {
97 return Err(ValidationError::InvalidFormat(
98 "table name",
99 "must start with a letter or underscore".to_string(),
100 ));
101 }
102
103 for c in name.chars() {
105 if !c.is_alphanumeric() && c != '_' && c != '-' {
106 return Err(ValidationError::InvalidCharacters {
107 field: "table name",
108 reason: format!("invalid character: '{}'", c),
109 });
110 }
111 }
112
113 if is_sql_reserved_word(name) {
115 return Err(ValidationError::ReservedWord {
116 field: "table name",
117 word: name.to_string(),
118 });
119 }
120
121 Ok(())
122}
123
124pub fn validate_column_name(name: &str) -> ValidationResult<()> {
145 if name.is_empty() {
146 return Err(ValidationError::Empty("column name"));
147 }
148
149 if name.len() > MAX_COLUMN_NAME_LENGTH {
150 return Err(ValidationError::TooLong {
151 field: "column name",
152 max: MAX_COLUMN_NAME_LENGTH,
153 actual: name.len(),
154 });
155 }
156
157 let first_char = name.chars().next().unwrap();
159 if !first_char.is_alphabetic() && first_char != '_' {
160 return Err(ValidationError::InvalidFormat(
161 "column name",
162 "must start with a letter or underscore".to_string(),
163 ));
164 }
165
166 for c in name.chars() {
168 if !c.is_alphanumeric() && c != '_' && c != '-' && c != '.' {
169 return Err(ValidationError::InvalidCharacters {
170 field: "column name",
171 reason: format!("invalid character: '{}'", c),
172 });
173 }
174 }
175
176 if !name.contains('.') && is_sql_reserved_word(name) {
178 return Err(ValidationError::ReservedWord {
179 field: "column name",
180 word: name.to_string(),
181 });
182 }
183
184 Ok(())
185}
186
187pub fn validate_uuid(id: &str) -> ValidationResult<Uuid> {
198 Uuid::parse_str(id)
199 .map_err(|e| ValidationError::InvalidFormat("UUID", format!("invalid UUID format: {}", e)))
200}
201
202pub fn validate_data_type(data_type: &str) -> ValidationResult<()> {
221 if data_type.is_empty() {
222 return Err(ValidationError::Empty("data type"));
223 }
224
225 if data_type.len() > MAX_IDENTIFIER_LENGTH {
226 return Err(ValidationError::TooLong {
227 field: "data type",
228 max: MAX_IDENTIFIER_LENGTH,
229 actual: data_type.len(),
230 });
231 }
232
233 let lower = data_type.to_lowercase();
235 if lower.contains(';') || lower.contains("--") || lower.contains("/*") {
236 return Err(ValidationError::InvalidCharacters {
237 field: "data type",
238 reason: "contains SQL comment or statement separator".to_string(),
239 });
240 }
241
242 for c in data_type.chars() {
244 if !c.is_alphanumeric()
245 && c != '('
246 && c != ')'
247 && c != ','
248 && c != ' '
249 && c != '_'
250 && c != '<'
251 && c != '>'
252 && c != '['
253 && c != ']'
254 {
255 return Err(ValidationError::InvalidCharacters {
256 field: "data type",
257 reason: format!("invalid character: '{}'", c),
258 });
259 }
260 }
261
262 Ok(())
263}
264
265pub fn validate_description(desc: &str) -> ValidationResult<()> {
273 if desc.len() > MAX_DESCRIPTION_LENGTH {
274 return Err(ValidationError::TooLong {
275 field: "description",
276 max: MAX_DESCRIPTION_LENGTH,
277 actual: desc.len(),
278 });
279 }
280
281 Ok(())
282}
283
284pub fn sanitize_sql_identifier(name: &str, dialect: &str) -> String {
298 let quote_char = match dialect.to_lowercase().as_str() {
299 "mysql" | "mariadb" => '`',
300 "sqlserver" | "mssql" => '[',
301 _ => '"', };
303
304 let end_char = if quote_char == '[' { ']' } else { quote_char };
305
306 let escaped = if quote_char == end_char {
308 name.replace(quote_char, &format!("{}{}", quote_char, quote_char))
309 } else {
310 name.replace(end_char, &format!("{}{}", end_char, end_char))
311 };
312
313 format!("{}{}{}", quote_char, escaped, end_char)
314}
315
316pub fn sanitize_description(desc: &str) -> String {
320 desc.chars()
322 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
323 .collect()
324}
325
326fn is_sql_reserved_word(word: &str) -> bool {
330 const RESERVED_WORDS: &[&str] = &[
331 "select",
332 "from",
333 "where",
334 "insert",
335 "update",
336 "delete",
337 "create",
338 "drop",
339 "alter",
340 "table",
341 "index",
342 "view",
343 "database",
344 "schema",
345 "grant",
346 "revoke",
347 "commit",
348 "rollback",
349 "begin",
350 "end",
351 "transaction",
352 "primary",
353 "foreign",
354 "key",
355 "references",
356 "constraint",
357 "unique",
358 "check",
359 "default",
360 "not",
361 "null",
362 "and",
363 "or",
364 "in",
365 "between",
366 "like",
367 "is",
368 "case",
369 "when",
370 "then",
371 "else",
372 "as",
373 "on",
374 "join",
375 "inner",
376 "outer",
377 "left",
378 "right",
379 "full",
380 "cross",
381 "natural",
382 "using",
383 "group",
384 "by",
385 "having",
386 "order",
387 "asc",
388 "desc",
389 "limit",
390 "offset",
391 "union",
392 "intersect",
393 "except",
394 "all",
395 "distinct",
396 "top",
397 "values",
398 "set",
399 "into",
400 "exec",
401 "execute",
402 "procedure",
403 "function",
404 "trigger",
405 "true",
406 "false",
407 "int",
408 "integer",
409 "varchar",
410 "char",
411 "text",
412 "boolean",
413 "date",
414 "time",
415 "timestamp",
416 "float",
417 "double",
418 "decimal",
419 "numeric",
420 ];
421
422 let lower = word.to_lowercase();
423 RESERVED_WORDS.contains(&lower.as_str())
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_validate_table_name_valid() {
432 assert!(validate_table_name("users").is_ok());
433 assert!(validate_table_name("user_orders").is_ok());
434 assert!(validate_table_name("User123").is_ok());
435 assert!(validate_table_name("_private").is_ok());
436 assert!(validate_table_name("my-table").is_ok());
437 }
438
439 #[test]
440 fn test_validate_table_name_empty() {
441 assert!(matches!(
442 validate_table_name(""),
443 Err(ValidationError::Empty(_))
444 ));
445 }
446
447 #[test]
448 fn test_validate_table_name_too_long() {
449 let long_name = "a".repeat(300);
450 assert!(matches!(
451 validate_table_name(&long_name),
452 Err(ValidationError::TooLong { .. })
453 ));
454 }
455
456 #[test]
457 fn test_validate_table_name_starts_with_digit() {
458 assert!(matches!(
459 validate_table_name("123users"),
460 Err(ValidationError::InvalidFormat(..))
461 ));
462 }
463
464 #[test]
465 fn test_validate_table_name_invalid_chars() {
466 assert!(matches!(
467 validate_table_name("user$table"),
468 Err(ValidationError::InvalidCharacters { .. })
469 ));
470 assert!(matches!(
471 validate_table_name("user;table"),
472 Err(ValidationError::InvalidCharacters { .. })
473 ));
474 }
475
476 #[test]
477 fn test_validate_table_name_reserved_word() {
478 assert!(matches!(
479 validate_table_name("SELECT"),
480 Err(ValidationError::ReservedWord { .. })
481 ));
482 assert!(matches!(
483 validate_table_name("table"),
484 Err(ValidationError::ReservedWord { .. })
485 ));
486 }
487
488 #[test]
489 fn test_validate_column_name_valid() {
490 assert!(validate_column_name("id").is_ok());
491 assert!(validate_column_name("user_name").is_ok());
492 assert!(validate_column_name("address.street").is_ok());
493 assert!(validate_column_name("nested.field.value").is_ok());
494 }
495
496 #[test]
497 fn test_validate_data_type_valid() {
498 assert!(validate_data_type("INTEGER").is_ok());
499 assert!(validate_data_type("VARCHAR(255)").is_ok());
500 assert!(validate_data_type("ARRAY<STRING>").is_ok());
501 assert!(validate_data_type("DECIMAL(10, 2)").is_ok());
502 }
503
504 #[test]
505 fn test_validate_data_type_injection() {
506 assert!(matches!(
507 validate_data_type("'; DROP TABLE users;--"),
508 Err(ValidationError::InvalidCharacters { .. })
509 ));
510 }
511
512 #[test]
513 fn test_validate_uuid_valid() {
514 assert!(validate_uuid("550e8400-e29b-41d4-a716-446655440000").is_ok());
515 }
516
517 #[test]
518 fn test_validate_uuid_invalid() {
519 assert!(validate_uuid("not-a-uuid").is_err());
520 assert!(validate_uuid("").is_err());
521 }
522
523 #[test]
524 fn test_sanitize_sql_identifier() {
525 assert_eq!(sanitize_sql_identifier("users", "postgres"), "\"users\"");
526 assert_eq!(
527 sanitize_sql_identifier("user-table", "mysql"),
528 "`user-table`"
529 );
530 assert_eq!(sanitize_sql_identifier("test", "sqlserver"), "[test]");
531 }
532
533 #[test]
534 fn test_sanitize_sql_identifier_escapes_quotes() {
535 assert_eq!(
536 sanitize_sql_identifier("my\"table", "postgres"),
537 "\"my\"\"table\""
538 );
539 assert_eq!(sanitize_sql_identifier("my`table", "mysql"), "`my``table`");
540 }
541
542 #[test]
543 fn test_sanitize_description() {
544 assert_eq!(sanitize_description("Hello\nWorld"), "Hello\nWorld");
545 assert_eq!(sanitize_description("Tab\tSeparated"), "Tab\tSeparated");
546 let with_control = "Hello\x00World";
548 assert_eq!(sanitize_description(with_control), "HelloWorld");
549 }
550}