actix-request-identifier 4.1.0

Middlerware for actix-web to associate an ID with each request.
Documentation
#![deny(unused, missing_docs)]
//! This is an actix-web middleware to associate every request with a unique ID. This ID can be
//! used to track errors in an application.

use std::{
    fmt,
    pin::Pin,
    task::{Context, Poll},
};

use actix_web::{
    dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform},
    error::ResponseError,
    http::header::{HeaderName, HeaderValue},
    Error as ActixError, FromRequest, HttpMessage, HttpRequest,
};
use futures::{
    future::{ok, ready, Ready},
    Future,
};
use uuid::Uuid;

/// The default header used for the request ID.
pub const DEFAULT_HEADER: &str = "x-request-id";

/// Possible error types for the middleware.
#[derive(Debug, Clone)]
pub enum Error {
    /// There is no ID associated with this request.
    NoAssociatedId,
}

/// Configuration setting to decide weather the request id from the incoming request header should
/// be used, if present or if a new one should be generated in any case.
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum IdReuse {
    /// Reuse the incoming request id.
    UseIncoming,
    /// Ignore the incoming request id and generate a random one, even if the request supplied an
    /// id.
    IgnoreIncoming,
}

/// ID wrapper for requests.
pub struct RequestIdMiddleware<S> {
    service: S,
    header_name: HeaderName,
    id_generator: Generator,
    use_incoming_id: IdReuse,
}

type Generator = fn() -> HeaderValue;

/// A middleware for generating per-request unique IDs
pub struct RequestIdentifier {
    header_name: &'static str,
    id_generator: Generator,
    use_incoming_id: IdReuse,
}

/// Request ID that can be extracted in handlers.
#[derive(Clone)]
pub struct RequestId(HeaderValue);

impl ResponseError for Error {}

impl fmt::Display for Error {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        use Error::NoAssociatedId;
        match self {
            NoAssociatedId => write!(fmt, "NoAssociatedId"),
        }
    }
}

impl RequestId {
    /// Get the value of the response header for this request id.
    pub const fn header_value(&self) -> &HeaderValue {
        &self.0
    }

    /// Get a string representation of this ID
    pub fn as_str(&self) -> &str {
        self.0.to_str().expect("Non-ASCII IDs are not supported")
    }
}

impl RequestIdentifier {
    /// Create a default middleware using [`DEFAULT_HEADER`](./constant.DEFAULT_HEADER.html) as the header name and
    /// UUID v4 for ID generation.
    #[must_use]
    pub fn with_uuid() -> Self {
        Self::default()
    }

    /// Create a middleware using a custom header name and UUID v4 for ID generation.
    #[must_use]
    pub fn with_header(header_name: &'static str) -> Self {
        Self {
            header_name,
            ..Default::default()
        }
    }

    /// Change the header name for this middleware.
    #[must_use]
    pub const fn header(self, header_name: &'static str) -> Self {
        Self {
            header_name,
            ..self
        }
    }

    /// Create a middleware using [`DEFAULT_HEADER`](./constant.DEFAULT_HEADER.html) as the header
    /// name and IDs as generated by `id_generator`. `id_generator` should return a unique ID
    /// (ASCII only), each time it is invoked.
    #[must_use]
    pub fn with_generator(id_generator: Generator) -> Self {
        Self {
            id_generator,
            ..Default::default()
        }
    }

    /// Change the ID generator for this middleware.
    #[must_use]
    pub fn generator(self, id_generator: Generator) -> Self {
        Self {
            id_generator,
            ..self
        }
    }

    /// Change the behavior for incoming request id headers. When this is set to
    /// [`IdReuse::UseIncoming`](./enum.IdReuse.html#variant.UseIncoming) (the default), each request is checked if it
    /// contains a header by the specified name and if it exists, the id from that header is used, otherwise a random id
    /// is generated. When this is set to [`IdReuse::IgnoreIncoming`](./enum.IdReuse.html#variant.IgnoreIncoming), the
    /// id from the request header is ignored.
    #[must_use]
    pub fn use_incoming_id(self, use_incoming_id: IdReuse) -> Self {
        Self {
            use_incoming_id,
            ..self
        }
    }
}

impl Default for RequestIdentifier {
    fn default() -> Self {
        Self {
            header_name: DEFAULT_HEADER,
            id_generator: default_generator,
            use_incoming_id: IdReuse::IgnoreIncoming,
        }
    }
}

/// Default UUID v4 based ID generator.
fn default_generator() -> HeaderValue {
    let uuid = Uuid::new_v4();
    HeaderValue::from_str(&uuid.to_string())
        // This unwrap can never fail since UUID v4 generated IDs are ASCII-only
        .unwrap()
}

impl<S, B> Transform<S, ServiceRequest> for RequestIdentifier
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError>,
    S::Future: 'static,
    B: 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Transform = RequestIdMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        // associate with the request
        ok(RequestIdMiddleware {
            service,
            header_name: HeaderName::from_static(self.header_name),
            id_generator: self.id_generator,
            use_incoming_id: self.use_incoming_id,
        })
    }
}

