mcp_protocol_sdk/core/
error.rs

1//! Error types for the MCP Rust SDK
2//!
3//! This module defines all error types that can occur within the MCP SDK,
4//! providing structured error handling with detailed context.
5
6use thiserror::Error;
7
8/// The main error type for the MCP SDK
9#[derive(Error, Debug, Clone)]
10pub enum McpError {
11    /// Transport-related errors (connection, I/O, etc.)
12    #[error("Transport error: {0}")]
13    Transport(String),
14
15    /// Protocol-level errors (invalid messages, unexpected responses, etc.)
16    #[error("Protocol error: {0}")]
17    Protocol(String),
18
19    /// JSON serialization/deserialization errors
20    #[error("Serialization error: {0}")]
21    Serialization(String),
22
23    /// Invalid URI format or content
24    #[error("Invalid URI: {0}")]
25    InvalidUri(String),
26
27    /// Requested tool was not found
28    #[error("Tool not found: {0}")]
29    ToolNotFound(String),
30
31    /// Requested resource was not found
32    #[error("Resource not found: {0}")]
33    ResourceNotFound(String),
34
35    /// Requested prompt was not found
36    #[error("Prompt not found: {0}")]
37    PromptNotFound(String),
38
39    /// Connection-related errors
40    #[error("Connection error: {0}")]
41    Connection(String),
42
43    /// Authentication/authorization errors
44    #[error("Authentication error: {0}")]
45    Authentication(String),
46
47    /// Input validation errors
48    #[error("Validation error: {0}")]
49    Validation(String),
50
51    /// I/O errors from the standard library
52    #[error("I/O error: {0}")]
53    Io(String),
54
55    /// URL parsing errors
56    #[error("URL error: {0}")]
57    Url(String),
58
59    /// HTTP-related errors when using HTTP transport
60    #[cfg(feature = "http")]
61    #[error("HTTP error: {0}")]
62    Http(String),
63
64    /// WebSocket-related errors when using WebSocket transport
65    #[cfg(feature = "websocket")]
66    #[error("WebSocket error: {0}")]
67    WebSocket(String),
68
69    /// JSON Schema validation errors
70    #[cfg(feature = "validation")]
71    #[error("Schema validation error: {0}")]
72    SchemaValidation(String),
73
74    /// Timeout errors
75    #[error("Timeout error: {0}")]
76    Timeout(String),
77
78    /// Cancellation errors
79    #[error("Operation cancelled: {0}")]
80    Cancelled(String),
81
82    /// Internal errors that shouldn't normally occur
83    #[error("Internal error: {0}")]
84    Internal(String),
85}
86
87// Manual From implementations for types that don't implement Clone
88impl From<serde_json::Error> for McpError {
89    fn from(err: serde_json::Error) -> Self {
90        McpError::Serialization(err.to_string())
91    }
92}
93
94impl From<std::io::Error> for McpError {
95    fn from(err: std::io::Error) -> Self {
96        McpError::Io(err.to_string())
97    }
98}
99
100impl From<url::ParseError> for McpError {
101    fn from(err: url::ParseError) -> Self {
102        McpError::Url(err.to_string())
103    }
104}
105
106/// Result type alias for MCP operations
107pub type McpResult<T> = Result<T, McpError>;
108
109impl McpError {
110    /// Create a new transport error
111    pub fn transport<S: Into<String>>(message: S) -> Self {
112        Self::Transport(message.into())
113    }
114
115    /// Create a new protocol error
116    pub fn protocol<S: Into<String>>(message: S) -> Self {
117        Self::Protocol(message.into())
118    }
119
120    /// Create a new validation error
121    pub fn validation<S: Into<String>>(message: S) -> Self {
122        Self::Validation(message.into())
123    }
124
125    /// Create a new connection error
126    pub fn connection<S: Into<String>>(message: S) -> Self {
127        Self::Connection(message.into())
128    }
129
130    /// Create a new internal error
131    pub fn internal<S: Into<String>>(message: S) -> Self {
132        Self::Internal(message.into())
133    }
134
135    /// Create a new IO error from std::io::Error
136    pub fn io(err: std::io::Error) -> Self {
137        Self::Io(err.to_string())
138    }
139
140    /// Create a new serialization error from serde_json::Error
141    pub fn serialization(err: serde_json::Error) -> Self {
142        Self::Serialization(err.to_string())
143    }
144
145    /// Create a new timeout error
146    pub fn timeout<S: Into<String>>(message: S) -> Self {
147        Self::Timeout(message.into())
148    }
149
150    /// Create a connection error (compatibility method)
151    pub fn connection_error<S: Into<String>>(message: S) -> Self {
152        Self::Connection(message.into())
153    }
154
155    /// Create a protocol error (compatibility method)
156    pub fn protocol_error<S: Into<String>>(message: S) -> Self {
157        Self::Protocol(message.into())
158    }
159
160    /// Create a validation error (compatibility method)
161    pub fn validation_error<S: Into<String>>(message: S) -> Self {
162        Self::Validation(message.into())
163    }
164
165    /// Create a timeout error (compatibility method)
166    pub fn timeout_error() -> Self {
167        Self::Timeout("Operation timed out".to_string())
168    }
169
170    /// Check if this error is recoverable
171    pub fn is_recoverable(&self) -> bool {
172        match self {
173            McpError::Transport(_) => false,
174            McpError::Protocol(_) => false,
175            McpError::Connection(_) => true,
176            McpError::Timeout(_) => true,
177            McpError::Validation(_) => false,
178            McpError::ToolNotFound(_) => false,
179            McpError::ResourceNotFound(_) => false,
180            McpError::PromptNotFound(_) => false,
181            McpError::Authentication(_) => false,
182            McpError::Serialization(_) => false,
183            McpError::InvalidUri(_) => false,
184            McpError::Io(_) => true,
185            McpError::Url(_) => false,
186            #[cfg(feature = "http")]
187            McpError::Http(_) => true,
188            #[cfg(feature = "websocket")]
189            McpError::WebSocket(_) => true,
190            #[cfg(feature = "validation")]
191            McpError::SchemaValidation(_) => false,
192            McpError::Cancelled(_) => false,
193            McpError::Internal(_) => false,
194        }
195    }
196
197    /// Get the error category for logging/metrics
198    pub fn category(&self) -> &'static str {
199        match self {
200            McpError::Transport(_) => "transport",
201            McpError::Protocol(_) => "protocol",
202            McpError::Connection(_) => "connection",
203            McpError::Timeout(_) => "timeout",
204            McpError::Validation(_) => "validation",
205            McpError::ToolNotFound(_) => "not_found",
206            McpError::ResourceNotFound(_) => "not_found",
207            McpError::PromptNotFound(_) => "not_found",
208            McpError::Authentication(_) => "auth",
209            McpError::Serialization(_) => "serialization",
210            McpError::InvalidUri(_) => "validation",
211            McpError::Io(_) => "io",
212            McpError::Url(_) => "validation",
213            #[cfg(feature = "http")]
214            McpError::Http(_) => "http",
215            #[cfg(feature = "websocket")]
216            McpError::WebSocket(_) => "websocket",
217            #[cfg(feature = "validation")]
218            McpError::SchemaValidation(_) => "validation",
219            McpError::Cancelled(_) => "cancelled",
220            McpError::Internal(_) => "internal",
221        }
222    }
223}
224
225// Convert common HTTP errors when the feature is enabled
226#[cfg(feature = "http")]
227impl From<reqwest::Error> for McpError {
228    fn from(err: reqwest::Error) -> Self {
229        McpError::Http(err.to_string())
230    }
231}
232
233// Convert common WebSocket errors when the feature is enabled
234#[cfg(feature = "websocket")]
235impl From<tokio_tungstenite::tungstenite::Error> for McpError {
236    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
237        McpError::WebSocket(err.to_string())
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_error_creation() {
247        let error = McpError::transport("Connection failed");
248        assert_eq!(error.to_string(), "Transport error: Connection failed");
249        assert_eq!(error.category(), "transport");
250        assert!(!error.is_recoverable());
251    }
252
253    #[test]
254    fn test_error_recovery() {
255        assert!(McpError::connection("timeout").is_recoverable());
256        assert!(!McpError::validation("invalid input").is_recoverable());
257        assert!(McpError::timeout("request timeout").is_recoverable());
258    }
259
260    #[test]
261    fn test_error_categories() {
262        assert_eq!(McpError::protocol("bad message").category(), "protocol");
263        assert_eq!(
264            McpError::ToolNotFound("missing".to_string()).category(),
265            "not_found"
266        );
267        assert_eq!(
268            McpError::Authentication("unauthorized".to_string()).category(),
269            "auth"
270        );
271    }
272}