use tower_http::{
request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
sensitive_headers::SetSensitiveRequestHeadersLayer,
};
use crate::ids::MakeTypedRequestId;
pub const PROPAGATE_HEADERS: &[&str] = &[
"x-request-id",
"x-trace-id",
"x-span-id",
"x-correlation-id",
"x-client-id",
];
pub const SENSITIVE_HEADERS: &[&str] = &[
"authorization",
"cookie",
"set-cookie",
"x-api-key",
"x-auth-token",
];
#[derive(Debug, Clone)]
pub struct RequestTrackingConfig {
pub request_id_enabled: bool,
pub request_id_header: String,
pub propagate_headers: bool,
pub mask_sensitive_headers: bool,
}
impl Default for RequestTrackingConfig {
fn default() -> Self {
Self {
request_id_enabled: true,
request_id_header: "x-request-id".to_string(),
propagate_headers: true,
mask_sensitive_headers: true,
}
}
}
impl RequestTrackingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_request_id(mut self, enabled: bool) -> Self {
self.request_id_enabled = enabled;
self
}
pub fn with_request_id_header(mut self, header: impl Into<String>) -> Self {
self.request_id_header = header.into();
self
}
pub fn with_header_propagation(mut self, enabled: bool) -> Self {
self.propagate_headers = enabled;
self
}
pub fn with_sensitive_header_masking(mut self, enabled: bool) -> Self {
self.mask_sensitive_headers = enabled;
self
}
}
pub fn request_id_layer() -> SetRequestIdLayer<MakeTypedRequestId> {
SetRequestIdLayer::x_request_id(MakeTypedRequestId)
}
pub fn request_id_propagation_layer() -> PropagateRequestIdLayer {
PropagateRequestIdLayer::x_request_id()
}
pub fn sensitive_headers_layer() -> SetSensitiveRequestHeadersLayer {
let headers = SENSITIVE_HEADERS
.iter()
.map(|h| h.parse().expect("valid header name"))
.collect::<Vec<_>>();
SetSensitiveRequestHeadersLayer::new(headers)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = RequestTrackingConfig::default();
assert!(config.request_id_enabled);
assert!(config.propagate_headers);
assert!(config.mask_sensitive_headers);
assert_eq!(config.request_id_header, "x-request-id");
}
#[test]
fn test_builder_pattern() {
let config = RequestTrackingConfig::new()
.with_request_id(false)
.with_request_id_header("x-custom-id")
.with_header_propagation(false);
assert!(!config.request_id_enabled);
assert_eq!(config.request_id_header, "x-custom-id");
assert!(!config.propagate_headers);
}
#[test]
fn test_propagate_headers_constant() {
assert!(PROPAGATE_HEADERS.contains(&"x-request-id"));
assert!(PROPAGATE_HEADERS.contains(&"x-trace-id"));
}
#[test]
fn test_sensitive_headers_constant() {
assert!(SENSITIVE_HEADERS.contains(&"authorization"));
assert!(SENSITIVE_HEADERS.contains(&"x-api-key"));
}
}