mcp_protocol_sdk/core/
error.rs

1//! Error types for the MCP Rust SDK
2//!
3//! This module defines all error types that can occur within the MCP SDK,
4//! providing structured error handling with detailed context.
5
6use thiserror::Error;
7
8/// The main error type for the MCP SDK
9#[derive(Error, Debug)]
10pub enum McpError {
11    /// Transport-related errors (connection, I/O, etc.)
12    #[error("Transport error: {0}")]
13    Transport(String),
14
15    /// Protocol-level errors (invalid messages, unexpected responses, etc.)
16    #[error("Protocol error: {0}")]
17    Protocol(String),
18
19    /// JSON serialization/deserialization errors
20    #[error("Serialization error: {0}")]
21    Serialization(#[from] serde_json::Error),
22
23    /// Invalid URI format or content
24    #[error("Invalid URI: {0}")]
25    InvalidUri(String),
26
27    /// Requested tool was not found
28    #[error("Tool not found: {0}")]
29    ToolNotFound(String),
30
31    /// Requested resource was not found
32    #[error("Resource not found: {0}")]
33    ResourceNotFound(String),
34
35    /// Requested prompt was not found
36    #[error("Prompt not found: {0}")]
37    PromptNotFound(String),
38
39    /// Connection-related errors
40    #[error("Connection error: {0}")]
41    Connection(String),
42
43    /// Authentication/authorization errors
44    #[error("Authentication error: {0}")]
45    Authentication(String),
46
47    /// Input validation errors
48    #[error("Validation error: {0}")]
49    Validation(String),
50
51    /// I/O errors from the standard library
52    #[error("I/O error: {0}")]
53    Io(#[from] std::io::Error),
54
55    /// URL parsing errors
56    #[error("URL error: {0}")]
57    Url(#[from] url::ParseError),
58
59    /// HTTP-related errors when using HTTP transport
60    #[cfg(feature = "http")]
61    #[error("HTTP error: {0}")]
62    Http(String),
63
64    /// WebSocket-related errors when using WebSocket transport
65    #[cfg(feature = "websocket")]
66    #[error("WebSocket error: {0}")]
67    WebSocket(String),
68
69    /// JSON Schema validation errors
70    #[cfg(feature = "validation")]
71    #[error("Schema validation error: {0}")]
72    SchemaValidation(String),
73
74    /// Timeout errors
75    #[error("Timeout error: {0}")]
76    Timeout(String),
77
78    /// Cancellation errors
79    #[error("Operation cancelled: {0}")]
80    Cancelled(String),
81
82    /// Internal errors that shouldn't normally occur
83    #[error("Internal error: {0}")]
84    Internal(String),
85}
86
87/// Result type alias for MCP operations
88pub type McpResult<T> = Result<T, McpError>;
89
90impl McpError {
91    /// Create a new transport error
92    pub fn transport<S: Into<String>>(message: S) -> Self {
93        Self::Transport(message.into())
94    }
95
96    /// Create a new protocol error
97    pub fn protocol<S: Into<String>>(message: S) -> Self {
98        Self::Protocol(message.into())
99    }
100
101    /// Create a new validation error
102    pub fn validation<S: Into<String>>(message: S) -> Self {
103        Self::Validation(message.into())
104    }
105
106    /// Create a new connection error
107    pub fn connection<S: Into<String>>(message: S) -> Self {
108        Self::Connection(message.into())
109    }
110
111    /// Create a new internal error
112    pub fn internal<S: Into<String>>(message: S) -> Self {
113        Self::Internal(message.into())
114    }
115
116    /// Create a new IO error from std::io::Error
117    pub fn io(err: std::io::Error) -> Self {
118        Self::Io(err)
119    }
120
121    /// Create a new serialization error from serde_json::Error
122    pub fn serialization(err: serde_json::Error) -> Self {
123        Self::Serialization(err)
124    }
125
126    /// Create a new timeout error
127    pub fn timeout<S: Into<String>>(message: S) -> Self {
128        Self::Timeout(message.into())
129    }
130
131    /// Check if this error is recoverable
132    pub fn is_recoverable(&self) -> bool {
133        match self {
134            McpError::Transport(_) => false,
135            McpError::Protocol(_) => false,
136            McpError::Connection(_) => true,
137            McpError::Timeout(_) => true,
138            McpError::Validation(_) => false,
139            McpError::ToolNotFound(_) => false,
140            McpError::ResourceNotFound(_) => false,
141            McpError::PromptNotFound(_) => false,
142            McpError::Authentication(_) => false,
143            McpError::Serialization(_) => false,
144            McpError::InvalidUri(_) => false,
145            McpError::Io(_) => true,
146            McpError::Url(_) => false,
147            #[cfg(feature = "http")]
148            McpError::Http(_) => true,
149            #[cfg(feature = "websocket")]
150            McpError::WebSocket(_) => true,
151            #[cfg(feature = "validation")]
152            McpError::SchemaValidation(_) => false,
153            McpError::Cancelled(_) => false,
154            McpError::Internal(_) => false,
155        }
156    }
157
158    /// Get the error category for logging/metrics
159    pub fn category(&self) -> &'static str {
160        match self {
161            McpError::Transport(_) => "transport",
162            McpError::Protocol(_) => "protocol",
163            McpError::Connection(_) => "connection",
164            McpError::Timeout(_) => "timeout",
165            McpError::Validation(_) => "validation",
166            McpError::ToolNotFound(_) => "not_found",
167            McpError::ResourceNotFound(_) => "not_found",
168            McpError::PromptNotFound(_) => "not_found",
169            McpError::Authentication(_) => "auth",
170            McpError::Serialization(_) => "serialization",
171            McpError::InvalidUri(_) => "validation",
172            McpError::Io(_) => "io",
173            McpError::Url(_) => "validation",
174            #[cfg(feature = "http")]
175            McpError::Http(_) => "http",
176            #[cfg(feature = "websocket")]
177            McpError::WebSocket(_) => "websocket",
178            #[cfg(feature = "validation")]
179            McpError::SchemaValidation(_) => "validation",
180            McpError::Cancelled(_) => "cancelled",
181            McpError::Internal(_) => "internal",
182        }
183    }
184}
185
186// Convert common HTTP errors when the feature is enabled
187#[cfg(feature = "http")]
188impl From<reqwest::Error> for McpError {
189    fn from(err: reqwest::Error) -> Self {
190        McpError::Http(err.to_string())
191    }
192}
193
194// Convert common WebSocket errors when the feature is enabled
195#[cfg(feature = "websocket")]
196impl From<tokio_tungstenite::tungstenite::Error> for McpError {
197    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
198        McpError::WebSocket(err.to_string())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_error_creation() {
208        let error = McpError::transport("Connection failed");
209        assert_eq!(error.to_string(), "Transport error: Connection failed");
210        assert_eq!(error.category(), "transport");
211        assert!(!error.is_recoverable());
212    }
213
214    #[test]
215    fn test_error_recovery() {
216        assert!(McpError::connection("timeout").is_recoverable());
217        assert!(!McpError::validation("invalid input").is_recoverable());
218        assert!(McpError::timeout("request timeout").is_recoverable());
219    }
220
221    #[test]
222    fn test_error_categories() {
223        assert_eq!(McpError::protocol("bad message").category(), "protocol");
224        assert_eq!(
225            McpError::ToolNotFound("missing".to_string()).category(),
226            "not_found"
227        );
228        assert_eq!(
229            McpError::Authentication("unauthorized".to_string()).category(),
230            "auth"
231        );
232    }
233}