prism_mcp_rs/core/
error.rs

1//! Error types for the MCP Rust SDK
2//!
3//! This module defines all error types that can occur within the MCP SDK,
4//! providing structured error handling with detailed context.
5
6use thiserror::Error;
7
8/// The main error type for the MCP SDK
9#[derive(Error, Debug, Clone)]
10pub enum McpError {
11    /// Transport-related errors (connection, I/O, etc.)
12    #[error("Transport error: {0}")]
13    Transport(String),
14
15    /// Protocol-level errors (invalid messages, unexpected responses, etc.)
16    #[error("Protocol error: {0}")]
17    Protocol(String),
18
19    /// JSON serialization/deserialization errors
20    #[error("Serialization error: {0}")]
21    Serialization(String),
22
23    /// Invalid URI format or content
24    #[error("Invalid URI: {0}")]
25    InvalidUri(String),
26
27    /// Requested tool was not found
28    #[error("Tool not found: {0}")]
29    ToolNotFound(String),
30
31    /// Requested resource was not found
32    #[error("Resource not found: {0}")]
33    ResourceNotFound(String),
34
35    /// Requested prompt was not found
36    #[error("Prompt not found: {0}")]
37    PromptNotFound(String),
38
39    /// Method not found (JSON-RPC error)
40    #[error("Method not found: {0}")]
41    MethodNotFound(String),
42
43    /// Invalid parameters (JSON-RPC error)
44    #[error("Invalid parameters: {0}")]
45    InvalidParams(String),
46
47    /// Connection-related errors
48    #[error("Connection error: {0}")]
49    Connection(String),
50
51    /// Authentication/authorization errors
52    #[error("Authentication error: {0}")]
53    Authentication(String),
54
55    /// OAuth 2.1 authorization errors
56    #[error("Authorization error: {0}")]
57    Auth(String),
58
59    /// Input validation errors
60    #[error("Validation error: {0}")]
61    Validation(String),
62
63    /// I/O errors from the standard library
64    #[error("I/O error: {0}")]
65    Io(String),
66
67    /// URL parsing errors
68    #[error("URL error: {0}")]
69    Url(String),
70
71    /// HTTP-related errors when using HTTP transport
72    #[cfg(feature = "http")]
73    #[error("HTTP error: {0}")]
74    Http(String),
75
76    /// WebSocket-related errors when using WebSocket transport
77    #[cfg(feature = "websocket")]
78    #[error("WebSocket error: {0}")]
79    WebSocket(String),
80
81    /// JSON Schema validation errors
82    #[error("Schema validation error: {0}")]
83    SchemaValidation(String),
84    /// Timeout errors
85    #[error("Timeout error: {0}")]
86    Timeout(String),
87
88    /// Cancellation errors
89    #[error("Operation cancelled: {0}")]
90    Cancelled(String),
91
92    /// Internal errors that shouldn't normally occur
93    #[error("Internal error: {0}")]
94    Internal(String),
95}
96
97// Manual From implementations for types that don't implement Clone
98impl From<serde_json::Error> for McpError {
99    fn from(err: serde_json::Error) -> Self {
100        McpError::Serialization(err.to_string())
101    }
102}
103
104impl From<std::io::Error> for McpError {
105    fn from(err: std::io::Error) -> Self {
106        McpError::Io(err.to_string())
107    }
108}
109
110impl From<url::ParseError> for McpError {
111    fn from(err: url::ParseError) -> Self {
112        McpError::Url(err.to_string())
113    }
114}
115
116/// Result type alias for MCP operations
117pub type McpResult<T> = Result<T, McpError>;
118
119impl McpError {
120    /// Create a new transport error
121    pub fn transport<S: Into<String>>(message: S) -> Self {
122        Self::Transport(message.into())
123    }
124
125    /// Create a new protocol error
126    pub fn protocol<S: Into<String>>(message: S) -> Self {
127        Self::Protocol(message.into())
128    }
129
130    /// Create a new validation error
131    pub fn validation<S: Into<String>>(message: S) -> Self {
132        Self::Validation(message.into())
133    }
134
135    /// Create a new connection error
136    pub fn connection<S: Into<String>>(message: S) -> Self {
137        Self::Connection(message.into())
138    }
139
140    /// Create a new internal error
141    pub fn internal<S: Into<String>>(message: S) -> Self {
142        Self::Internal(message.into())
143    }
144
145    /// Create a new IO error from std::io::Error
146    pub fn io(err: std::io::Error) -> Self {
147        Self::Io(err.to_string())
148    }
149
150    /// Create a new serialization error from serde_json::Error
151    pub fn serialization(err: serde_json::Error) -> Self {
152        Self::Serialization(err.to_string())
153    }
154
155    /// Create a new timeout error
156    pub fn timeout<S: Into<String>>(message: S) -> Self {
157        Self::Timeout(message.into())
158    }
159
160    /// Create a connection error (compatibility method)
161    pub fn connection_error<S: Into<String>>(message: S) -> Self {
162        Self::Connection(message.into())
163    }
164
165    /// Create a protocol error (compatibility method)
166    pub fn protocol_error<S: Into<String>>(message: S) -> Self {
167        Self::Protocol(message.into())
168    }
169
170    /// Create a validation error (compatibility method)
171    pub fn validation_error<S: Into<String>>(message: S) -> Self {
172        Self::Validation(message.into())
173    }
174
175    /// Create a timeout error (compatibility method)
176    pub fn timeout_error() -> Self {
177        Self::Timeout("Operation timed out".to_string())
178    }
179
180    /// Check if this error is recoverable
181    pub fn is_recoverable(&self) -> bool {
182        match self {
183            McpError::Transport(_) => false,
184            McpError::Protocol(_) => false,
185            McpError::Connection(_) => true,
186            McpError::Timeout(_) => true,
187            McpError::Validation(_) => false,
188            McpError::ToolNotFound(_) => false,
189            McpError::ResourceNotFound(_) => false,
190            McpError::PromptNotFound(_) => false,
191            McpError::MethodNotFound(_) => false,
192            McpError::InvalidParams(_) => false,
193            McpError::Authentication(_) => false,
194            McpError::Serialization(_) => false,
195            McpError::InvalidUri(_) => false,
196            McpError::Io(_) => true,
197            McpError::Url(_) => false,
198            #[cfg(feature = "http")]
199            McpError::Http(_) => true,
200            #[cfg(feature = "websocket")]
201            McpError::WebSocket(_) => true,
202            McpError::SchemaValidation(_) => false,
203            McpError::Cancelled(_) => false,
204            McpError::Auth(_) => false,
205            McpError::Internal(_) => false,
206        }
207    }
208
209    /// Get the error category for logging/metrics
210    pub fn category(&self) -> &'static str {
211        match self {
212            McpError::Transport(_) => "transport",
213            McpError::Protocol(_) => "protocol",
214            McpError::Connection(_) => "connection",
215            McpError::Timeout(_) => "timeout",
216            McpError::Validation(_) => "validation",
217            McpError::ToolNotFound(_) => "not_found",
218            McpError::ResourceNotFound(_) => "not_found",
219            McpError::PromptNotFound(_) => "not_found",
220            McpError::MethodNotFound(_) => "not_found",
221            McpError::InvalidParams(_) => "validation",
222            McpError::Authentication(_) => "auth",
223            McpError::Serialization(_) => "serialization",
224            McpError::InvalidUri(_) => "validation",
225            McpError::Io(_) => "io",
226            McpError::Url(_) => "validation",
227            #[cfg(feature = "http")]
228            McpError::Http(_) => "http",
229            #[cfg(feature = "websocket")]
230            McpError::WebSocket(_) => "websocket",
231            McpError::SchemaValidation(_) => "validation",
232            McpError::Cancelled(_) => "cancelled",
233            McpError::Auth(_) => "auth",
234            McpError::Internal(_) => "internal",
235        }
236    }
237}
238
239// Convert common HTTP errors when the feature is enabled
240#[cfg(feature = "http")]
241impl From<reqwest::Error> for McpError {
242    fn from(err: reqwest::Error) -> Self {
243        McpError::Http(err.to_string())
244    }
245}
246
247// Convert common WebSocket errors when the feature is enabled
248#[cfg(feature = "websocket")]
249impl From<tokio_tungstenite::tungstenite::Error> for McpError {
250    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
251        McpError::WebSocket(err.to_string())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_error_creation() {
261        let error = McpError::transport("Connection failed");
262        assert_eq!(error.to_string(), "Transport error: Connection failed");
263        assert_eq!(error.category(), "transport");
264        assert!(!error.is_recoverable());
265    }
266
267    #[test]
268    fn test_error_recovery() {
269        assert!(McpError::connection("timeout").is_recoverable());
270        assert!(!McpError::validation("invalid input").is_recoverable());
271        assert!(McpError::timeout("request timeout").is_recoverable());
272    }
273
274    #[test]
275    fn test_error_categories() {
276        assert_eq!(McpError::protocol("bad message").category(), "protocol");
277        assert_eq!(
278            McpError::ToolNotFound("missing".to_string()).category(),
279            "not_found"
280        );
281        assert_eq!(
282            McpError::Authentication("unauthorized".to_string()).category(),
283            "auth"
284        );
285    }
286}