Skip to main content

rmcp_server_kit/
error.rs

1use axum::{
2    http::StatusCode,
3    response::{IntoResponse, Response},
4};
5use thiserror::Error;
6
7/// Generic MCP server error type.
8///
9/// Application crates should define their own error types and convert
10/// from/into `McpxError` where needed.
11#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum McpxError {
14    /// Configuration parsing or validation failed.
15    #[error("configuration error: {0}")]
16    Config(String),
17
18    /// Authentication failed (bad/missing credential).
19    #[error("authentication failed: {0}")]
20    Auth(String),
21
22    /// Authorization (RBAC) denied the request.
23    #[error("authorization denied: {0}")]
24    Rbac(String),
25
26    /// Request was rejected by a rate limiter.
27    #[error("rate limited: {0}")]
28    RateLimited(String),
29
30    /// Request was rejected by a rate limiter that knows the wait time.
31    ///
32    /// Renders as HTTP 429 with a `Retry-After` header (RFC 9110
33    /// delta-seconds: the duration is rounded **up** to whole seconds,
34    /// minimum `1`) and the message as a plain-text body. The legacy
35    /// [`RateLimited`](Self::RateLimited) variant remains headerless.
36    #[error("rate limited: {message} (retry after {retry_after:?})")]
37    RateLimitedFor {
38        /// Plain-text client-facing message (response body).
39        message: String,
40        /// Best-effort wait until the next request could be admitted.
41        retry_after: std::time::Duration,
42    },
43
44    /// Underlying I/O error.
45    #[error("I/O error: {0}")]
46    Io(#[from] std::io::Error),
47
48    /// JSON (de)serialization error.
49    #[error("JSON error: {0}")]
50    Json(#[from] serde_json::Error),
51
52    /// TOML parse error (configuration loading).
53    #[error("TOML parse error: {0}")]
54    Toml(#[from] toml::de::Error),
55
56    /// TLS configuration failure (certificate load, key parse, rustls config).
57    #[error("TLS error: {0}")]
58    Tls(String),
59
60    /// Server startup failure (binding, listener, runtime initialization).
61    #[error("server startup error: {0}")]
62    Startup(String),
63
64    /// Metrics registration failure (e.g. Prometheus duplicate or invalid metric).
65    #[cfg(feature = "metrics")]
66    #[error("metrics error: {0}")]
67    Metrics(String),
68}
69
70/// Render a wait [`Duration`](std::time::Duration) as RFC 9110
71/// `Retry-After` delta-seconds: rounded **up** to whole seconds, never
72/// below `1` (a `0` would invite an immediate retry storm).
73fn retry_after_secs(wait: std::time::Duration) -> u64 {
74    let mut secs = wait.as_secs();
75    if wait.subsec_nanos() > 0 {
76        secs = secs.saturating_add(1);
77    }
78    secs.max(1)
79}
80
81impl IntoResponse for McpxError {
82    fn into_response(self) -> Response {
83        let (status, client_msg) = match self {
84            Self::Auth(msg) => (StatusCode::UNAUTHORIZED, msg),
85            Self::Rbac(msg) => (StatusCode::FORBIDDEN, msg),
86            Self::RateLimited(msg) => (StatusCode::TOO_MANY_REQUESTS, msg),
87            Self::RateLimitedFor {
88                message,
89                retry_after,
90            } => {
91                return (
92                    StatusCode::TOO_MANY_REQUESTS,
93                    [(
94                        axum::http::header::RETRY_AFTER,
95                        retry_after_secs(retry_after).to_string(),
96                    )],
97                    message,
98                )
99                    .into_response();
100            }
101            // All remaining variants are internal - return a generic 500
102            // to avoid leaking implementation details.
103            other @ (Self::Config(_)
104            | Self::Io(_)
105            | Self::Json(_)
106            | Self::Toml(_)
107            | Self::Tls(_)
108            | Self::Startup(_)) => {
109                tracing::error!(error = %other, "internal error");
110                (
111                    StatusCode::INTERNAL_SERVER_ERROR,
112                    "internal server error".into(),
113                )
114            }
115            #[cfg(feature = "metrics")]
116            other @ Self::Metrics(_) => {
117                tracing::error!(error = %other, "internal error");
118                (
119                    StatusCode::INTERNAL_SERVER_ERROR,
120                    "internal server error".into(),
121                )
122            }
123        };
124        (status, client_msg).into_response()
125    }
126}
127
128/// Convenience `Result` alias bound to [`McpxError`].
129pub type Result<T> = std::result::Result<T, McpxError>;
130
131#[cfg(test)]
132mod tests {
133    use axum::{http::StatusCode, response::IntoResponse};
134    use http_body_util::BodyExt;
135
136    use super::*;
137
138    async fn status_of(err: McpxError) -> (StatusCode, String) {
139        let resp = err.into_response();
140        let status = resp.status();
141        let body = resp.into_body().collect().await.unwrap().to_bytes();
142        (status, String::from_utf8(body.to_vec()).unwrap())
143    }
144
145    #[tokio::test]
146    async fn auth_error_returns_401() {
147        let (status, body) = status_of(McpxError::Auth("bad token".into())).await;
148        assert_eq!(status, StatusCode::UNAUTHORIZED);
149        assert!(body.contains("bad token"));
150    }
151
152    #[tokio::test]
153    async fn rbac_error_returns_403() {
154        let (status, body) = status_of(McpxError::Rbac("denied".into())).await;
155        assert_eq!(status, StatusCode::FORBIDDEN);
156        assert!(body.contains("denied"));
157    }
158
159    #[tokio::test]
160    async fn rate_limited_error_returns_429() {
161        let (status, body) = status_of(McpxError::RateLimited("slow down".into())).await;
162        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
163        assert!(body.contains("slow down"));
164    }
165
166    #[tokio::test]
167    async fn legacy_rate_limited_has_no_retry_after_header() {
168        let resp = McpxError::RateLimited("slow down".into()).into_response();
169        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
170        assert!(
171            !resp.headers().contains_key(axum::http::header::RETRY_AFTER),
172            "legacy variant must stay headerless"
173        );
174    }
175
176    #[tokio::test]
177    async fn rate_limited_for_sets_retry_after_header() {
178        let resp = McpxError::RateLimitedFor {
179            message: "slow down".into(),
180            retry_after: std::time::Duration::from_millis(1500),
181        }
182        .into_response();
183        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
184        let header = resp
185            .headers()
186            .get(axum::http::header::RETRY_AFTER)
187            .expect("Retry-After present")
188            .to_str()
189            .unwrap()
190            .to_owned();
191        assert_eq!(header, "2", "1.5s must round UP to 2");
192        let body = resp.into_body().collect().await.unwrap().to_bytes();
193        assert_eq!(body.as_ref(), b"slow down");
194    }
195
196    #[test]
197    fn retry_after_secs_rounds_up_and_never_zero() {
198        use std::time::Duration;
199        assert_eq!(retry_after_secs(Duration::ZERO), 1, "zero floors to 1");
200        assert_eq!(retry_after_secs(Duration::from_millis(1)), 1);
201        assert_eq!(retry_after_secs(Duration::from_millis(999)), 1);
202        assert_eq!(retry_after_secs(Duration::from_secs(1)), 1, "exact stays");
203        assert_eq!(retry_after_secs(Duration::from_millis(1001)), 2, "ceil");
204        assert_eq!(retry_after_secs(Duration::from_secs(60)), 60);
205    }
206
207    #[tokio::test]
208    async fn config_error_returns_500() {
209        let (status, body) = status_of(McpxError::Config("bad".into())).await;
210        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
211        assert_eq!(
212            body, "internal server error",
213            "must not leak internal detail"
214        );
215    }
216
217    #[tokio::test]
218    async fn io_error_returns_500() {
219        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "gone");
220        let (status, body) = status_of(McpxError::from(io_err)).await;
221        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
222        assert_eq!(
223            body, "internal server error",
224            "must not leak internal detail"
225        );
226    }
227
228    #[tokio::test]
229    async fn tls_error_returns_500() {
230        let (status, body) = status_of(McpxError::Tls("bad cert".into())).await;
231        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
232        assert_eq!(
233            body, "internal server error",
234            "must not leak internal detail"
235        );
236    }
237
238    #[tokio::test]
239    async fn startup_error_returns_500() {
240        let (status, body) = status_of(McpxError::Startup("bind failed".into())).await;
241        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
242        assert_eq!(
243            body, "internal server error",
244            "must not leak internal detail"
245        );
246    }
247
248    #[cfg(feature = "metrics")]
249    #[tokio::test]
250    async fn metrics_error_returns_500() {
251        let (status, body) = status_of(McpxError::Metrics("dup metric".into())).await;
252        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
253        assert_eq!(
254            body, "internal server error",
255            "must not leak internal detail"
256        );
257    }
258
259    #[test]
260    fn display_preserves_message() {
261        let err = McpxError::Auth("unauthorized".into());
262        assert_eq!(err.to_string(), "authentication failed: unauthorized");
263
264        let err = McpxError::Rbac("forbidden".into());
265        assert_eq!(err.to_string(), "authorization denied: forbidden");
266
267        let err = McpxError::RateLimited("throttled".into());
268        assert_eq!(err.to_string(), "rate limited: throttled");
269    }
270}