use std::sync::{Arc, RwLock};
use super::RequestId;
use axum::extract::Request;
use http::Extensions;
#[derive(Clone)]
pub struct RequestContext {
request_id: RequestId,
inner: Arc<RwLock<Extensions>>,
}
impl RequestContext {
#[cfg(test)]
pub(crate) fn new() -> Self {
Self {
request_id: RequestId::new(),
inner: Arc::new(RwLock::new(Extensions::new())),
}
}
pub(crate) fn from_http(req: &Request) -> Self {
let request_id = req
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(RequestId::from)
.unwrap_or_default();
let mut inner = Extensions::new();
inner.insert(req.method().clone());
inner.insert(req.uri().clone());
Self {
request_id,
inner: Arc::new(RwLock::new(inner)),
}
}
#[cfg(feature = "grpc")]
pub(crate) fn from_grpc(metadata: &tonic::metadata::MetadataMap) -> Self {
let request_id = metadata
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(RequestId::from)
.unwrap_or_default();
let inner = Extensions::new();
Self {
request_id,
inner: Arc::new(RwLock::new(inner)),
}
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub fn set<T: Send + Sync + Clone + 'static>(&self, value: T) {
self.inner.write().unwrap().insert(value);
}
pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
self.inner.read().unwrap().get::<T>().cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Method;
#[derive(Clone, Debug, PartialEq)]
struct UserId(u64);
#[tokio::test]
async fn test_request_context() {
let ctx = RequestContext::new();
ctx.set(Method::GET);
ctx.set::<UserId>(UserId(42));
assert_eq!(ctx.get::<Method>(), Some(Method::GET));
assert_eq!(ctx.get::<UserId>(), Some(UserId(42)));
}
}