1use std::time::Duration;
16use thiserror::Error;
17
18#[derive(Error, Debug)]
36pub enum McpError {
37    #[error("Transport error: {0}")]
39    Transport(#[from] TransportError),
40
41    #[error("Protocol error: {0}")]
43    Protocol(#[from] ProtocolError),
44
45    #[error("Validation error: {0}")]
47    Validation(#[from] ValidationError),
48
49    #[error("Authentication error: {0}")]
51    Auth(#[from] AuthError),
52
53    #[error("Operation timed out after {duration_ms}ms: {operation}")]
55    Timeout {
56        operation: String,
58        duration_ms: u64,
60    },
61
62    #[error("Configuration error: {0}")]
64    Config(#[from] ConfigError),
65
66    #[error("Serialization error: {source}")]
68    Serialization {
69        #[from]
70        source: serde_json::Error,
72    },
73
74    #[error("IO error: {source}")]
76    Io {
77        #[from]
78        source: std::io::Error,
80    },
81
82    #[error("Internal error: {message}")]
84    Internal {
85        message: String,
87    },
88}
89
90#[derive(Error, Debug, Clone)]
95#[allow(missing_docs)]
96pub enum TransportError {
97    #[error("Failed to connect to {transport_type} server: {reason}")]
99    ConnectionFailed {
100        transport_type: String,
101        reason: String,
102    },
103
104    #[error("Connection lost to {transport_type} server: {reason}")]
106    ConnectionLost {
107        transport_type: String,
108        reason: String,
109    },
110
111    #[error("Failed to send message via {transport_type}: {reason}")]
113    SendFailed {
114        transport_type: String,
115        reason: String,
116    },
117
118    #[error("Failed to receive message via {transport_type}: {reason}")]
120    ReceiveFailed {
121        transport_type: String,
122        reason: String,
123    },
124
125    #[error("Invalid {transport_type} configuration: {reason}")]
127    InvalidConfig {
128        transport_type: String,
129        reason: String,
130    },
131
132    #[error("Process error: {reason}")]
134    ProcessError { reason: String },
135
136    #[error("HTTP error: {status_code} - {reason}")]
138    HttpError { status_code: u16, reason: String },
139
140    #[error("SSE error: {reason}")]
142    SseError { reason: String },
143
144    #[error("Streaming error: {reason}")]
146    StreamingError { reason: String },
147
148    #[error("Transport not connected ({transport_type}): {reason}")]
150    NotConnected {
151        transport_type: String,
152        reason: String,
153    },
154
155    #[error("Network error ({transport_type}): {reason}")]
157    NetworkError {
158        transport_type: String,
159        reason: String,
160    },
161
162    #[error("Serialization error ({transport_type}): {reason}")]
164    SerializationError {
165        transport_type: String,
166        reason: String,
167    },
168
169    #[error("Operation timed out ({transport_type}): {reason}")]
171    TimeoutError {
172        transport_type: String,
173        reason: String,
174    },
175
176    #[error("Transport disconnected ({transport_type}): {reason}")]
178    DisconnectedError {
179        transport_type: String,
180        reason: String,
181    },
182
183    #[error("Connection error ({transport_type}): {reason}")]
185    ConnectionError {
186        transport_type: String,
187        reason: String,
188    },
189}
190
191#[derive(Error, Debug, Clone)]
196#[allow(missing_docs)]
197pub enum ProtocolError {
198    #[error("Invalid JSON-RPC message: {reason}")]
200    InvalidJsonRpc { reason: String },
201
202    #[error("Unsupported protocol version: {version}, supported versions: {supported:?}")]
204    UnsupportedVersion {
205        version: String,
206        supported: Vec<String>,
207    },
208
209    #[error("Message ID mismatch: expected {expected}, got {actual}")]
211    MessageIdMismatch { expected: String, actual: String },
212
213    #[error("Unexpected message type: expected {expected}, got {actual}")]
215    UnexpectedMessageType { expected: String, actual: String },
216
217    #[error("Missing required field '{field}' in {message_type}")]
219    MissingField { field: String, message_type: String },
220
221    #[error("Invalid method name: {method}")]
223    InvalidMethod { method: String },
224
225    #[error("Server error {code}: {message}")]
227    ServerError { code: i32, message: String },
228
229    #[error("Protocol state violation: {reason}")]
231    StateViolation { reason: String },
232
233    #[error("Protocol initialization failed: {reason}")]
235    InitializationFailed { reason: String },
236
237    #[error("Protocol not initialized: {reason}")]
239    NotInitialized { reason: String },
240
241    #[error("Invalid response: {reason}")]
243    InvalidResponse { reason: String },
244
245    #[error("Protocol configuration error: {reason}")]
247    InvalidConfig { reason: String },
248
249    #[error("Protocol operation '{operation}' timed out after {timeout:?}")]
251    TimeoutError {
252        operation: String,
253        timeout: std::time::Duration,
254    },
255
256    #[error("Request failed: {reason}")]
258    RequestFailed { reason: String },
259
260    #[error("Request timed out after {timeout:?}")]
262    RequestTimeout { timeout: Duration },
263}
264
265#[derive(Error, Debug, Clone)]
270#[allow(missing_docs)]
271pub enum ValidationError {
272    #[error("Schema validation failed for {object_type}: {reason}")]
274    SchemaValidation { object_type: String, reason: String },
275
276    #[error("Capability '{capability}' not supported by server")]
278    UnsupportedCapability { capability: String },
279
280    #[error("Invalid parameter '{parameter}' for tool '{tool}': {reason}")]
282    InvalidToolParameter {
283        tool: String,
284        parameter: String,
285        reason: String,
286    },
287
288    #[error("Invalid resource '{resource}': {reason}")]
290    InvalidResource { resource: String, reason: String },
291
292    #[error("Invalid prompt '{prompt}': {reason}")]
294    InvalidPrompt { prompt: String, reason: String },
295
296    #[error("Constraint violation: {constraint} - {reason}")]
298    ConstraintViolation { constraint: String, reason: String },
299}
300
301#[derive(Error, Debug, Clone)]
306#[allow(missing_docs)]
307pub enum AuthError {
308    #[error("Missing authentication credentials for {auth_type}")]
310    MissingCredentials { auth_type: String },
311
312    #[error("Invalid {auth_type} credentials: {reason}")]
314    InvalidCredentials { auth_type: String, reason: String },
315
316    #[error("Authentication expired for {auth_type}")]
318    Expired { auth_type: String },
319
320    #[error("Access denied: {reason}")]
322    AccessDenied { reason: String },
323
324    #[error("OAuth error: {error_code} - {description}")]
326    OAuth {
327        error_code: String,
328        description: String,
329    },
330
331    #[error("JWT error: {reason}")]
333    Jwt { reason: String },
334}
335
336#[derive(Error, Debug, Clone)]
341#[allow(missing_docs)]
342pub enum ConfigError {
343    #[error("Configuration file not found: {path}")]
345    FileNotFound { path: String },
346
347    #[error("Invalid configuration format in {path}: {reason}")]
349    InvalidFormat { path: String, reason: String },
350
351    #[error("Missing required configuration parameter: {parameter}")]
353    MissingParameter { parameter: String },
354
355    #[error("Invalid value for parameter '{parameter}': {value} - {reason}")]
357    InvalidValue {
358        parameter: String,
359        value: String,
360        reason: String,
361    },
362
363    #[error("Conflicting configuration: {reason}")]
365    Conflict { reason: String },
366}
367
368pub type McpResult<T> = Result<T, McpError>;
370
371impl McpError {
372    pub fn internal(message: impl Into<String>) -> Self {
385        Self::Internal {
386            message: message.into(),
387        }
388    }
389
390    pub fn timeout(operation: impl Into<String>, duration: std::time::Duration) -> Self {
401        Self::Timeout {
402            operation: operation.into(),
403            duration_ms: duration.as_millis() as u64,
404        }
405    }
406
407    pub fn is_retryable(&self) -> bool {
429        match self {
430            McpError::Transport(transport_err) => transport_err.is_retryable(),
431            McpError::Timeout { .. } => true,
432            McpError::Io { .. } => true,
433            McpError::Auth(_) => false,
434            McpError::Protocol(_) => false,
435            McpError::Validation(_) => false,
436            McpError::Config(_) => false,
437            McpError::Serialization { .. } => false,
438            McpError::Internal { .. } => false,
439        }
440    }
441
442    pub fn category(&self) -> &'static str {
446        match self {
447            McpError::Transport(_) => "transport",
448            McpError::Protocol(_) => "protocol",
449            McpError::Validation(_) => "validation",
450            McpError::Auth(_) => "auth",
451            McpError::Timeout { .. } => "timeout",
452            McpError::Config(_) => "config",
453            McpError::Serialization { .. } => "serialization",
454            McpError::Io { .. } => "io",
455            McpError::Internal { .. } => "internal",
456        }
457    }
458}
459
460impl TransportError {
461    pub fn is_retryable(&self) -> bool {
463        match self {
464            TransportError::ConnectionFailed { .. } => true,
465            TransportError::ConnectionLost { .. } => true,
466            TransportError::ConnectionError { .. } => true,
467            TransportError::SendFailed { .. } => true,
468            TransportError::ReceiveFailed { .. } => true,
469            TransportError::NetworkError { .. } => true,
470            TransportError::TimeoutError { .. } => true,
471            TransportError::DisconnectedError { .. } => true,
472            TransportError::HttpError { status_code, .. } => {
473                *status_code >= 500
475            }
476            TransportError::SseError { .. } => true,
477            TransportError::StreamingError { .. } => true,
478            TransportError::ProcessError { .. } => false,
479            TransportError::InvalidConfig { .. } => false,
480            TransportError::NotConnected { .. } => false,
481            TransportError::SerializationError { .. } => false,
482        }
483    }
484}
485
486impl From<reqwest::Error> for McpError {
487    fn from(err: reqwest::Error) -> Self {
488        if err.is_timeout() {
489            McpError::timeout("HTTP request", std::time::Duration::from_secs(30))
490        } else if err.is_connect() {
491            McpError::Transport(TransportError::ConnectionFailed {
492                transport_type: "http".to_string(),
493                reason: err.to_string(),
494            })
495        } else if let Some(status) = err.status() {
496            McpError::Transport(TransportError::HttpError {
497                status_code: status.as_u16(),
498                reason: err.to_string(),
499            })
500        } else {
501            McpError::Transport(TransportError::HttpError {
502                status_code: 0,
503                reason: err.to_string(),
504            })
505        }
506    }
507}
508
509impl From<url::ParseError> for McpError {
510    fn from(err: url::ParseError) -> Self {
511        McpError::Config(ConfigError::InvalidValue {
512            parameter: "url".to_string(),
513            value: err.to_string(),
514            reason: "Invalid URL format".to_string(),
515        })
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use std::time::Duration;
523
524    #[test]
525    fn test_error_display() {
526        let error = McpError::timeout("test operation", Duration::from_secs(30));
527        assert_eq!(
528            error.to_string(),
529            "Operation timed out after 30000ms: test operation"
530        );
531    }
532
533    #[test]
534    fn test_retryable_errors() {
535        let timeout = McpError::timeout("test", Duration::from_secs(30));
536        assert!(timeout.is_retryable());
537
538        let auth_error = McpError::Auth(AuthError::InvalidCredentials {
539            auth_type: "Bearer".to_string(),
540            reason: "Invalid token".to_string(),
541        });
542        assert!(!auth_error.is_retryable());
543    }
544
545    #[test]
546    fn test_error_categories() {
547        let timeout = McpError::timeout("test", Duration::from_secs(30));
548        assert_eq!(timeout.category(), "timeout");
549
550        let transport_error = McpError::Transport(TransportError::ConnectionFailed {
551            transport_type: "stdio".to_string(),
552            reason: "Process failed".to_string(),
553        });
554        assert_eq!(transport_error.category(), "transport");
555    }
556
557    #[test]
558    fn test_transport_error_retryable() {
559        let connection_failed = TransportError::ConnectionFailed {
560            transport_type: "stdio".to_string(),
561            reason: "Process failed".to_string(),
562        };
563        assert!(connection_failed.is_retryable());
564
565        let invalid_config = TransportError::InvalidConfig {
566            transport_type: "stdio".to_string(),
567            reason: "Missing command".to_string(),
568        };
569        assert!(!invalid_config.is_retryable());
570    }
571}