use crate::{
Request, Response,
header::{HeaderName, HeaderValue},
};
use rama_core::{
Layer, Service,
extensions::{Extension, ExtensionsRef},
telemetry::tracing,
};
use rama_utils::macros::define_inner_service_accessors;
use rama_utils::str::smol_str::ToSmolStr as _;
use rand::RngExt as _;
use uuid::Uuid;
pub(crate) const REQUEST_ID: HeaderName = HeaderName::from_static("request-id");
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub trait MakeRequestId: Send + Sync + 'static {
fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
}
#[derive(Debug, Clone, Extension)]
#[extension(tags(http))]
pub struct RequestId(HeaderValue);
impl RequestId {
pub const fn new(header_value: HeaderValue) -> Self {
Self(header_value)
}
pub fn header_value(&self) -> &HeaderValue {
&self.0
}
pub fn into_header_value(self) -> HeaderValue {
self.0
}
}
impl From<HeaderValue> for RequestId {
fn from(value: HeaderValue) -> Self {
Self::new(value)
}
}
#[derive(Debug, Clone)]
pub struct SetRequestIdLayer<M> {
header_name: HeaderName,
make_request_id: M,
}
impl<M> SetRequestIdLayer<M> {
pub const fn new(header_name: HeaderName, make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self {
header_name,
make_request_id,
}
}
pub const fn request_id(make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self::new(REQUEST_ID, make_request_id)
}
pub const fn x_request_id(make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self::new(X_REQUEST_ID, make_request_id)
}
}
impl<S, M> Layer<S> for SetRequestIdLayer<M>
where
M: Clone + MakeRequestId,
{
type Service = SetRequestId<S, M>;
fn layer(&self, inner: S) -> Self::Service {
SetRequestId::new(
inner,
self.header_name.clone(),
self.make_request_id.clone(),
)
}
fn into_layer(self, inner: S) -> Self::Service {
SetRequestId::new(inner, self.header_name, self.make_request_id)
}
}
#[derive(Debug, Clone)]
pub struct SetRequestId<S, M> {
inner: S,
header_name: HeaderName,
make_request_id: M,
}
impl<S, M> SetRequestId<S, M> {
pub const fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self {
inner,
header_name,
make_request_id,
}
}
pub const fn request_id(inner: S, make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self::new(inner, REQUEST_ID, make_request_id)
}
pub const fn x_request_id(inner: S, make_request_id: M) -> Self
where
M: MakeRequestId,
{
Self::new(inner, X_REQUEST_ID, make_request_id)
}
define_inner_service_accessors!();
}
impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
M: MakeRequestId,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Output = S::Output;
type Error = S::Error;
async fn serve(&self, mut req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
if let Some(request_id) = req.headers().get(&self.header_name) {
if req.extensions().get_ref::<RequestId>().is_none() {
let request_id = request_id.clone();
req.extensions().insert(RequestId::new(request_id));
}
} else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
req.extensions().insert(request_id.clone());
req.headers_mut()
.insert(self.header_name.clone(), request_id.0);
}
self.inner.serve(req).await
}
}
#[derive(Debug, Clone)]
pub struct PropagateRequestIdLayer {
header_name: HeaderName,
}
impl PropagateRequestIdLayer {
pub const fn new(header_name: HeaderName) -> Self {
Self { header_name }
}
pub const fn request_id() -> Self {
Self::new(REQUEST_ID)
}
pub const fn x_request_id() -> Self {
Self::new(X_REQUEST_ID)
}
}
impl<S> Layer<S> for PropagateRequestIdLayer {
type Service = PropagateRequestId<S>;
fn layer(&self, inner: S) -> Self::Service {
PropagateRequestId::new(inner, self.header_name.clone())
}
}
#[derive(Debug, Clone)]
pub struct PropagateRequestId<S> {
inner: S,
header_name: HeaderName,
}
impl<S> PropagateRequestId<S> {
pub const fn new(inner: S, header_name: HeaderName) -> Self {
Self { inner, header_name }
}
pub const fn request_id(inner: S) -> Self {
Self::new(inner, REQUEST_ID)
}
pub const fn x_request_id(inner: S) -> Self {
Self::new(inner, X_REQUEST_ID)
}
define_inner_service_accessors!();
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Output = S::Output;
type Error = S::Error;
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
let request_id = req
.headers()
.get(&self.header_name)
.cloned()
.map(RequestId::new);
let mut response = self.inner.serve(req).await?;
if let Some(current_id) = response.headers().get(&self.header_name) {
if response.extensions().get_ref::<RequestId>().is_none() {
let current_id = current_id.clone();
response.extensions().insert(RequestId::new(current_id));
}
} else if let Some(request_id) = request_id {
response
.headers_mut()
.insert(self.header_name.clone(), request_id.0.clone());
response.extensions().insert(request_id);
}
Ok(response)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MakeRequestUuid;
impl MakeRequestId for MakeRequestUuid {
fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
let request_id = Uuid::new_v4()
.to_smolstr()
.parse()
.inspect_err(|err| {
tracing::debug!("failed to parse UUID4 as RequestId: {err}");
})
.ok()?;
Some(RequestId::new(request_id))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MakeRequestNanoid;
impl MakeRequestId for MakeRequestNanoid {
fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
let request_id = make_nano_id();
Some(RequestId::new(request_id))
}
}
fn make_nano_id() -> HeaderValue {
const ALPHABET_LEN: usize = 64;
const ALPHABET: [u8; ALPHABET_LEN] =
*b"_-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
const ID_LEN: usize = 21;
const STEP: usize = 8 * ID_LEN / 5;
const MASK: usize = (ALPHABET_LEN * 2) - 1;
let mut id = [0u8; ID_LEN];
let mut index = 0;
loop {
let input: [u8; STEP] = rand::rng().random();
for byte in input {
let byte = byte as usize & MASK;
if ALPHABET_LEN > byte {
id[index] = ALPHABET[byte];
index += 1;
if index == ID_LEN {
return unsafe { HeaderValue::from_maybe_shared_unchecked(id) };
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::layer::set_header;
use crate::{Body, Response};
use rama_core::Layer;
use rama_core::service::service_fn;
use std::{
convert::Infallible,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use super::*;
#[tokio::test]
async fn basic() {
let svc = (
SetRequestIdLayer::x_request_id(Counter::default()),
PropagateRequestIdLayer::x_request_id(),
)
.into_layer(service_fn(handler));
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["x-request-id"], "0");
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["x-request-id"], "1");
let req = Request::builder()
.header("x-request-id", "foo")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["x-request-id"], "foo");
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.extensions().get_ref::<RequestId>().unwrap().0, "2");
}
#[tokio::test]
async fn basic_with_request_id() {
let svc = (
SetRequestIdLayer::request_id(Counter::default()),
PropagateRequestIdLayer::request_id(),
)
.into_layer(service_fn(handler));
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["request-id"], "0");
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["request-id"], "1");
let req = Request::builder()
.header("request-id", "foo")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["request-id"], "foo");
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.extensions().get_ref::<RequestId>().unwrap().0, "2");
}
#[tokio::test]
async fn other_middleware_setting_request_id_on_response() {
let svc = (
SetRequestIdLayer::x_request_id(Counter::default()),
PropagateRequestIdLayer::x_request_id(),
set_header::SetResponseHeaderLayer::overriding(
HeaderName::from_static("x-request-id"),
HeaderValue::from_static("foo"),
),
)
.into_layer(service_fn(handler));
let req = Request::builder()
.header("x-request-id", "foo")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()["x-request-id"], "foo");
assert_eq!(res.extensions().get_ref::<RequestId>().unwrap().0, "foo");
}
#[derive(Clone, Default)]
struct Counter(Arc<AtomicU64>);
impl MakeRequestId for Counter {
fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
let id =
HeaderValue::from_str(&self.0.fetch_add(1, Ordering::AcqRel).to_string()).unwrap();
Some(RequestId::new(id))
}
}
async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::empty()))
}
#[tokio::test]
async fn uuid() {
let svc = (
SetRequestIdLayer::x_request_id(MakeRequestUuid),
PropagateRequestIdLayer::x_request_id(),
)
.into_layer(service_fn(handler));
let req = Request::builder().body(Body::empty()).unwrap();
let mut res = svc.serve(req).await.unwrap();
let id = res.headers_mut().remove("x-request-id").unwrap();
id.to_str().unwrap().parse::<Uuid>().unwrap();
}
#[tokio::test]
async fn nanoid() {
let svc = (
SetRequestIdLayer::x_request_id(MakeRequestNanoid),
PropagateRequestIdLayer::x_request_id(),
)
.into_layer(service_fn(handler));
let req = Request::builder().body(Body::empty()).unwrap();
let mut res = svc.serve(req).await.unwrap();
let id = res.headers_mut().remove("x-request-id").unwrap();
assert_eq!(id.to_str().unwrap().chars().count(), 21);
}
#[test]
fn nanoid_no_intra_id_mirror() {
fn has_mirror(id: &[u8]) -> bool {
(11..=16).any(|k| id[k..21] == id[0..21 - k])
}
use ahash::HashSet;
let samples = 2_000;
let mut mirrors = 0;
let mut ids: HashSet<Vec<u8>> = HashSet::default();
for _ in 0..samples {
let hv = make_nano_id();
let id = hv.as_bytes();
assert_eq!(id.len(), 21);
if has_mirror(id) {
mirrors += 1;
}
ids.insert(id.to_vec());
}
assert!(
mirrors < 5,
"{mirrors}/{samples} ids exhibit the entropy-bug mirror pattern",
);
assert_eq!(ids.len(), samples, "duplicate ids generated");
}
}