use std::time::Duration;
use axum::http::HeaderMap;
use tonic::metadata::MetadataMap;
pub fn http_headers_to_grpc_metadata(
headers: &HeaderMap,
forwarded_headers: &[String],
) -> MetadataMap {
let mut metadata = MetadataMap::new();
for header_name in forwarded_headers {
if let Some(value) = headers.get(header_name.as_str()) {
insert_ascii(&mut metadata, header_name, value.as_bytes());
}
}
inject_trace_context(&mut metadata, headers);
metadata
}
fn insert_ascii(metadata: &mut MetadataMap, key: &str, value: &[u8]) {
if let (Ok(k), Ok(v)) = (
key.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>(),
tonic::metadata::AsciiMetadataValue::try_from(value),
) {
metadata.insert(k, v);
}
}
fn inject_trace_context(metadata: &mut MetadataMap, headers: &HeaderMap) {
if let Some(tp) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
if is_valid_traceparent(tp) {
insert_ascii(metadata, "traceparent", tp.as_bytes());
if let Some(ts) = headers.get("tracestate") {
insert_ascii(metadata, "tracestate", ts.as_bytes());
}
return;
}
}
if let Some(tp) = new_traceparent() {
insert_ascii(metadata, "traceparent", tp.as_bytes());
}
}
fn is_valid_traceparent(tp: &str) -> bool {
let parts: Vec<&str> = tp.split('-').collect();
if parts.len() < 4 {
return false;
}
let (version, trace_id, parent_id, flags) = (parts[0], parts[1], parts[2], parts[3]);
if version == "00" && parts.len() != 4 {
return false;
}
let is_hex = |s: &str, len: usize| s.len() == len && s.bytes().all(|b| b.is_ascii_hexdigit());
is_hex(version, 2)
&& !version.eq_ignore_ascii_case("ff")
&& is_hex(trace_id, 32)
&& is_hex(parent_id, 16)
&& is_hex(flags, 2)
&& trace_id.bytes().any(|b| b != b'0')
&& parent_id.bytes().any(|b| b != b'0')
}
fn new_traceparent() -> Option<String> {
let mut buf = [0u8; 24];
getrandom::fill(&mut buf).ok()?;
let trace_id = hex(&buf[..16]);
let span_id = hex(&buf[16..]);
Some(format!("00-{trace_id}-{span_id}-01"))
}
fn hex(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
pub fn apply_request_deadline<T>(
request: &mut tonic::Request<T>,
headers: &HeaderMap,
) -> Option<Duration> {
let timeout = headers
.get("grpc-timeout")
.and_then(|v| v.to_str().ok())
.and_then(parse_grpc_timeout)?;
request.set_timeout(timeout);
Some(timeout)
}
fn parse_grpc_timeout(value: &str) -> Option<Duration> {
let value = value.trim();
let (digits, unit) = value.split_at(value.len().checked_sub(1)?);
if digits.is_empty() || digits.len() > 8 {
return None;
}
let n: u64 = digits.parse().ok()?;
let dur = match unit {
"H" => Duration::from_secs(n * 3600),
"M" => Duration::from_secs(n * 60),
"S" => Duration::from_secs(n),
"m" => Duration::from_millis(n),
"u" => Duration::from_micros(n),
"n" => Duration::from_nanos(n),
_ => return None,
};
if dur.is_zero() {
return None;
}
Some(dur)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn default_headers() -> Vec<String> {
vec![
"authorization".into(),
"dpop".into(),
"x-request-id".into(),
"x-forwarded-for".into(),
"x-forwarded-proto".into(),
"x-real-ip".into(),
"accept-language".into(),
"user-agent".into(),
"idempotency-key".into(),
]
}
#[test]
fn test_authorization_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("Bearer tok123"));
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
assert_eq!(
meta.get("authorization").unwrap().to_str().unwrap(),
"Bearer tok123"
);
}
#[test]
fn test_multiple_headers_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("Bearer tok"));
headers.insert("x-request-id", HeaderValue::from_static("req-42"));
headers.insert("accept-language", HeaderValue::from_static("en-US"));
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
assert_eq!(
meta.get("authorization").unwrap().to_str().unwrap(),
"Bearer tok"
);
assert_eq!(
meta.get("x-request-id").unwrap().to_str().unwrap(),
"req-42"
);
assert_eq!(
meta.get("accept-language").unwrap().to_str().unwrap(),
"en-US"
);
}
#[test]
fn test_unknown_headers_not_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("x-custom-header", HeaderValue::from_static("value"));
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
assert!(meta.get("x-custom-header").is_none());
}
#[test]
fn test_custom_forwarded_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-custom-header", HeaderValue::from_static("value"));
let forwarded = vec!["x-custom-header".to_string()];
let meta = http_headers_to_grpc_metadata(&headers, &forwarded);
assert_eq!(
meta.get("x-custom-header").unwrap().to_str().unwrap(),
"value"
);
}
#[test]
fn test_empty_headers_still_inject_traceparent() {
let headers = HeaderMap::new();
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
let tp = meta.get("traceparent").unwrap().to_str().unwrap();
assert!(is_valid_traceparent(tp), "bad traceparent: {tp}");
assert!(meta.get("authorization").is_none());
}
#[test]
fn traceparent_is_forwarded_when_present() {
let mut headers = HeaderMap::new();
let incoming = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
headers.insert("traceparent", HeaderValue::from_static(incoming));
headers.insert("tracestate", HeaderValue::from_static("vendor=value"));
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
assert_eq!(
meta.get("tracestate").unwrap().to_str().unwrap(),
"vendor=value"
);
}
#[test]
fn synthesized_traceparent_is_unique_per_call() {
let headers = HeaderMap::new();
let a = http_headers_to_grpc_metadata(&headers, &[]);
let b = http_headers_to_grpc_metadata(&headers, &[]);
assert_ne!(
a.get("traceparent").unwrap().to_str().unwrap(),
b.get("traceparent").unwrap().to_str().unwrap()
);
}
#[test]
fn grpc_timeout_parses_each_unit() {
assert_eq!(parse_grpc_timeout("5S"), Some(Duration::from_secs(5)));
assert_eq!(parse_grpc_timeout("100m"), Some(Duration::from_millis(100)));
assert_eq!(parse_grpc_timeout("2M"), Some(Duration::from_secs(120)));
assert_eq!(parse_grpc_timeout("1H"), Some(Duration::from_secs(3600)));
assert_eq!(parse_grpc_timeout("250u"), Some(Duration::from_micros(250)));
assert_eq!(parse_grpc_timeout("9n"), Some(Duration::from_nanos(9)));
}
#[test]
fn grpc_timeout_rejects_malformed() {
assert_eq!(parse_grpc_timeout(""), None);
assert_eq!(parse_grpc_timeout("S"), None);
assert_eq!(parse_grpc_timeout("10X"), None);
assert_eq!(parse_grpc_timeout("abcS"), None);
}
#[test]
fn grpc_timeout_rejects_zero_duration() {
assert_eq!(parse_grpc_timeout("0S"), None);
assert_eq!(parse_grpc_timeout("0m"), None);
assert_eq!(parse_grpc_timeout("0n"), None);
}
#[test]
fn grpc_timeout_enforces_8_digit_limit() {
assert_eq!(
parse_grpc_timeout("99999999S"),
Some(Duration::from_secs(99_999_999))
);
assert_eq!(parse_grpc_timeout("999999999S"), None); }
#[test]
fn versioned_traceparent_is_forwarded() {
let incoming = "01-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
let mut headers = HeaderMap::new();
headers.insert("traceparent", HeaderValue::from_static(incoming));
let meta = http_headers_to_grpc_metadata(&headers, &[]);
assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
}
#[test]
fn ff_version_traceparent_is_rejected() {
let invalid = "ff-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
let mut headers = HeaderMap::new();
headers.insert("traceparent", HeaderValue::from_static(invalid));
let meta = http_headers_to_grpc_metadata(&headers, &[]);
let tp = meta.get("traceparent").unwrap().to_str().unwrap();
assert_ne!(tp, invalid);
assert!(is_valid_traceparent(tp));
}
#[test]
fn malformed_or_zero_traceparent_is_not_forwarded() {
let zeros = "00-00000000000000000000000000000000-0000000000000000-01";
let mut headers = HeaderMap::new();
headers.insert("traceparent", HeaderValue::from_static(zeros));
let meta = http_headers_to_grpc_metadata(&headers, &[]);
let tp = meta.get("traceparent").unwrap().to_str().unwrap();
assert_ne!(tp, zeros);
assert!(
is_valid_traceparent(tp),
"synthesized traceparent invalid: {tp}"
);
}
#[test]
fn apply_request_deadline_sets_timeout_from_header() {
let mut headers = HeaderMap::new();
headers.insert("grpc-timeout", HeaderValue::from_static("3S"));
let mut req = tonic::Request::new(());
assert_eq!(
apply_request_deadline(&mut req, &headers),
Some(Duration::from_secs(3))
);
}
#[test]
fn apply_request_deadline_noop_without_header() {
let headers = HeaderMap::new();
let mut req = tonic::Request::new(());
assert_eq!(apply_request_deadline(&mut req, &headers), None);
}
#[test]
fn test_dpop_forwarded() {
let mut headers = HeaderMap::new();
headers.insert("dpop", HeaderValue::from_static("eyJ0eXAiOiJkcG9wK2p3dCJ9"));
let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
assert!(meta.get("dpop").is_some());
}
}