mcpkit_rocket/
error.rs

1//! Error types for Rocket MCP integration.
2
3use rocket::Request;
4use rocket::http::Status;
5use rocket::response::{self, Responder};
6use thiserror::Error;
7
8/// Errors that can occur during MCP request handling.
9#[derive(Debug, Error)]
10pub enum RocketError {
11    /// Invalid JSON-RPC message format.
12    #[error("Invalid message: {0}")]
13    InvalidMessage(String),
14
15    /// Unsupported protocol version.
16    #[error("Unsupported protocol version: {0}")]
17    UnsupportedVersion(String),
18
19    /// Session not found.
20    #[error("Session not found: {0}")]
21    SessionNotFound(String),
22
23    /// JSON serialization error.
24    #[error("Serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26
27    /// Internal server error.
28    #[error("Internal error: {0}")]
29    Internal(String),
30}
31
32impl<'r> Responder<'r, 'static> for RocketError {
33    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
34        let status = match &self {
35            RocketError::InvalidMessage(_) => Status::BadRequest,
36            RocketError::UnsupportedVersion(_) => Status::BadRequest,
37            RocketError::SessionNotFound(_) => Status::NotFound,
38            RocketError::Serialization(_) => Status::InternalServerError,
39            RocketError::Internal(_) => Status::InternalServerError,
40        };
41
42        Err(status)
43    }
44}
45
46impl RocketError {
47    /// Get the HTTP status code for this error.
48    #[must_use]
49    pub fn status(&self) -> Status {
50        match self {
51            RocketError::InvalidMessage(_) => Status::BadRequest,
52            RocketError::UnsupportedVersion(_) => Status::BadRequest,
53            RocketError::SessionNotFound(_) => Status::NotFound,
54            RocketError::Serialization(_) => Status::InternalServerError,
55            RocketError::Internal(_) => Status::InternalServerError,
56        }
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn test_invalid_message_error() {
66        let error = RocketError::InvalidMessage("bad json".to_string());
67        assert_eq!(error.to_string(), "Invalid message: bad json");
68        assert_eq!(error.status(), Status::BadRequest);
69    }
70
71    #[test]
72    fn test_unsupported_version_error() {
73        let error = RocketError::UnsupportedVersion("1.0.0".to_string());
74        assert_eq!(error.to_string(), "Unsupported protocol version: 1.0.0");
75        assert_eq!(error.status(), Status::BadRequest);
76    }
77
78    #[test]
79    fn test_session_not_found_error() {
80        let error = RocketError::SessionNotFound("abc-123".to_string());
81        assert_eq!(error.to_string(), "Session not found: abc-123");
82        assert_eq!(error.status(), Status::NotFound);
83    }
84
85    #[test]
86    fn test_serialization_error() {
87        let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
88        let error = RocketError::Serialization(json_err);
89        assert!(error.to_string().starts_with("Serialization error:"));
90        assert_eq!(error.status(), Status::InternalServerError);
91    }
92
93    #[test]
94    fn test_internal_error() {
95        let error = RocketError::Internal("something went wrong".to_string());
96        assert_eq!(error.to_string(), "Internal error: something went wrong");
97        assert_eq!(error.status(), Status::InternalServerError);
98    }
99
100    #[test]
101    fn test_from_serde_json_error() {
102        let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
103        let error: RocketError = json_err.into();
104        assert!(matches!(error, RocketError::Serialization(_)));
105    }
106
107    #[test]
108    fn test_error_is_send_sync() {
109        fn assert_send_sync<T: Send + Sync>() {}
110        assert_send_sync::<RocketError>();
111    }
112
113    #[test]
114    fn test_error_debug_format() {
115        let error = RocketError::InvalidMessage("test".to_string());
116        let debug = format!("{error:?}");
117        assert!(debug.contains("InvalidMessage"));
118        assert!(debug.contains("test"));
119    }
120}