use std::{
future::Future,
net::{IpAddr, SocketAddr},
};
use axum::extract::FromRequestParts;
use http::{HeaderMap, Method, request::Parts};
use serde::Serialize;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ClientKind {
ApiKey,
Authenticated,
Anonymous,
}
impl ClientKind {
pub const fn as_str(self) -> &'static str {
match self {
Self::ApiKey => "api_key",
Self::Authenticated => "authenticated",
Self::Anonymous => "anonymous",
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RequestContext {
request_id: String,
correlation_id: Option<String>,
method: Method,
route: Option<String>,
path: String,
trace_id: Option<String>,
span_id: Option<String>,
client_kind: ClientKind,
user_id: Option<String>,
tenant_id: Option<String>,
session_id: Option<String>,
}
impl RequestContext {
pub fn new(request_id: impl Into<String>, method: Method, path: impl Into<String>) -> Self {
Self {
request_id: request_id.into(),
correlation_id: None,
method,
route: None,
path: path.into(),
trace_id: None,
span_id: None,
client_kind: ClientKind::Anonymous,
user_id: None,
tenant_id: None,
session_id: None,
}
}
pub fn from_parts(parts: &Parts, request_id: impl Into<String>) -> Self {
let request_id = request_id.into();
let mut context = Self::new(request_id.clone(), parts.method.clone(), parts.uri.path());
context.correlation_id = header_to_string(&parts.headers, "x-correlation-id")
.or_else(|| Some(request_id).filter(|value| !value.is_empty()));
context.route = parts
.extensions
.get::<axum::extract::MatchedPath>()
.map(|path| path.as_str().to_owned());
context.client_kind = infer_client_kind(&parts.headers);
context.trace_id = header_to_string(&parts.headers, "traceparent")
.and_then(|value| value.split('-').nth(1).map(str::to_owned));
context
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub(crate) fn into_request_id(self) -> String {
self.request_id
}
pub fn correlation_id(&self) -> Option<&str> {
self.correlation_id.as_deref()
}
pub const fn method(&self) -> &Method {
&self.method
}
pub fn route(&self) -> Option<&str> {
self.route.as_deref()
}
pub fn path(&self) -> &str {
&self.path
}
pub fn trace_id(&self) -> Option<&str> {
self.trace_id.as_deref()
}
pub fn span_id(&self) -> Option<&str> {
self.span_id.as_deref()
}
pub const fn client_kind(&self) -> ClientKind {
self.client_kind
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn tenant_id(&self) -> Option<&str> {
self.tenant_id.as_deref()
}
pub fn session_id(&self) -> Option<&str> {
self.session_id.as_deref()
}
pub fn with_route(mut self, route: impl Into<String>) -> Self {
self.route = Some(route.into());
self
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_tenant_id(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
}
impl<S> FromRequestParts<S> for RequestContext
where
S: Send + Sync,
{
type Rejection = axum::http::StatusCode;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
let context = parts.extensions.get::<Self>().cloned();
async move { context.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR) }
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct RequestIdentity(String);
impl RequestIdentity {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
pub trait IdentityExtractor: Clone + Send + Sync + 'static {
fn extract(&self, parts: &Parts) -> Option<RequestIdentity>;
}
impl<F> IdentityExtractor for F
where
F: Fn(&Parts) -> Option<RequestIdentity> + Clone + Send + Sync + 'static,
{
fn extract(&self, parts: &Parts) -> Option<RequestIdentity> {
self(parts)
}
}
pub fn context_identity() -> impl IdentityExtractor {
|parts: &Parts| {
if let Some(context) = parts.extensions.get::<RequestContext>()
&& let Some(value) = context.user_id().or_else(|| context.tenant_id())
{
return Some(RequestIdentity::new(value.to_owned()));
}
header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
}
}
pub fn api_key_identity() -> impl IdentityExtractor {
|parts: &Parts| header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
}
pub fn client_ip_identity() -> impl IdentityExtractor {
|parts: &Parts| {
peer_ip(parts)
.map(|ip| RequestIdentity::new(ip.to_string()))
.or_else(|| Some(RequestIdentity::new("anonymous")))
}
}
pub fn trusted_proxy_client_ip_identity(
trusted_proxies: impl IntoIterator<Item = IpAddr>,
) -> impl IdentityExtractor {
let trusted_proxies = trusted_proxies.into_iter().collect::<Vec<_>>();
move |parts: &Parts| {
peer_ip(parts)
.map(|peer| {
if trusted_proxies.contains(&peer)
&& let Some(forwarded_ip) = forwarded_for_ip(&parts.headers)
{
RequestIdentity::new(forwarded_ip.to_string())
} else {
RequestIdentity::new(peer.to_string())
}
})
.or_else(|| Some(RequestIdentity::new("anonymous")))
}
}
fn peer_ip(parts: &Parts) -> Option<IpAddr> {
parts
.extensions
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|connect| connect.0.ip())
}
fn forwarded_for_ip(headers: &HeaderMap) -> Option<IpAddr> {
header_to_string(headers, "x-forwarded-for").and_then(|value| {
value
.split(',')
.next()
.map(str::trim)
.filter(|value| !value.is_empty())
.and_then(|value| value.parse().ok())
})
}
pub(crate) fn header_to_string(headers: &HeaderMap, name: &'static str) -> Option<String> {
headers
.get(name)
.and_then(|value| value.to_str().ok())
.filter(|value| !value.is_empty())
.map(str::to_owned)
}
fn infer_client_kind(headers: &HeaderMap) -> ClientKind {
if headers.contains_key("x-api-key") {
ClientKind::ApiKey
} else if headers.contains_key(http::header::AUTHORIZATION) {
ClientKind::Authenticated
} else {
ClientKind::Anonymous
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
#[test]
fn request_context_can_consume_request_id() {
let context = RequestContext::new("req-123", Method::GET, "/users");
assert_eq!(context.into_request_id(), "req-123");
}
#[test]
fn client_ip_identity_ignores_forwarded_headers_without_peer_info() {
let parts = request_parts(None, Some("203.0.113.10"));
let identity = client_ip_identity().extract(&parts).unwrap();
assert_eq!(identity.as_str(), "anonymous");
}
#[test]
fn trusted_proxy_client_ip_identity_uses_forwarded_header_from_trusted_peer() {
let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10, 10.0.0.5"));
let trusted_proxy = "127.0.0.1".parse::<IpAddr>().unwrap();
let identity = trusted_proxy_client_ip_identity([trusted_proxy])
.extract(&parts)
.unwrap();
assert_eq!(identity.as_str(), "203.0.113.10");
}
#[test]
fn trusted_proxy_client_ip_identity_ignores_forwarded_header_from_untrusted_peer() {
let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10"));
let trusted_proxy = "10.0.0.1".parse::<IpAddr>().unwrap();
let identity = trusted_proxy_client_ip_identity([trusted_proxy])
.extract(&parts)
.unwrap();
assert_eq!(identity.as_str(), "127.0.0.1");
}
fn request_parts(peer: Option<&str>, forwarded_for: Option<&str>) -> Parts {
let mut builder = Request::builder().uri("/");
if let Some(forwarded_for) = forwarded_for {
builder = builder.header("x-forwarded-for", forwarded_for);
}
let (mut parts, ()) = builder.body(()).unwrap().into_parts();
if let Some(peer) = peer {
parts.extensions.insert(axum::extract::ConnectInfo(
peer.parse::<SocketAddr>().unwrap(),
));
}
parts
}
}