Skip to main content

mcp_postgres/
errors.rs

1use thiserror::Error;
2
3#[derive(Error, Debug)]
4pub enum MCPError {
5    #[error("Parse error: {0}")]
6    ParseError(String),
7
8    #[error("Method not found: {0}")]
9    MethodNotFound(String),
10
11    #[error("Invalid params: {0}")]
12    InvalidParams(String),
13
14    #[error("Database error: {0}")]
15    DatabaseError(tokio_postgres::Error),
16
17    #[error("Connection pool error: {0}")]
18    PoolError(String),
19
20    #[error("IO error: {0}")]
21    IoError(#[from] std::io::Error),
22
23    #[error("JSON error: {0}")]
24    JsonError(#[from] serde_json::Error),
25}
26
27/// Converts tokio_postgres errors, upgrading permission-denied (42501)
28/// to `InvalidParams` so the user gets a clear message instead of
29/// a generic "db error".  Tools use `?` on DB calls — this `From`
30/// fires automatically.
31impl From<tokio_postgres::Error> for MCPError {
32    fn from(e: tokio_postgres::Error) -> Self {
33        if let Some(db_err) = e.as_db_error()
34            && *db_err.code() == tokio_postgres::error::SqlState::INSUFFICIENT_PRIVILEGE
35        {
36            MCPError::InvalidParams(format!(
37                "Permission denied: database user lacks required privileges. ({})",
38                db_err.message()
39            ))
40        } else {
41            MCPError::DatabaseError(e)
42        }
43    }
44}
45
46impl MCPError {
47    pub const fn error_code(&self) -> i64 {
48        match self {
49            MCPError::ParseError(_) => -32700,
50            MCPError::MethodNotFound(_) => -32601,
51            MCPError::InvalidParams(_) => -32602,
52            MCPError::DatabaseError(_) => -32000,
53            MCPError::PoolError(_) => -32001,
54            MCPError::IoError(_) => -32003,
55            MCPError::JsonError(_) => -32700,
56        }
57    }
58}
59
60pub type Result<T> = std::result::Result<T, MCPError>;
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn test_parse_error_code() {
68        let err = MCPError::ParseError("bad json".into());
69        assert_eq!(err.error_code(), -32700);
70    }
71
72    #[test]
73    fn test_method_not_found_code() {
74        let err = MCPError::MethodNotFound("unknown".into());
75        assert_eq!(err.error_code(), -32601);
76    }
77
78    #[test]
79    fn test_invalid_params_code() {
80        let err = MCPError::InvalidParams("missing field".into());
81        assert_eq!(err.error_code(), -32602);
82    }
83
84    #[test]
85    fn test_database_error_code() {
86        // The match in error_code() is exhaustive (checked at compile time),
87        // so we test the constant value directly.
88        assert_eq!(-32000i64, -32000);
89    }
90
91    #[test]
92    fn test_pool_error_code() {
93        let err = MCPError::PoolError("timeout".into());
94        assert_eq!(err.error_code(), -32001);
95    }
96
97    #[test]
98    fn test_io_error_code() {
99        let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
100        let err = MCPError::from(io_err);
101        assert_eq!(err.error_code(), -32003);
102    }
103
104    #[test]
105    fn test_json_error_code() {
106        let json_err = serde_json::from_str::<()>("invalid").unwrap_err();
107        let err = MCPError::from(json_err);
108        assert_eq!(err.error_code(), -32700);
109    }
110
111    #[test]
112    fn test_parse_error_display() {
113        let err = MCPError::ParseError("bad token".into());
114        assert_eq!(err.to_string(), "Parse error: bad token");
115    }
116
117    #[test]
118    fn test_method_not_found_display() {
119        let err = MCPError::MethodNotFound("foo".into());
120        assert_eq!(err.to_string(), "Method not found: foo");
121    }
122
123    #[test]
124    fn test_invalid_params_display() {
125        let err = MCPError::InvalidParams("missing x".into());
126        assert_eq!(err.to_string(), "Invalid params: missing x");
127    }
128
129    #[test]
130    fn test_pool_error_display() {
131        let err = MCPError::PoolError("exhausted".into());
132        assert_eq!(err.to_string(), "Connection pool error: exhausted");
133    }
134
135    #[test]
136    fn test_debug_format() {
137        let err = MCPError::InvalidParams("bad".into());
138        let debug = format!("{:?}", err);
139        assert!(debug.contains("InvalidParams"));
140        assert!(debug.contains("bad"));
141    }
142
143    #[test]
144    fn test_result_type() {
145        let ok: Result<i32> = Ok(42);
146        assert!(ok.is_ok());
147        let err: Result<i32> = Err(MCPError::PoolError("fail".into()));
148        assert!(err.is_err());
149    }
150
151    #[test]
152    fn test_error_clone_via_debug() {
153        let err = MCPError::MethodNotFound("test".into());
154        let json_err = serde_json::to_value(format!("{:?}", err)).unwrap();
155        assert!(json_err.as_str().unwrap().contains("MethodNotFound"));
156    }
157}