mcp_postgres/
validation.rs1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6 if name.is_empty() {
7 return Err(MCPError::InvalidParams(format!("'{label}' must not be empty")));
8 }
9 if name.len() > MAX_IDENTIFIER_LEN {
10 return Err(MCPError::InvalidParams(
11 format!("'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})", name.len())
12 ));
13 }
14 for ch in name.chars() {
15 if !ch.is_alphanumeric() && ch != '_' {
16 return Err(MCPError::InvalidParams(
17 format!("'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed")
18 ));
19 }
20 }
21 if name.starts_with(|c: char| c.is_ascii_digit()) {
22 return Err(MCPError::InvalidParams(
23 format!("'{label}' must not start with a digit")
24 ));
25 }
26 Ok(())
27}
28
29pub fn quote_identifier(name: &str) -> String {
30 quote_ident(name)
31}
32
33pub fn quote_ident(name: &str) -> String {
36 let mut out = String::with_capacity(name.len() + 2);
37 out.push('"');
38 for ch in name.chars() {
39 if ch == '"' {
40 out.push_str("\"\"");
41 } else {
42 out.push(ch);
43 }
44 }
45 out.push('"');
46 out
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn test_valid_identifier() {
55 assert!(validate_identifier("users", "table").is_ok());
56 assert!(validate_identifier("user_orders_2024", "table").is_ok());
57 }
58
59 #[test]
60 fn test_empty_identifier() {
61 let err = validate_identifier("", "table").unwrap_err();
62 assert!(err.to_string().contains("must not be empty"));
63 }
64
65 #[test]
66 fn test_too_long_identifier() {
67 let long = "a".repeat(256);
68 let err = validate_identifier(&long, "table").unwrap_err();
69 assert!(err.to_string().contains("exceeds maximum length"));
70 }
71
72 #[test]
73 fn test_invalid_char_identifier() {
74 let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
75 assert!(err.to_string().contains("invalid character"));
76 }
77
78 #[test]
79 fn test_digit_start_identifier() {
80 let err = validate_identifier("1users", "table").unwrap_err();
81 assert!(err.to_string().contains("must not start with a digit"));
82 }
83
84 #[test]
85 fn test_quote_identifier() {
86 assert_eq!(quote_identifier("users"), "\"users\"");
87 assert_eq!(quote_identifier("order_items"), "\"order_items\"");
88 }
89}