1use axum::{
2 http::StatusCode,
3 response::{IntoResponse, Response},
4};
5use thiserror::Error;
6
7#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum McpxError {
27 #[error("configuration error: {0}")]
29 Config(String),
30
31 #[error("authentication failed: {0}")]
33 Auth(String),
34
35 #[error("authorization denied: {0}")]
37 Rbac(String),
38
39 #[error("rate limited: {0}")]
41 RateLimited(String),
42
43 #[error("rate limited: {message} (retry after {retry_after:?})")]
50 RateLimitedFor {
51 message: String,
53 retry_after: std::time::Duration,
55 },
56
57 #[error("I/O error: {0}")]
59 Io(#[from] std::io::Error),
60
61 #[error("JSON error: {0}")]
63 Json(#[from] serde_json::Error),
64
65 #[error("TOML parse error: {0}")]
67 Toml(#[from] toml::de::Error),
68
69 #[error("TLS error: {0}")]
71 Tls(String),
72
73 #[error("server startup error: {0}")]
75 Startup(String),
76
77 #[cfg(feature = "metrics")]
79 #[error("metrics error: {0}")]
80 Metrics(String),
81}
82
83fn retry_after_secs(wait: std::time::Duration) -> u64 {
87 let mut secs = wait.as_secs();
88 if wait.subsec_nanos() > 0 {
89 secs = secs.saturating_add(1);
90 }
91 secs.max(1)
92}
93
94impl McpxError {
95 #[must_use]
108 pub fn client_message(&self) -> std::borrow::Cow<'_, str> {
109 use std::borrow::Cow;
110 match self {
111 Self::Auth(msg) | Self::Rbac(msg) | Self::RateLimited(msg) => Cow::Borrowed(msg),
112 Self::RateLimitedFor { message, .. } => Cow::Borrowed(message),
113 Self::Config(_)
115 | Self::Io(_)
116 | Self::Json(_)
117 | Self::Toml(_)
118 | Self::Tls(_)
119 | Self::Startup(_) => Cow::Borrowed("internal server error"),
120 #[cfg(feature = "metrics")]
121 Self::Metrics(_) => Cow::Borrowed("internal server error"),
122 }
123 }
124}
125
126impl IntoResponse for McpxError {
127 fn into_response(self) -> Response {
128 let (status, client_msg) = match self {
129 Self::Auth(msg) => (StatusCode::UNAUTHORIZED, msg),
130 Self::Rbac(msg) => (StatusCode::FORBIDDEN, msg),
131 Self::RateLimited(msg) => (StatusCode::TOO_MANY_REQUESTS, msg),
132 Self::RateLimitedFor {
133 message,
134 retry_after,
135 } => {
136 return (
137 StatusCode::TOO_MANY_REQUESTS,
138 [(
139 axum::http::header::RETRY_AFTER,
140 retry_after_secs(retry_after).to_string(),
141 )],
142 message,
143 )
144 .into_response();
145 }
146 other @ (Self::Config(_)
149 | Self::Io(_)
150 | Self::Json(_)
151 | Self::Toml(_)
152 | Self::Tls(_)
153 | Self::Startup(_)) => {
154 tracing::error!(error = %other, "internal error");
155 (
156 StatusCode::INTERNAL_SERVER_ERROR,
157 "internal server error".into(),
158 )
159 }
160 #[cfg(feature = "metrics")]
161 other @ Self::Metrics(_) => {
162 tracing::error!(error = %other, "internal error");
163 (
164 StatusCode::INTERNAL_SERVER_ERROR,
165 "internal server error".into(),
166 )
167 }
168 };
169 (status, client_msg).into_response()
170 }
171}
172
173pub type Result<T> = std::result::Result<T, McpxError>;
175
176#[cfg(test)]
177mod tests {
178 use axum::{http::StatusCode, response::IntoResponse};
179 use http_body_util::BodyExt;
180
181 use super::*;
182
183 async fn status_of(err: McpxError) -> (StatusCode, String) {
184 let resp = err.into_response();
185 let status = resp.status();
186 let body = resp.into_body().collect().await.unwrap().to_bytes();
187 (status, String::from_utf8(body.to_vec()).unwrap())
188 }
189
190 #[tokio::test]
191 async fn auth_error_returns_401() {
192 let (status, body) = status_of(McpxError::Auth("bad token".into())).await;
193 assert_eq!(status, StatusCode::UNAUTHORIZED);
194 assert!(body.contains("bad token"));
195 }
196
197 #[tokio::test]
198 async fn rbac_error_returns_403() {
199 let (status, body) = status_of(McpxError::Rbac("denied".into())).await;
200 assert_eq!(status, StatusCode::FORBIDDEN);
201 assert!(body.contains("denied"));
202 }
203
204 #[tokio::test]
205 async fn rate_limited_error_returns_429() {
206 let (status, body) = status_of(McpxError::RateLimited("slow down".into())).await;
207 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
208 assert!(body.contains("slow down"));
209 }
210
211 #[tokio::test]
212 async fn legacy_rate_limited_has_no_retry_after_header() {
213 let resp = McpxError::RateLimited("slow down".into()).into_response();
214 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
215 assert!(
216 !resp.headers().contains_key(axum::http::header::RETRY_AFTER),
217 "legacy variant must stay headerless"
218 );
219 }
220
221 #[tokio::test]
222 async fn rate_limited_for_sets_retry_after_header() {
223 let resp = McpxError::RateLimitedFor {
224 message: "slow down".into(),
225 retry_after: std::time::Duration::from_millis(1500),
226 }
227 .into_response();
228 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
229 let header = resp
230 .headers()
231 .get(axum::http::header::RETRY_AFTER)
232 .expect("Retry-After present")
233 .to_str()
234 .unwrap()
235 .to_owned();
236 assert_eq!(header, "2", "1.5s must round UP to 2");
237 let body = resp.into_body().collect().await.unwrap().to_bytes();
238 assert_eq!(body.as_ref(), b"slow down");
239 }
240
241 #[test]
242 fn retry_after_secs_rounds_up_and_never_zero() {
243 use std::time::Duration;
244 assert_eq!(retry_after_secs(Duration::ZERO), 1, "zero floors to 1");
245 assert_eq!(retry_after_secs(Duration::from_millis(1)), 1);
246 assert_eq!(retry_after_secs(Duration::from_millis(999)), 1);
247 assert_eq!(retry_after_secs(Duration::from_secs(1)), 1, "exact stays");
248 assert_eq!(retry_after_secs(Duration::from_millis(1001)), 2, "ceil");
249 assert_eq!(retry_after_secs(Duration::from_secs(60)), 60);
250 }
251
252 #[tokio::test]
253 async fn config_error_returns_500() {
254 let (status, body) = status_of(McpxError::Config("bad".into())).await;
255 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
256 assert_eq!(
257 body, "internal server error",
258 "must not leak internal detail"
259 );
260 }
261
262 #[tokio::test]
263 async fn io_error_returns_500() {
264 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "gone");
265 let (status, body) = status_of(McpxError::from(io_err)).await;
266 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
267 assert_eq!(
268 body, "internal server error",
269 "must not leak internal detail"
270 );
271 }
272
273 #[tokio::test]
274 async fn tls_error_returns_500() {
275 let (status, body) = status_of(McpxError::Tls("bad cert".into())).await;
276 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
277 assert_eq!(
278 body, "internal server error",
279 "must not leak internal detail"
280 );
281 }
282
283 #[tokio::test]
284 async fn startup_error_returns_500() {
285 let (status, body) = status_of(McpxError::Startup("bind failed".into())).await;
286 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
287 assert_eq!(
288 body, "internal server error",
289 "must not leak internal detail"
290 );
291 }
292
293 #[cfg(feature = "metrics")]
294 #[tokio::test]
295 async fn metrics_error_returns_500() {
296 let (status, body) = status_of(McpxError::Metrics("dup metric".into())).await;
297 assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
298 assert_eq!(
299 body, "internal server error",
300 "must not leak internal detail"
301 );
302 }
303
304 #[test]
305 fn display_preserves_message() {
306 let err = McpxError::Auth("unauthorized".into());
307 assert_eq!(err.to_string(), "authentication failed: unauthorized");
308
309 let err = McpxError::Rbac("forbidden".into());
310 assert_eq!(err.to_string(), "authorization denied: forbidden");
311
312 let err = McpxError::RateLimited("throttled".into());
313 assert_eq!(err.to_string(), "rate limited: throttled");
314 }
315
316 #[test]
317 fn client_message_exposes_client_facing_text_and_hides_internal_detail() {
318 assert_eq!(
320 McpxError::Auth("bad token".into()).client_message(),
321 "bad token"
322 );
323 assert_eq!(McpxError::Rbac("nope".into()).client_message(), "nope");
324 assert_eq!(
325 McpxError::RateLimited("slow down".into()).client_message(),
326 "slow down"
327 );
328 assert_eq!(
329 McpxError::RateLimitedFor {
330 message: "too many".into(),
331 retry_after: std::time::Duration::from_secs(1),
332 }
333 .client_message(),
334 "too many"
335 );
336
337 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "secret/path/leak");
339 assert_eq!(
340 McpxError::from(io_err).client_message(),
341 "internal server error"
342 );
343 assert_eq!(
344 McpxError::Tls("private key /etc/certs/server.key".into()).client_message(),
345 "internal server error"
346 );
347 assert_eq!(
348 McpxError::Config("bind 10.0.0.5:8443 failed".into()).client_message(),
349 "internal server error"
350 );
351 }
352
353 #[tokio::test]
354 async fn client_message_matches_into_response_body() {
355 for err in [
357 McpxError::Auth("a".into()),
358 McpxError::Rbac("b".into()),
359 McpxError::RateLimited("c".into()),
360 McpxError::Config("d".into()),
361 McpxError::Tls("e".into()),
362 ] {
363 let expected = err.client_message().into_owned();
364 let (_status, body) = status_of(err).await;
365 assert_eq!(body, expected, "client_message must equal the wire body");
366 }
367 }
368}