axum_anyhow/
middleware.rs1use crate::ApiErrorBuilder;
7use axum::{
8 extract::Request,
9 http::{HeaderMap, Method, Uri},
10 response::Response,
11};
12use futures_util::future::BoxFuture;
13use std::{
14 cell::RefCell,
15 sync::Arc,
16 task::{Context, Poll},
17};
18use tower::{Layer, Service};
19
20thread_local! {
21 static ENRICHMENT_CONTEXT: RefCell<Option<EnrichmentContext>> = const { RefCell::new(None) };
22}
23
24#[derive(Clone, Debug)]
28pub struct RequestSnapshot {
29 method: Method,
31 uri: Uri,
33 headers: HeaderMap,
35}
36
37impl RequestSnapshot {
38 pub fn method(&self) -> &Method {
40 &self.method
41 }
42
43 pub fn uri(&self) -> &Uri {
45 &self.uri
46 }
47
48 pub fn headers(&self) -> &HeaderMap {
50 &self.headers
51 }
52
53 pub fn from_request(request: &Request) -> Self {
57 Self {
58 method: request.method().clone(),
59 uri: request.uri().clone(),
60 headers: request.headers().clone(),
61 }
62 }
63}
64
65type ErrorEnricher =
67 Arc<dyn Fn(ApiErrorBuilder, &RequestSnapshot) -> ApiErrorBuilder + Send + Sync + 'static>;
68
69#[derive(Clone)]
74pub(crate) struct EnrichmentContext {
75 request: RequestSnapshot,
76 enricher: ErrorEnricher,
77}
78
79impl EnrichmentContext {
80 fn new(request: RequestSnapshot, enricher: ErrorEnricher) -> Self {
82 Self { request, enricher }
83 }
84
85 fn set(self) {
87 ENRICHMENT_CONTEXT.with(|data| {
88 *data.borrow_mut() = Some(self);
89 });
90 }
91
92 fn clear() {
94 ENRICHMENT_CONTEXT.with(|data| {
95 *data.borrow_mut() = None;
96 });
97 }
98
99 fn apply(&self, builder: ApiErrorBuilder) -> ApiErrorBuilder {
101 (self.enricher)(builder, &self.request)
102 }
103
104 pub(crate) fn invoke(builder: ApiErrorBuilder) -> ApiErrorBuilder {
108 ENRICHMENT_CONTEXT.with(|data| {
109 if let Some(enrichment_ctx) = data.borrow().as_ref() {
110 enrichment_ctx.apply(builder)
111 } else {
112 builder
113 }
114 })
115 }
116}
117
118pub struct ErrorInterceptor<S> {
120 inner: S,
121 enricher: ErrorEnricher,
122}
123
124impl<S> Clone for ErrorInterceptor<S>
125where
126 S: Clone,
127{
128 fn clone(&self) -> Self {
129 Self {
130 inner: self.inner.clone(),
131 enricher: self.enricher.clone(),
132 }
133 }
134}
135
136impl<S> Service<Request> for ErrorInterceptor<S>
137where
138 S: Service<Request, Response = Response> + Send + 'static,
139 S::Future: Send + 'static,
140{
141 type Response = S::Response;
142 type Error = S::Error;
143 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
144
145 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 self.inner.poll_ready(cx)
147 }
148
149 fn call(&mut self, request: Request) -> Self::Future {
150 let snapshot = RequestSnapshot::from_request(&request);
152 let ctx = EnrichmentContext::new(snapshot, self.enricher.clone());
153
154 let future = self.inner.call(request);
155
156 Box::pin(async move {
157 ctx.set();
159
160 let result = future.await;
162
163 EnrichmentContext::clear();
165
166 result
167 })
168 }
169}
170
171#[derive(Clone)]
200pub struct ErrorInterceptorLayer {
201 enricher: ErrorEnricher,
202}
203
204impl ErrorInterceptorLayer {
205 pub fn new<F>(enricher: F) -> Self
210 where
211 F: Fn(ApiErrorBuilder, &RequestSnapshot) -> ApiErrorBuilder + Send + Sync + 'static,
212 {
213 Self {
214 enricher: Arc::new(enricher),
215 }
216 }
217}
218
219impl<S> Layer<S> for ErrorInterceptorLayer {
220 type Service = ErrorInterceptor<S>;
221
222 fn layer(&self, inner: S) -> Self::Service {
223 ErrorInterceptor {
224 inner,
225 enricher: self.enricher.clone(),
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use axum::http::StatusCode;
234 use serde_json::json;
235 use serial_test::serial;
236
237 #[test]
238 #[serial]
239 fn test_error_enricher() {
240 let enricher = Arc::new(|builder: ApiErrorBuilder, req: &RequestSnapshot| {
241 builder.meta(json!({
242 "method": req.method.as_str(),
243 "uri": req.uri.to_string(),
244 }))
245 });
246
247 let snapshot = RequestSnapshot {
249 method: Method::GET,
250 uri: "/test".parse().unwrap(),
251 headers: HeaderMap::default(),
252 };
253 EnrichmentContext::new(snapshot, enricher).set();
254
255 let error = crate::ApiError::builder()
257 .status(StatusCode::NOT_FOUND)
258 .title("Not Found")
259 .detail("Resource not found")
260 .build();
261
262 assert!(error.meta().is_some());
264 let meta = error.meta().unwrap();
265 assert_eq!(meta["method"], "GET");
266 assert_eq!(meta["uri"], "/test");
267
268 EnrichmentContext::clear();
269 }
270
271 #[test]
272 #[serial]
273 fn test_enricher_without_context() {
274 EnrichmentContext::clear();
276
277 let error = crate::ApiError::builder()
279 .status(StatusCode::BAD_REQUEST)
280 .title("Bad Request")
281 .detail("Invalid input")
282 .build();
283
284 assert!(error.meta().is_none());
286 }
287
288 #[test]
289 #[serial]
290 fn test_request_data_lifecycle() {
291 let snapshot = RequestSnapshot {
292 method: Method::POST,
293 uri: "/api/users".parse().unwrap(),
294 headers: HeaderMap::default(),
295 };
296 let enricher = Arc::new(|builder: ApiErrorBuilder, _req: &RequestSnapshot| builder);
297
298 EnrichmentContext::new(snapshot.clone(), enricher).set();
300
301 ENRICHMENT_CONTEXT.with(|data| {
303 let borrowed = data.borrow();
304 assert!(borrowed.is_some());
305 let stored_req = &borrowed.as_ref().unwrap().request;
306 assert_eq!(stored_req.method, Method::POST);
307 assert_eq!(stored_req.uri.to_string(), "/api/users");
308 });
309
310 EnrichmentContext::clear();
312
313 ENRICHMENT_CONTEXT.with(|data| {
315 assert!(data.borrow().is_none());
316 });
317 }
318}