mcp_postgres/
validation.rs1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6 if name.is_empty() {
7 return Err(MCPError::InvalidParams(format!(
8 "'{label}' must not be empty"
9 )));
10 }
11 if name.len() > MAX_IDENTIFIER_LEN {
12 return Err(MCPError::InvalidParams(format!(
13 "'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})",
14 name.len()
15 )));
16 }
17 for ch in name.chars() {
18 if !ch.is_alphanumeric() && ch != '_' {
19 return Err(MCPError::InvalidParams(format!(
20 "'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed"
21 )));
22 }
23 }
24 if name.starts_with(|c: char| c.is_ascii_digit()) {
25 return Err(MCPError::InvalidParams(format!(
26 "'{label}' must not start with a digit"
27 )));
28 }
29 Ok(())
30}
31
32pub fn quote_identifier(name: &str) -> String {
33 quote_ident(name)
34}
35
36pub fn quote_ident(name: &str) -> String {
39 let mut out = String::with_capacity(name.len() + 2);
40 out.push('"');
41 for ch in name.chars() {
42 if ch == '"' {
43 out.push_str("\"\"");
44 } else {
45 out.push(ch);
46 }
47 }
48 out.push('"');
49 out
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[test]
57 fn test_valid_identifier() {
58 assert!(validate_identifier("users", "table").is_ok());
59 assert!(validate_identifier("user_orders_2024", "table").is_ok());
60 }
61
62 #[test]
63 fn test_empty_identifier() {
64 let err = validate_identifier("", "table").unwrap_err();
65 assert!(err.to_string().contains("must not be empty"));
66 }
67
68 #[test]
69 fn test_too_long_identifier() {
70 let long = "a".repeat(256);
71 let err = validate_identifier(&long, "table").unwrap_err();
72 assert!(err.to_string().contains("exceeds maximum length"));
73 }
74
75 #[test]
76 fn test_invalid_char_identifier() {
77 let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
78 assert!(err.to_string().contains("invalid character"));
79 }
80
81 #[test]
82 fn test_digit_start_identifier() {
83 let err = validate_identifier("1users", "table").unwrap_err();
84 assert!(err.to_string().contains("must not start with a digit"));
85 }
86
87 #[test]
88 fn test_quote_identifier() {
89 assert_eq!(quote_identifier("users"), "\"users\"");
90 assert_eq!(quote_identifier("order_items"), "\"order_items\"");
91 }
92}