use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum McpError {
#[error("Transport error: {0}")]
Transport(String),
#[error("Protocol error: {0}")]
Protocol(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Invalid URI: {0}")]
InvalidUri(String),
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Resource not found: {0}")]
ResourceNotFound(String),
#[error("Prompt not found: {0}")]
PromptNotFound(String),
#[error("Connection error: {0}")]
Connection(String),
#[error("Authentication error: {0}")]
Authentication(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("I/O error: {0}")]
Io(String),
#[error("URL error: {0}")]
Url(String),
#[cfg(feature = "http")]
#[error("HTTP error: {0}")]
Http(String),
#[cfg(feature = "websocket")]
#[error("WebSocket error: {0}")]
WebSocket(String),
#[cfg(feature = "validation")]
#[error("Schema validation error: {0}")]
SchemaValidation(String),
#[error("Timeout error: {0}")]
Timeout(String),
#[error("Operation cancelled: {0}")]
Cancelled(String),
#[error("Internal error: {0}")]
Internal(String),
}
impl From<serde_json::Error> for McpError {
fn from(err: serde_json::Error) -> Self {
McpError::Serialization(err.to_string())
}
}
impl From<std::io::Error> for McpError {
fn from(err: std::io::Error) -> Self {
McpError::Io(err.to_string())
}
}
impl From<url::ParseError> for McpError {
fn from(err: url::ParseError) -> Self {
McpError::Url(err.to_string())
}
}
pub type McpResult<T> = Result<T, McpError>;
impl McpError {
pub fn transport<S: Into<String>>(message: S) -> Self {
Self::Transport(message.into())
}
pub fn protocol<S: Into<String>>(message: S) -> Self {
Self::Protocol(message.into())
}
pub fn validation<S: Into<String>>(message: S) -> Self {
Self::Validation(message.into())
}
pub fn connection<S: Into<String>>(message: S) -> Self {
Self::Connection(message.into())
}
pub fn internal<S: Into<String>>(message: S) -> Self {
Self::Internal(message.into())
}
pub fn io(err: std::io::Error) -> Self {
Self::Io(err.to_string())
}
pub fn serialization(err: serde_json::Error) -> Self {
Self::Serialization(err.to_string())
}
pub fn timeout<S: Into<String>>(message: S) -> Self {
Self::Timeout(message.into())
}
pub fn connection_error<S: Into<String>>(message: S) -> Self {
Self::Connection(message.into())
}
pub fn protocol_error<S: Into<String>>(message: S) -> Self {
Self::Protocol(message.into())
}
pub fn validation_error<S: Into<String>>(message: S) -> Self {
Self::Validation(message.into())
}
pub fn timeout_error() -> Self {
Self::Timeout("Operation timed out".to_string())
}
pub fn is_recoverable(&self) -> bool {
match self {
McpError::Transport(_) => false,
McpError::Protocol(_) => false,
McpError::Connection(_) => true,
McpError::Timeout(_) => true,
McpError::Validation(_) => false,
McpError::ToolNotFound(_) => false,
McpError::ResourceNotFound(_) => false,
McpError::PromptNotFound(_) => false,
McpError::Authentication(_) => false,
McpError::Serialization(_) => false,
McpError::InvalidUri(_) => false,
McpError::Io(_) => true,
McpError::Url(_) => false,
#[cfg(feature = "http")]
McpError::Http(_) => true,
#[cfg(feature = "websocket")]
McpError::WebSocket(_) => true,
#[cfg(feature = "validation")]
McpError::SchemaValidation(_) => false,
McpError::Cancelled(_) => false,
McpError::Internal(_) => false,
}
}
pub fn category(&self) -> &'static str {
match self {
McpError::Transport(_) => "transport",
McpError::Protocol(_) => "protocol",
McpError::Connection(_) => "connection",
McpError::Timeout(_) => "timeout",
McpError::Validation(_) => "validation",
McpError::ToolNotFound(_) => "not_found",
McpError::ResourceNotFound(_) => "not_found",
McpError::PromptNotFound(_) => "not_found",
McpError::Authentication(_) => "auth",
McpError::Serialization(_) => "serialization",
McpError::InvalidUri(_) => "validation",
McpError::Io(_) => "io",
McpError::Url(_) => "validation",
#[cfg(feature = "http")]
McpError::Http(_) => "http",
#[cfg(feature = "websocket")]
McpError::WebSocket(_) => "websocket",
#[cfg(feature = "validation")]
McpError::SchemaValidation(_) => "validation",
McpError::Cancelled(_) => "cancelled",
McpError::Internal(_) => "internal",
}
}
}
#[cfg(feature = "http")]
impl From<reqwest::Error> for McpError {
fn from(err: reqwest::Error) -> Self {
McpError::Http(err.to_string())
}
}
#[cfg(feature = "websocket")]
impl From<tokio_tungstenite::tungstenite::Error> for McpError {
fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
McpError::WebSocket(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_creation() {
let error = McpError::transport("Connection failed");
assert_eq!(error.to_string(), "Transport error: Connection failed");
assert_eq!(error.category(), "transport");
assert!(!error.is_recoverable());
}
#[test]
fn test_error_recovery() {
assert!(McpError::connection("timeout").is_recoverable());
assert!(!McpError::validation("invalid input").is_recoverable());
assert!(McpError::timeout("request timeout").is_recoverable());
}
#[test]
fn test_error_categories() {
assert_eq!(McpError::protocol("bad message").category(), "protocol");
assert_eq!(
McpError::ToolNotFound("missing".to_string()).category(),
"not_found"
);
assert_eq!(
McpError::Authentication("unauthorized".to_string()).category(),
"auth"
);
}
}