Skip to main content

cloudillo_core/
extract.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Custom extractors for Cloudillo-specific data
5
6use async_trait::async_trait;
7use axum::extract::FromRequestParts;
8use axum::http::request::Parts;
9
10use crate::app::AppState;
11use crate::prelude::*;
12use cloudillo_types::auth_adapter;
13
14// Re-export IdTag and TnIdResolver from cloudillo-types
15pub use cloudillo_types::extract::{IdTag, TnIdResolver};
16
17// Implement TnIdResolver for AppState so TnId can be extracted from requests.
18// The blanket impl `TnIdResolver for Arc<T>` in cloudillo-types makes this
19// work for `App = Arc<AppState>` automatically.
20#[async_trait]
21impl TnIdResolver for AppState {
22	async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error> {
23		self.auth_adapter.read_tn_id(id_tag).await.map_err(|_| Error::PermissionDenied)
24	}
25}
26
27// Auth //
28//******//
29#[derive(Debug, Clone)]
30pub struct Auth(pub auth_adapter::AuthCtx);
31
32impl<S> FromRequestParts<S> for Auth
33where
34	S: Send + Sync,
35{
36	type Rejection = Error;
37
38	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
39		if let Some(auth) = parts.extensions.get::<Auth>().cloned() {
40			Ok(auth)
41		} else {
42			Err(Error::PermissionDenied)
43		}
44	}
45}
46
47// OptionalAuth //
48//***************//
49/// Optional auth extractor that doesn't fail if auth is missing
50#[derive(Debug, Clone)]
51pub struct OptionalAuth(pub Option<auth_adapter::AuthCtx>);
52
53impl<S> FromRequestParts<S> for OptionalAuth
54where
55	S: Send + Sync,
56{
57	type Rejection = Error;
58
59	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
60		let auth = parts.extensions.get::<Auth>().cloned().map(|a| a.0);
61		Ok(OptionalAuth(auth))
62	}
63}
64
65// RequestId //
66//***********//
67/// Request ID for tracing and debugging
68#[derive(Clone, Debug)]
69pub struct RequestId(pub String);
70
71/// Validate a client-supplied `X-Request-ID`. Reject empty, overlong, and
72/// non-`[A-Za-z0-9_.-]` values. Caller falls through to random generation
73/// when this returns `None` so log injection (CRLF, whitespace) and unbounded
74/// IDs cannot ride into log lines or response headers.
75fn sanitize_external_id(s: &str) -> Option<String> {
76	let s = s.trim();
77	if s.is_empty() || s.len() > 64 {
78		return None;
79	}
80	if !s.chars().all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.')) {
81		return None;
82	}
83	Some(s.to_string())
84}
85
86/// Generate an 8-char random id, with a deterministic sequence fallback if
87/// `random_id()` ever fails so that two concurrent failed-random requests do
88/// not produce indistinguishable log streams.
89fn random_short() -> String {
90	match cloudillo_types::utils::random_id() {
91		Ok(s) if !s.is_empty() => s.chars().take(8).collect(),
92		_ => {
93			use std::sync::atomic::{AtomicU64, Ordering};
94			static CTR: AtomicU64 = AtomicU64::new(0);
95			let n = CTR.fetch_add(1, Ordering::Relaxed);
96			warn!("random_id() failed; using sequence fallback");
97			format!("seq{n:05}")
98		}
99	}
100}
101
102impl RequestId {
103	/// Read `X-Request-ID` from `headers` (validated), or generate an 8-char
104	/// random id when the header is absent or rejected.
105	pub fn from_headers_or_random(headers: &axum::http::HeaderMap) -> Self {
106		let from_header = headers
107			.get("X-Request-ID")
108			.and_then(|h| h.to_str().ok())
109			.and_then(sanitize_external_id);
110		Self(from_header.unwrap_or_else(random_short))
111	}
112
113	/// Short 4-char form for log prefixes. Stable for a given full id.
114	/// Not cryptographically distinct — purely a visual aid.
115	pub fn short(&self) -> &str {
116		let s = self.0.as_str();
117		let end = s.char_indices().nth(4).map_or(s.len(), |(i, _)| i);
118		&s[..end]
119	}
120
121	/// Ensure a `RequestId` extension is present on `req` and return the
122	/// `request` span carrying its short form. Single source of truth for the
123	/// span name and field name shared between the HTTPS transport closure
124	/// (`webserver.rs`) and the request-id middleware.
125	///
126	/// The span is created at `Level::ERROR` so it remains active even when
127	/// the global filter is set to `warn` or `error` — the level on the span
128	/// gates only span creation, not event filtering inside it.
129	pub fn install<B>(req: &mut axum::http::Request<B>) -> tracing::Span {
130		if let Some(existing) = req.extensions().get::<RequestId>() {
131			return tracing::span!(tracing::Level::ERROR, "request", id = %existing.short());
132		}
133		let id = Self::from_headers_or_random(req.headers());
134		let span = tracing::span!(tracing::Level::ERROR, "request", id = %id.short());
135		req.extensions_mut().insert(id);
136		span
137	}
138}
139
140/// Optional Request ID extractor - always succeeds, returns None if not available
141#[derive(Clone, Debug)]
142pub struct OptionalRequestId(pub Option<String>);
143
144impl<S> FromRequestParts<S> for OptionalRequestId
145where
146	S: Send + Sync,
147{
148	type Rejection = Error;
149
150	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
151		let req_id = parts.extensions.get::<RequestId>().map(|r| r.0.clone());
152		Ok(OptionalRequestId(req_id))
153	}
154}
155
156// vim: ts=4