1use serde_json::Value;
6use std::collections::HashMap;
7use std::fmt;
8use thiserror::Error;
9
10use super::types::JsonRpcError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ErrorType {
15 Validation,
17 NotFound,
19 MethodNotFound,
21 Unauthorized,
23 Internal,
25 Timeout,
27 Cancelled,
29 RateLimit,
31 Conflict,
33 CircuitOpen,
35}
36
37impl fmt::Display for ErrorType {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 ErrorType::Validation => write!(f, "validation"),
41 ErrorType::NotFound => write!(f, "not_found"),
42 ErrorType::MethodNotFound => write!(f, "method_not_found"),
43 ErrorType::Unauthorized => write!(f, "unauthorized"),
44 ErrorType::Internal => write!(f, "internal"),
45 ErrorType::Timeout => write!(f, "timeout"),
46 ErrorType::Cancelled => write!(f, "cancelled"),
47 ErrorType::RateLimit => write!(f, "rate_limit"),
48 ErrorType::Conflict => write!(f, "conflict"),
49 ErrorType::CircuitOpen => write!(f, "circuit_open"),
50 }
51 }
52}
53
54#[derive(Debug, Error)]
56pub struct McpError {
57 pub error_type: ErrorType,
59 pub code: String,
61 pub message: String,
63 pub details: HashMap<String, Value>,
65 #[source]
67 pub cause: Option<Box<dyn std::error::Error + Send + Sync>>,
68}
69
70impl fmt::Display for McpError {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "[{}] {}: {}", self.error_type, self.code, self.message)
73 }
74}
75
76impl McpError {
77 pub fn to_jsonrpc(&self) -> JsonRpcError {
79 let code = match self.error_type {
80 ErrorType::Validation => -32602, ErrorType::NotFound => -32002, ErrorType::MethodNotFound => -32601, ErrorType::Unauthorized => -32001, ErrorType::Internal => -32603, ErrorType::Timeout => -32000, ErrorType::Cancelled => -32000, ErrorType::RateLimit => -32000, ErrorType::Conflict => -32000, ErrorType::CircuitOpen => -32000, };
91
92 let mut data = self.details.clone();
93 data.insert(
94 "type".to_string(),
95 Value::String(self.error_type.to_string()),
96 );
97 data.insert("code".to_string(), Value::String(self.code.clone()));
98
99 JsonRpcError::new(code, &self.message).with_data(serde_json::to_value(data).unwrap())
100 }
101
102 pub fn builder(error_type: ErrorType, code: impl Into<String>) -> McpErrorBuilder {
104 McpErrorBuilder {
105 error_type,
106 code: code.into(),
107 message: String::new(),
108 details: HashMap::new(),
109 cause: None,
110 }
111 }
112
113 pub fn validation(code: impl Into<String>, message: impl Into<String>) -> Self {
115 Self::builder(ErrorType::Validation, code)
116 .message(message)
117 .build()
118 }
119
120 pub fn not_found(code: impl Into<String>, message: impl Into<String>) -> Self {
122 Self::builder(ErrorType::NotFound, code)
123 .message(message)
124 .build()
125 }
126
127 pub fn internal(code: impl Into<String>, message: impl Into<String>) -> Self {
129 Self::builder(ErrorType::Internal, code)
130 .message(message)
131 .build()
132 }
133
134 pub fn method_not_found(method: impl Into<String>) -> Self {
136 let method = method.into();
137 Self::builder(ErrorType::MethodNotFound, "method_not_found")
138 .message(format!("Method not found: {}", method))
139 .detail("method", method)
140 .build()
141 }
142
143 pub fn not_implemented(message: impl Into<String>) -> Self {
145 Self::builder(ErrorType::Internal, "not_implemented")
146 .message(message)
147 .build()
148 }
149
150 pub fn rate_limited(message: impl Into<String>) -> Self {
152 Self::builder(ErrorType::RateLimit, "rate_limit_exceeded")
153 .message(message)
154 .build()
155 }
156
157 pub fn circuit_open(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
159 let tool = tool_name.into();
160 Self::builder(ErrorType::CircuitOpen, "circuit_breaker_open")
161 .message(message)
162 .detail("tool", tool)
163 .build()
164 }
165}
166
167pub struct McpErrorBuilder {
169 error_type: ErrorType,
170 code: String,
171 message: String,
172 details: HashMap<String, Value>,
173 cause: Option<Box<dyn std::error::Error + Send + Sync>>,
174}
175
176impl McpErrorBuilder {
177 pub fn message(mut self, message: impl Into<String>) -> Self {
179 self.message = message.into();
180 self
181 }
182
183 pub fn detail(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
185 self.details.insert(key.into(), value.into());
186 self
187 }
188
189 pub fn cause(mut self, cause: impl std::error::Error + Send + Sync + 'static) -> Self {
191 self.cause = Some(Box::new(cause));
192 self
193 }
194
195 pub fn build(self) -> McpError {
197 McpError {
198 error_type: self.error_type,
199 code: self.code,
200 message: self.message,
201 details: self.details,
202 cause: self.cause,
203 }
204 }
205}
206
207pub mod codes {
209 pub const RESOURCE_NOT_FOUND: &str = "RESOURCE_NOT_FOUND";
211 pub const TOOL_NOT_FOUND: &str = "TOOL_NOT_FOUND";
213 pub const PROMPT_NOT_FOUND: &str = "PROMPT_NOT_FOUND";
215 pub const INVALID_TOOL_ARGS: &str = "INVALID_TOOL_ARGS";
217 pub const TOOL_EXECUTION_FAILED: &str = "TOOL_EXECUTION_FAILED";
219 pub const CIRCUIT_BREAKER_OPEN: &str = "CIRCUIT_BREAKER_OPEN";
221 pub const RATE_LIMIT_EXCEEDED: &str = "RATE_LIMIT_EXCEEDED";
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_error_builder() {
231 let error = McpError::builder(ErrorType::Validation, "TEST_ERROR")
232 .message("Test error message")
233 .detail("field", "value")
234 .build();
235
236 assert_eq!(error.error_type, ErrorType::Validation);
237 assert_eq!(error.code, "TEST_ERROR");
238 assert_eq!(error.message, "Test error message");
239 assert_eq!(
240 error.details.get("field").unwrap(),
241 &Value::String("value".to_string())
242 );
243 }
244
245 #[test]
246 fn test_to_jsonrpc() {
247 let error = McpError::validation("INVALID_PARAM", "Parameter X is invalid");
248 let json_err = error.to_jsonrpc();
249
250 assert_eq!(json_err.code, -32602);
251 assert_eq!(json_err.message, "Parameter X is invalid");
252 }
253
254 #[test]
255 fn test_convenience_constructors() {
256 let err = McpError::not_found("TEST", "not found");
257 assert_eq!(err.error_type, ErrorType::NotFound);
258
259 let err = McpError::internal("TEST", "internal");
260 assert_eq!(err.error_type, ErrorType::Internal);
261 }
262}