use std::ops::Deref;
use std::str::FromStr;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct Header<T> {
value: T,
name: &'static str,
}
impl<T> Header<T> {
pub fn new(name: &'static str, value: T) -> Self {
Self { value, name }
}
pub fn header_name(&self) -> &'static str {
self.name
}
pub fn into_inner(self) -> T {
self.value
}
#[doc(hidden)]
pub fn from_string(name: String, value: T) -> Self {
let name: &'static str = Box::leak(name.into_boxed_str());
Self { value, name }
}
}
impl<T> Deref for Header<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
pub trait FromHeaderStr: Sized {
fn from_header_str(s: &str) -> Result<Self, String>;
}
impl FromHeaderStr for String {
fn from_header_str(s: &str) -> Result<Self, String> {
Ok(s.to_owned())
}
}
impl FromHeaderStr for uuid::Uuid {
fn from_header_str(s: &str) -> Result<Self, String> {
uuid::Uuid::parse_str(s).map_err(|e| e.to_string())
}
}
macro_rules! impl_header_value_fromstr {
($($ty:ty),+ $(,)?) => {
$(impl FromHeaderStr for $ty {
fn from_header_str(s: &str) -> Result<Self, String> {
<$ty as FromStr>::from_str(s).map_err(|e| e.to_string())
}
})+
};
}
impl_header_value_fromstr!(
i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64, bool,
);
pub fn extract_header<T: FromHeaderStr>(
parts: &http::request::Parts,
name: &'static str,
) -> Result<T, Error> {
let raw = parts
.headers
.get(name)
.ok_or_else(|| {
Error::new(
400,
"MISSING_HEADER",
format!("Missing required header: {name}"),
)
.with_details(serde_json::json!({ "header": name }))
})?
.to_str()
.map_err(|_| {
Error::new(
400,
"INVALID_HEADER",
format!("Header '{name}' contains non-UTF-8 bytes"),
)
.with_details(serde_json::json!({
"header": name,
"reason": "non-UTF-8 bytes",
}))
})?;
T::from_header_str(raw).map_err(|reason| {
Error::new(
400,
"INVALID_HEADER",
format!("Invalid value for header '{name}': {reason}"),
)
.with_details(serde_json::json!({
"header": name,
"reason": reason,
}))
})
}
pub fn extract_optional_header<T: FromHeaderStr>(
parts: &http::request::Parts,
name: &'static str,
) -> Result<Option<T>, Error> {
let raw = match parts.headers.get(name) {
None => return Ok(None),
Some(v) => v.to_str().map_err(|_| {
Error::new(
400,
"INVALID_HEADER",
format!("Header '{name}' contains non-UTF-8 bytes"),
)
.with_details(serde_json::json!({
"header": name,
"reason": "non-UTF-8 bytes",
}))
})?,
};
T::from_header_str(raw).map(Some).map_err(|reason| {
Error::new(
400,
"INVALID_HEADER",
format!("Invalid value for header '{name}': {reason}"),
)
.with_details(serde_json::json!({
"header": name,
"reason": reason,
}))
})
}
#[doc(hidden)]
pub use extract_header as __extract_header;
#[doc(hidden)]
pub use extract_optional_header as __extract_optional_header;
#[cfg(test)]
mod tests {
use crate::test::TestRequest;
use super::*;
fn parts_with_header(name: &str, value: &str) -> http::request::Parts {
let (parts, _) = TestRequest::get("/").header(name, value).into_parts();
parts
}
fn parts_without_header() -> http::request::Parts {
TestRequest::get("/").into_parts().0
}
#[test]
fn test_extract_string_header_present() {
let parts = parts_with_header("x-request-id", "abc-123");
let v = extract_header::<String>(&parts, "x-request-id").unwrap();
assert_eq!(v, "abc-123");
}
#[test]
fn test_extract_string_header_missing_returns_400() {
let parts = parts_without_header();
let err = extract_header::<String>(&parts, "x-request-id").unwrap_err();
assert_eq!(err.status(), 400);
assert_eq!(err.code(), "MISSING_HEADER");
let details = err.details().unwrap();
assert_eq!(details["header"], "x-request-id");
}
#[test]
fn test_extract_u64_header_valid() {
let parts = parts_with_header("x-retry-count", "3");
let v = extract_header::<u64>(&parts, "x-retry-count").unwrap();
assert_eq!(v, 3);
}
#[test]
fn test_extract_u64_header_malformed_returns_400() {
let parts = parts_with_header("x-retry-count", "not-a-number");
let err = extract_header::<u64>(&parts, "x-retry-count").unwrap_err();
assert_eq!(err.status(), 400);
assert_eq!(err.code(), "INVALID_HEADER");
let details = err.details().unwrap();
assert_eq!(details["header"], "x-retry-count");
assert!(details["reason"].is_string());
}
#[test]
fn test_extract_uuid_header_valid() {
let id = uuid::Uuid::new_v4();
let parts = parts_with_header("x-correlation-id", &id.to_string());
let v = extract_header::<uuid::Uuid>(&parts, "x-correlation-id").unwrap();
assert_eq!(v, id);
}
#[test]
fn test_extract_uuid_header_malformed() {
let parts = parts_with_header("x-correlation-id", "not-a-uuid");
let err = extract_header::<uuid::Uuid>(&parts, "x-correlation-id").unwrap_err();
assert_eq!(err.status(), 400);
assert_eq!(err.code(), "INVALID_HEADER");
}
#[test]
fn test_optional_header_present() {
let parts = parts_with_header("x-request-id", "abc");
let result = extract_optional_header::<String>(&parts, "x-request-id").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap(), "abc");
}
#[test]
fn test_optional_header_absent_returns_none() {
let parts = parts_without_header();
let result = extract_optional_header::<String>(&parts, "x-request-id").unwrap();
assert!(result.is_none());
}
#[test]
fn test_optional_header_malformed_returns_400() {
let parts = parts_with_header("x-count", "not-a-number");
let err = extract_optional_header::<u32>(&parts, "x-count").unwrap_err();
assert_eq!(err.status(), 400);
assert_eq!(err.code(), "INVALID_HEADER");
}
#[test]
fn test_header_into_inner() {
let h = Header::new("x-foo", "bar".to_string());
assert_eq!(h.into_inner(), "bar");
}
#[test]
fn test_header_deref() {
let h = Header::new("x-count", 42u64);
assert_eq!(*h, 42);
}
#[test]
fn test_header_name() {
let h = Header::new("x-request-id", "id".to_string());
assert_eq!(h.header_name(), "x-request-id");
}
#[test]
fn test_header_value_bool() {
assert!(bool::from_header_str("true").unwrap());
assert!(!bool::from_header_str("false").unwrap());
assert!(bool::from_header_str("yes").is_err());
}
#[test]
fn test_header_value_f64() {
assert_eq!(f64::from_header_str("1.5").unwrap(), 1.5_f64);
assert!(f64::from_header_str("abc").is_err());
}
}