#[allow(clippy::type_complexity)]
impl<S, B> Service<ServiceRequest> for RequestIdMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError>,
    S::Future: 'static,
    B: 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let header_name = self.header_name.clone();
        let header_value = match self.use_incoming_id {
            IdReuse::UseIncoming => req
                .headers()
                .get(&header_name)
                .map_or_else(self.id_generator, |v| v.clone()),
            IdReuse::IgnoreIncoming => (self.id_generator)(),
        };

        // make the id available as an extractor in route handlers
        let request_id = RequestId(header_value.clone());
        req.extensions_mut().insert(request_id);

        let fut = self.service.call(req);
        Box::pin(async move {
            let mut res = fut.await?;

            res.headers_mut().insert(header_name, header_value);

            Ok(res)
        })
    }
}

impl FromRequest for RequestId {
    type Error = Error;
    type Future = Ready<Result<Self, Self::Error>>;

    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
        ready(
            req.extensions()
                .get::<RequestId>()
                .map(RequestId::clone)
                .ok_or(Error::NoAssociatedId),
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use actix_web::{test, web, App};
    use bytes::Bytes;

    async fn handler(id: RequestId) -> String {
        id.as_str().to_string()
    }

    // Using a macro to reduce code duplication when initializing the test service.
    // The return type of `test::init_service` is to complicated to create a normal function.
    macro_rules! service {
        ($middleware:expr) => {
            test::init_service(
                App::new()
                    .wrap($middleware)
                    .route("/", web::get().to(handler)),
            )
            .await
        };
    }

    async fn test_get(middleware: RequestIdentifier) -> ServiceResponse {
        let service = service!(middleware);
        test::call_service(&service, test::TestRequest::get().uri("/").to_request()).await
    }

    #[actix_web::test]
    async fn default_identifier() {
        let resp = test_get(RequestIdentifier::with_uuid()).await;
        let uid = resp
            .headers()
            .get(HeaderName::from_static(DEFAULT_HEADER))
            .map(|v| v.to_str().unwrap().to_string())
            .unwrap();
        let body: Bytes = test::read_body(resp).await;
        let body = String::from_utf8_lossy(&body);
        assert_eq!(uid, body);
    }

    #[actix_web::test]
    async fn deterministic_identifier() {
        let resp = test_get(RequestIdentifier::with_generator(|| {
            HeaderValue::from_static("look ma, i'm an id")
        }))
        .await;
        let uid = resp
            .headers()
            .get(HeaderName::from_static(DEFAULT_HEADER))
            .map(|v| v.to_str().unwrap().to_string())
            .unwrap();
        let body: Bytes = test::read_body(resp).await;
        let body = String::from_utf8_lossy(&body);
        assert_eq!(uid, body);
    }

    #[actix_web::test]
    async fn custom_header() {
        let resp = test_get(RequestIdentifier::with_header("custom-header")).await;
        assert!(resp
            .headers()
            .get(HeaderName::from_static(DEFAULT_HEADER))
            .is_none());
        let uid = resp
            .headers()
            .get(HeaderName::from_static("custom-header"))
            .map(|v| v.to_str().unwrap().to_string())
            .unwrap();
        let body: Bytes = test::read_body(resp).await;
        let body = String::from_utf8_lossy(&body);
        assert_eq!(uid, body);
    }

    #[actix_web::test]
    async fn existing_request_id() {
        let uuid4 = Uuid::new_v4().to_string();
        let service =
            service!(RequestIdentifier::with_uuid().use_incoming_id(IdReuse::UseIncoming));
        let req = test::TestRequest::get()
            .insert_header((DEFAULT_HEADER, uuid4.as_str()))
            .uri("/")
            .to_request();
        let resp = test::call_service(&service, req).await;
        let uid = resp
            .headers()
            .get(HeaderName::from_static(DEFAULT_HEADER))
            .map(|v| v.to_str().unwrap().to_string())
            .unwrap();
        assert_eq!(uid, uuid4);
        let body: Bytes = test::read_body(resp).await;
        let body = String::from_utf8_lossy(&body);
        assert_eq!(body, uuid4);
    }

    #[actix_web::test]
    async fn ignore_existing_request_id() {
        let uuid4 = Uuid::new_v4().to_string();
        let service = service!(RequestIdentifier::with_uuid()
            // use deterministic generator so we can check, if the supplied id is
            // ignored
            .generator(|| HeaderValue::from_static("0")));
        let req = test::TestRequest::get()
            .insert_header((DEFAULT_HEADER, uuid4.as_str()))
            .uri("/")
            .to_request();
        let resp = test::call_service(&service, req).await;
        let uid = resp
            .headers()
            .get(HeaderName::from_static(DEFAULT_HEADER))
            .map(|v| v.to_str().unwrap().to_string())
            .unwrap();
        assert_eq!(uid, "0");
        let body: Bytes = test::read_body(resp).await;
        let body = String::from_utf8_lossy(&body);
        assert_eq!(body, "0");
    }
}