Skip to main content

mcp_postgres/
validation.rs

1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6    if name.is_empty() {
7        return Err(MCPError::InvalidParams(format!(
8            "'{label}' must not be empty"
9        )));
10    }
11    if name.len() > MAX_IDENTIFIER_LEN {
12        return Err(MCPError::InvalidParams(format!(
13            "'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})",
14            name.len()
15        )));
16    }
17    for ch in name.chars() {
18        if !ch.is_alphanumeric() && ch != '_' {
19            return Err(MCPError::InvalidParams(format!(
20                "'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed"
21            )));
22        }
23    }
24    if name.starts_with(|c: char| c.is_ascii_digit()) {
25        return Err(MCPError::InvalidParams(format!(
26            "'{label}' must not start with a digit"
27        )));
28    }
29    Ok(())
30}
31
32pub fn quote_identifier(name: &str) -> String {
33    quote_ident(name)
34}
35
36/// PostgreSQL privilege keywords accepted by GRANT/REVOKE.
37const VALID_PRIVILEGES: &[&str] = &[
38    "SELECT",
39    "INSERT",
40    "UPDATE",
41    "DELETE",
42    "TRUNCATE",
43    "REFERENCES",
44    "TRIGGER",
45    "CREATE",
46    "CONNECT",
47    "TEMPORARY",
48    "TEMP",
49    "EXECUTE",
50    "USAGE",
51    "MAINTAIN",
52    "ALL",
53    "ALL PRIVILEGES",
54];
55
56/// Validate a privilege specification for GRANT/REVOKE. Accepts a
57/// comma-separated list of known privilege keywords (case-insensitive), e.g.
58/// `"SELECT, INSERT"` or `"ALL PRIVILEGES"`. Rejects anything else so the
59/// value can be safely interpolated into the statement.
60pub fn validate_privilege_list(privilege: &str) -> Result<(), MCPError> {
61    let trimmed = privilege.trim();
62    if trimmed.is_empty() {
63        return Err(MCPError::InvalidParams(
64            "'privilege' must not be empty".into(),
65        ));
66    }
67    for part in trimmed.split(',') {
68        let token = part.trim().to_ascii_uppercase();
69        if !VALID_PRIVILEGES.contains(&token.as_str()) {
70            return Err(MCPError::InvalidParams(format!(
71                "Invalid privilege '{}'. Allowed: {}",
72                part.trim(),
73                VALID_PRIVILEGES.join(", ")
74            )));
75        }
76    }
77    Ok(())
78}
79
80/// Quote a PostgreSQL identifier, escaping embedded double-quotes.
81/// Use this instead of duplicating `format!("\"{}\"", s.replace('"', "\"\""))` in every module.
82pub fn quote_ident(name: &str) -> String {
83    let mut out = String::with_capacity(name.len() + 2);
84    out.push('"');
85    for ch in name.chars() {
86        if ch == '"' {
87            out.push_str("\"\"");
88        } else {
89            out.push(ch);
90        }
91    }
92    out.push('"');
93    out
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_valid_identifier() {
102        assert!(validate_identifier("users", "table").is_ok());
103        assert!(validate_identifier("user_orders_2024", "table").is_ok());
104    }
105
106    #[test]
107    fn test_empty_identifier() {
108        let err = validate_identifier("", "table").unwrap_err();
109        assert!(err.to_string().contains("must not be empty"));
110    }
111
112    #[test]
113    fn test_too_long_identifier() {
114        let long = "a".repeat(256);
115        let err = validate_identifier(&long, "table").unwrap_err();
116        assert!(err.to_string().contains("exceeds maximum length"));
117    }
118
119    #[test]
120    fn test_invalid_char_identifier() {
121        let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
122        assert!(err.to_string().contains("invalid character"));
123    }
124
125    #[test]
126    fn test_digit_start_identifier() {
127        let err = validate_identifier("1users", "table").unwrap_err();
128        assert!(err.to_string().contains("must not start with a digit"));
129    }
130
131    #[test]
132    fn test_quote_identifier() {
133        assert_eq!(quote_identifier("users"), "\"users\"");
134        assert_eq!(quote_identifier("order_items"), "\"order_items\"");
135    }
136
137    #[test]
138    fn test_validate_privilege_list_valid() {
139        assert!(validate_privilege_list("SELECT").is_ok());
140        assert!(validate_privilege_list("select, insert").is_ok());
141        assert!(validate_privilege_list("ALL PRIVILEGES").is_ok());
142        assert!(validate_privilege_list("SELECT, UPDATE, DELETE").is_ok());
143    }
144
145    #[test]
146    fn test_validate_privilege_list_rejects_injection() {
147        let err = validate_privilege_list("SELECT ON pg_authid TO attacker; --").unwrap_err();
148        assert!(err.to_string().contains("Invalid privilege"));
149        assert!(validate_privilege_list("").is_err());
150        assert!(validate_privilege_list("DROP").is_err());
151    }
152}