1use axum::{
2 http::StatusCode,
3 response::{IntoResponse, Response},
4};
5use thiserror::Error;
6
7#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum McpxError {
14 #[error("configuration error: {0}")]
16 Config(String),
17
18 #[error("authentication failed: {0}")]
20 Auth(String),
21
22 #[error("authorization denied: {0}")]
24 Rbac(String),
25
26 #[error("rate limited: {0}")]
28 RateLimited(String),
29
30 #[error("rate limited: {message} (retry after {retry_after:?})")]
37 RateLimitedFor {
38 message: String,
40 retry_after: std::time::Duration,
42 },
43
44 #[error("I/O error: {0}")]
46 Io(#[from] std::io::Error),
47
48 #[error("JSON error: {0}")]
50 Json(#[from] serde_json::Error),
51
52 #[error("TOML parse error: {0}")]
54 Toml(#[from] toml::de::Error),
55
56 #[error("TLS error: {0}")]
58 Tls(String),
59
60 #[error("server startup error: {0}")]
62 Startup(String),
63
64 #[cfg(feature = "metrics")]
66 #[error("metrics error: {0}")]
67 Metrics(String),
68}
69
70fn retry_after_secs(wait: std::time::Duration) -> u64 {
74 let mut secs = wait.as_secs();
75 if wait.subsec_nanos() > 0 {
76 secs = secs.saturating_add(1);
77 }
78 secs.max(1)
79}
80
81impl IntoResponse for McpxError {
82 fn into_response(self) -> Response {
83 let (status, client_msg) = match self {
84 Self::Auth(msg) => (StatusCode::UNAUTHORIZED, msg),
85 Self::Rbac(msg) => (StatusCode::FORBIDDEN, msg),
86 Self::RateLimited(msg) => (StatusCode::TOO_MANY_REQUESTS, msg),
87 Self::RateLimitedFor {
88 message,
89 retry_after,
90 } => {
91 return (
92 StatusCode::TOO_MANY_REQUESTS,
93 [(
94 axum::http::header::RETRY_AFTER,
95 retry_after_secs(retry_after).to_string(),
96 )],
97 message,
98 )
99 .into_response();
100 }
101 other @ (Self::Config(_)
104 | Self::Io(_)
105 | Self::Json(_)
106 | Self::Toml(_)
107 | Self::Tls(_)
108 | Self::Startup(_)) => {
109 tracing::error!(error = %other, "internal error");
110 (
111 StatusCode::INTERNAL_SERVER_ERROR,
112 "internal server error".into(),
113 )
114 }
115 #[cfg(feature = "metrics")]
116 other @ Self::Metrics(_) => {
117 tracing::error!(error = %other, "internal error");
118 (
119 StatusCode::INTERNAL_SERVER_ERROR,
120 "internal server error".into(),
121 )
122 }
123 };
124 (status, client_msg).into_response()
125 }
126}
127
128pub type Result<T> = std::result::Result<T, McpxError>;
130
131#[cfg(test)]
132mod tests {
133 use axum::{http::StatusCode, response::IntoResponse};
134 use http_body_util::BodyExt;
135
136 use super::*;
137
138 async fn status_of(err: McpxError) -> (StatusCode, String) {
139 let resp = err.into_response();
140 let status = resp.status();
141 let body = resp.into_body().collect().await.unwrap().to_bytes();
142 (status, String::from_utf8(body.to_vec()).unwrap())
143 }
144
145 #[tokio::test]
146 async fn auth_error_returns_401() {
147 let (status, body) = status_of(McpxError::Auth("bad token".into())).await;
148 assert_eq!(status, StatusCode::UNAUTHORIZED);
149 assert!(body.contains("bad token"));
150 }
151
152 #[tokio::test]
153 async fn rbac_error_returns_403() {
154 let (status, body) = status_of(McpxError::Rbac("denied".into())).await;
155 assert_eq!(status, StatusCode::FORBIDDEN);
156 assert!(body.contains("denied"));
157 }
158
159 #[tokio::test]
160 async fn rate_limited_error_returns_429() {
161 let (status, body) = status_of(McpxError::RateLimited("slow down".into())).await;
162 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
163 assert!(body.contains("slow down"));
164 }
165
166 #[tokio::test]
167 async fn legacy_rate_limited_has_no_retry_after_header() {
168 let resp = McpxError::RateLimited("slow down".into()).into_response();
169 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
170 assert!(
171 !resp.headers().contains_key(axum::http::header::RETRY_AFTER),
172 "legacy variant must stay headerless"
173 );
174 }
175
176 #[tokio::test]
177 async fn rate_limited_for_sets_retry_after_header() {
178 let resp = McpxError::RateLimitedFor {
179 message: "slow down".into(),
180 retry_after: std::time::Duration::from_millis(1500),
181 }
182 .into_response();
183 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
184 let header = resp
185 .headers()
186 .get(axum::http::header::RETRY_AFTER)
187 .expect("Retry-After present")
188 .to_str()
189 .unwrap()
190 .to_owned();
191 assert_eq!(header, "2", "1.5s must round UP to 2");
192 let body = resp.into_body().collect().await.unwrap().to_bytes();
193 assert_eq!(body.as_ref(), b"slow down");
194 }
195
196 #[test]
197 fn retry_after_secs_rounds_up_and_never_zero() {
198 use std::time::Duration;
199 assert_eq!(retry_after_secs(Duration::ZERO), 1, "zero floors to 1");
200 assert_eq!(retry_after_secs(Duration::from_millis(1)), 1);
201 assert_eq!(retry_after_secs(Duration::from_millis(999)), 1);
202 assert_eq!(retry_after_secs(Duration::from_secs(1)), 1, "exact stays");
203 assert_eq!(retry_after_secs(Duration::from_millis(1001)), 2, "ceil");
204 assert_eq!(retry_after_secs(Duration::from_secs(60)), 60);
205 }
206
207 #[tokio::test]
208 async fn config_error_returns_500() {
209 let (status, body) = status_of(McpxError::Config("bad".into())).await;
210 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
211 assert_eq!(
212 body, "internal server error",
213 "must not leak internal detail"
214 );
215 }
216
217 #[tokio::test]
218 async fn io_error_returns_500() {
219 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "gone");
220 let (status, body) = status_of(McpxError::from(io_err)).await;
221 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
222 assert_eq!(
223 body, "internal server error",
224 "must not leak internal detail"
225 );
226 }
227
228 #[tokio::test]
229 async fn tls_error_returns_500() {
230 let (status, body) = status_of(McpxError::Tls("bad cert".into())).await;
231 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
232 assert_eq!(
233 body, "internal server error",
234 "must not leak internal detail"
235 );
236 }
237
238 #[tokio::test]
239 async fn startup_error_returns_500() {
240 let (status, body) = status_of(McpxError::Startup("bind failed".into())).await;
241 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
242 assert_eq!(
243 body, "internal server error",
244 "must not leak internal detail"
245 );
246 }
247
248 #[cfg(feature = "metrics")]
249 #[tokio::test]
250 async fn metrics_error_returns_500() {
251 let (status, body) = status_of(McpxError::Metrics("dup metric".into())).await;
252 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
253 assert_eq!(
254 body, "internal server error",
255 "must not leak internal detail"
256 );
257 }
258
259 #[test]
260 fn display_preserves_message() {
261 let err = McpxError::Auth("unauthorized".into());
262 assert_eq!(err.to_string(), "authentication failed: unauthorized");
263
264 let err = McpxError::Rbac("forbidden".into());
265 assert_eq!(err.to_string(), "authorization denied: forbidden");
266
267 let err = McpxError::RateLimited("throttled".into());
268 assert_eq!(err.to_string(), "rate limited: throttled");
269 }
270}