Skip to main content

axum_anyhow/
middleware.rs

1//! Middleware for enriching API errors with request context.
2//!
3//! This module provides a middleware layer and global hook system for automatically
4//! enriching errors with request-specific metadata like URIs, methods, headers, etc.
5
6use crate::ApiErrorBuilder;
7use axum::{
8    extract::Request,
9    http::{HeaderMap, Method, Uri},
10    response::Response,
11};
12use futures_util::future::BoxFuture;
13use std::{
14    cell::RefCell,
15    sync::Arc,
16    task::{Context, Poll},
17};
18use tower::{Layer, Service};
19
20thread_local! {
21    static ENRICHMENT_CONTEXT: RefCell<Option<EnrichmentContext>> = const { RefCell::new(None) };
22}
23
24/// Request information snapshot available to the error enricher.
25///
26/// This struct contains request metadata that can be used to enrich errors.
27#[derive(Clone, Debug)]
28pub struct RequestSnapshot {
29    /// The HTTP method of the request
30    method: Method,
31    /// The URI of the request
32    uri: Uri,
33    /// The HTTP headers of the request
34    headers: HeaderMap,
35}
36
37impl RequestSnapshot {
38    /// Returns a reference to the HTTP method of the request.
39    pub fn method(&self) -> &Method {
40        &self.method
41    }
42
43    /// Returns a reference to the URI of the request.
44    pub fn uri(&self) -> &Uri {
45        &self.uri
46    }
47
48    /// Returns a reference to the HTTP headers of the request.
49    pub fn headers(&self) -> &HeaderMap {
50        &self.headers
51    }
52
53    /// Creates a `RequestSnapshot` from an Axum `Request`.
54    ///
55    /// Extracts the method, URI, headers, and extensions from the request.
56    pub fn from_request(request: &Request) -> Self {
57        Self {
58            method: request.method().clone(),
59            uri: request.uri().clone(),
60            headers: request.headers().clone(),
61        }
62    }
63}
64
65/// Type alias for the error enricher function.
66type ErrorEnricher =
67    Arc<dyn Fn(ApiErrorBuilder, &RequestSnapshot) -> ApiErrorBuilder + Send + Sync + 'static>;
68
69/// Context for enriching errors with request information.
70///
71/// This struct combines the request context with the enricher callback,
72/// making it easier to pass both pieces of data together through the middleware.
73#[derive(Clone)]
74pub(crate) struct EnrichmentContext {
75    request: RequestSnapshot,
76    enricher: ErrorEnricher,
77}
78
79impl EnrichmentContext {
80    /// Creates a new `EnrichmentContext` with the given context and enricher.
81    fn new(request: RequestSnapshot, enricher: ErrorEnricher) -> Self {
82        Self { request, enricher }
83    }
84
85    /// Installs this enrichment context as the current thread-local data.
86    fn set(self) {
87        ENRICHMENT_CONTEXT.with(|data| {
88            *data.borrow_mut() = Some(self);
89        });
90    }
91
92    /// Removes the current thread-local enrichment context.
93    fn clear() {
94        ENRICHMENT_CONTEXT.with(|data| {
95            *data.borrow_mut() = None;
96        });
97    }
98
99    /// Applies the enricher to the given builder.
100    fn apply(&self, builder: ApiErrorBuilder) -> ApiErrorBuilder {
101        (self.enricher)(builder, &self.request)
102    }
103
104    /// Invokes the error enricher if one is set and request context is available.
105    ///
106    /// This is called internally by `ApiErrorBuilder::build()`.
107    pub(crate) fn invoke(builder: ApiErrorBuilder) -> ApiErrorBuilder {
108        ENRICHMENT_CONTEXT.with(|data| {
109            if let Some(enrichment_ctx) = data.borrow().as_ref() {
110                enrichment_ctx.apply(builder)
111            } else {
112                builder
113            }
114        })
115    }
116}
117
118/// Service that captures request context and makes it available for error enrichment.
119pub struct ErrorInterceptor<S> {
120    inner: S,
121    enricher: ErrorEnricher,
122}
123
124impl<S> Clone for ErrorInterceptor<S>
125where
126    S: Clone,
127{
128    fn clone(&self) -> Self {
129        Self {
130            inner: self.inner.clone(),
131            enricher: self.enricher.clone(),
132        }
133    }
134}
135
136impl<S> Service<Request> for ErrorInterceptor<S>
137where
138    S: Service<Request, Response = Response> + Send + 'static,
139    S::Future: Send + 'static,
140{
141    type Response = S::Response;
142    type Error = S::Error;
143    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
144
145    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        self.inner.poll_ready(cx)
147    }
148
149    fn call(&mut self, request: Request) -> Self::Future {
150        // Capture request context
151        let snapshot = RequestSnapshot::from_request(&request);
152        let ctx = EnrichmentContext::new(snapshot, self.enricher.clone());
153
154        let future = self.inner.call(request);
155
156        Box::pin(async move {
157            // Install enrichment context for this task
158            ctx.set();
159
160            // Call the inner service
161            let result = future.await;
162
163            // Remove enrichment context after request completes
164            EnrichmentContext::clear();
165
166            result
167        })
168    }
169}
170
171/// Middleware layer that enables error enrichment with request context.
172///
173/// This layer captures request information (method, URI, headers) and makes it available
174/// to the error enricher callback.
175///
176/// # Example
177///
178/// ```rust
179/// use axum::Router;
180/// use axum_anyhow::ErrorInterceptorLayer;
181/// use serde_json::json;
182///
183/// // Create the layer with an enricher
184/// let enricher_layer = ErrorInterceptorLayer::new(|builder, ctx| {
185///     builder.meta(json!({
186///         "method": ctx.method().as_str(),
187///         "uri": ctx.uri().to_string(),
188///         "user_agent": ctx.headers()
189///             .get("user-agent")
190///             .and_then(|v| v.to_str().ok())
191///             .unwrap_or("unknown"),
192///     }))
193/// });
194///
195/// // Apply the middleware
196/// let app: Router = Router::new()
197///     .layer(enricher_layer);
198/// ```
199#[derive(Clone)]
200pub struct ErrorInterceptorLayer {
201    enricher: ErrorEnricher,
202}
203
204impl ErrorInterceptorLayer {
205    /// Creates a new `ErrorInterceptorLayer` with the given enricher function.
206    ///
207    /// The enricher will be called for every error created during request handling,
208    /// allowing you to add request-specific metadata.
209    pub fn new<F>(enricher: F) -> Self
210    where
211        F: Fn(ApiErrorBuilder, &RequestSnapshot) -> ApiErrorBuilder + Send + Sync + 'static,
212    {
213        Self {
214            enricher: Arc::new(enricher),
215        }
216    }
217}
218
219impl<S> Layer<S> for ErrorInterceptorLayer {
220    type Service = ErrorInterceptor<S>;
221
222    fn layer(&self, inner: S) -> Self::Service {
223        ErrorInterceptor {
224            inner,
225            enricher: self.enricher.clone(),
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use axum::http::StatusCode;
234    use serde_json::json;
235    use serial_test::serial;
236
237    #[test]
238    #[serial]
239    fn test_error_enricher() {
240        let enricher = Arc::new(|builder: ApiErrorBuilder, req: &RequestSnapshot| {
241            builder.meta(json!({
242                "method": req.method.as_str(),
243                "uri": req.uri.to_string(),
244            }))
245        });
246
247        // Set up request context with enricher
248        let snapshot = RequestSnapshot {
249            method: Method::GET,
250            uri: "/test".parse().unwrap(),
251            headers: HeaderMap::default(),
252        };
253        EnrichmentContext::new(snapshot, enricher).set();
254
255        // Build an error
256        let error = crate::ApiError::builder()
257            .status(StatusCode::NOT_FOUND)
258            .title("Not Found")
259            .detail("Resource not found")
260            .build();
261
262        // Verify enrichment happened
263        assert!(error.meta().is_some());
264        let meta = error.meta().unwrap();
265        assert_eq!(meta["method"], "GET");
266        assert_eq!(meta["uri"], "/test");
267
268        EnrichmentContext::clear();
269    }
270
271    #[test]
272    #[serial]
273    fn test_enricher_without_context() {
274        // No request context set
275        EnrichmentContext::clear();
276
277        // Build an error
278        let error = crate::ApiError::builder()
279            .status(StatusCode::BAD_REQUEST)
280            .title("Bad Request")
281            .detail("Invalid input")
282            .build();
283
284        // Enrichment should not happen without context
285        assert!(error.meta().is_none());
286    }
287
288    #[test]
289    #[serial]
290    fn test_request_data_lifecycle() {
291        let snapshot = RequestSnapshot {
292            method: Method::POST,
293            uri: "/api/users".parse().unwrap(),
294            headers: HeaderMap::default(),
295        };
296        let enricher = Arc::new(|builder: ApiErrorBuilder, _req: &RequestSnapshot| builder);
297
298        // Install enrichment context
299        EnrichmentContext::new(snapshot.clone(), enricher).set();
300
301        // Verify it's set
302        ENRICHMENT_CONTEXT.with(|data| {
303            let borrowed = data.borrow();
304            assert!(borrowed.is_some());
305            let stored_req = &borrowed.as_ref().unwrap().request;
306            assert_eq!(stored_req.method, Method::POST);
307            assert_eq!(stored_req.uri.to_string(), "/api/users");
308        });
309
310        // Remove enrichment context
311        EnrichmentContext::clear();
312
313        // Verify it's cleared
314        ENRICHMENT_CONTEXT.with(|data| {
315            assert!(data.borrow().is_none());
316        });
317    }
318}