ro11y 0.3.0

Lightweight Rust observability. Hand-rolled OTLP protobuf over HTTP, built on tracing.
Documentation
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;

use http::{Request, Response};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use tracing::field::Empty;

use crate::constants::metrics;
use crate::trace_id::{generate_span_id, generate_trace_id, hex_encode};

/// Tower Layer that extracts `x-amz-cf-id` and wraps requests in a tracing span.
///
/// Creates an `info_span!("request", ...)` with HTTP method, URI, status code,
/// latency, CloudFront request ID, and trace/span IDs. Emits RED metric events
/// on response.
#[derive(Clone, Debug)]
pub struct CfRequestIdLayer;

impl<S> Layer<S> for CfRequestIdLayer {
    type Service = CfRequestIdService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        CfRequestIdService { inner }
    }
}

#[derive(Clone, Debug)]
pub struct CfRequestIdService<S> {
    inner: S,
}

impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CfRequestIdService<S>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
    S::Future: Send + 'static,
    S::Error: std::fmt::Display,
    ReqBody: Send + 'static,
    ResBody: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = CfRequestIdFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
        let method = req.method().to_string();
        let uri = req.uri().path().to_string();

        let request_id = req
            .headers()
            .get("x-amz-cf-id")
            .and_then(|v| v.to_str().ok())
            .map(|s| s.to_string());

        let request_id_display = request_id.as_deref().unwrap_or("-").to_string();
        let trace_id = generate_trace_id(request_id.as_deref());
        let span_id = generate_span_id();
        let trace_id_hex = hex_encode(&trace_id);
        let span_id_hex = hex_encode(&span_id);

        let span = tracing::info_span!(
            "request",
            http.method = %method,
            http.uri = %uri,
            http.status_code = Empty,
            http.latency_ms = Empty,
            cf.request_id = %request_id_display,
            trace_id = %trace_id_hex,
            span_id = %span_id_hex,
        );

        let start = Instant::now();
        let future = self.inner.call(req);

        CfRequestIdFuture {
            inner: future,
            span,
            start,
            method,
            uri,
        }
    }
}

pin_project! {
    pub struct CfRequestIdFuture<F> {
        #[pin]
        inner: F,
        span: tracing::Span,
        start: Instant,
        method: String,
        uri: String,
    }
}

impl<F, ResBody, E> Future for CfRequestIdFuture<F>
where
    F: Future<Output = Result<Response<ResBody>, E>>,
    E: std::fmt::Display,
{
    type Output = F::Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let _guard = this.span.enter();

        match this.inner.poll(cx) {
            Poll::Ready(result) => {
                let latency_ms = this.start.elapsed().as_secs_f64() * 1000.0;

                match &result {
                    Ok(response) => {
                        let status = response.status().as_u16();
                        this.span.record("http.status_code", status);
                        this.span.record("http.latency_ms", latency_ms);

                        tracing::info!(
                            metric = metrics::REQUEST_DURATION,
                            r#type = "histogram",
                            value = latency_ms,
                            method = %this.method,
                            route = %this.uri,
                            status = status,
                        );
                        tracing::info!(
                            metric = metrics::REQUEST_COUNT,
                            r#type = "counter",
                            value = 1u64,
                            method = %this.method,
                            route = %this.uri,
                            status = status,
                        );
                        if status >= 400 {
                            tracing::info!(
                                metric = metrics::ERROR_COUNT,
                                r#type = "counter",
                                value = 1u64,
                                method = %this.method,
                                route = %this.uri,
                                status = status,
                            );
                        }
                    }
                    Err(_) => {
                        this.span.record("http.latency_ms", latency_ms);
                    }
                }

                Poll::Ready(result)
            }
            Poll::Pending => Poll::Pending,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tower::ServiceExt;

    #[test]
    fn cf_request_id_layer_constructs() {
        let _layer = CfRequestIdLayer;
    }

    #[test]
    fn extract_request_id_from_header() {
        let req = Request::builder()
            .header("x-amz-cf-id", "test-cf-id-123")
            .body(())
            .unwrap();
        let request_id = req
            .headers()
            .get("x-amz-cf-id")
            .and_then(|v| v.to_str().ok());
        assert_eq!(request_id, Some("test-cf-id-123"));
    }

    #[test]
    fn missing_request_id_header() {
        let req = Request::builder().body(()).unwrap();
        let request_id = req
            .headers()
            .get("x-amz-cf-id")
            .and_then(|v| v.to_str().ok());
        assert_eq!(request_id, None);
    }

    #[tokio::test]
    async fn middleware_wraps_request_and_returns_response() {
        let svc = tower::service_fn(|_req: Request<String>| async {
            Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
        });
        let svc = CfRequestIdLayer.layer(svc);

        let req = Request::builder()
            .header("x-amz-cf-id", "test-cf-id-123")
            .body(String::new())
            .unwrap();

        let resp = svc.oneshot(req).await.unwrap();
        assert_eq!(resp.status(), 200);
        assert_eq!(resp.into_body(), "ok");
    }

    #[tokio::test]
    async fn middleware_works_without_cf_header() {
        let svc = tower::service_fn(|_req: Request<String>| async {
            Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
        });
        let svc = CfRequestIdLayer.layer(svc);

        let req = Request::builder().body(String::new()).unwrap();

        let resp = svc.oneshot(req).await.unwrap();
        assert_eq!(resp.status(), 200);
    }

    #[tokio::test]
    async fn middleware_strips_query_params_from_uri() {
        let svc = tower::service_fn(|_req: Request<String>| async {
            Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
        });
        let mut svc = CfRequestIdLayer.layer(svc);

        let req = Request::builder()
            .uri("/api/resource?token=secret&foo=bar")
            .body(String::new())
            .unwrap();

        // Use poll_ready + call so we can inspect the future's uri field
        std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
        let fut = svc.call(req);
        // The uri stored in the future should be path-only
        assert_eq!(fut.uri, "/api/resource");
        assert!(!fut.uri.contains("secret"));
        let resp = fut.await.unwrap();
        assert_eq!(resp.status(), 200);
    }

    #[tokio::test]
    async fn middleware_propagates_error_status() {
        let svc = tower::service_fn(|_req: Request<String>| async {
            Ok::<_, std::convert::Infallible>(
                Response::builder()
                    .status(404)
                    .body(String::from("not found"))
                    .unwrap(),
            )
        });
        let svc = CfRequestIdLayer.layer(svc);

        let req = Request::builder().body(String::new()).unwrap();

        let resp = svc.oneshot(req).await.unwrap();
        assert_eq!(resp.status(), 404);
    }
}