use std::collections::HashMap;
use std::fmt;
use crate::bytes::Bytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StatusCode(u16);
impl StatusCode {
pub const CONTINUE: Self = Self(100);
pub const SWITCHING_PROTOCOLS: Self = Self(101);
pub const OK: Self = Self(200);
pub const CREATED: Self = Self(201);
pub const ACCEPTED: Self = Self(202);
pub const NO_CONTENT: Self = Self(204);
pub const MOVED_PERMANENTLY: Self = Self(301);
pub const FOUND: Self = Self(302);
pub const SEE_OTHER: Self = Self(303);
pub const NOT_MODIFIED: Self = Self(304);
pub const TEMPORARY_REDIRECT: Self = Self(307);
pub const PERMANENT_REDIRECT: Self = Self(308);
pub const BAD_REQUEST: Self = Self(400);
pub const UNAUTHORIZED: Self = Self(401);
pub const FORBIDDEN: Self = Self(403);
pub const NOT_FOUND: Self = Self(404);
pub const METHOD_NOT_ALLOWED: Self = Self(405);
pub const CONFLICT: Self = Self(409);
pub const PAYLOAD_TOO_LARGE: Self = Self(413);
pub const UNSUPPORTED_MEDIA_TYPE: Self = Self(415);
pub const UNPROCESSABLE_ENTITY: Self = Self(422);
pub const TOO_MANY_REQUESTS: Self = Self(429);
pub const CLIENT_CLOSED_REQUEST: Self = Self(499);
pub const INTERNAL_SERVER_ERROR: Self = Self(500);
pub const NOT_IMPLEMENTED: Self = Self(501);
pub const BAD_GATEWAY: Self = Self(502);
pub const SERVICE_UNAVAILABLE: Self = Self(503);
pub const GATEWAY_TIMEOUT: Self = Self(504);
#[must_use]
pub const fn from_u16(code: u16) -> Self {
Self(code)
}
#[must_use]
pub const fn as_u16(self) -> u16 {
self.0
}
#[must_use]
pub const fn is_success(self) -> bool {
self.0 >= 200 && self.0 < 300
}
#[must_use]
pub const fn is_client_error(self) -> bool {
self.0 >= 400 && self.0 < 500
}
#[must_use]
pub const fn is_server_error(self) -> bool {
self.0 >= 500 && self.0 < 600
}
}
impl fmt::Display for StatusCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: StatusCode,
pub headers: HashMap<String, String>,
pub body: Bytes,
}
impl Response {
#[must_use]
pub fn new(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
headers: HashMap::with_capacity(4),
body: body.into(),
}
}
#[must_use]
pub fn empty(status: StatusCode) -> Self {
Self::new(status, Bytes::new())
}
#[must_use]
pub fn header_value(&self, name: &str) -> Option<&str> {
if let Some(value) = self.headers.get(name) {
return Some(value.as_str());
}
self.headers
.iter()
.filter(|(key, _)| key.eq_ignore_ascii_case(name))
.min_by(|(a, _), (b, _)| a.cmp(b))
.map(|(_, value)| value.as_str())
}
#[must_use]
pub fn has_header(&self, name: &str) -> bool {
self.header_value(name).is_some()
}
pub fn set_header(&mut self, name: impl Into<String>, value: impl Into<String>) {
let normalized = sanitize_header_name(name.into()).to_ascii_lowercase();
let stale_keys: Vec<String> = self
.headers
.keys()
.filter(|key| key.eq_ignore_ascii_case(&normalized) && *key != &normalized)
.cloned()
.collect();
for key in stale_keys {
self.headers.remove(&key);
}
self.headers
.insert(normalized, sanitize_header_value(value.into()));
}
pub fn ensure_header(&mut self, name: &str, default_value: impl Into<String>) {
let normalized = sanitize_header_name(name.to_owned()).to_ascii_lowercase();
if let Some(existing) = self.remove_header(name) {
self.headers
.insert(normalized, sanitize_header_value(existing));
} else {
self.headers
.insert(normalized, sanitize_header_value(default_value.into()));
}
}
pub fn remove_header(&mut self, name: &str) -> Option<String> {
let normalized = name.to_ascii_lowercase();
let mut matching_keys: Vec<String> = self
.headers
.keys()
.filter(|key| key.eq_ignore_ascii_case(name))
.cloned()
.collect();
matching_keys.sort_by(|left, right| {
(left != &normalized, left.as_str()).cmp(&(right != &normalized, right.as_str()))
});
let mut removed = None;
for key in matching_keys {
if let Some(value) = self.headers.remove(&key) {
removed.get_or_insert(value);
}
}
removed
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.set_header(name, value);
self
}
}
pub trait IntoResponse {
fn into_response(self) -> Response;
}
impl IntoResponse for Response {
fn into_response(self) -> Response {
self
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> Response {
Response::empty(self)
}
}
impl IntoResponse for String {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from(self))
.header("content-type", "text/plain; charset=utf-8")
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from_static(self.as_bytes()))
.header("content-type", "text/plain; charset=utf-8")
}
}
impl IntoResponse for Bytes {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, self).header("content-type", "application/octet-stream")
}
}
impl IntoResponse for Vec<u8> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from(self))
.header("content-type", "application/octet-stream")
}
}
impl IntoResponse for () {
fn into_response(self) -> Response {
Response::empty(StatusCode::OK)
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
fn into_response(self) -> Response {
let mut resp = self.1.into_response();
resp.status = self.0;
resp
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, Vec<(String, String)>, T) {
fn into_response(self) -> Response {
let mut resp = self.2.into_response();
resp.status = self.0;
for (k, v) in self.1 {
resp.set_header(k, v);
}
resp
}
}
impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
fn into_response(self) -> Response {
match self {
Ok(ok) => ok.into_response(),
Err(err) => err.into_response(),
}
}
}
#[derive(Debug, Clone)]
pub struct Json<T>(pub T);
impl<T: serde::Serialize> IntoResponse for Json<T> {
fn into_response(self) -> Response {
serde_json::to_vec(&self.0).map_or_else(
|_| Response::empty(StatusCode::INTERNAL_SERVER_ERROR),
|body| {
Response::new(StatusCode::OK, Bytes::from(body))
.header("content-type", "application/json")
},
)
}
}
#[derive(Debug, Clone)]
pub struct Html<T>(pub T);
impl IntoResponse for Html<String> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::copy_from_slice(self.0.as_bytes()))
.header("content-type", "text/html; charset=utf-8")
}
}
impl IntoResponse for Html<&'static str> {
fn into_response(self) -> Response {
Response::new(StatusCode::OK, Bytes::from_static(self.0.as_bytes()))
.header("content-type", "text/html; charset=utf-8")
}
}
#[derive(Debug, Clone)]
pub struct Redirect {
status: StatusCode,
location: String,
}
impl Redirect {
#[must_use]
pub fn to(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::FOUND,
location: uri.into(),
}
}
#[must_use]
pub fn permanent(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::MOVED_PERMANENTLY,
location: uri.into(),
}
}
#[must_use]
pub fn temporary(uri: impl Into<String>) -> Self {
Self {
status: StatusCode::TEMPORARY_REDIRECT,
location: uri.into(),
}
}
}
impl IntoResponse for Redirect {
fn into_response(self) -> Response {
let location = self.location.replace(['\r', '\n'], "");
Response::empty(self.status).header("location", location)
}
}
fn sanitize_header_value(value: String) -> String {
if value.bytes().any(|b| b == b'\r' || b == b'\n') {
value.replace(['\r', '\n'], "")
} else {
value
}
}
fn sanitize_header_name(name: String) -> String {
if name.bytes().any(|b| b == b'\r' || b == b'\n') {
name.replace(['\r', '\n'], "")
} else {
name
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_code_into_response() {
let resp = StatusCode::NOT_FOUND.into_response();
assert_eq!(resp.status, StatusCode::NOT_FOUND);
assert!(resp.body.is_empty());
}
#[test]
fn string_into_response() {
let resp = "hello".into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"text/plain; charset=utf-8"
);
}
#[test]
fn json_into_response() {
let resp = Json(serde_json::json!({"ok": true})).into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"application/json"
);
assert!(!resp.body.is_empty());
}
#[test]
fn html_into_response() {
let resp = Html("<h1>Hello</h1>").into_response();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-type").unwrap(),
"text/html; charset=utf-8"
);
}
#[test]
fn redirect_into_response() {
let resp = Redirect::to("/login").into_response();
assert_eq!(resp.status, StatusCode::FOUND);
assert_eq!(resp.headers.get("location").unwrap(), "/login");
}
#[test]
fn tuple_status_override() {
let resp = (StatusCode::CREATED, "done").into_response();
assert_eq!(resp.status, StatusCode::CREATED);
}
#[test]
fn response_header_helpers_are_case_insensitive() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("Content-Type".to_string(), "text/plain".to_string());
assert_eq!(resp.header_value("content-type"), Some("text/plain"));
assert_eq!(resp.header_value("CONTENT-TYPE"), Some("text/plain"));
assert!(resp.has_header("content-type"));
}
#[test]
fn response_set_header_canonicalizes_existing_case_variant() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("X-Trace-Id".to_string(), "old".to_string());
resp.set_header("x-trace-id", "new");
assert_eq!(resp.headers.get("x-trace-id"), Some(&"new".to_string()));
assert!(!resp.headers.contains_key("X-Trace-Id"));
}
#[test]
fn response_ensure_header_preserves_existing_value_and_canonicalizes_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers
.insert("Server".to_string(), "custom".to_string());
resp.ensure_header("server", "fallback");
assert_eq!(resp.headers.get("server"), Some(&"custom".to_string()));
assert!(!resp.headers.contains_key("Server"));
}
#[test]
fn response_remove_header_clears_case_variants() {
let mut resp = Response::empty(StatusCode::OK);
resp.headers.insert("Server".to_string(), "one".to_string());
resp.headers.insert("server".to_string(), "two".to_string());
let removed = resp.remove_header("SERVER");
assert_eq!(removed.as_deref(), Some("two"));
assert!(!resp.has_header("server"));
assert!(resp.headers.is_empty());
}
#[test]
fn result_ok_response() {
let resp: Result<&str, StatusCode> = Ok("success");
let r = resp.into_response();
assert_eq!(r.status, StatusCode::OK);
}
#[test]
fn result_err_response() {
let resp: Result<&str, StatusCode> = Err(StatusCode::BAD_REQUEST);
let r = resp.into_response();
assert_eq!(r.status, StatusCode::BAD_REQUEST);
}
#[test]
fn status_code_properties() {
assert!(StatusCode::OK.is_success());
assert!(!StatusCode::OK.is_client_error());
assert!(StatusCode::NOT_FOUND.is_client_error());
assert!(StatusCode::INTERNAL_SERVER_ERROR.is_server_error());
}
#[test]
fn status_code_debug_clone_copy_hash_display() {
use std::collections::HashSet;
let sc = StatusCode::OK;
let dbg = format!("{sc:?}");
assert!(dbg.contains("StatusCode"), "{dbg}");
assert!(dbg.contains("200"), "{dbg}");
let copied = sc;
let cloned = sc;
assert_eq!(copied, cloned);
let display = format!("{sc}");
assert_eq!(display, "200");
let mut set = HashSet::new();
set.insert(sc);
assert!(set.contains(&StatusCode::OK));
}
#[test]
fn response_debug_clone() {
let resp = Response::new(StatusCode::OK, Bytes::from_static(b"hi"));
let dbg = format!("{resp:?}");
assert!(dbg.contains("Response"), "{dbg}");
let cloned = resp;
assert_eq!(cloned.status, StatusCode::OK);
}
#[test]
fn redirect_debug_clone() {
let r = Redirect::to("/home");
let dbg = format!("{r:?}");
assert!(dbg.contains("Redirect"), "{dbg}");
let cloned = r;
let dbg2 = format!("{cloned:?}");
assert_eq!(dbg, dbg2);
}
#[test]
fn set_header_strips_crlf_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "value\r\nEvil-Header: injected");
assert_eq!(
resp.headers.get("x-test").unwrap(),
"valueEvil-Header: injected"
);
}
#[test]
fn set_header_strips_bare_lf_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "line1\nline2");
assert_eq!(resp.headers.get("x-test").unwrap(), "line1line2");
}
#[test]
fn set_header_strips_bare_cr_from_value() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "line1\rline2");
assert_eq!(resp.headers.get("x-test").unwrap(), "line1line2");
}
#[test]
fn builder_header_strips_crlf() {
let resp = Response::empty(StatusCode::OK).header("x-test", "safe\r\nX-Injected: oops");
assert_eq!(resp.headers.get("x-test").unwrap(), "safeX-Injected: oops");
}
#[test]
fn ensure_header_strips_crlf_from_default() {
let mut resp = Response::empty(StatusCode::OK);
resp.ensure_header("x-test", "default\r\nEvil: yes");
assert_eq!(resp.headers.get("x-test").unwrap(), "defaultEvil: yes");
}
#[test]
fn tuple_headers_strip_crlf() {
let resp = (
StatusCode::OK,
vec![("x-test".to_string(), "a\r\nb".to_string())],
"body",
)
.into_response();
assert_eq!(resp.headers.get("x-test").unwrap(), "ab");
}
#[test]
fn set_header_strips_crlf_from_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test\r\nEvil-Header: injected", "value");
assert!(resp.headers.contains_key("x-testevil-header: injected"));
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn ensure_header_strips_crlf_from_name() {
let mut resp = Response::empty(StatusCode::OK);
resp.ensure_header("x-test\r\nEvil:", "value");
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn tuple_headers_strip_crlf_from_name() {
let resp = (
StatusCode::OK,
vec![("x-test\r\nEvil:".to_string(), "value".to_string())],
"body",
)
.into_response();
assert!(
!resp
.headers
.keys()
.any(|k| k.contains('\r') || k.contains('\n'))
);
}
#[test]
fn clean_header_value_passes_through_unchanged() {
let mut resp = Response::empty(StatusCode::OK);
resp.set_header("x-test", "normal-value");
assert_eq!(resp.headers.get("x-test").unwrap(), "normal-value");
}
#[test]
fn json_html_debug_clone() {
let j = Json(42);
let dbg = format!("{j:?}");
assert!(dbg.contains("Json"), "{dbg}");
let jc = j;
assert_eq!(format!("{jc:?}"), dbg);
let h = Html("hello");
let dbg2 = format!("{h:?}");
assert!(dbg2.contains("Html"), "{dbg2}");
let hc = h.clone();
assert_eq!(format!("{hc:?}"), dbg2);
}
}