use async_trait::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use crate::app::AppState;
use crate::prelude::*;
use cloudillo_types::auth_adapter;
pub use cloudillo_types::extract::{IdTag, TnIdResolver};
#[async_trait]
impl TnIdResolver for AppState {
async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error> {
self.auth_adapter.read_tn_id(id_tag).await.map_err(|_| Error::PermissionDenied)
}
}
#[derive(Debug, Clone)]
pub struct Auth(pub auth_adapter::AuthCtx);
impl<S> FromRequestParts<S> for Auth
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(auth) = parts.extensions.get::<Auth>().cloned() {
Ok(auth)
} else {
Err(Error::PermissionDenied)
}
}
}
#[derive(Debug, Clone)]
pub struct OptionalAuth(pub Option<auth_adapter::AuthCtx>);
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let auth = parts.extensions.get::<Auth>().cloned().map(|a| a.0);
Ok(OptionalAuth(auth))
}
}
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
fn sanitize_external_id(s: &str) -> Option<String> {
let s = s.trim();
if s.is_empty() || s.len() > 64 {
return None;
}
if !s.chars().all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.')) {
return None;
}
Some(s.to_string())
}
fn random_short() -> String {
match cloudillo_types::utils::random_id() {
Ok(s) if !s.is_empty() => s.chars().take(8).collect(),
_ => {
use std::sync::atomic::{AtomicU64, Ordering};
static CTR: AtomicU64 = AtomicU64::new(0);
let n = CTR.fetch_add(1, Ordering::Relaxed);
warn!("random_id() failed; using sequence fallback");
format!("seq{n:05}")
}
}
}
impl RequestId {
pub fn from_headers_or_random(headers: &axum::http::HeaderMap) -> Self {
let from_header = headers
.get("X-Request-ID")
.and_then(|h| h.to_str().ok())
.and_then(sanitize_external_id);
Self(from_header.unwrap_or_else(random_short))
}
pub fn short(&self) -> &str {
let s = self.0.as_str();
let end = s.char_indices().nth(4).map_or(s.len(), |(i, _)| i);
&s[..end]
}
pub fn install<B>(req: &mut axum::http::Request<B>) -> tracing::Span {
if let Some(existing) = req.extensions().get::<RequestId>() {
return tracing::span!(tracing::Level::ERROR, "request", id = %existing.short());
}
let id = Self::from_headers_or_random(req.headers());
let span = tracing::span!(tracing::Level::ERROR, "request", id = %id.short());
req.extensions_mut().insert(id);
span
}
}
#[derive(Clone, Debug)]
pub struct OptionalRequestId(pub Option<String>);
impl<S> FromRequestParts<S> for OptionalRequestId
where
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let req_id = parts.extensions.get::<RequestId>().map(|r| r.0.clone());
Ok(OptionalRequestId(req_id))
}
}