use std::{
collections::HashMap,
convert::TryInto,
time::{SystemTime, UNIX_EPOCH},
};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode, header};
use serde::Serialize;
use crate::utils::http::http_date;
fn default_charset() -> String {
"utf-8".to_string()
}
fn build_content_type_value(content_type: &str, charset: &str) -> header::HeaderValue {
let value = if charset.is_empty() {
content_type.to_string()
} else {
format!("{content_type}; charset={charset}")
};
header::HeaderValue::from_str(&value).expect("valid content-type header")
}
fn unix_epoch_seconds() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status_code: StatusCode,
pub headers: HeaderMap,
pub content: Bytes,
pub content_type: String,
pub charset: String,
}
impl HttpResponse {
#[must_use]
pub fn new(content: impl Into<Bytes>) -> Self {
Self::with_status(StatusCode::OK, content)
}
#[must_use]
pub fn with_status(status: StatusCode, content: impl Into<Bytes>) -> Self {
let content_type = "text/html".to_string();
let charset = default_charset();
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
build_content_type_value(&content_type, &charset),
);
Self {
status_code: status,
headers,
content: content.into(),
content_type,
charset,
}
}
pub fn set_header<K, V>(&mut self, key: K, value: V)
where
K: TryInto<header::HeaderName>,
K::Error: std::fmt::Debug,
V: TryInto<header::HeaderValue>,
V::Error: std::fmt::Debug,
{
let key = key.try_into().expect("valid header name");
let value = value.try_into().expect("valid header value");
if key == header::CONTENT_TYPE {
self.update_content_type_from_header(&value);
}
self.headers.insert(key, value);
}
pub fn set_cookie(&mut self, name: &str, value: &str, max_age: Option<i64>) {
let mut cookie = format!("{name}={value}; Path=/");
if let Some(max_age) = max_age {
cookie.push_str(&format!("; Max-Age={max_age}"));
let expires = if max_age <= 0 {
"Thu, 01 Jan 1970 00:00:00 GMT".to_string()
} else {
http_date(unix_epoch_seconds() + max_age as f64)
};
cookie.push_str("; Expires=");
cookie.push_str(&expires);
}
self.headers.append(
header::SET_COOKIE,
header::HeaderValue::from_str(&cookie).expect("valid set-cookie header"),
);
}
pub fn delete_cookie(&mut self, name: &str) {
let cookie = format!("{name}=; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path=/");
self.headers.append(
header::SET_COOKIE,
header::HeaderValue::from_str(&cookie).expect("valid set-cookie header"),
);
}
fn update_content_type_from_header(&mut self, value: &header::HeaderValue) {
let Ok(value) = value.to_str() else {
return;
};
let mut parts = value.split(';').map(str::trim);
if let Some(content_type) = parts.next() {
self.content_type = content_type.to_string();
}
if let Some(charset) = parts.find_map(|part| {
part.strip_prefix("charset=")
.map(|charset| charset.trim_matches('"'))
}) {
self.charset = charset.to_string();
}
}
}
impl IntoResponse for HttpResponse {
fn into_response(self) -> Response {
let mut response = self.content.into_response();
*response.status_mut() = self.status_code;
*response.headers_mut() = self.headers;
if !response.headers().contains_key(header::CONTENT_TYPE) {
response.headers_mut().insert(
header::CONTENT_TYPE,
build_content_type_value(&self.content_type, &self.charset),
);
}
response
}
}
#[derive(Debug)]
pub struct StreamingHttpResponse {
pub status_code: u16,
pub content_type: String,
pub headers: HashMap<String, String>,
pub chunks: Vec<Vec<u8>>,
}
impl StreamingHttpResponse {
#[must_use]
pub fn new(chunks: Vec<Vec<u8>>) -> Self {
Self {
status_code: StatusCode::OK.as_u16(),
content_type: "text/html; charset=utf-8".to_string(),
headers: HashMap::new(),
chunks,
}
}
#[must_use]
pub fn with_content_type(mut self, ct: &str) -> Self {
self.content_type = ct.to_string();
self
}
#[must_use]
pub fn total_size(&self) -> usize {
self.chunks.iter().map(Vec::len).sum()
}
}
#[derive(Debug)]
pub struct FileResponse {
pub content: Vec<u8>,
pub filename: String,
pub content_type: String,
pub as_attachment: bool,
}
impl FileResponse {
#[must_use]
pub fn new(content: Vec<u8>, filename: &str) -> Self {
Self {
content,
filename: filename.to_string(),
content_type: "application/octet-stream".to_string(),
as_attachment: false,
}
}
#[must_use]
pub fn as_attachment(mut self) -> Self {
self.as_attachment = true;
self
}
#[must_use]
pub fn content_disposition(&self) -> String {
if self.as_attachment {
format!("attachment; filename=\"{}\"", self.filename)
} else {
format!("inline; filename=\"{}\"", self.filename)
}
}
}
pub struct JsonResponse;
impl JsonResponse {
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new(data: impl Serialize) -> HttpResponse {
let payload = serde_json::to_vec(&data).expect("json serialization should succeed");
let mut response = HttpResponse::new(payload);
response.content_type = "application/json".to_string();
response.headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
response
}
}
#[must_use]
pub fn redirect(url: &str) -> HttpResponse {
let mut response = HttpResponse::with_status(StatusCode::FOUND, Bytes::new());
response.set_header(header::LOCATION, url);
response
}
#[must_use]
pub fn not_found() -> HttpResponse {
HttpResponse::with_status(StatusCode::NOT_FOUND, Bytes::new())
}
#[must_use]
pub fn server_error() -> HttpResponse {
HttpResponse::with_status(StatusCode::INTERNAL_SERVER_ERROR, Bytes::new())
}
#[must_use]
pub fn bad_request() -> HttpResponse {
HttpResponse::with_status(StatusCode::BAD_REQUEST, Bytes::new())
}
#[must_use]
pub fn forbidden() -> HttpResponse {
HttpResponse::with_status(StatusCode::FORBIDDEN, Bytes::new())
}
#[must_use]
pub fn method_not_allowed(permitted: &[Method]) -> HttpResponse {
let mut response = HttpResponse::with_status(StatusCode::METHOD_NOT_ALLOWED, Bytes::new());
let allow = permitted
.iter()
.map(Method::as_str)
.collect::<Vec<_>>()
.join(", ");
response.set_header(header::ALLOW, allow);
response
}
#[derive(Debug, Clone)]
pub struct HttpResponseRedirect(pub HttpResponse);
impl HttpResponseRedirect {
#[must_use]
pub fn new(url: &str) -> Self {
Self(redirect(url))
}
}
impl IntoResponse for HttpResponseRedirect {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Debug, Clone)]
pub struct HttpResponseBadRequest(pub HttpResponse);
impl HttpResponseBadRequest {
#[must_use]
pub fn new(content: impl Into<Bytes>) -> Self {
Self(HttpResponse::with_status(StatusCode::BAD_REQUEST, content))
}
}
impl IntoResponse for HttpResponseBadRequest {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Debug, Clone)]
pub struct HttpResponseNotFound(pub HttpResponse);
impl HttpResponseNotFound {
#[must_use]
pub fn new(content: impl Into<Bytes>) -> Self {
Self(HttpResponse::with_status(StatusCode::NOT_FOUND, content))
}
}
impl IntoResponse for HttpResponseNotFound {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Debug, Clone)]
pub struct HttpResponseForbidden(pub HttpResponse);
impl HttpResponseForbidden {
#[must_use]
pub fn new(content: impl Into<Bytes>) -> Self {
Self(HttpResponse::with_status(StatusCode::FORBIDDEN, content))
}
}
impl IntoResponse for HttpResponseForbidden {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Debug, Clone)]
pub struct HttpResponseNotAllowed(pub HttpResponse);
impl HttpResponseNotAllowed {
#[must_use]
pub fn new(permitted: &[Method]) -> Self {
Self(method_not_allowed(permitted))
}
}
impl IntoResponse for HttpResponseNotAllowed {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Debug, Clone)]
pub struct HttpResponseServerError(pub HttpResponse);
impl HttpResponseServerError {
#[must_use]
pub fn new(content: impl Into<Bytes>) -> Self {
Self(HttpResponse::with_status(
StatusCode::INTERNAL_SERVER_ERROR,
content,
))
}
}
impl IntoResponse for HttpResponseServerError {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[cfg(test)]
mod tests {
use axum::body::to_bytes;
use serde_json::{Value, json};
use super::*;
fn into_parts(response: impl IntoResponse) -> (StatusCode, HeaderMap, Bytes) {
crate::runtime::block_on(async {
let response = response.into_response();
let status = response.status();
let headers = response.headers().clone();
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("body should be readable");
(status, headers, body)
})
}
#[test]
fn http_response_new_sets_defaults() {
let (status, headers, body) = into_parts(HttpResponse::new("hello"));
assert_eq!(status, StatusCode::OK);
assert_eq!(
headers
.get(header::CONTENT_TYPE)
.expect("content-type header")
.to_str()
.expect("utf-8 header value"),
"text/html; charset=utf-8"
);
assert_eq!(body, Bytes::from_static(b"hello"));
}
#[test]
fn http_response_with_status_and_set_header_work() {
let mut response = HttpResponse::with_status(StatusCode::CREATED, "created");
response.set_header("x-request-id", "abc123");
response.set_header(header::CONTENT_TYPE, "text/plain; charset=utf-8");
let (status, headers, body) = into_parts(response);
assert_eq!(status, StatusCode::CREATED);
assert_eq!(
headers
.get("x-request-id")
.expect("x-request-id header")
.to_str()
.expect("utf-8 header value"),
"abc123"
);
assert_eq!(
headers
.get(header::CONTENT_TYPE)
.expect("content-type header")
.to_str()
.expect("utf-8 header value"),
"text/plain; charset=utf-8"
);
assert_eq!(body, Bytes::from_static(b"created"));
}
#[test]
fn cookie_helpers_append_set_cookie_headers() {
let mut response = HttpResponse::new(Bytes::new());
response.set_cookie("sessionid", "abc", Some(60));
response.delete_cookie("sessionid");
let cookies = response
.headers
.get_all(header::SET_COOKIE)
.iter()
.map(|value| value.to_str().expect("utf-8 header value").to_string())
.collect::<Vec<_>>();
assert_eq!(cookies.len(), 2);
assert!(cookies[0].starts_with("sessionid=abc; Path=/; Max-Age=60; Expires="));
assert_eq!(
cookies[1],
"sessionid=; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path=/"
);
}
#[test]
fn json_response_serializes_data() {
let (status, headers, body) = into_parts(JsonResponse::new(json!({"ok": true})));
assert_eq!(status, StatusCode::OK);
assert_eq!(
headers
.get(header::CONTENT_TYPE)
.expect("content-type header")
.to_str()
.expect("utf-8 header value"),
"application/json"
);
assert_eq!(
serde_json::from_slice::<Value>(&body).expect("valid json body"),
json!({"ok": true})
);
}
#[test]
fn redirect_helper_and_wrapper_set_location() {
let (status, headers, body) = into_parts(redirect("/login"));
assert_eq!(status, StatusCode::FOUND);
assert_eq!(
headers
.get(header::LOCATION)
.expect("location header")
.to_str()
.expect("utf-8 header value"),
"/login"
);
assert!(body.is_empty());
let (status, headers, _) = into_parts(HttpResponseRedirect::new("/admin"));
assert_eq!(status, StatusCode::FOUND);
assert_eq!(
headers
.get(header::LOCATION)
.expect("location header")
.to_str()
.expect("utf-8 header value"),
"/admin"
);
}
#[test]
fn status_helpers_and_wrappers_preserve_status_codes() {
let cases = [
(not_found(), StatusCode::NOT_FOUND),
(bad_request(), StatusCode::BAD_REQUEST),
(forbidden(), StatusCode::FORBIDDEN),
(server_error(), StatusCode::INTERNAL_SERVER_ERROR),
];
for (response, expected) in cases {
let (status, _, body) = into_parts(response);
assert_eq!(status, expected);
assert!(body.is_empty());
}
assert_eq!(
into_parts(HttpResponseNotFound::new("missing")).0,
StatusCode::NOT_FOUND
);
assert_eq!(
into_parts(HttpResponseBadRequest::new("bad")).0,
StatusCode::BAD_REQUEST
);
assert_eq!(
into_parts(HttpResponseForbidden::new("nope")).0,
StatusCode::FORBIDDEN
);
assert_eq!(
into_parts(HttpResponseServerError::new("boom")).0,
StatusCode::INTERNAL_SERVER_ERROR
);
}
#[test]
fn method_not_allowed_helper_and_wrapper_set_allow_header() {
let permitted = [Method::GET, Method::POST];
let (status, headers, body) = into_parts(method_not_allowed(&permitted));
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(
headers
.get(header::ALLOW)
.expect("allow header")
.to_str()
.expect("utf-8 header value"),
"GET, POST"
);
assert!(body.is_empty());
let (status, headers, _) = into_parts(HttpResponseNotAllowed::new(&permitted));
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(
headers
.get(header::ALLOW)
.expect("allow header")
.to_str()
.expect("utf-8 header value"),
"GET, POST"
);
}
#[test]
fn streaming_http_response_defaults_to_ok_html_and_empty_headers() {
let response = StreamingHttpResponse::new(vec![b"hello".to_vec(), b"world".to_vec()]);
assert_eq!(response.status_code, StatusCode::OK.as_u16());
assert_eq!(response.content_type, "text/html; charset=utf-8");
assert!(response.headers.is_empty());
assert_eq!(response.chunks.len(), 2);
}
#[test]
fn streaming_http_response_tracks_total_size_and_content_type() {
let response = StreamingHttpResponse::new(vec![b"chunk".to_vec(), b"ed".to_vec()])
.with_content_type("text/plain");
assert_eq!(response.content_type, "text/plain");
assert_eq!(response.total_size(), 7);
}
#[test]
fn file_response_defaults_to_inline_disposition() {
let response = FileResponse::new(b"report".to_vec(), "report.txt");
assert_eq!(response.content, b"report".to_vec());
assert_eq!(response.filename, "report.txt");
assert_eq!(response.content_type, "application/octet-stream");
assert!(!response.as_attachment);
assert_eq!(
response.content_disposition(),
"inline; filename=\"report.txt\""
);
}
#[test]
fn file_response_uses_attachment_disposition_when_requested() {
let response = FileResponse::new(Vec::new(), "archive.tar.gz").as_attachment();
assert!(response.as_attachment);
assert_eq!(
response.content_disposition(),
"attachment; filename=\"archive.tar.gz\""
);
}
}