#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use crate::{
Error, HttpRequest,
from_request::{FromRequest, IntoHandlerError},
};
use std::{collections::BTreeMap, fmt, str::FromStr};
#[derive(Debug)]
pub enum HeaderError {
MissingHeader {
name: String,
},
ParseError {
name: String,
value: String,
target_type: &'static str,
source: String,
},
InvalidHeaderValue {
name: String,
value: String,
reason: String,
},
DeserializationError {
source: String,
headers: BTreeMap<String, String>,
target_type: &'static str,
},
}
impl fmt::Display for HeaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingHeader { name } => {
write!(f, "Required header '{name}' is missing from the request")
}
Self::ParseError {
name,
value,
target_type,
source,
} => {
write!(
f,
"Failed to parse header '{name}' with value '{value}' into type '{target_type}': {source}"
)
}
Self::InvalidHeaderValue {
name,
value,
reason,
} => {
write!(f, "Header '{name}' has invalid value '{value}': {reason}")
}
Self::DeserializationError {
source,
target_type,
headers,
} => {
write!(
f,
"Failed to deserialize headers into type '{target_type}': {source}. Headers: {headers:?}"
)
}
}
}
}
impl std::error::Error for HeaderError {}
impl IntoHandlerError for HeaderError {
fn into_handler_error(self) -> Error {
Error::bad_request(self.to_string())
}
}
#[derive(Debug)]
pub struct Header<T>(pub T);
impl<T> Header<T> {
#[must_use]
pub const fn new(value: T) -> Self {
Self(value)
}
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
fn extract_single_header<T>(req: &HttpRequest, header_name: &str) -> Result<T, HeaderError>
where
T: FromStr,
T::Err: fmt::Display,
{
let value = req
.header(header_name)
.ok_or_else(|| HeaderError::MissingHeader {
name: header_name.to_string(),
})?;
value.parse::<T>().map_err(|e| HeaderError::ParseError {
name: header_name.to_string(),
value: value.to_string(),
target_type: std::any::type_name::<T>(),
source: e.to_string(),
})
}
fn extract_tuple_headers(
req: &HttpRequest,
header_names: &[&str],
) -> Result<Vec<String>, HeaderError> {
let mut values = Vec::new();
for &header_name in header_names {
let value = req
.header(header_name)
.ok_or_else(|| HeaderError::MissingHeader {
name: header_name.to_string(),
})?;
values.push(value.to_string());
}
Ok(values)
}
impl FromRequest for Header<String> {
type Error = HeaderError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
let value = extract_single_header::<String>(req, "authorization")?;
Ok(Self(value))
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
impl FromRequest for Header<u64> {
type Error = HeaderError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
let value = extract_single_header::<u64>(req, "content-length")?;
Ok(Self(value))
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
impl FromRequest for Header<bool> {
type Error = HeaderError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
let value = req.header("upgrade").is_some();
Ok(Self(value))
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
impl FromRequest for Header<(String, String)> {
type Error = HeaderError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
let values = extract_tuple_headers(req, &["authorization", "content-type"])?;
Ok(Self((values[0].clone(), values[1].clone())))
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
impl FromRequest for Header<(String, String, String)> {
type Error = HeaderError;
type Future = std::future::Ready<Result<Self, Self::Error>>;
fn from_request_sync(req: &HttpRequest) -> Result<Self, Self::Error> {
let values = extract_tuple_headers(req, &["authorization", "content-type", "user-agent"])?;
Ok(Self((
values[0].clone(),
values[1].clone(),
values[2].clone(),
)))
}
fn from_request_async(req: HttpRequest) -> Self::Future {
std::future::ready(Self::from_request_sync(&req))
}
}
#[cfg(all(test, feature = "simulator"))]
mod tests {
use super::*;
use crate::HttpRequest;
#[cfg(any(feature = "simulator", not(feature = "actix")))]
use crate::simulator::{SimulationRequest, SimulationStub};
fn create_test_request_with_headers(headers: &[(&str, &str)]) -> HttpRequest {
#[cfg(any(feature = "simulator", not(feature = "actix")))]
{
let mut sim_req = SimulationRequest::new(crate::Method::Get, "/test");
for (name, value) in headers {
sim_req = sim_req.with_header(*name, *value);
}
HttpRequest::new(SimulationStub::new(sim_req))
}
#[cfg(all(feature = "actix", not(feature = "simulator")))]
{
let _ = headers;
HttpRequest::new(crate::EmptyRequest)
}
}
#[test]
fn test_single_header_string_extraction() {
let http_req = create_test_request_with_headers(&[("authorization", "Bearer token123")]);
let result = Header::<String>::from_request_sync(&http_req);
assert!(result.is_ok());
assert_eq!(result.unwrap().0, "Bearer token123");
}
#[test]
fn test_single_header_u64_extraction() {
let http_req = create_test_request_with_headers(&[("content-length", "1024")]);
let result = Header::<u64>::from_request_sync(&http_req);
assert!(result.is_ok());
assert_eq!(result.unwrap().0, 1024);
}
#[test]
fn test_single_header_bool_extraction() {
let http_req = create_test_request_with_headers(&[("upgrade", "websocket")]);
let result = Header::<bool>::from_request_sync(&http_req);
assert!(result.is_ok());
assert!(result.unwrap().0);
}
#[test]
fn test_single_header_bool_extraction_missing() {
let http_req = create_test_request_with_headers(&[]);
let result = Header::<bool>::from_request_sync(&http_req);
assert!(result.is_ok());
assert!(!result.unwrap().0);
}
#[test]
fn test_tuple_header_extraction() {
let http_req = create_test_request_with_headers(&[
("authorization", "Bearer token123"),
("content-type", "application/json"),
]);
let result = Header::<(String, String)>::from_request_sync(&http_req);
assert!(result.is_ok());
let (auth, ct) = result.unwrap().0;
assert_eq!(auth, "Bearer token123");
assert_eq!(ct, "application/json");
}
#[test]
fn test_triple_header_extraction() {
let http_req = create_test_request_with_headers(&[
("authorization", "Bearer token123"),
("content-type", "application/json"),
("user-agent", "TestAgent/1.0"),
]);
let result = Header::<(String, String, String)>::from_request_sync(&http_req);
assert!(result.is_ok());
let (auth, ct, ua) = result.unwrap().0;
assert_eq!(auth, "Bearer token123");
assert_eq!(ct, "application/json");
assert_eq!(ua, "TestAgent/1.0");
}
#[test]
fn test_missing_header_error() {
let http_req = create_test_request_with_headers(&[]);
let result = Header::<String>::from_request_sync(&http_req);
assert!(result.is_err());
match result.unwrap_err() {
HeaderError::MissingHeader { name } => {
assert_eq!(name, "authorization");
}
_ => panic!("Expected MissingHeader error"),
}
}
#[test]
fn test_parse_error() {
let http_req = create_test_request_with_headers(&[("content-length", "invalid")]);
let result = Header::<u64>::from_request_sync(&http_req);
assert!(result.is_err());
match result.unwrap_err() {
HeaderError::ParseError {
name,
value,
target_type,
..
} => {
assert_eq!(name, "content-length");
assert_eq!(value, "invalid");
assert_eq!(target_type, "u64");
}
_ => panic!("Expected ParseError"),
}
}
#[test]
fn test_tuple_missing_header_error() {
let http_req = create_test_request_with_headers(&[("authorization", "Bearer token123")]);
let result = Header::<(String, String)>::from_request_sync(&http_req);
assert!(result.is_err());
match result.unwrap_err() {
HeaderError::MissingHeader { name } => {
assert_eq!(name, "content-type");
}
_ => panic!("Expected MissingHeader error"),
}
}
#[test]
fn test_header_error_display() {
let error = HeaderError::MissingHeader {
name: "authorization".to_string(),
};
assert_eq!(
error.to_string(),
"Required header 'authorization' is missing from the request"
);
let error = HeaderError::ParseError {
name: "content-length".to_string(),
value: "invalid".to_string(),
target_type: "u64",
source: "invalid digit found in string".to_string(),
};
assert!(
error
.to_string()
.contains("Failed to parse header 'content-length'")
);
}
#[test]
fn test_header_into_inner() {
let header = Header::new("test_value".to_string());
assert_eq!(header.into_inner(), "test_value");
}
#[test]
fn test_header_error_display_invalid_header_value() {
let error = HeaderError::InvalidHeaderValue {
name: "x-custom-header".to_string(),
value: "invalid\x00value".to_string(),
reason: "contains null byte".to_string(),
};
let display = error.to_string();
assert_eq!(
display,
"Header 'x-custom-header' has invalid value 'invalid\x00value': contains null byte"
);
}
#[test]
fn test_header_error_display_deserialization_error() {
let mut headers = BTreeMap::new();
headers.insert("content-type".to_string(), "application/json".to_string());
headers.insert("authorization".to_string(), "Bearer token".to_string());
let error = HeaderError::DeserializationError {
source: "missing field `user_agent`".to_string(),
headers,
target_type: "RequestHeaders",
};
let display = error.to_string();
assert!(display.contains("Failed to deserialize headers into type 'RequestHeaders'"));
assert!(display.contains("missing field `user_agent`"));
assert!(display.contains("authorization"));
assert!(display.contains("content-type"));
}
#[test]
fn test_header_error_into_handler_error() {
let error = HeaderError::MissingHeader {
name: "x-api-key".to_string(),
};
let handler_error = error.into_handler_error();
match handler_error {
crate::Error::Http { status_code, .. } => {
assert_eq!(status_code, switchy_http_models::StatusCode::BadRequest);
}
}
}
#[test]
fn test_header_error_is_std_error() {
let error: &dyn std::error::Error = &HeaderError::MissingHeader {
name: "test".to_string(),
};
assert!(error.source().is_none());
assert!(!error.to_string().is_empty());
}
}