Skip to main content

mcp_postgres/
validation.rs

1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6    if name.is_empty() {
7        return Err(MCPError::InvalidParams(format!(
8            "'{label}' must not be empty"
9        )));
10    }
11    if name.len() > MAX_IDENTIFIER_LEN {
12        return Err(MCPError::InvalidParams(format!(
13            "'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})",
14            name.len()
15        )));
16    }
17    for ch in name.chars() {
18        if !ch.is_alphanumeric() && ch != '_' {
19            return Err(MCPError::InvalidParams(format!(
20                "'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed"
21            )));
22        }
23    }
24    if name.starts_with(|c: char| c.is_ascii_digit()) {
25        return Err(MCPError::InvalidParams(format!(
26            "'{label}' must not start with a digit"
27        )));
28    }
29    Ok(())
30}
31
32pub fn quote_identifier(name: &str) -> String {
33    quote_ident(name)
34}
35
36/// Quote a PostgreSQL identifier, escaping embedded double-quotes.
37/// Use this instead of duplicating `format!("\"{}\"", s.replace('"', "\"\""))` in every module.
38pub fn quote_ident(name: &str) -> String {
39    let mut out = String::with_capacity(name.len() + 2);
40    out.push('"');
41    for ch in name.chars() {
42        if ch == '"' {
43            out.push_str("\"\"");
44        } else {
45            out.push(ch);
46        }
47    }
48    out.push('"');
49    out
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[test]
57    fn test_valid_identifier() {
58        assert!(validate_identifier("users", "table").is_ok());
59        assert!(validate_identifier("user_orders_2024", "table").is_ok());
60    }
61
62    #[test]
63    fn test_empty_identifier() {
64        let err = validate_identifier("", "table").unwrap_err();
65        assert!(err.to_string().contains("must not be empty"));
66    }
67
68    #[test]
69    fn test_too_long_identifier() {
70        let long = "a".repeat(256);
71        let err = validate_identifier(&long, "table").unwrap_err();
72        assert!(err.to_string().contains("exceeds maximum length"));
73    }
74
75    #[test]
76    fn test_invalid_char_identifier() {
77        let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
78        assert!(err.to_string().contains("invalid character"));
79    }
80
81    #[test]
82    fn test_digit_start_identifier() {
83        let err = validate_identifier("1users", "table").unwrap_err();
84        assert!(err.to_string().contains("must not start with a digit"));
85    }
86
87    #[test]
88    fn test_quote_identifier() {
89        assert_eq!(quote_identifier("users"), "\"users\"");
90        assert_eq!(quote_identifier("order_items"), "\"order_items\"");
91    }
92}