use axum::{
extract::{FromRequestParts, OptionalFromRequestParts, Request},
http::{request::Parts, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub user_id: i64,
pub tenant_id: Option<i64>,
pub username: String,
}
pub async fn user_context_middleware(request: Request, next: Next) -> Response {
let mut request = request;
let user_id = request
.headers()
.get("X-User-Id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok());
if let Some(user_id) = user_id {
let tenant_id = request
.headers()
.get("X-Tenant-Id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok());
let username = request
.headers()
.get("X-Username")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let ctx = RequestContext {
user_id,
tenant_id,
username,
};
request.extensions_mut().insert(ctx);
}
next.run(request).await
}
impl<S> FromRequestParts<S> for RequestContext
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<RequestContext>()
.cloned()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "Missing user context").into_response())
}
}
impl<S> OptionalFromRequestParts<S> for RequestContext
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<RequestContext>().cloned())
}
}