mcp_postgres/
validation.rs1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6 if name.is_empty() {
7 return Err(MCPError::InvalidParams(format!(
8 "'{label}' must not be empty"
9 )));
10 }
11 if name.len() > MAX_IDENTIFIER_LEN {
12 return Err(MCPError::InvalidParams(format!(
13 "'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})",
14 name.len()
15 )));
16 }
17 for ch in name.chars() {
18 if !ch.is_alphanumeric() && ch != '_' {
19 return Err(MCPError::InvalidParams(format!(
20 "'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed"
21 )));
22 }
23 }
24 if name.starts_with(|c: char| c.is_ascii_digit()) {
25 return Err(MCPError::InvalidParams(format!(
26 "'{label}' must not start with a digit"
27 )));
28 }
29 Ok(())
30}
31
32pub fn quote_identifier(name: &str) -> String {
33 quote_ident(name)
34}
35
36const VALID_PRIVILEGES: &[&str] = &[
38 "SELECT",
39 "INSERT",
40 "UPDATE",
41 "DELETE",
42 "TRUNCATE",
43 "REFERENCES",
44 "TRIGGER",
45 "CREATE",
46 "CONNECT",
47 "TEMPORARY",
48 "TEMP",
49 "EXECUTE",
50 "USAGE",
51 "MAINTAIN",
52 "ALL",
53 "ALL PRIVILEGES",
54];
55
56pub fn validate_privilege_list(privilege: &str) -> Result<(), MCPError> {
61 let trimmed = privilege.trim();
62 if trimmed.is_empty() {
63 return Err(MCPError::InvalidParams(
64 "'privilege' must not be empty".into(),
65 ));
66 }
67 for part in trimmed.split(',') {
68 let token = part.trim().to_ascii_uppercase();
69 if !VALID_PRIVILEGES.contains(&token.as_str()) {
70 return Err(MCPError::InvalidParams(format!(
71 "Invalid privilege '{}'. Allowed: {}",
72 part.trim(),
73 VALID_PRIVILEGES.join(", ")
74 )));
75 }
76 }
77 Ok(())
78}
79
80pub fn quote_ident(name: &str) -> String {
83 let mut out = String::with_capacity(name.len() + 2);
84 out.push('"');
85 for ch in name.chars() {
86 if ch == '"' {
87 out.push_str("\"\"");
88 } else {
89 out.push(ch);
90 }
91 }
92 out.push('"');
93 out
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_valid_identifier() {
102 assert!(validate_identifier("users", "table").is_ok());
103 assert!(validate_identifier("user_orders_2024", "table").is_ok());
104 }
105
106 #[test]
107 fn test_empty_identifier() {
108 let err = validate_identifier("", "table").unwrap_err();
109 assert!(err.to_string().contains("must not be empty"));
110 }
111
112 #[test]
113 fn test_too_long_identifier() {
114 let long = "a".repeat(256);
115 let err = validate_identifier(&long, "table").unwrap_err();
116 assert!(err.to_string().contains("exceeds maximum length"));
117 }
118
119 #[test]
120 fn test_invalid_char_identifier() {
121 let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
122 assert!(err.to_string().contains("invalid character"));
123 }
124
125 #[test]
126 fn test_digit_start_identifier() {
127 let err = validate_identifier("1users", "table").unwrap_err();
128 assert!(err.to_string().contains("must not start with a digit"));
129 }
130
131 #[test]
132 fn test_quote_identifier() {
133 assert_eq!(quote_identifier("users"), "\"users\"");
134 assert_eq!(quote_identifier("order_items"), "\"order_items\"");
135 }
136
137 #[test]
138 fn test_validate_privilege_list_valid() {
139 assert!(validate_privilege_list("SELECT").is_ok());
140 assert!(validate_privilege_list("select, insert").is_ok());
141 assert!(validate_privilege_list("ALL PRIVILEGES").is_ok());
142 assert!(validate_privilege_list("SELECT, UPDATE, DELETE").is_ok());
143 }
144
145 #[test]
146 fn test_validate_privilege_list_rejects_injection() {
147 let err = validate_privilege_list("SELECT ON pg_authid TO attacker; --").unwrap_err();
148 assert!(err.to_string().contains("Invalid privilege"));
149 assert!(validate_privilege_list("").is_err());
150 assert!(validate_privilege_list("DROP").is_err());
151 }
152}