use serde::{Deserialize, Serialize};
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
#[non_exhaustive]
pub enum ErrorCode {
ParseError = -32700,
InvalidRequest = -32600,
MethodNotFound = -32601,
InvalidParams = -32602,
InternalError = -32603,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
#[non_exhaustive]
pub enum McpErrorCode {
ConnectionClosed = -32000,
RequestTimeout = -32001,
ResourceNotFound = -32002,
AlreadySubscribed = -32003,
NotSubscribed = -32004,
SessionNotFound = -32005,
SessionRequired = -32006,
Forbidden = -32007,
UrlElicitationRequired = -32042,
}
impl McpErrorCode {
pub fn code(self) -> i32 {
self as i32
}
}
impl ErrorCode {
pub fn code(self) -> i32 {
self as i32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl JsonRpcError {
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
Self {
code: code.code(),
message: message.into(),
data: None,
}
}
pub fn with_data(mut self, data: serde_json::Value) -> Self {
self.data = Some(data);
self
}
pub fn parse_error(message: impl Into<String>) -> Self {
Self::new(ErrorCode::ParseError, message)
}
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InvalidRequest, message)
}
pub fn method_not_found(method: &str) -> Self {
Self::new(
ErrorCode::MethodNotFound,
format!("Method not found: {}", method),
)
}
pub fn invalid_params(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InvalidParams, message)
}
pub fn internal_error(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InternalError, message)
}
pub fn mcp_error(code: McpErrorCode, message: impl Into<String>) -> Self {
Self {
code: code.code(),
message: message.into(),
data: None,
}
}
pub fn connection_closed(message: impl Into<String>) -> Self {
Self::mcp_error(McpErrorCode::ConnectionClosed, message)
}
pub fn request_timeout(message: impl Into<String>) -> Self {
Self::mcp_error(McpErrorCode::RequestTimeout, message)
}
pub fn resource_not_found(uri: &str) -> Self {
Self::mcp_error(
McpErrorCode::ResourceNotFound,
format!("Resource not found: {}", uri),
)
}
pub fn already_subscribed(uri: &str) -> Self {
Self::mcp_error(
McpErrorCode::AlreadySubscribed,
format!("Already subscribed to: {}", uri),
)
}
pub fn not_subscribed(uri: &str) -> Self {
Self::mcp_error(
McpErrorCode::NotSubscribed,
format!("Not subscribed to: {}", uri),
)
}
pub fn session_not_found() -> Self {
Self::mcp_error(
McpErrorCode::SessionNotFound,
"Session not found or expired. Please re-initialize the connection.",
)
}
pub fn session_not_found_with_id(session_id: &str) -> Self {
Self::mcp_error(
McpErrorCode::SessionNotFound,
format!(
"Session '{}' not found or expired. Please re-initialize the connection.",
session_id
),
)
}
pub fn session_required() -> Self {
Self::mcp_error(
McpErrorCode::SessionRequired,
"MCP-Session-Id header is required for this request.",
)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::mcp_error(McpErrorCode::Forbidden, message)
}
pub fn url_elicitation_required(message: impl Into<String>) -> Self {
Self::mcp_error(McpErrorCode::UrlElicitationRequired, message)
}
}
#[derive(Debug)]
pub struct ToolError {
pub tool: Option<String>,
pub message: String,
pub source: Option<BoxError>,
}
impl std::fmt::Display for ToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(tool) = &self.tool {
write!(f, "Tool '{}' error: {}", tool, self.message)
} else {
write!(f, "Tool error: {}", self.message)
}
}
}
impl std::error::Error for ToolError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
}
}
impl ToolError {
pub fn new(message: impl Into<String>) -> Self {
Self {
tool: None,
message: message.into(),
source: None,
}
}
pub fn with_tool(tool: impl Into<String>, message: impl Into<String>) -> Self {
Self {
tool: Some(tool.into()),
message: message.into(),
source: None,
}
}
pub fn with_source(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
self.source = Some(Box::new(source));
self
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("JSON-RPC error: {0:?}")]
JsonRpc(JsonRpcError),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("{0}")]
Tool(#[from] ToolError),
#[error("Transport error: {0}")]
Transport(String),
#[error("Session expired")]
SessionExpired,
#[error("Internal error: {0}")]
Internal(String),
}
impl Error {
pub fn tool(message: impl Into<String>) -> Self {
Error::Tool(ToolError::new(message))
}
pub fn tool_with_name(tool: impl Into<String>, message: impl Into<String>) -> Self {
Error::Tool(ToolError::with_tool(tool, message))
}
pub fn tool_from<E: std::fmt::Display>(err: E) -> Self {
Error::Tool(ToolError::new(err.to_string()))
}
pub fn tool_context<E: std::fmt::Display>(context: impl Into<String>, err: E) -> Self {
Error::Tool(ToolError::new(format!("{}: {}", context.into(), err)))
}
pub fn invalid_params(message: impl Into<String>) -> Self {
Error::JsonRpc(JsonRpcError::invalid_params(message))
}
pub fn internal(message: impl Into<String>) -> Self {
Error::JsonRpc(JsonRpcError::internal_error(message))
}
}
pub trait ResultExt<T> {
fn tool_err(self) -> std::result::Result<T, Error>;
fn tool_context(self, context: impl Into<String>) -> std::result::Result<T, Error>;
}
impl<T, E: std::fmt::Display> ResultExt<T> for std::result::Result<T, E> {
fn tool_err(self) -> std::result::Result<T, Error> {
self.map_err(Error::tool_from)
}
fn tool_context(self, context: impl Into<String>) -> std::result::Result<T, Error> {
self.map_err(|e| Error::tool_context(context, e))
}
}
impl From<JsonRpcError> for Error {
fn from(err: JsonRpcError) -> Self {
Error::JsonRpc(err)
}
}
impl From<std::convert::Infallible> for Error {
fn from(err: std::convert::Infallible) -> Self {
match err {}
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_box_error_from_io_error() {
let io_err = std::io::Error::other("disk full");
let boxed: BoxError = io_err.into();
assert_eq!(boxed.to_string(), "disk full");
}
#[test]
fn test_box_error_from_string() {
let err: BoxError = "something went wrong".into();
assert_eq!(err.to_string(), "something went wrong");
}
#[test]
fn test_box_error_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BoxError>();
}
#[test]
fn test_tool_error_source_uses_box_error() {
let io_err = std::io::Error::other("timeout");
let tool_err = ToolError::new("failed").with_source(io_err);
assert!(tool_err.source.is_some());
assert_eq!(tool_err.source.unwrap().to_string(), "timeout");
}
#[test]
fn test_result_ext_tool_err() {
let result: std::result::Result<(), std::io::Error> =
Err(std::io::Error::other("disk full"));
let err = result.tool_err().unwrap_err();
assert!(matches!(err, Error::Tool(_)));
assert!(err.to_string().contains("disk full"));
}
#[test]
fn test_result_ext_tool_context() {
let result: std::result::Result<(), std::io::Error> =
Err(std::io::Error::other("connection refused"));
let err = result.tool_context("database query failed").unwrap_err();
assert!(matches!(err, Error::Tool(_)));
assert!(err.to_string().contains("database query failed"));
assert!(err.to_string().contains("connection refused"));
}
#[test]
fn test_result_ext_ok_passes_through() {
let result: std::result::Result<i32, std::io::Error> = Ok(42);
assert_eq!(result.tool_err().unwrap(), 42);
let result: std::result::Result<i32, std::io::Error> = Ok(42);
assert_eq!(result.tool_context("should not appear").unwrap(), 42);
}
}