use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use http::{Request, Response};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use tracing::field::Empty;
use crate::constants::metrics;
use crate::trace_id::{generate_span_id, generate_trace_id, hex_encode};
#[derive(Clone, Debug)]
pub struct CfRequestIdLayer;
impl<S> Layer<S> for CfRequestIdLayer {
type Service = CfRequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
CfRequestIdService { inner }
}
}
#[derive(Clone, Debug)]
pub struct CfRequestIdService<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CfRequestIdService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
S::Future: Send + 'static,
S::Error: std::fmt::Display,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = CfRequestIdFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let method = req.method().to_string();
let uri = req.uri().path().to_string();
let request_id = req
.headers()
.get("x-amz-cf-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let request_id_display = request_id.as_deref().unwrap_or("-").to_string();
let trace_id = generate_trace_id(request_id.as_deref());
let span_id = generate_span_id();
let trace_id_hex = hex_encode(&trace_id);
let span_id_hex = hex_encode(&span_id);
let span = tracing::info_span!(
"request",
http.method = %method,
http.uri = %uri,
http.status_code = Empty,
http.latency_ms = Empty,
cf.request_id = %request_id_display,
trace_id = %trace_id_hex,
span_id = %span_id_hex,
);
let start = Instant::now();
let future = self.inner.call(req);
CfRequestIdFuture {
inner: future,
span,
start,
method,
uri,
}
}
}
pin_project! {
pub struct CfRequestIdFuture<F> {
#[pin]
inner: F,
span: tracing::Span,
start: Instant,
method: String,
uri: String,
}
}
impl<F, ResBody, E> Future for CfRequestIdFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
E: std::fmt::Display,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = this.span.enter();
match this.inner.poll(cx) {
Poll::Ready(result) => {
let latency_ms = this.start.elapsed().as_secs_f64() * 1000.0;
match &result {
Ok(response) => {
let status = response.status().as_u16();
this.span.record("http.status_code", status);
this.span.record("http.latency_ms", latency_ms);
tracing::info!(
metric = metrics::REQUEST_DURATION,
r#type = "histogram",
value = latency_ms,
method = %this.method,
route = %this.uri,
status = status,
);
tracing::info!(
metric = metrics::REQUEST_COUNT,
r#type = "counter",
value = 1u64,
method = %this.method,
route = %this.uri,
status = status,
);
if status >= 400 {
tracing::info!(
metric = metrics::ERROR_COUNT,
r#type = "counter",
value = 1u64,
method = %this.method,
route = %this.uri,
status = status,
);
}
}
Err(_) => {
this.span.record("http.latency_ms", latency_ms);
}
}
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tower::ServiceExt;
#[test]
fn cf_request_id_layer_constructs() {
let _layer = CfRequestIdLayer;
}
#[test]
fn extract_request_id_from_header() {
let req = Request::builder()
.header("x-amz-cf-id", "test-cf-id-123")
.body(())
.unwrap();
let request_id = req
.headers()
.get("x-amz-cf-id")
.and_then(|v| v.to_str().ok());
assert_eq!(request_id, Some("test-cf-id-123"));
}
#[test]
fn missing_request_id_header() {
let req = Request::builder().body(()).unwrap();
let request_id = req
.headers()
.get("x-amz-cf-id")
.and_then(|v| v.to_str().ok());
assert_eq!(request_id, None);
}
#[tokio::test]
async fn middleware_wraps_request_and_returns_response() {
let svc = tower::service_fn(|_req: Request<String>| async {
Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
});
let svc = CfRequestIdLayer.layer(svc);
let req = Request::builder()
.header("x-amz-cf-id", "test-cf-id-123")
.body(String::new())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.into_body(), "ok");
}
#[tokio::test]
async fn middleware_works_without_cf_header() {
let svc = tower::service_fn(|_req: Request<String>| async {
Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
});
let svc = CfRequestIdLayer.layer(svc);
let req = Request::builder().body(String::new()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn middleware_strips_query_params_from_uri() {
let svc = tower::service_fn(|_req: Request<String>| async {
Ok::<_, std::convert::Infallible>(Response::new(String::from("ok")))
});
let mut svc = CfRequestIdLayer.layer(svc);
let req = Request::builder()
.uri("/api/resource?token=secret&foo=bar")
.body(String::new())
.unwrap();
std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
let fut = svc.call(req);
assert_eq!(fut.uri, "/api/resource");
assert!(!fut.uri.contains("secret"));
let resp = fut.await.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn middleware_propagates_error_status() {
let svc = tower::service_fn(|_req: Request<String>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(404)
.body(String::from("not found"))
.unwrap(),
)
});
let svc = CfRequestIdLayer.layer(svc);
let req = Request::builder().body(String::new()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 404);
}
}