use tower::{Layer, Service};
use uuid::Uuid;
const REQUEST_ID_HEADER: &str = "x-request-id";
pub struct RequestIdLayer;
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdService { inner }
}
}
pub struct RequestIdService<S> {
inner: S,
}
impl<S, B> Service<http::Request<B>> for RequestIdService<S>
where
S: Service<http::Request<B>> + Send + 'static,
S::Future: Send + 'static,
S::Response: Send + 'static,
S::Error: std::fmt::Debug + Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
let request_id = Uuid::new_v4().to_string();
req.headers_mut().insert(
REQUEST_ID_HEADER,
http::HeaderValue::from_str(&request_id).unwrap(),
);
Box::pin(self.inner.call(req))
}
}
pub fn get_request_id(headers: &http::HeaderMap) -> Option<String> {
headers
.get(REQUEST_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_id_layer_creation() {
let _ = RequestIdLayer;
}
#[test]
fn test_request_id_header_name() {
assert_eq!(REQUEST_ID_HEADER, "x-request-id");
}
#[test]
fn test_get_request_id() {
let mut headers = http::HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, http::HeaderValue::from_static("test-id"));
assert_eq!(get_request_id(&headers), Some("test-id".to_string()));
}
}