use bytes::Bytes;
use futures::stream::BoxStream;
use http::{HeaderMap, HeaderValue, StatusCode};
use tokio::io::AsyncReadExt;
pub struct PingoraWebHttpResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Body,
}
impl PingoraWebHttpResponse {
pub fn new(status: StatusCode) -> Self {
Self {
status,
headers: HeaderMap::new(),
body: Body::Bytes(Bytes::new()),
}
}
pub fn text<S: Into<String>>(status: StatusCode, body: S) -> Self {
let mut res = Self::new(status);
res.headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
let bytes = body.into().into_bytes();
res.body = Body::Bytes(Bytes::from(bytes));
res
}
pub fn empty(status: StatusCode) -> Self {
let mut res = Self::new(status);
res.body = Body::Bytes(Bytes::new());
res
}
pub fn html<S: Into<String>>(status: StatusCode, body: S) -> Self {
let mut res = Self::new(status);
res.headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
let bytes = body.into().into_bytes();
res.body = Body::Bytes(Bytes::from(bytes));
res
}
pub fn bytes(status: StatusCode, body: impl Into<Bytes>) -> Self {
let mut res = Self::new(status);
res.body = Body::Bytes(body.into());
res
}
pub fn json(status: StatusCode, value: impl serde::Serialize) -> Self {
let mut res = Self::new(status);
res.headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
match serde_json::to_vec(&value) {
Ok(bytes) => {
res.body = Body::Bytes(Bytes::from(bytes));
res
}
Err(_) => {
res.status = StatusCode::INTERNAL_SERVER_ERROR;
res.body = Body::Bytes(Bytes::new());
res
}
}
}
pub fn stream_file<P: AsRef<std::path::Path>>(status: StatusCode, path: P) -> Self {
let mut res = Self::new(status);
let ct = mime_guess::from_path(path.as_ref()).first_or_octet_stream();
let content_type = {
let s = ct.to_string();
if s.starts_with("text/") {
format!("{}; charset=utf-8", s)
} else {
s
}
};
let _ = res.headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_str(&content_type)
.unwrap_or(HeaderValue::from_static("application/octet-stream")),
);
if let Ok(meta) = std::fs::metadata(path.as_ref()) {
let len_s = meta.len().to_string();
let _ = res.headers.insert(
http::header::CONTENT_LENGTH,
HeaderValue::from_str(&len_s).unwrap_or(HeaderValue::from_static("0")),
);
}
let pathbuf = path.as_ref().to_path_buf();
let stream = futures::stream::unfold(
Some((None::<tokio::fs::File>, pathbuf)),
|state| async move {
let (opt_file, path) = state?;
let mut file = match opt_file {
Some(f) => f,
None => match tokio::fs::File::open(&path).await {
Ok(f) => f,
Err(_) => return None,
},
};
let mut buf = vec![0u8; 64 * 1024];
match file.read(&mut buf).await {
Ok(0) => None,
Ok(n) => {
buf.truncate(n);
Some((Bytes::from(buf), Some((Some(file), path))))
}
Err(_) => None,
}
},
);
res.body = Body::Stream(Box::pin(stream));
res
}
pub fn stream(status: StatusCode, stream: BoxStream<'static, Bytes>) -> Self {
let mut res = Self::new(status);
res.body = Body::Stream(stream);
res
}
pub fn set_header<K, V>(&mut self, k: K, v: V)
where
K: TryInto<http::HeaderName>,
V: TryInto<HeaderValue>,
K::Error: std::fmt::Debug,
V::Error: std::fmt::Debug,
{
if let (Ok(key), Ok(value)) = (k.try_into(), v.try_into()) {
self.headers.insert(key, value);
}
}
pub fn header<K, V>(mut self, k: K, v: V) -> Self
where
K: TryInto<http::HeaderName>,
V: TryInto<HeaderValue>,
K::Error: std::fmt::Debug,
V::Error: std::fmt::Debug,
{
self.set_header(k, v);
self
}
pub fn ok<S: Into<String>>(body: S) -> Self {
Self::text(StatusCode::OK, body)
}
pub fn created(value: impl serde::Serialize) -> Self {
Self::json(StatusCode::CREATED, value)
}
pub fn no_content() -> Self {
Self::empty(StatusCode::NO_CONTENT)
}
pub fn bad_request<S: Into<String>>(message: S) -> Self {
Self::json(
StatusCode::BAD_REQUEST,
serde_json::json!({
"error": "Bad Request",
"message": message.into()
}),
)
}
pub fn unauthorized<S: Into<String>>(message: S) -> Self {
Self::json(
StatusCode::UNAUTHORIZED,
serde_json::json!({
"error": "Unauthorized",
"message": message.into()
}),
)
}
pub fn forbidden<S: Into<String>>(message: S) -> Self {
Self::json(
StatusCode::FORBIDDEN,
serde_json::json!({
"error": "Forbidden",
"message": message.into()
}),
)
}
pub fn not_found<S: Into<String>>(message: S) -> Self {
Self::json(
StatusCode::NOT_FOUND,
serde_json::json!({
"error": "Not Found",
"message": message.into()
}),
)
}
pub fn internal_error<S: Into<String>>(message: S) -> Self {
Self::json(
StatusCode::INTERNAL_SERVER_ERROR,
serde_json::json!({
"error": "Internal Server Error",
"message": message.into()
}),
)
}
pub fn redirect<S: Into<String>>(url: S, permanent: bool) -> Self {
let status = if permanent {
StatusCode::MOVED_PERMANENTLY
} else {
StatusCode::FOUND
};
Self::empty(status).header("Location", url.into())
}
pub fn redirect_to<S: Into<String>>(url: S) -> Self {
Self::redirect(url, false)
}
pub fn redirect_permanent<S: Into<String>>(url: S) -> Self {
Self::redirect(url, true)
}
}
pub enum Body {
Bytes(Bytes),
Stream(BoxStream<'static, Bytes>),
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn json_builds_response() {
let v = json!({"a": 1, "b": "x"});
let res = PingoraWebHttpResponse::json(StatusCode::OK, &v);
assert_eq!(res.status.as_u16(), 200);
assert_eq!(
res.headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
Some("application/json")
);
match res.body {
Body::Bytes(b) => assert_eq!(b.as_ref(), serde_json::to_vec(&v).unwrap().as_slice()),
_ => panic!("expected bytes body"),
}
}
#[test]
fn html_and_empty_and_bytes() {
let res = PingoraWebHttpResponse::html(StatusCode::OK, "<h1>ok</h1>");
assert_eq!(res.status.as_u16(), 200);
assert_eq!(
res.headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
Some("text/html; charset=utf-8")
);
assert!(!res.headers.contains_key(http::header::CONTENT_LENGTH));
let res = PingoraWebHttpResponse::empty(StatusCode::NO_CONTENT);
assert_eq!(res.status.as_u16(), 204);
assert!(!res.headers.contains_key(http::header::CONTENT_LENGTH));
let res = PingoraWebHttpResponse::bytes(StatusCode::CREATED, Bytes::from(vec![1, 2, 3]));
assert_eq!(res.status.as_u16(), 201);
assert!(!res.headers.contains_key("content-length"));
match res.body {
Body::Bytes(b) => assert_eq!(b.as_ref(), &[1, 2, 3]),
_ => panic!("expected bytes body"),
}
}
#[test]
fn response_constructors() {
let res = PingoraWebHttpResponse::text(StatusCode::OK, "hello world");
assert_eq!(
res.headers.get(http::header::CONTENT_TYPE).unwrap(),
&HeaderValue::from_static("text/plain; charset=utf-8")
);
assert!(!res.headers.contains_key("content-length"));
use futures::StreamExt;
let stream = futures::stream::iter(vec![
Bytes::from_static(b"chunk1"),
Bytes::from_static(b"chunk2"),
]);
let res = PingoraWebHttpResponse::stream(StatusCode::OK, stream.boxed());
assert!(!res.headers.contains_key(http::header::CONTENT_LENGTH));
assert!(!res.headers.contains_key(http::header::TRANSFER_ENCODING));
}
#[test]
fn manual_headers_not_overridden() {
let mut res = PingoraWebHttpResponse::text(StatusCode::OK, "hello");
res.set_header("content-length", "999");
assert_eq!(
res.headers.get(http::header::CONTENT_LENGTH).unwrap(),
&HeaderValue::from_static("999")
);
}
#[test]
fn convenience_methods() {
let res = PingoraWebHttpResponse::ok("Success");
assert_eq!(res.status.as_u16(), 200);
let res = PingoraWebHttpResponse::no_content();
assert_eq!(res.status.as_u16(), 204);
let res = PingoraWebHttpResponse::not_found("Resource not found");
assert_eq!(res.status.as_u16(), 404);
assert_eq!(
res.headers.get(http::header::CONTENT_TYPE).unwrap(),
&HeaderValue::from_static("application/json")
);
let res = PingoraWebHttpResponse::redirect_to("/login");
assert_eq!(res.status.as_u16(), 302);
assert_eq!(
res.headers.get("location").unwrap(),
&HeaderValue::from_static("/login")
);
let res = PingoraWebHttpResponse::redirect_permanent("/new-url");
assert_eq!(res.status.as_u16(), 301);
}
}