Skip to main content

mcp_postgres/
validation.rs

1use crate::errors::MCPError;
2
3const MAX_IDENTIFIER_LEN: usize = 255;
4
5pub fn validate_identifier(name: &str, label: &str) -> Result<(), MCPError> {
6    if name.is_empty() {
7        return Err(MCPError::InvalidParams(format!("'{label}' must not be empty")));
8    }
9    if name.len() > MAX_IDENTIFIER_LEN {
10        return Err(MCPError::InvalidParams(
11            format!("'{label}' exceeds maximum length of {MAX_IDENTIFIER_LEN} characters (got {})", name.len())
12        ));
13    }
14    for ch in name.chars() {
15        if !ch.is_alphanumeric() && ch != '_' {
16            return Err(MCPError::InvalidParams(
17                format!("'{label}' contains invalid character '{ch}' — only alphanumeric and underscore allowed")
18            ));
19        }
20    }
21    if name.starts_with(|c: char| c.is_ascii_digit()) {
22        return Err(MCPError::InvalidParams(
23            format!("'{label}' must not start with a digit")
24        ));
25    }
26    Ok(())
27}
28
29pub fn quote_identifier(name: &str) -> String {
30    quote_ident(name)
31}
32
33/// Quote a PostgreSQL identifier, escaping embedded double-quotes.
34/// Use this instead of duplicating `format!("\"{}\"", s.replace('"', "\"\""))` in every module.
35pub fn quote_ident(name: &str) -> String {
36    let mut out = String::with_capacity(name.len() + 2);
37    out.push('"');
38    for ch in name.chars() {
39        if ch == '"' {
40            out.push_str("\"\"");
41        } else {
42            out.push(ch);
43        }
44    }
45    out.push('"');
46    out
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn test_valid_identifier() {
55        assert!(validate_identifier("users", "table").is_ok());
56        assert!(validate_identifier("user_orders_2024", "table").is_ok());
57    }
58
59    #[test]
60    fn test_empty_identifier() {
61        let err = validate_identifier("", "table").unwrap_err();
62        assert!(err.to_string().contains("must not be empty"));
63    }
64
65    #[test]
66    fn test_too_long_identifier() {
67        let long = "a".repeat(256);
68        let err = validate_identifier(&long, "table").unwrap_err();
69        assert!(err.to_string().contains("exceeds maximum length"));
70    }
71
72    #[test]
73    fn test_invalid_char_identifier() {
74        let err = validate_identifier("users; DROP TABLE", "table").unwrap_err();
75        assert!(err.to_string().contains("invalid character"));
76    }
77
78    #[test]
79    fn test_digit_start_identifier() {
80        let err = validate_identifier("1users", "table").unwrap_err();
81        assert!(err.to_string().contains("must not start with a digit"));
82    }
83
84    #[test]
85    fn test_quote_identifier() {
86        assert_eq!(quote_identifier("users"), "\"users\"");
87        assert_eq!(quote_identifier("order_items"), "\"order_items\"");
88    }
89}