use crate::logging;
use crate::service::request_id::RequestId;
use crate::service::unverified_jwt::UnverifiedJwt;
use crate::service::{Layer, Service};
use http::Request;
use witchcraft_log::mdc;
pub struct WitchcraftMdcLayer;
impl<S> Layer<S> for WitchcraftMdcLayer {
type Service = WitchcraftMdcService<S>;
fn layer(self, inner: S) -> Self::Service {
WitchcraftMdcService { inner }
}
}
pub struct WitchcraftMdcService<S> {
inner: S,
}
impl<S, B> Service<Request<B>> for WitchcraftMdcService<S>
where
S: Service<Request<B>> + Sync,
B: Send,
{
type Response = S::Response;
async fn call(&self, req: Request<B>) -> Self::Response {
if let Some(jwt) = req.extensions().get::<UnverifiedJwt>() {
mdc::insert_safe(logging::mdc::UID_KEY, jwt.unverified_user_id());
if let Some(session_id) = jwt.unverified_session_id() {
mdc::insert_safe(logging::mdc::SID_KEY, session_id);
}
if let Some(token_id) = jwt.unverified_token_id() {
mdc::insert_safe(logging::mdc::TOKEN_ID_KEY, token_id);
}
if let Some(org_id) = jwt.unverified_organization_id() {
mdc::insert_safe(logging::mdc::ORG_ID_KEY, org_id);
}
}
let context = zipkin::current().expect("zipkin trace not initialized");
mdc::insert_safe(logging::mdc::TRACE_ID_KEY, context.trace_id().to_string());
if let Some(sampled) = context.sampled() {
mdc::insert_safe(logging::SAMPLED_KEY, sampled);
}
let request_id = req
.extensions()
.get::<RequestId>()
.expect("RequestId missing from request extensions");
mdc::insert_safe(logging::REQUEST_ID_KEY, request_id.to_string());
self.inner.call(req).await
}
}