use std::collections::HashMap;
use std::fmt;
use crate::bytes::Bytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StatusCode(u16);
impl StatusCode {
pub const CONTINUE: Self = Self(100);
pub const SWITCHING_PROTOCOLS: Self = Self(101);
pub const OK: Self = Self(200);
pub const CREATED: Self = Self(201);
pub const ACCEPTED: Self = Self(202);
pub const NO_CONTENT: Self = Self(204);
pub const PARTIAL_CONTENT: Self = Self(206);
pub const MOVED_PERMANENTLY: Self = Self(301);
pub const FOUND: Self = Self(302);
pub const SEE_OTHER: Self = Self(303);
pub const NOT_MODIFIED: Self = Self(304);
pub const TEMPORARY_REDIRECT: Self = Self(307);
pub const PERMANENT_REDIRECT: Self = Self(308);
pub const BAD_REQUEST: Self = Self(400);
pub const UNAUTHORIZED: Self = Self(401);
pub const FORBIDDEN: Self = Self(403);
pub const NOT_FOUND: Self = Self(404);
pub const METHOD_NOT_ALLOWED: Self = Self(405);
pub const REQUEST_TIMEOUT: Self = Self(408);
pub const CONFLICT: Self = Self(409);
pub const PAYLOAD_TOO_LARGE: Self = Self(413);
pub const UNSUPPORTED_MEDIA_TYPE: Self = Self(415);
pub const RANGE_NOT_SATISFIABLE: Self = Self(416);
pub const UNPROCESSABLE_ENTITY: Self = Self(422);
pub const TOO_MANY_REQUESTS: Self = Self(429);
pub const CLIENT_CLOSED_REQUEST: Self = Self(499);
pub const INTERNAL_SERVER_ERROR: Self = Self(500);
pub const NOT_IMPLEMENTED: Self = Self(501);
pub const BAD_GATEWAY: Self = Self(502);
pub const SERVICE_UNAVAILABLE: Self = Self(503);
pub const GATEWAY_TIMEOUT: Self = Self(504);
#[must_use]
pub const fn from_u16(code: u16) -> Self {
Self(code)
}
#[must_use]
pub const fn as_u16(self) -> u16 {
self.0
}
#[must_use]
pub const fn is_success(self) -> bool {
self.0 >= 200 && self.0 < 300
}
#[must_use]
pub const fn is_client_error(self) -> bool {
self.0 >= 400 && self.0 < 500
}
#[must_use]
pub const fn is_server_error(self) -> bool {
self.0 >= 500 && self.0 < 600
}
}
impl fmt::Display for StatusCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: StatusCode,
pub headers: HashMap<String, String>,
pub set_cookies: Vec<String>,
pub body: Bytes,
}
impl Response {
#[must_use]
pub fn new(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
headers: HashMap::with_capacity(4),
set_cookies: Vec::new(),
body: body.into(),
}
}
#[must_use]
pub fn empty(status: StatusCode) -> Self {
Self::new(status, Bytes::new())
}
#[must_use]
pub fn header_value(&self, name: &str) -> Option<&str> {
if name.eq_ignore_ascii_case("set-cookie") {
return self.set_cookies.first().map(String::as_str);
}
if let Some(value) = self.headers.get(name) {
return Some(value.as_str());
}
self.headers
.iter()
.filter(|(key, _)| key.eq_ignore_ascii_case(name))
.min_by(|(a, _), (b, _)| a.cmp(b))
.map(|(_, value)| value.as_str())
}
#[must_use]
pub fn has_header(&self, name: &str) -> bool {
if name.eq_ignore_ascii_case("set-cookie") {
return !self.set_cookies.is_empty();
}
self.header_value(name).is_some()
}
pub fn append_set_cookie(&mut self, value: impl Into<String>) {
self.set_cookies.push(sanitize_header_value(value.into()));
}
pub fn set_header(&mut self, name: impl Into<String>, value: impl Into<String>) {
let normalized = sanitize_header_name(name.into()).to_ascii_lowercase();
if normalized == "set-cookie" {
self.append_set_cookie(value.into());
return;
}
let sanitized_value = sanitize_header_value(value.into());
self.headers
.retain(|key, _| !key.eq_ignore_ascii_case(&normalized));
self.headers.insert(normalized, sanitized_value);
}
pub fn ensure_header(&mut self, name: &str, default_value: impl Into<String>) {
if name.eq_ignore_ascii_case("set-cookie") {
if self.set_cookies.is_empty() {
self.append_set_cookie(default_value.into());
}
return;
}
let normalized = sanitize_header_name(name.to_owned()).to_ascii_lowercase();
let value = self
.headers
.iter()
.find(|(key, _)| key.eq_ignore_ascii_case(&normalized))
.map_or_else(|| default_value.into(), |(_, value)| value.clone());
self.headers
.retain(|key, _| !key.eq_ignore_ascii_case(&normalized));
self.headers
.insert(normalized, sanitize_header_value(value));
}
pub fn remove_header(&mut self, name: &str) -> Option<String> {
if name.eq_ignore_ascii_case("set-cookie") {
if self.set_cookies.is_empty() {
return None;
}
let first = self.set_cookies.remove(0);
self.set_cookies.clear();
return Some(first);
}
let normalized = name.to_ascii_lowercase();
let mut matching_keys: Vec<String> = self
.headers
.keys()
.filter(|key| key.eq_ignore_ascii_case(name))
.cloned()
.collect();
matching_keys.sort_by(|left, right| {
(left != &normalized, left.as_str()).cmp(&(right != &normalized, right.as_str()))
});
let mut removed = None;
for key in matching_keys {
if let Some(value) = self.headers.remove(&key) {
removed.get_or_insert(value);
}
}
removed
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.set_header(name, value);
self
}
}
pub trait IntoResponse {
fn into_response(self) -> Response;
}
impl IntoResponse for Response {
fn into_response(self) -> Response {
self
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> Response {
Response::empty(self)
}
}
impl IntoResponse for String {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from(self))
.header("content-type", "text/plain; charset=utf-8")
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from_static(self.as_bytes()))
.header("content-type", "text/plain; charset=utf-8")
}
}
impl IntoResponse for Bytes {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, self).header("content-type", "application/octet-stream")
}
}
impl IntoResponse for Vec<u8> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from(self))
.header("content-type", "application/octet-stream")
}
}
impl IntoResponse for () {
fn into_response(self) -> Response {
Response::empty(StatusCode::OK)
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
fn into_response(self) -> Response {
let mut resp = self.1.into_response();
resp.status = self.0;
resp
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, Vec<(String, String)>, T) {
fn into_response(self) -> Response {
let mut resp = self.2.into_response();
resp.status = self.0;
for (k, v) in self.1 {
resp.set_header(k, v);
}
resp
}
}
impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
fn into_response(self) -> Response {
match self {
Ok(ok) => ok.into_response(),
Err(err) => err.into_response(),
}
}
}
#[derive(Debug, Clone)]
pub struct Json<T>(pub T);
impl<T: serde::Serialize> IntoResponse for Json<T> {
fn into_response(self) -> Response {
serde_json::to_vec(&self.0).map_or_else(
|_| Response::empty(StatusCode::INTERNAL_SERVER_ERROR),
|body| {
Response::new(StatusCode::OK, Bytes::from(body))
.header("content-type", "application/json")
},
)
}
}
#[derive(Debug, Clone)]
pub struct Html<T>(pub T);
impl IntoResponse for Html<String> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::copy_from_slice(self.0.as_bytes()))
.header("content-type", "text/html; charset=utf-8")
}
}
impl IntoResponse for Html<&'static str> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from_static(self.0.as_bytes()))
.header("content-type", "text/html; charset=utf-8")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RedirectError {
EmptyUri,
ProtocolRelative,
BackslashInPath,
SchemeNotAllowed {
scheme: String,
},
HostNotAllowed {
host: String,
},
}
impl fmt::Display for RedirectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EmptyUri => write!(f, "redirect URI is empty"),
Self::ProtocolRelative => write!(
f,
"redirect URI starts with '//' (protocol-relative — defeats naive same-origin checks)"
),
Self::BackslashInPath => write!(
f,
"redirect URI contains a backslash (intermediaries may normalize to '/' creating a protocol-relative URL)"
),
Self::SchemeNotAllowed { scheme } => write!(
f,
"redirect URI scheme '{scheme}' not allowed (only 'http' and 'https')"
),
Self::HostNotAllowed { host } => write!(
f,
"redirect URI host '{host}' not in the allowed-hosts allowlist"
),
}
}
}
impl std::error::Error for RedirectError {}
fn validate_redirect_uri(uri: &str, allowed_hosts: Option<&[&str]>) -> Result<(), RedirectError> {
if uri.is_empty() {
return Err(RedirectError::EmptyUri);
}
if uri.bytes().any(|b| !(0x21..=0x7E).contains(&b)) {
return Err(RedirectError::ProtocolRelative);
}
if uri.contains('\\') {
return Err(RedirectError::BackslashInPath);
}
if uri.starts_with("//") {
return Err(RedirectError::ProtocolRelative);
}
if let Some(rest) = uri.strip_prefix('/') {
let lower_first = rest.bytes().next().map(|b| b.to_ascii_lowercase());
if rest.starts_with("%2f")
|| rest.starts_with("%2F")
|| rest.starts_with("%5c")
|| rest.starts_with("%5C")
|| lower_first == Some(b'\\')
{
return Err(RedirectError::ProtocolRelative);
}
}
if uri.starts_with('/') {
return Ok(());
}
let (scheme, rest) = match uri.split_once(':') {
Some((scheme, rest)) => (scheme.to_ascii_lowercase(), rest),
None => {
return Err(RedirectError::SchemeNotAllowed {
scheme: String::new(),
});
}
};
if scheme != "http" && scheme != "https" {
return Err(RedirectError::SchemeNotAllowed { scheme });
}
let after_slashes = rest.strip_prefix("//").ok_or_else(|| {
RedirectError::SchemeNotAllowed {
scheme: scheme.clone(),
}
})?;
let host_with_port = after_slashes.split(['/', '?', '#']).next().unwrap_or("");
let host = host_with_port
.rsplit_once(':')
.map_or(host_with_port, |(h, _)| h);
let host = host.trim_start_matches('[').trim_end_matches(']'); if host.is_empty() {
return Err(RedirectError::HostNotAllowed {
host: String::new(),
});
}
let allowed_hosts = allowed_hosts.unwrap_or(&[]);
if allowed_hosts
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(host))
{
Ok(())
} else {
Err(RedirectError::HostNotAllowed {
host: host.to_string(),
})
}
}
#[derive(Debug, Clone)]
pub struct Redirect {
status: StatusCode,
location: String,
}
impl Redirect {
pub fn to(uri: impl Into<String>) -> Result<Self, RedirectError> {
let uri = uri.into();
validate_redirect_uri(&uri, None)?;
Ok(Self {
status: StatusCode::FOUND,
location: uri,
})
}
pub fn permanent(uri: impl Into<String>) -> Result<Self, RedirectError> {
let uri = uri.into();
validate_redirect_uri(&uri, None)?;
Ok(Self {
status: StatusCode::MOVED_PERMANENTLY,
location: uri,
})
}
pub fn temporary(uri: impl Into<String>) -> Result<Self, RedirectError> {
let uri = uri.into();
validate_redirect_uri(&uri, None)?;
Ok(Self {
status: StatusCode::TEMPORARY_REDIRECT,
location: uri,
})
}
pub fn to_with_allowed_hosts(
uri: impl Into<String>,
allowed_hosts: &[&str],
) -> Result<Self, RedirectError> {
let uri = uri.into();
validate_redirect_uri(&uri, Some(allowed_hosts))?;
Ok(Self {
status: StatusCode::FOUND,
location: uri,
})
}
#[must_use]
pub fn external_unchecked(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::FOUND,
location: uri.into(),
}
}
#[must_use]
pub fn external_unchecked_permanent(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::MOVED_PERMANENTLY,
location: uri.into(),
}
}
#[must_use]
pub fn external_unchecked_temporary(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::TEMPORARY_REDIRECT,
location: uri.into(),
}
}
}
impl IntoResponse for Redirect {
fn into_response(self) -> Response {
let location = self
.location
.bytes()
.filter(|&b| (0x21..=0x7E).contains(&b))
.map(|b| b as char)
.collect::<String>();
Response::empty(self.status).header("location", location)
}
}
fn sanitize_header_value(value: String) -> String {
if value.bytes().all(is_valid_header_value_byte) {
return value;
}
let bytes: Vec<u8> = value
.bytes()
.filter(|&b| is_valid_header_value_byte(b))
.collect();
String::from_utf8(bytes)
.expect("filter only drops ASCII control bytes that are not UTF-8 leads/conts")
}
#[inline]
const fn is_valid_header_value_byte(b: u8) -> bool {
b == 0x09 || (b >= 0x20 && b <= 0x7E) || b >= 0x80
}
fn sanitize_header_name(name: String) -> String {
name.bytes()
.filter(|&b| {
matches!(b,
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' |
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' |
b'*' | b'+' | b'-' | b'.' | b'^' | b'_' |
b'`' | b'|' | b'~'
)
})
.map(|b| b as char)
.collect()
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
#[test]
fn status_code_into_response() {
let resp = StatusCode::NOT_FOUND.into_response();
assert_eq!(resp.status, StatusCode::NOT_FOUND);
assert!(resp.body.is_empty());
}
#[test]
fn string_into_response() {
let resp = "hello".into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"text/plain; charset=utf-8"
);
}
#[test]
fn json_into_response() {
let resp = Json(serde_json::json!({"ok": true})).into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"application/json"
);
assert!(!resp.body.is_empty());
}
#[test]
fn html_into_response() {
let resp = Html("<h1>Hello</h1>").into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"text/html; charset=utf-8"
);
}
#[test]
fn redirect_into_response() {
let resp = Redirect::to("/login")
.expect("relative path must validate")
.into_response();
assert_eq!(resp.status, StatusCode::FOUND);
assert_eq!(resp.headers.get("location").unwrap(), "/login");
}
#[test]
fn redirect_to_rejects_external_uri_by_default() {
let err = Redirect::to("https://attacker.com/phish").unwrap_err();
assert!(
matches!(err, RedirectError::HostNotAllowed { .. }),
"external https URL must be rejected, got {err:?}"
);
let err = Redirect::to("http://attacker.com").unwrap_err();
assert!(matches!(err, RedirectError::HostNotAllowed { .. }));
assert!(Redirect::permanent("https://attacker.com").is_err());
assert!(Redirect::temporary("https://attacker.com").is_err());
}
#[test]
fn redirect_to_rejects_protocol_relative_url() {
let err = Redirect::to("//attacker.com/phish").unwrap_err();
assert!(
matches!(err, RedirectError::ProtocolRelative),
"//... URL must be rejected as ProtocolRelative, got {err:?}"
);
}
#[test]
fn redirect_to_rejects_backslash_path() {
let err = Redirect::to("/\\attacker.com/phish").unwrap_err();
assert!(
matches!(err, RedirectError::BackslashInPath),
"backslash in path must be rejected, got {err:?}"
);
}
#[test]
fn redirect_to_rejects_non_http_schemes() {
for uri in &[
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
"file:///etc/passwd",
"ftp://attacker.com/",
] {
let err = Redirect::to(*uri).unwrap_err();
assert!(
matches!(err, RedirectError::SchemeNotAllowed { .. }),
"{uri} must be rejected as SchemeNotAllowed, got {err:?}"
);
}
}
#[test]
fn redirect_to_rejects_empty_uri() {
let err = Redirect::to("").unwrap_err();
assert!(matches!(err, RedirectError::EmptyUri));
}
#[test]
fn redirect_to_accepts_well_formed_relative_paths() {
for uri in &[
"/",
"/login",
"/path/with/multiple/segments",
"/path?with=query",
"/path#fragment",
"/path?next=/another",
] {
assert!(
Redirect::to(*uri).is_ok(),
"relative path {uri} must validate"
);
}
}
#[test]
fn redirect_to_with_allowed_hosts_accepts_listed_rejects_others() {
let allowed = &["example.com", "auth.example.com"];
assert!(Redirect::to_with_allowed_hosts("https://example.com/path", allowed).is_ok());
assert!(
Redirect::to_with_allowed_hosts(
"https://auth.example.com/oauth/callback?code=xyz",
allowed
)
.is_ok()
);
assert!(Redirect::to_with_allowed_hosts("HTTPS://EXAMPLE.COM/", allowed).is_ok());
assert!(Redirect::to_with_allowed_hosts("/local-path", allowed).is_ok());
let err =
Redirect::to_with_allowed_hosts("https://attacker.com/phish", allowed).unwrap_err();
assert!(matches!(err, RedirectError::HostNotAllowed { .. }));
let err =
Redirect::to_with_allowed_hosts("https://evil.example.com/", allowed).unwrap_err();
assert!(matches!(err, RedirectError::HostNotAllowed { .. }));
let err = Redirect::to_with_allowed_hosts("//example.com/path", allowed).unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative));
}
#[test]
fn redirect_external_unchecked_accepts_arbitrary_uri() {
let r = Redirect::external_unchecked("https://anywhere.example/path?q=1");
assert_eq!(r.status, StatusCode::FOUND);
assert_eq!(r.location, "https://anywhere.example/path?q=1");
let r = Redirect::external_unchecked_permanent("https://moved.example/");
assert_eq!(r.status, StatusCode::MOVED_PERMANENTLY);
let r = Redirect::external_unchecked_temporary("https://temp.example/");
assert_eq!(r.status, StatusCode::TEMPORARY_REDIRECT);
}
#[test]
fn tuple_status_override() {
let resp = (StatusCode::CREATED, "done").into_response();
assert_eq!(resp.status, StatusCode::CREATED);
}
#[test]
fn response_header_helpers_are_case_insensitive() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("Content-Type".to_string(), "text/plain".to_string());
assert_eq!(resp.header_value("content-type"), Some("text/plain"));
assert_eq!(resp.header_value("CONTENT-TYPE"), Some("text/plain"));
assert!(resp.has_header("content-type"));
}
#[test]
fn response_set_header_canonicalizes_existing_case_variant() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("X-Trace-Id".to_string(), "old".to_string());
resp.set_header("x-trace-id", "new");
assert_eq!(resp.headers.get("x-trace-id"), Some(&"new".to_string()));
assert!(!resp.headers.contains_key("X-Trace-Id"));
}
#[test]
fn response_ensure_header_preserves_existing_value_and_canonicalizes_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("Server".to_string(), "custom".to_string());
resp.ensure_header("server", "fallback");
assert_eq!(resp.headers.get("server"), Some(&"custom".to_string()));
assert!(!resp.headers.contains_key("Server"));
}
#[test]
fn response_remove_header_clears_case_variants() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers.insert("Server".to_string(), "one".to_string());
resp.headers.insert("server".to_string(), "two".to_string());
let removed = resp.remove_header("SERVER");
assert_eq!(removed.as_deref(), Some("two"));
assert!(!resp.has_header("server"));
assert!(resp.headers.is_empty());
}
#[test]
fn result_ok_response() {
let resp: Result<&str, StatusCode> = Ok("success");
let r = resp.into_response();
assert_eq!(r.status, StatusCode::OK);
}
#[test]
fn result_err_response() {
let resp: Result<&str, StatusCode> = Err(StatusCode::BAD_REQUEST);
let r = resp.into_response();
assert_eq!(r.status, StatusCode::BAD_REQUEST);
}
#[test]
fn status_code_properties() {
assert!(StatusCode::OK.is_success());
assert!(!StatusCode::OK.is_client_error());
assert!(StatusCode::NOT_FOUND.is_client_error());
assert!(StatusCode::INTERNAL_SERVER_ERROR.is_server_error());
}
#[test]
fn status_code_debug_clone_copy_hash_display() {
use std::collections::HashSet;
let sc = StatusCode::OK;
let dbg = format!("{sc:?}");
assert!(dbg.contains("StatusCode"), "{dbg}");
assert!(dbg.contains("200"), "{dbg}");
let copied = sc;
let cloned = sc;
assert_eq!(copied, cloned);
let display = format!("{sc}");
assert_eq!(display, "200");
let mut set = HashSet::new();
set.insert(sc);
assert!(set.contains(&StatusCode::OK));
}
#[test]
fn response_debug_clone() {
let resp = Response::new(StatusCode::OK, Bytes::from_static(b"hi"));
let dbg = format!("{resp:?}");
assert!(dbg.contains("Response"), "{dbg}");
let cloned = resp;
assert_eq!(cloned.status, StatusCode::OK);
}
#[test]
fn redirect_debug_clone() {
let r = Redirect::to("/home").expect("relative path must validate");
let dbg = format!("{r:?}");
assert!(dbg.contains("Redirect"), "{dbg}");
let cloned = r;
let dbg2 = format!("{cloned:?}");
assert_eq!(dbg, dbg2);
}
#[test]
fn set_header_strips_crlf_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "value\r\nEvil-Header: injected");
assert_eq!(
resp.headers.get("x-test").unwrap(),
"valueEvil-Header: injected"
);
}
#[test]
fn set_header_strips_bare_lf_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "line1\nline2");
assert_eq!(resp.headers.get("x-test").unwrap(), "line1line2");
}
#[test]
fn set_header_strips_bare_cr_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "line1\rline2");
assert_eq!(resp.headers.get("x-test").unwrap(), "line1line2");
}
#[test]
fn builder_header_strips_crlf() {
let resp = Response::empty(StatusCode::OK).header("x-test", "safe\r\nX-Injected: oops");
assert_eq!(resp.headers.get("x-test").unwrap(), "safeX-Injected: oops");
}
#[test]
fn ensure_header_strips_crlf_from_default() {
let mut resp = Response::empty(StatusCode::OK);
resp.ensure_header("x-test", "default\r\nEvil: yes");
assert_eq!(resp.headers.get("x-test").unwrap(), "defaultEvil: yes");
}
#[test]
fn tuple_headers_strip_crlf() {
let resp = (
StatusCode::OK,
vec![("x-test".to_string(), "a\r\nb".to_string())],
"body",
)
.into_response();
assert_eq!(resp.headers.get("x-test").unwrap(), "ab");
}
#[test]
fn set_header_strips_crlf_from_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test\r\nEvil-Header: injected", "value");
assert!(resp.headers.contains_key("x-testevil-header: injected"));
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn ensure_header_strips_crlf_from_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.ensure_header("x-test\r\nEvil:", "value");
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn tuple_headers_strip_crlf_from_name() {
let resp = (
StatusCode::OK,
vec![("x-test\r\nEvil:".to_string(), "value".to_string())],
"body",
)
.into_response();
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn clean_header_value_passes_through_unchanged() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "normal-value");
assert_eq!(resp.headers.get("x-test").unwrap(), "normal-value");
}
#[test]
fn set_cookie_appends_instead_of_overwriting() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("set-cookie", "csrf=abc123; HttpOnly");
resp.set_header("set-cookie", "session=def456; HttpOnly; Secure");
assert_eq!(resp.set_cookies.len(), 2, "both cookies must survive");
assert_eq!(resp.set_cookies[0], "csrf=abc123; HttpOnly");
assert_eq!(resp.set_cookies[1], "session=def456; HttpOnly; Secure");
assert!(!resp.headers.contains_key("set-cookie"));
assert_eq!(
resp.header_value("Set-Cookie"),
Some("csrf=abc123; HttpOnly"),
);
assert!(resp.has_header("set-cookie"));
}
#[test]
fn append_set_cookie_strips_crlf_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.append_set_cookie("session=abc\r\nX-Injected: yes");
assert_eq!(resp.set_cookies.len(), 1);
assert!(!resp.set_cookies[0].contains('\r'));
assert!(!resp.set_cookies[0].contains('\n'));
}
#[test]
fn remove_set_cookie_drains_all_queued_cookies() {
let mut resp = Response::empty(StatusCode::OK);
resp.append_set_cookie("a=1");
resp.append_set_cookie("b=2");
let dropped = resp.remove_header("Set-Cookie");
assert_eq!(dropped.as_deref(), Some("a=1"));
assert!(resp.set_cookies.is_empty(), "no cookies should remain");
}
#[test]
fn json_html_debug_clone() {
let j = Json(42);
let dbg = format!("{j:?}");
assert!(dbg.contains("Json"), "{dbg}");
let jc = j;
assert_eq!(format!("{jc:?}"), dbg);
let h = Html("hello");
let dbg2 = format!("{h:?}");
assert!(dbg2.contains("Html"), "{dbg2}");
let hc = h.clone();
assert_eq!(format!("{hc:?}"), dbg2);
}
#[test]
fn _5jtjo0_strips_nul_byte_from_header_value() {
let raw = String::from("alice\u{0000}forged-header: value");
let cleaned = sanitize_header_value(raw);
assert!(!cleaned.contains('\u{0000}'));
assert_eq!(cleaned, "aliceforged-header: value");
}
#[test]
fn _5jtjo0_strips_c0_control_bytes() {
let raw: String = (0x01u8..=0x1F)
.filter(|b| *b != 0x09) .map(|b| b as char)
.collect::<String>()
+ "trailing";
let cleaned = sanitize_header_value(raw);
assert_eq!(cleaned, "trailing");
}
#[test]
fn _5jtjo0_preserves_htab_space_printable_ascii() {
let raw = String::from("\tHello, World! 123 -_+=()[];,./?\\:");
let cleaned = sanitize_header_value(raw.clone());
assert_eq!(cleaned, raw);
}
#[test]
fn _5jtjo0_preserves_obs_text_utf8_passthrough() {
let raw = String::from("café résumé日本語");
let cleaned = sanitize_header_value(raw.clone());
assert_eq!(cleaned, raw);
}
#[test]
fn _5jtjo0_strips_crlf_legacy_behavior_preserved() {
let raw = String::from("first\r\nforged-header: bad");
let cleaned = sanitize_header_value(raw);
assert_eq!(cleaned, "firstforged-header: bad");
}
#[test]
fn _5jtjo0_strips_del_byte() {
let raw = String::from("hello\u{007F}world");
let cleaned = sanitize_header_value(raw);
assert_eq!(cleaned, "helloworld");
}
#[test]
fn n5b94b_redirect_sanitization_matches_validation_strictness() {
let redirect = Redirect::external_unchecked("http://example.com/path\x01\x1F");
let response = redirect.into_response();
let location = response.headers.get("location").unwrap();
assert!(!location.contains('\x01'));
assert!(!location.contains('\x1F'));
assert_eq!(location, "http://example.com/path");
}
#[test]
fn n5b94b_header_name_sanitization_consistency() {
let mut resp = Response::new(StatusCode::OK, "test");
resp.set_header("x-test\r\n-header\x01", "value");
let headers: Vec<_> = resp.headers.keys().collect();
assert_eq!(headers.len(), 1);
assert_eq!(headers[0], "x-test-header");
}
#[test]
fn n5b94b_header_case_normalization_atomic() {
let mut resp = Response::new(StatusCode::OK, "test");
resp.headers
.insert("X-Test".to_string(), "value1".to_string());
resp.headers
.insert("x-TEST".to_string(), "value2".to_string());
resp.headers
.insert("X-test".to_string(), "value3".to_string());
resp.set_header("x-test", "final");
let test_headers: Vec<_> = resp
.headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case("x-test"))
.collect();
assert_eq!(
test_headers.len(),
1,
"All case variants should be removed atomically"
);
assert_eq!(test_headers[0].0, "x-test");
assert_eq!(test_headers[0].1, "final");
}
#[test]
fn n5b94b_ensure_header_atomic_check_and_set() {
let mut resp = Response::new(StatusCode::OK, "test");
resp.headers
.insert("X-Custom".to_string(), "existing".to_string());
resp.ensure_header("x-custom", "default");
let custom_headers: Vec<_> = resp
.headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case("x-custom"))
.collect();
assert_eq!(
custom_headers.len(),
1,
"Should be exactly one header after ensure"
);
assert_eq!(custom_headers[0].0, "x-custom"); assert_eq!(custom_headers[0].1, "existing"); }
#[test]
fn oms1b7_rejects_protocol_relative() {
let err = Redirect::to("//attacker.com/path").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative));
}
#[test]
fn oms1b7_rejects_leading_whitespace_then_protocol_relative() {
let err = Redirect::to(" //attacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
#[test]
fn oms1b7_rejects_leading_tab_then_protocol_relative() {
let err = Redirect::to("\t//attacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
#[test]
fn oms1b7_rejects_leading_crlf() {
let err = Redirect::to("\r\n//attacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
#[test]
fn oms1b7_rejects_percent_encoded_double_slash() {
let err = Redirect::to("/%2fattacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
let err = Redirect::to("/%2Fattacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
#[test]
fn oms1b7_rejects_percent_encoded_backslash_after_slash() {
let err = Redirect::to("/%5cattacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
#[test]
fn oms1b7_accepts_legitimate_relative_paths() {
assert!(Redirect::to("/login").is_ok());
assert!(Redirect::to("/api/v1/foo?x=1&y=2").is_ok());
assert!(Redirect::to("/path#anchor").is_ok());
}
#[test]
fn oms1b7_rejects_null_byte_in_uri() {
let err = Redirect::to("/safe\u{0000}//attacker.com").unwrap_err();
assert!(matches!(err, RedirectError::ProtocolRelative), "{err:?}");
}
}