1use std::{
4 future::Future,
5 net::{IpAddr, SocketAddr},
6};
7
8use axum::extract::FromRequestParts;
9use http::{HeaderMap, Method, request::Parts};
10use serde::Serialize;
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ClientKind {
16 ApiKey,
18 Authenticated,
20 Anonymous,
22}
23
24impl ClientKind {
25 pub const fn as_str(self) -> &'static str {
27 match self {
28 Self::ApiKey => "api_key",
29 Self::Authenticated => "authenticated",
30 Self::Anonymous => "anonymous",
31 }
32 }
33}
34
35#[derive(Clone, Debug, Eq, PartialEq)]
63pub struct RequestContext {
64 request_id: String,
65 correlation_id: Option<String>,
66 method: Method,
67 route: Option<String>,
68 path: String,
69 trace_id: Option<String>,
70 span_id: Option<String>,
71 client_kind: ClientKind,
72 user_id: Option<String>,
73 tenant_id: Option<String>,
74 session_id: Option<String>,
75}
76
77impl RequestContext {
78 pub fn new(request_id: impl Into<String>, method: Method, path: impl Into<String>) -> Self {
84 Self {
85 request_id: request_id.into(),
86 correlation_id: None,
87 method,
88 route: None,
89 path: path.into(),
90 trace_id: None,
91 span_id: None,
92 client_kind: ClientKind::Anonymous,
93 user_id: None,
94 tenant_id: None,
95 session_id: None,
96 }
97 }
98
99 pub fn from_parts(parts: &Parts, request_id: impl Into<String>) -> Self {
106 let request_id = request_id.into();
107 let mut context = Self::new(request_id.clone(), parts.method.clone(), parts.uri.path());
108 context.correlation_id = header_to_string(&parts.headers, "x-correlation-id")
109 .or_else(|| Some(request_id).filter(|value| !value.is_empty()));
110 context.route = parts
111 .extensions
112 .get::<axum::extract::MatchedPath>()
113 .map(|path| path.as_str().to_owned());
114 context.client_kind = infer_client_kind(&parts.headers);
115 context.trace_id = header_to_string(&parts.headers, "traceparent")
116 .and_then(|value| value.split('-').nth(1).map(str::to_owned));
117 context
118 }
119
120 pub fn request_id(&self) -> &str {
125 &self.request_id
126 }
127
128 pub(crate) fn into_request_id(self) -> String {
129 self.request_id
130 }
131
132 pub fn correlation_id(&self) -> Option<&str> {
137 self.correlation_id.as_deref()
138 }
139
140 pub const fn method(&self) -> &Method {
142 &self.method
143 }
144
145 pub fn route(&self) -> Option<&str> {
151 self.route.as_deref()
152 }
153
154 pub fn path(&self) -> &str {
156 &self.path
157 }
158
159 pub fn trace_id(&self) -> Option<&str> {
165 self.trace_id.as_deref()
166 }
167
168 pub fn span_id(&self) -> Option<&str> {
170 self.span_id.as_deref()
171 }
172
173 pub const fn client_kind(&self) -> ClientKind {
178 self.client_kind
179 }
180
181 pub fn user_id(&self) -> Option<&str> {
183 self.user_id.as_deref()
184 }
185
186 pub fn tenant_id(&self) -> Option<&str> {
188 self.tenant_id.as_deref()
189 }
190
191 pub fn session_id(&self) -> Option<&str> {
193 self.session_id.as_deref()
194 }
195
196 pub fn with_route(mut self, route: impl Into<String>) -> Self {
198 self.route = Some(route.into());
199 self
200 }
201
202 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
204 self.user_id = Some(user_id.into());
205 self
206 }
207
208 pub fn with_tenant_id(mut self, tenant_id: impl Into<String>) -> Self {
210 self.tenant_id = Some(tenant_id.into());
211 self
212 }
213
214 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
216 self.session_id = Some(session_id.into());
217 self
218 }
219}
220
221impl<S> FromRequestParts<S> for RequestContext
222where
223 S: Send + Sync,
224{
225 type Rejection = axum::http::StatusCode;
226
227 fn from_request_parts(
228 parts: &mut Parts,
229 _state: &S,
230 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
231 let context = parts.extensions.get::<Self>().cloned();
232 async move { context.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR) }
233 }
234}
235
236#[derive(Clone, Debug, Eq, Hash, PartialEq)]
238pub struct RequestIdentity(String);
239
240impl RequestIdentity {
241 pub fn new(value: impl Into<String>) -> Self {
243 Self(value.into())
244 }
245
246 pub fn as_str(&self) -> &str {
248 &self.0
249 }
250}
251
252pub trait IdentityExtractor: Clone + Send + Sync + 'static {
254 fn extract(&self, parts: &Parts) -> Option<RequestIdentity>;
256}
257
258impl<F> IdentityExtractor for F
259where
260 F: Fn(&Parts) -> Option<RequestIdentity> + Clone + Send + Sync + 'static,
261{
262 fn extract(&self, parts: &Parts) -> Option<RequestIdentity> {
263 self(parts)
264 }
265}
266
267pub fn context_identity() -> impl IdentityExtractor {
269 |parts: &Parts| {
270 if let Some(context) = parts.extensions.get::<RequestContext>()
271 && let Some(value) = context.user_id().or_else(|| context.tenant_id())
272 {
273 return Some(RequestIdentity::new(value.to_owned()));
274 }
275 header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
276 }
277}
278
279pub fn api_key_identity() -> impl IdentityExtractor {
281 |parts: &Parts| header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
282}
283
284pub fn client_ip_identity() -> impl IdentityExtractor {
291 |parts: &Parts| {
292 peer_ip(parts)
293 .map(|ip| RequestIdentity::new(ip.to_string()))
294 .or_else(|| Some(RequestIdentity::new("anonymous")))
295 }
296}
297
298pub fn trusted_proxy_client_ip_identity(
306 trusted_proxies: impl IntoIterator<Item = IpAddr>,
307) -> impl IdentityExtractor {
308 let trusted_proxies = trusted_proxies.into_iter().collect::<Vec<_>>();
309 move |parts: &Parts| {
310 peer_ip(parts)
311 .map(|peer| {
312 if trusted_proxies.contains(&peer)
313 && let Some(forwarded_ip) = forwarded_for_ip(&parts.headers)
314 {
315 RequestIdentity::new(forwarded_ip.to_string())
316 } else {
317 RequestIdentity::new(peer.to_string())
318 }
319 })
320 .or_else(|| Some(RequestIdentity::new("anonymous")))
321 }
322}
323
324fn peer_ip(parts: &Parts) -> Option<IpAddr> {
325 parts
326 .extensions
327 .get::<axum::extract::ConnectInfo<SocketAddr>>()
328 .map(|connect| connect.0.ip())
329}
330
331fn forwarded_for_ip(headers: &HeaderMap) -> Option<IpAddr> {
332 header_to_string(headers, "x-forwarded-for").and_then(|value| {
333 value
334 .split(',')
335 .next()
336 .map(str::trim)
337 .filter(|value| !value.is_empty())
338 .and_then(|value| value.parse().ok())
339 })
340}
341
342pub(crate) fn header_to_string(headers: &HeaderMap, name: &'static str) -> Option<String> {
343 headers
344 .get(name)
345 .and_then(|value| value.to_str().ok())
346 .filter(|value| !value.is_empty())
347 .map(str::to_owned)
348}
349
350fn infer_client_kind(headers: &HeaderMap) -> ClientKind {
351 if headers.contains_key("x-api-key") {
352 ClientKind::ApiKey
353 } else if headers.contains_key(http::header::AUTHORIZATION) {
354 ClientKind::Authenticated
355 } else {
356 ClientKind::Anonymous
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use http::Request;
364
365 #[test]
366 fn request_context_can_consume_request_id() {
367 let context = RequestContext::new("req-123", Method::GET, "/users");
368
369 assert_eq!(context.into_request_id(), "req-123");
370 }
371
372 #[test]
373 fn client_ip_identity_ignores_forwarded_headers_without_peer_info() {
374 let parts = request_parts(None, Some("203.0.113.10"));
375 let identity = client_ip_identity().extract(&parts).unwrap();
376
377 assert_eq!(identity.as_str(), "anonymous");
378 }
379
380 #[test]
381 fn trusted_proxy_client_ip_identity_uses_forwarded_header_from_trusted_peer() {
382 let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10, 10.0.0.5"));
383 let trusted_proxy = "127.0.0.1".parse::<IpAddr>().unwrap();
384 let identity = trusted_proxy_client_ip_identity([trusted_proxy])
385 .extract(&parts)
386 .unwrap();
387
388 assert_eq!(identity.as_str(), "203.0.113.10");
389 }
390
391 #[test]
392 fn trusted_proxy_client_ip_identity_ignores_forwarded_header_from_untrusted_peer() {
393 let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10"));
394 let trusted_proxy = "10.0.0.1".parse::<IpAddr>().unwrap();
395 let identity = trusted_proxy_client_ip_identity([trusted_proxy])
396 .extract(&parts)
397 .unwrap();
398
399 assert_eq!(identity.as_str(), "127.0.0.1");
400 }
401
402 fn request_parts(peer: Option<&str>, forwarded_for: Option<&str>) -> Parts {
403 let mut builder = Request::builder().uri("/");
404 if let Some(forwarded_for) = forwarded_for {
405 builder = builder.header("x-forwarded-for", forwarded_for);
406 }
407 let (mut parts, ()) = builder.body(()).unwrap().into_parts();
408 if let Some(peer) = peer {
409 parts.extensions.insert(axum::extract::ConnectInfo(
410 peer.parse::<SocketAddr>().unwrap(),
411 ));
412 }
413 parts
414 }
415}