mod errors;
pub(crate) mod layer;
mod request_context;
mod request_id;
use axum::extract::FromRequestParts;
use std::future::Future;
use tokio::task_local;
use crate::error::{Error, Result, TypedError};
pub use errors::*;
pub use request_context::RequestContext;
pub use request_id::*;
task_local! {
static REQUEST_CONTEXT: RequestContext;
}
pub(crate) async fn run_with_context<Fut, R>(context: RequestContext, fut: Fut) -> R
where
Fut: Future<Output = R>,
{
REQUEST_CONTEXT.scope(context, fut).await
}
pub fn get_context() -> Result<RequestContext> {
REQUEST_CONTEXT
.try_with(|ctx| ctx.clone())
.map_err(|_| OutOfScope::error("get_context() called outside request scope"))
}
impl<S> FromRequestParts<S> for RequestContext
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
get_context()
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Method;
#[tokio::test]
async fn test_context() {
let ctx = RequestContext::new();
ctx.set(Method::GET);
run_with_context(ctx, async {
let ctx = get_context().unwrap();
assert!(ctx.get::<Method>() == Some(Method::GET));
})
.await;
}
#[tokio::test]
async fn test_nested_context() {
let outer_ctx = RequestContext::new();
outer_ctx.set(Method::GET);
run_with_context(outer_ctx.clone(), async {
let ctx = get_context().unwrap();
assert!(ctx.get::<Method>() == Some(Method::GET));
let inner_ctx = RequestContext::new();
inner_ctx.set(Method::POST);
run_with_context(inner_ctx.clone(), async {
let ctx = get_context().unwrap();
assert!(ctx.get::<Method>() == Some(Method::POST));
})
.await;
let ctx = get_context().unwrap();
assert!(ctx.get::<Method>() == Some(Method::GET));
})
.await;
}
#[test]
fn test_get_context_outside() {
use crate::error::ErrorCode;
match get_context() {
Ok(_) => panic!("should return error outside request context"),
Err(err) => {
assert!(err.is_code(ErrorCode::Internal));
assert!(err.is(OutOfScope));
}
}
}
}