Skip to main content

rmcp_server_kit/
error.rs

1use axum::{
2    http::StatusCode,
3    response::{IntoResponse, Response},
4};
5use thiserror::Error;
6
7/// Generic MCP server error type.
8///
9/// Application crates should define their own error types and convert
10/// from/into `McpxError` where needed.
11///
12/// # Client-facing message invariant
13///
14/// The `String` payloads of [`Auth`](Self::Auth), [`Rbac`](Self::Rbac),
15/// [`RateLimited`](Self::RateLimited), and the `message` field of
16/// [`RateLimitedFor`](Self::RateLimitedFor) are rendered **verbatim to the
17/// HTTP client** by [`IntoResponse`]. Construction sites MUST keep these
18/// client-safe: no internal error text, source-error chains, file paths, IPs,
19/// SQL, or dependency details. Internal-only variants (`Config`, `Io`, `Json`,
20/// `Toml`, `Tls`, `Startup`, `Metrics`) are collapsed to a generic
21/// `"internal server error"` body and their detail is logged server-side only.
22/// Use [`client_message`](Self::client_message) to obtain the exact body that
23/// will be sent to the client for any variant.
24#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum McpxError {
27    /// Configuration parsing or validation failed.
28    #[error("configuration error: {0}")]
29    Config(String),
30
31    /// Authentication failed (bad/missing credential).
32    #[error("authentication failed: {0}")]
33    Auth(String),
34
35    /// Authorization (RBAC) denied the request.
36    #[error("authorization denied: {0}")]
37    Rbac(String),
38
39    /// Request was rejected by a rate limiter.
40    #[error("rate limited: {0}")]
41    RateLimited(String),
42
43    /// Request was rejected by a rate limiter that knows the wait time.
44    ///
45    /// Renders as HTTP 429 with a `Retry-After` header (RFC 9110
46    /// delta-seconds: the duration is rounded **up** to whole seconds,
47    /// minimum `1`) and the message as a plain-text body. The legacy
48    /// [`RateLimited`](Self::RateLimited) variant remains headerless.
49    #[error("rate limited: {message} (retry after {retry_after:?})")]
50    RateLimitedFor {
51        /// Plain-text client-facing message (response body).
52        message: String,
53        /// Best-effort wait until the next request could be admitted.
54        retry_after: std::time::Duration,
55    },
56
57    /// Underlying I/O error.
58    #[error("I/O error: {0}")]
59    Io(#[from] std::io::Error),
60
61    /// JSON (de)serialization error.
62    #[error("JSON error: {0}")]
63    Json(#[from] serde_json::Error),
64
65    /// TOML parse error (configuration loading).
66    #[error("TOML parse error: {0}")]
67    Toml(#[from] toml::de::Error),
68
69    /// TLS configuration failure (certificate load, key parse, rustls config).
70    #[error("TLS error: {0}")]
71    Tls(String),
72
73    /// Server startup failure (binding, listener, runtime initialization).
74    #[error("server startup error: {0}")]
75    Startup(String),
76
77    /// Metrics registration failure (e.g. Prometheus duplicate or invalid metric).
78    #[cfg(feature = "metrics")]
79    #[error("metrics error: {0}")]
80    Metrics(String),
81}
82
83/// Render a wait [`Duration`](std::time::Duration) as RFC 9110
84/// `Retry-After` delta-seconds: rounded **up** to whole seconds, never
85/// below `1` (a `0` would invite an immediate retry storm).
86fn retry_after_secs(wait: std::time::Duration) -> u64 {
87    let mut secs = wait.as_secs();
88    if wait.subsec_nanos() > 0 {
89        secs = secs.saturating_add(1);
90    }
91    secs.max(1)
92}
93
94impl McpxError {
95    /// The exact body this error sends to the HTTP client.
96    ///
97    /// Client-facing variants ([`Auth`](Self::Auth), [`Rbac`](Self::Rbac),
98    /// [`RateLimited`](Self::RateLimited), [`RateLimitedFor`](Self::RateLimitedFor))
99    /// return their message verbatim; all internal variants return the generic
100    /// `"internal server error"` so implementation detail never leaks on the
101    /// wire. This is the single source of truth for the client body — the
102    /// [`IntoResponse`] impl uses it — so callers can assert or reuse the
103    /// client-safe text without duplicating the mapping.
104    ///
105    /// See the type-level "Client-facing message invariant" for the contract
106    /// construction sites must uphold.
107    #[must_use]
108    pub fn client_message(&self) -> std::borrow::Cow<'_, str> {
109        use std::borrow::Cow;
110        match self {
111            Self::Auth(msg) | Self::Rbac(msg) | Self::RateLimited(msg) => Cow::Borrowed(msg),
112            Self::RateLimitedFor { message, .. } => Cow::Borrowed(message),
113            // Internal variants: never leak detail to the client.
114            Self::Config(_)
115            | Self::Io(_)
116            | Self::Json(_)
117            | Self::Toml(_)
118            | Self::Tls(_)
119            | Self::Startup(_) => Cow::Borrowed("internal server error"),
120            #[cfg(feature = "metrics")]
121            Self::Metrics(_) => Cow::Borrowed("internal server error"),
122        }
123    }
124}
125
126impl IntoResponse for McpxError {
127    fn into_response(self) -> Response {
128        let (status, client_msg) = match self {
129            Self::Auth(msg) => (StatusCode::UNAUTHORIZED, msg),
130            Self::Rbac(msg) => (StatusCode::FORBIDDEN, msg),
131            Self::RateLimited(msg) => (StatusCode::TOO_MANY_REQUESTS, msg),
132            Self::RateLimitedFor {
133                message,
134                retry_after,
135            } => {
136                return (
137                    StatusCode::TOO_MANY_REQUESTS,
138                    [(
139                        axum::http::header::RETRY_AFTER,
140                        retry_after_secs(retry_after).to_string(),
141                    )],
142                    message,
143                )
144                    .into_response();
145            }
146            // All remaining variants are internal - return a generic 500
147            // to avoid leaking implementation details.
148            other @ (Self::Config(_)
149            | Self::Io(_)
150            | Self::Json(_)
151            | Self::Toml(_)
152            | Self::Tls(_)
153            | Self::Startup(_)) => {
154                tracing::error!(error = %other, "internal error");
155                (
156                    StatusCode::INTERNAL_SERVER_ERROR,
157                    "internal server error".into(),
158                )
159            }
160            #[cfg(feature = "metrics")]
161            other @ Self::Metrics(_) => {
162                tracing::error!(error = %other, "internal error");
163                (
164                    StatusCode::INTERNAL_SERVER_ERROR,
165                    "internal server error".into(),
166                )
167            }
168        };
169        (status, client_msg).into_response()
170    }
171}
172
173/// Convenience `Result` alias bound to [`McpxError`].
174pub type Result<T> = std::result::Result<T, McpxError>;
175
176#[cfg(test)]
177mod tests {
178    use axum::{http::StatusCode, response::IntoResponse};
179    use http_body_util::BodyExt;
180
181    use super::*;
182
183    async fn status_of(err: McpxError) -> (StatusCode, String) {
184        let resp = err.into_response();
185        let status = resp.status();
186        let body = resp.into_body().collect().await.unwrap().to_bytes();
187        (status, String::from_utf8(body.to_vec()).unwrap())
188    }
189
190    #[tokio::test]
191    async fn auth_error_returns_401() {
192        let (status, body) = status_of(McpxError::Auth("bad token".into())).await;
193        assert_eq!(status, StatusCode::UNAUTHORIZED);
194        assert!(body.contains("bad token"));
195    }
196
197    #[tokio::test]
198    async fn rbac_error_returns_403() {
199        let (status, body) = status_of(McpxError::Rbac("denied".into())).await;
200        assert_eq!(status, StatusCode::FORBIDDEN);
201        assert!(body.contains("denied"));
202    }
203
204    #[tokio::test]
205    async fn rate_limited_error_returns_429() {
206        let (status, body) = status_of(McpxError::RateLimited("slow down".into())).await;
207        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
208        assert!(body.contains("slow down"));
209    }
210
211    #[tokio::test]
212    async fn legacy_rate_limited_has_no_retry_after_header() {
213        let resp = McpxError::RateLimited("slow down".into()).into_response();
214        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
215        assert!(
216            !resp.headers().contains_key(axum::http::header::RETRY_AFTER),
217            "legacy variant must stay headerless"
218        );
219    }
220
221    #[tokio::test]
222    async fn rate_limited_for_sets_retry_after_header() {
223        let resp = McpxError::RateLimitedFor {
224            message: "slow down".into(),
225            retry_after: std::time::Duration::from_millis(1500),
226        }
227        .into_response();
228        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
229        let header = resp
230            .headers()
231            .get(axum::http::header::RETRY_AFTER)
232            .expect("Retry-After present")
233            .to_str()
234            .unwrap()
235            .to_owned();
236        assert_eq!(header, "2", "1.5s must round UP to 2");
237        let body = resp.into_body().collect().await.unwrap().to_bytes();
238        assert_eq!(body.as_ref(), b"slow down");
239    }
240
241    #[test]
242    fn retry_after_secs_rounds_up_and_never_zero() {
243        use std::time::Duration;
244        assert_eq!(retry_after_secs(Duration::ZERO), 1, "zero floors to 1");
245        assert_eq!(retry_after_secs(Duration::from_millis(1)), 1);
246        assert_eq!(retry_after_secs(Duration::from_millis(999)), 1);
247        assert_eq!(retry_after_secs(Duration::from_secs(1)), 1, "exact stays");
248        assert_eq!(retry_after_secs(Duration::from_millis(1001)), 2, "ceil");
249        assert_eq!(retry_after_secs(Duration::from_secs(60)), 60);
250    }
251
252    #[tokio::test]
253    async fn config_error_returns_500() {
254        let (status, body) = status_of(McpxError::Config("bad".into())).await;
255        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
256        assert_eq!(
257            body, "internal server error",
258            "must not leak internal detail"
259        );
260    }
261
262    #[tokio::test]
263    async fn io_error_returns_500() {
264        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "gone");
265        let (status, body) = status_of(McpxError::from(io_err)).await;
266        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
267        assert_eq!(
268            body, "internal server error",
269            "must not leak internal detail"
270        );
271    }
272
273    #[tokio::test]
274    async fn tls_error_returns_500() {
275        let (status, body) = status_of(McpxError::Tls("bad cert".into())).await;
276        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
277        assert_eq!(
278            body, "internal server error",
279            "must not leak internal detail"
280        );
281    }
282
283    #[tokio::test]
284    async fn startup_error_returns_500() {
285        let (status, body) = status_of(McpxError::Startup("bind failed".into())).await;
286        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
287        assert_eq!(
288            body, "internal server error",
289            "must not leak internal detail"
290        );
291    }
292
293    #[cfg(feature = "metrics")]
294    #[tokio::test]
295    async fn metrics_error_returns_500() {
296        let (status, body) = status_of(McpxError::Metrics("dup metric".into())).await;
297        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
298        assert_eq!(
299            body, "internal server error",
300            "must not leak internal detail"
301        );
302    }
303
304    #[test]
305    fn display_preserves_message() {
306        let err = McpxError::Auth("unauthorized".into());
307        assert_eq!(err.to_string(), "authentication failed: unauthorized");
308
309        let err = McpxError::Rbac("forbidden".into());
310        assert_eq!(err.to_string(), "authorization denied: forbidden");
311
312        let err = McpxError::RateLimited("throttled".into());
313        assert_eq!(err.to_string(), "rate limited: throttled");
314    }
315
316    #[test]
317    fn client_message_exposes_client_facing_text_and_hides_internal_detail() {
318        // Client-facing variants: message passes through verbatim.
319        assert_eq!(
320            McpxError::Auth("bad token".into()).client_message(),
321            "bad token"
322        );
323        assert_eq!(McpxError::Rbac("nope".into()).client_message(), "nope");
324        assert_eq!(
325            McpxError::RateLimited("slow down".into()).client_message(),
326            "slow down"
327        );
328        assert_eq!(
329            McpxError::RateLimitedFor {
330                message: "too many".into(),
331                retry_after: std::time::Duration::from_secs(1),
332            }
333            .client_message(),
334            "too many"
335        );
336
337        // Internal variants: detail is hidden behind a generic body.
338        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "secret/path/leak");
339        assert_eq!(
340            McpxError::from(io_err).client_message(),
341            "internal server error"
342        );
343        assert_eq!(
344            McpxError::Tls("private key /etc/certs/server.key".into()).client_message(),
345            "internal server error"
346        );
347        assert_eq!(
348            McpxError::Config("bind 10.0.0.5:8443 failed".into()).client_message(),
349            "internal server error"
350        );
351    }
352
353    #[tokio::test]
354    async fn client_message_matches_into_response_body() {
355        // The accessor and the wire body must agree for every variant we test.
356        for err in [
357            McpxError::Auth("a".into()),
358            McpxError::Rbac("b".into()),
359            McpxError::RateLimited("c".into()),
360            McpxError::Config("d".into()),
361            McpxError::Tls("e".into()),
362        ] {
363            let expected = err.client_message().into_owned();
364            let (_status, body) = status_of(err).await;
365            assert_eq!(body, expected, "client_message must equal the wire body");
366        }
367    }
368}