nidus_http/middleware/
request_context.rs1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use axum::extract::Request;
8use tower::{Layer, Service};
9
10use crate::context::{RequestContext, header_to_string};
11
12pub fn request_context_layer() -> RequestContextLayer {
24 RequestContextLayer
25}
26
27#[derive(Clone, Copy, Debug, Default)]
36pub struct RequestContextLayer;
37
38impl<S> Layer<S> for RequestContextLayer {
39 type Service = RequestContextService<S>;
40
41 fn layer(&self, inner: S) -> Self::Service {
42 RequestContextService { inner }
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct RequestContextService<S> {
49 inner: S,
50}
51
52impl<S> Service<Request> for RequestContextService<S>
53where
54 S: Service<Request> + Send + 'static,
55 S::Future: Send + 'static,
56 S::Error: Send + 'static,
57{
58 type Response = S::Response;
59 type Error = S::Error;
60 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 self.inner.poll_ready(cx)
64 }
65
66 fn call(&mut self, mut request: Request) -> Self::Future {
67 let (mut parts, body) = request.into_parts();
68 let request_id = parts
69 .extensions
70 .remove::<RequestContext>()
71 .map(RequestContext::into_request_id)
72 .or_else(|| header_to_string(&parts.headers, "x-request-id"))
73 .unwrap_or_else(|| "unknown".to_owned());
74 let context = RequestContext::from_parts(&parts, request_id);
75 parts.extensions.insert(context);
76 request = Request::from_parts(parts, body);
77 let future = self.inner.call(request);
78 Box::pin(future)
79 }
80}