use axum::{
body::Body,
extract::State,
http::{HeaderName, HeaderValue, Request},
middleware::Next,
response::Response,
};
use tracing::debug;
use uuid::Uuid;
use crate::HttpServerState;
pub async fn production_headers_middleware(
State(state): State<HttpServerState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let mut response = next.run(req).await;
if let Some(headers) = &state.production_headers {
for (key, value) in headers.iter() {
let expanded_value = expand_templates(value);
if let (Ok(header_name), Ok(header_value)) =
(key.parse::<HeaderName>(), expanded_value.parse::<HeaderValue>())
{
if !response.headers().contains_key(&header_name) {
response.headers_mut().insert(header_name, header_value);
debug!("Added production header: {} = {}", key, expanded_value);
}
} else {
tracing::warn!("Failed to parse production header: {} = {}", key, expanded_value);
}
}
}
response
}
fn expand_templates(value: &str) -> String {
let mut result = value.to_string();
if result.contains("{{uuid}}") {
let uuid = Uuid::new_v4().to_string();
result = result.replace("{{uuid}}", &uuid);
}
if result.contains("{{now}}") {
let now = chrono::Utc::now().to_rfc3339();
result = result.replace("{{now}}", &now);
}
if result.contains("{{timestamp}}") {
let timestamp = chrono::Utc::now().timestamp().to_string();
result = result.replace("{{timestamp}}", ×tamp);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expand_uuid_template() {
let value = "{{uuid}}";
let expanded = expand_templates(value);
assert_eq!(expanded.len(), 36);
assert!(!expanded.contains("{{uuid}}"));
}
#[test]
fn test_expand_now_template() {
let value = "{{now}}";
let expanded = expand_templates(value);
assert!(expanded.len() > 15);
assert!(!expanded.contains("{{now}}"));
assert!(expanded.contains('T'));
}
#[test]
fn test_expand_timestamp_template() {
let value = "{{timestamp}}";
let expanded = expand_templates(value);
assert!(expanded.parse::<i64>().is_ok());
assert!(!expanded.contains("{{timestamp}}"));
}
#[test]
fn test_expand_multiple_templates() {
let value = "Request-{{uuid}} at {{timestamp}}";
let expanded = expand_templates(value);
assert!(!expanded.contains("{{uuid}}"));
assert!(!expanded.contains("{{timestamp}}"));
assert!(expanded.starts_with("Request-"));
}
#[test]
fn test_no_templates() {
let value = "Static header value";
let expanded = expand_templates(value);
assert_eq!(expanded, value);
}
}