#![deny(unused, missing_docs)]
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use actix_web::{
dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform},
error::ResponseError,
http::header::{HeaderName, HeaderValue},
Error as ActixError, FromRequest, HttpMessage, HttpRequest,
};
use futures::{
future::{ok, ready, Ready},
Future,
};
use uuid::Uuid;
pub const DEFAULT_HEADER: &str = "x-request-id";
#[derive(Debug, Clone)]
pub enum Error {
NoAssociatedId,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum IdReuse {
UseIncoming,
IgnoreIncoming,
}
pub struct RequestIdMiddleware<S> {
service: S,
header_name: HeaderName,
id_generator: Generator,
use_incoming_id: IdReuse,
}
type Generator = fn() -> HeaderValue;
pub struct RequestIdentifier {
header_name: &'static str,
id_generator: Generator,
use_incoming_id: IdReuse,
}
#[derive(Clone)]
pub struct RequestId(HeaderValue);
impl ResponseError for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use Error::NoAssociatedId;
match self {
NoAssociatedId => write!(fmt, "NoAssociatedId"),
}
}
}
impl RequestId {
pub const fn header_value(&self) -> &HeaderValue {
&self.0
}
pub fn as_str(&self) -> &str {
self.0.to_str().expect("Non-ASCII IDs are not supported")
}
}
impl RequestIdentifier {
#[must_use]
pub fn with_uuid() -> Self {
Self::default()
}
#[must_use]
pub fn with_header(header_name: &'static str) -> Self {
Self {
header_name,
..Default::default()
}
}
#[must_use]
pub const fn header(self, header_name: &'static str) -> Self {
Self {
header_name,
..self
}
}
#[must_use]
pub fn with_generator(id_generator: Generator) -> Self {
Self {
id_generator,
..Default::default()
}
}
#[must_use]
pub fn generator(self, id_generator: Generator) -> Self {
Self {
id_generator,
..self
}
}
#[must_use]
pub fn use_incoming_id(self, use_incoming_id: IdReuse) -> Self {
Self {
use_incoming_id,
..self
}
}
}
impl Default for RequestIdentifier {
fn default() -> Self {
Self {
header_name: DEFAULT_HEADER,
id_generator: default_generator,
use_incoming_id: IdReuse::IgnoreIncoming,
}
}
}
fn default_generator() -> HeaderValue {
let uuid = Uuid::new_v4();
HeaderValue::from_str(&uuid.to_string())
.unwrap()
}
impl<S, B> Transform<S, ServiceRequest> for RequestIdentifier
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError>,
S::Future: 'static,
B: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Transform = RequestIdMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RequestIdMiddleware {
service,
header_name: HeaderName::from_static(self.header_name),
id_generator: self.id_generator,
use_incoming_id: self.use_incoming_id,
})
}
}
#[allow(clippy::type_complexity)]
impl<S, B> Service<ServiceRequest> for RequestIdMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError>,
S::Future: 'static,
B: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(ctx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let header_name = self.header_name.clone();
let header_value = match self.use_incoming_id {
IdReuse::UseIncoming => req
.headers()
.get(&header_name)
.map_or_else(self.id_generator, |v| v.clone()),
IdReuse::IgnoreIncoming => (self.id_generator)(),
};
let request_id = RequestId(header_value.clone());
req.extensions_mut().insert(request_id);
let fut = self.service.call(req);
Box::pin(async move {
let mut res = fut.await?;
res.headers_mut().insert(header_name, header_value);
Ok(res)
})
}
}
impl FromRequest for RequestId {
type Error = Error;
type Future = Ready<Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ready(
req.extensions()
.get::<RequestId>()
.map(RequestId::clone)
.ok_or(Error::NoAssociatedId),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::{test, web, App};
use bytes::Bytes;
async fn handler(id: RequestId) -> String {
id.as_str().to_string()
}
macro_rules! service {
($middleware:expr) => {
test::init_service(
App::new()
.wrap($middleware)
.route("/", web::get().to(handler)),
)
.await
};
}
async fn test_get(middleware: RequestIdentifier) -> ServiceResponse {
let service = service!(middleware);
test::call_service(&service, test::TestRequest::get().uri("/").to_request()).await
}
#[actix_web::test]
async fn default_identifier() {
let resp = test_get(RequestIdentifier::with_uuid()).await;
let uid = resp
.headers()
.get(HeaderName::from_static(DEFAULT_HEADER))
.map(|v| v.to_str().unwrap().to_string())
.unwrap();
let body: Bytes = test::read_body(resp).await;
let body = String::from_utf8_lossy(&body);
assert_eq!(uid, body);
}
#[actix_web::test]
async fn deterministic_identifier() {
let resp = test_get(RequestIdentifier::with_generator(|| {
HeaderValue::from_static("look ma, i'm an id")
}))
.await;
let uid = resp
.headers()
.get(HeaderName::from_static(DEFAULT_HEADER))
.map(|v| v.to_str().unwrap().to_string())
.unwrap();
let body: Bytes = test::read_body(resp).await;
let body = String::from_utf8_lossy(&body);
assert_eq!(uid, body);
}
#[actix_web::test]
async fn custom_header() {
let resp = test_get(RequestIdentifier::with_header("custom-header")).await;
assert!(resp
.headers()
.get(HeaderName::from_static(DEFAULT_HEADER))
.is_none());
let uid = resp
.headers()
.get(HeaderName::from_static("custom-header"))
.map(|v| v.to_str().unwrap().to_string())
.unwrap();
let body: Bytes = test::read_body(resp).await;
let body = String::from_utf8_lossy(&body);
assert_eq!(uid, body);
}
#[actix_web::test]
async fn existing_request_id() {
let uuid4 = Uuid::new_v4().to_string();
let service =
service!(RequestIdentifier::with_uuid().use_incoming_id(IdReuse::UseIncoming));
let req = test::TestRequest::get()
.insert_header((DEFAULT_HEADER, uuid4.as_str()))
.uri("/")
.to_request();
let resp = test::call_service(&service, req).await;
let uid = resp
.headers()
.get(HeaderName::from_static(DEFAULT_HEADER))
.map(|v| v.to_str().unwrap().to_string())
.unwrap();
assert_eq!(uid, uuid4);
let body: Bytes = test::read_body(resp).await;
let body = String::from_utf8_lossy(&body);
assert_eq!(body, uuid4);
}
#[actix_web::test]
async fn ignore_existing_request_id() {
let uuid4 = Uuid::new_v4().to_string();
let service = service!(RequestIdentifier::with_uuid()
.generator(|| HeaderValue::from_static("0")));
let req = test::TestRequest::get()
.insert_header((DEFAULT_HEADER, uuid4.as_str()))
.uri("/")
.to_request();
let resp = test::call_service(&service, req).await;
let uid = resp
.headers()
.get(HeaderName::from_static(DEFAULT_HEADER))
.map(|v| v.to_str().unwrap().to_string())
.unwrap();
assert_eq!(uid, "0");
let body: Bytes = test::read_body(resp).await;
let body = String::from_utf8_lossy(&body);
assert_eq!(body, "0");
}
}