1use serde::{Deserialize, Serialize};
16use thiserror::Error;
17use uuid::Uuid;
18
19pub const MAX_TABLE_NAME_LENGTH: usize = 255;
21
22pub const MAX_COLUMN_NAME_LENGTH: usize = 255;
24
25pub const MAX_IDENTIFIER_LENGTH: usize = 255;
27
28pub const MAX_DESCRIPTION_LENGTH: usize = 10000;
30
31#[derive(Debug, Clone, Error, Serialize, Deserialize)]
33pub enum ValidationError {
34 #[error("{0} cannot be empty")]
36 Empty(&'static str),
37
38 #[error("{field} exceeds maximum length (max: {max}, got: {actual})")]
40 TooLong {
41 field: &'static str,
42 max: usize,
43 actual: usize,
44 },
45
46 #[error("{field} contains invalid characters: {reason}")]
48 InvalidCharacters { field: &'static str, reason: String },
49
50 #[error("{0}: {1}")]
52 InvalidFormat(&'static str, String),
53
54 #[error("{field} cannot be a reserved word: {word}")]
56 ReservedWord { field: &'static str, word: String },
57}
58
59pub type ValidationResult<T> = Result<T, ValidationError>;
61
62pub fn validate_table_name(name: &str) -> ValidationResult<()> {
83 if name.is_empty() {
84 return Err(ValidationError::Empty("table name"));
85 }
86
87 if name.len() > MAX_TABLE_NAME_LENGTH {
88 return Err(ValidationError::TooLong {
89 field: "table name",
90 max: MAX_TABLE_NAME_LENGTH,
91 actual: name.len(),
92 });
93 }
94
95 let first_char = name.chars().next().unwrap();
97 if !first_char.is_alphabetic() && first_char != '_' {
98 return Err(ValidationError::InvalidFormat(
99 "table name",
100 "must start with a letter or underscore".to_string(),
101 ));
102 }
103
104 for c in name.chars() {
106 if !c.is_alphanumeric() && c != '_' && c != '-' {
107 return Err(ValidationError::InvalidCharacters {
108 field: "table name",
109 reason: format!("invalid character: '{}'", c),
110 });
111 }
112 }
113
114 if is_sql_reserved_word(name) {
116 return Err(ValidationError::ReservedWord {
117 field: "table name",
118 word: name.to_string(),
119 });
120 }
121
122 Ok(())
123}
124
125pub fn validate_column_name(name: &str) -> ValidationResult<()> {
146 if name.is_empty() {
147 return Err(ValidationError::Empty("column name"));
148 }
149
150 if name.len() > MAX_COLUMN_NAME_LENGTH {
151 return Err(ValidationError::TooLong {
152 field: "column name",
153 max: MAX_COLUMN_NAME_LENGTH,
154 actual: name.len(),
155 });
156 }
157
158 let first_char = name.chars().next().unwrap();
160 if !first_char.is_alphabetic() && first_char != '_' {
161 return Err(ValidationError::InvalidFormat(
162 "column name",
163 "must start with a letter or underscore".to_string(),
164 ));
165 }
166
167 for c in name.chars() {
169 if !c.is_alphanumeric() && c != '_' && c != '-' && c != '.' {
170 return Err(ValidationError::InvalidCharacters {
171 field: "column name",
172 reason: format!("invalid character: '{}'", c),
173 });
174 }
175 }
176
177 if !name.contains('.') && is_sql_reserved_word(name) {
179 return Err(ValidationError::ReservedWord {
180 field: "column name",
181 word: name.to_string(),
182 });
183 }
184
185 Ok(())
186}
187
188pub fn validate_uuid(id: &str) -> ValidationResult<Uuid> {
199 Uuid::parse_str(id)
200 .map_err(|e| ValidationError::InvalidFormat("UUID", format!("invalid UUID format: {}", e)))
201}
202
203pub fn validate_data_type(data_type: &str) -> ValidationResult<()> {
222 if data_type.is_empty() {
223 return Err(ValidationError::Empty("data type"));
224 }
225
226 if data_type.len() > MAX_IDENTIFIER_LENGTH {
227 return Err(ValidationError::TooLong {
228 field: "data type",
229 max: MAX_IDENTIFIER_LENGTH,
230 actual: data_type.len(),
231 });
232 }
233
234 let lower = data_type.to_lowercase();
236 if lower.contains(';') || lower.contains("--") || lower.contains("/*") {
237 return Err(ValidationError::InvalidCharacters {
238 field: "data type",
239 reason: "contains SQL comment or statement separator".to_string(),
240 });
241 }
242
243 for c in data_type.chars() {
245 if !c.is_alphanumeric()
246 && c != '('
247 && c != ')'
248 && c != ','
249 && c != ' '
250 && c != '_'
251 && c != '<'
252 && c != '>'
253 && c != '['
254 && c != ']'
255 {
256 return Err(ValidationError::InvalidCharacters {
257 field: "data type",
258 reason: format!("invalid character: '{}'", c),
259 });
260 }
261 }
262
263 Ok(())
264}
265
266pub fn validate_description(desc: &str) -> ValidationResult<()> {
274 if desc.len() > MAX_DESCRIPTION_LENGTH {
275 return Err(ValidationError::TooLong {
276 field: "description",
277 max: MAX_DESCRIPTION_LENGTH,
278 actual: desc.len(),
279 });
280 }
281
282 Ok(())
283}
284
285pub fn sanitize_sql_identifier(name: &str, dialect: &str) -> String {
299 let quote_char = match dialect.to_lowercase().as_str() {
300 "mysql" | "mariadb" => '`',
301 "sqlserver" | "mssql" => '[',
302 _ => '"', };
304
305 let end_char = if quote_char == '[' { ']' } else { quote_char };
306
307 let escaped = if quote_char == end_char {
309 name.replace(quote_char, &format!("{}{}", quote_char, quote_char))
310 } else {
311 name.replace(end_char, &format!("{}{}", end_char, end_char))
312 };
313
314 format!("{}{}{}", quote_char, escaped, end_char)
315}
316
317pub fn sanitize_description(desc: &str) -> String {
321 desc.chars()
323 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
324 .collect()
325}
326
327fn is_sql_reserved_word(word: &str) -> bool {
331 const RESERVED_WORDS: &[&str] = &[
332 "select",
333 "from",
334 "where",
335 "insert",
336 "update",
337 "delete",
338 "create",
339 "drop",
340 "alter",
341 "table",
342 "index",
343 "view",
344 "database",
345 "schema",
346 "grant",
347 "revoke",
348 "commit",
349 "rollback",
350 "begin",
351 "end",
352 "transaction",
353 "primary",
354 "foreign",
355 "key",
356 "references",
357 "constraint",
358 "unique",
359 "check",
360 "default",
361 "not",
362 "null",
363 "and",
364 "or",
365 "in",
366 "between",
367 "like",
368 "is",
369 "case",
370 "when",
371 "then",
372 "else",
373 "as",
374 "on",
375 "join",
376 "inner",
377 "outer",
378 "left",
379 "right",
380 "full",
381 "cross",
382 "natural",
383 "using",
384 "group",
385 "by",
386 "having",
387 "order",
388 "asc",
389 "desc",
390 "limit",
391 "offset",
392 "union",
393 "intersect",
394 "except",
395 "all",
396 "distinct",
397 "top",
398 "values",
399 "set",
400 "into",
401 "exec",
402 "execute",
403 "procedure",
404 "function",
405 "trigger",
406 "true",
407 "false",
408 "int",
409 "integer",
410 "varchar",
411 "char",
412 "text",
413 "boolean",
414 "date",
415 "time",
416 "timestamp",
417 "float",
418 "double",
419 "decimal",
420 "numeric",
421 ];
422
423 let lower = word.to_lowercase();
424 RESERVED_WORDS.contains(&lower.as_str())
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_validate_table_name_valid() {
433 assert!(validate_table_name("users").is_ok());
434 assert!(validate_table_name("user_orders").is_ok());
435 assert!(validate_table_name("User123").is_ok());
436 assert!(validate_table_name("_private").is_ok());
437 assert!(validate_table_name("my-table").is_ok());
438 }
439
440 #[test]
441 fn test_validate_table_name_empty() {
442 assert!(matches!(
443 validate_table_name(""),
444 Err(ValidationError::Empty(_))
445 ));
446 }
447
448 #[test]
449 fn test_validate_table_name_too_long() {
450 let long_name = "a".repeat(300);
451 assert!(matches!(
452 validate_table_name(&long_name),
453 Err(ValidationError::TooLong { .. })
454 ));
455 }
456
457 #[test]
458 fn test_validate_table_name_starts_with_digit() {
459 assert!(matches!(
460 validate_table_name("123users"),
461 Err(ValidationError::InvalidFormat(..))
462 ));
463 }
464
465 #[test]
466 fn test_validate_table_name_invalid_chars() {
467 assert!(matches!(
468 validate_table_name("user$table"),
469 Err(ValidationError::InvalidCharacters { .. })
470 ));
471 assert!(matches!(
472 validate_table_name("user;table"),
473 Err(ValidationError::InvalidCharacters { .. })
474 ));
475 }
476
477 #[test]
478 fn test_validate_table_name_reserved_word() {
479 assert!(matches!(
480 validate_table_name("SELECT"),
481 Err(ValidationError::ReservedWord { .. })
482 ));
483 assert!(matches!(
484 validate_table_name("table"),
485 Err(ValidationError::ReservedWord { .. })
486 ));
487 }
488
489 #[test]
490 fn test_validate_column_name_valid() {
491 assert!(validate_column_name("id").is_ok());
492 assert!(validate_column_name("user_name").is_ok());
493 assert!(validate_column_name("address.street").is_ok());
494 assert!(validate_column_name("nested.field.value").is_ok());
495 }
496
497 #[test]
498 fn test_validate_data_type_valid() {
499 assert!(validate_data_type("INTEGER").is_ok());
500 assert!(validate_data_type("VARCHAR(255)").is_ok());
501 assert!(validate_data_type("ARRAY<STRING>").is_ok());
502 assert!(validate_data_type("DECIMAL(10, 2)").is_ok());
503 }
504
505 #[test]
506 fn test_validate_data_type_injection() {
507 assert!(matches!(
508 validate_data_type("'; DROP TABLE users;--"),
509 Err(ValidationError::InvalidCharacters { .. })
510 ));
511 }
512
513 #[test]
514 fn test_validate_uuid_valid() {
515 assert!(validate_uuid("550e8400-e29b-41d4-a716-446655440000").is_ok());
516 }
517
518 #[test]
519 fn test_validate_uuid_invalid() {
520 assert!(validate_uuid("not-a-uuid").is_err());
521 assert!(validate_uuid("").is_err());
522 }
523
524 #[test]
525 fn test_sanitize_sql_identifier() {
526 assert_eq!(sanitize_sql_identifier("users", "postgres"), "\"users\"");
527 assert_eq!(
528 sanitize_sql_identifier("user-table", "mysql"),
529 "`user-table`"
530 );
531 assert_eq!(sanitize_sql_identifier("test", "sqlserver"), "[test]");
532 }
533
534 #[test]
535 fn test_sanitize_sql_identifier_escapes_quotes() {
536 assert_eq!(
537 sanitize_sql_identifier("my\"table", "postgres"),
538 "\"my\"\"table\""
539 );
540 assert_eq!(sanitize_sql_identifier("my`table", "mysql"), "`my``table`");
541 }
542
543 #[test]
544 fn test_sanitize_description() {
545 assert_eq!(sanitize_description("Hello\nWorld"), "Hello\nWorld");
546 assert_eq!(sanitize_description("Tab\tSeparated"), "Tab\tSeparated");
547 let with_control = "Hello\x00World";
549 assert_eq!(sanitize_description(with_control), "HelloWorld");
550 }
551}