use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::body::Body;
use http::Request;
use tower::Service;
use crate::context::{RequestContext, layer::ContextService, run_with_context};
type HttpRequest = Request<Body>;
impl<S> Service<HttpRequest> for ContextService<S>
where
S: Service<HttpRequest> + Clone + Send + 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: HttpRequest) -> Self::Future {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
let ctx = RequestContext::from_http(&req);
Box::pin(async move { run_with_context(ctx, inner.call(req)).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{context::get_context, context::layer::ContextLayer, http::routing::Router};
#[tokio::test]
async fn test_http_context_layer_with_request_id() {
async fn handler() -> Result<String, crate::error::Error> {
let ctx = get_context()?;
Ok(ctx.request_id().to_string())
}
let router = Router::compose(|router| {
router.middleware(ContextLayer).get("/", handler);
});
let req = Request::builder()
.uri("/")
.header("x-request-id", "test-123")
.body(Body::empty())
.unwrap();
let response = router.build().call(req).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(&body[..], b"test-123");
}
#[tokio::test]
async fn test_http_context_layer_auto_generate_request_id() {
async fn handler() -> Result<String, crate::error::Error> {
let ctx = get_context()?;
let request_id = ctx.request_id().to_string();
assert!(!request_id.is_empty());
Ok(request_id)
}
let router = Router::compose(|router| {
router.middleware(ContextLayer).get("/", handler);
});
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
let response = router.build().call(req).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(!body_str.is_empty());
}
}