use crate::HeaderField;
use candid::{
types::{Serializer, Type, TypeInner},
CandidType, Deserialize,
};
pub use http::StatusCode;
use serde::Deserializer;
use std::{borrow::Cow, fmt::Debug};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct StatusCodeWrapper(StatusCode);
impl CandidType for StatusCodeWrapper {
fn _ty() -> Type {
TypeInner::Nat16.into()
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: Serializer,
{
self.0.as_u16().idl_serialize(serializer)
}
}
impl<'de> Deserialize<'de> for StatusCodeWrapper {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
u16::deserialize(deserializer).and_then(|status_code| {
StatusCode::from_u16(status_code)
.map(Into::into)
.map_err(|_| serde::de::Error::custom("Invalid HTTP Status Code."))
})
}
}
impl From<StatusCode> for StatusCodeWrapper {
fn from(status_code: StatusCode) -> Self {
Self(status_code)
}
}
#[derive(Clone, CandidType, Deserialize)]
pub struct HttpResponse<'a> {
status_code: StatusCodeWrapper,
headers: Vec<HeaderField>,
body: Cow<'a, [u8]>,
upgrade: Option<bool>,
}
impl<'a> HttpResponse<'a> {
pub fn ok(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::OK)
.with_body(body)
.with_headers(headers)
}
pub fn created(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::CREATED)
.with_body(body)
.with_headers(headers)
}
pub fn no_content(headers: Vec<(String, String)>) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::NO_CONTENT)
.with_headers(headers)
}
pub fn moved_permanently(
location: impl Into<String>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
let headers = headers
.into_iter()
.chain(std::iter::once(("Location".into(), location.into())))
.collect();
Self::builder()
.with_status_code(StatusCode::MOVED_PERMANENTLY)
.with_headers(headers)
}
pub fn not_modified(headers: Vec<(String, String)>) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::NOT_MODIFIED)
.with_headers(headers)
}
pub fn temporary_redirect(
location: impl Into<String>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
let headers = headers
.into_iter()
.chain(std::iter::once(("Location".into(), location.into())))
.collect();
Self::builder()
.with_status_code(StatusCode::TEMPORARY_REDIRECT)
.with_headers(headers)
}
pub fn bad_request(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::BAD_REQUEST)
.with_body(body)
.with_headers(headers)
}
pub fn unauthorized(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::UNAUTHORIZED)
.with_body(body)
.with_headers(headers)
}
pub fn forbidden(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::FORBIDDEN)
.with_body(body)
.with_headers(headers)
}
pub fn not_found(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::NOT_FOUND)
.with_body(body)
.with_headers(headers)
}
pub fn method_not_allowed(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::METHOD_NOT_ALLOWED)
.with_body(body)
.with_headers(headers)
}
pub fn too_many_requests(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::TOO_MANY_REQUESTS)
.with_body(body)
.with_headers(headers)
}
pub fn internal_server_error(
body: impl Into<Cow<'a, [u8]>>,
headers: Vec<(String, String)>,
) -> HttpResponseBuilder<'a> {
Self::builder()
.with_status_code(StatusCode::INTERNAL_SERVER_ERROR)
.with_body(body)
.with_headers(headers)
}
#[inline]
pub fn builder() -> HttpResponseBuilder<'a> {
HttpResponseBuilder::new()
}
#[inline]
pub fn status_code(&self) -> StatusCode {
self.status_code.0
}
#[inline]
pub fn headers(&self) -> &[HeaderField] {
&self.headers
}
#[inline]
pub fn headers_mut(&mut self) -> &mut Vec<HeaderField> {
&mut self.headers
}
#[inline]
pub fn add_header(&mut self, header: HeaderField) {
self.headers.push(header);
}
#[inline]
pub fn body(&self) -> &[u8] {
&self.body
}
#[inline]
pub fn upgrade(&self) -> Option<bool> {
self.upgrade
}
}
#[derive(Debug, Clone, Default)]
pub struct HttpResponseBuilder<'a> {
status_code: Option<StatusCodeWrapper>,
headers: Vec<HeaderField>,
body: Cow<'a, [u8]>,
upgrade: Option<bool>,
}
impl<'a> HttpResponseBuilder<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn with_status_code(mut self, status_code: StatusCode) -> Self {
self.status_code = Some(status_code.into());
self
}
pub fn with_headers(mut self, headers: Vec<HeaderField>) -> Self {
self.headers = headers;
self
}
pub fn with_body(mut self, body: impl Into<Cow<'a, [u8]>>) -> Self {
self.body = body.into();
self
}
pub fn with_upgrade(mut self, upgrade: bool) -> Self {
self.upgrade = Some(upgrade);
self
}
pub fn build(self) -> HttpResponse<'a> {
HttpResponse {
status_code: self.status_code.unwrap_or(StatusCode::OK.into()),
headers: self.headers,
body: self.body,
upgrade: self.upgrade,
}
}
pub fn build_update(self) -> HttpUpdateResponse<'a> {
HttpUpdateResponse {
status_code: self.status_code.unwrap_or(StatusCode::OK.into()),
headers: self.headers,
body: self.body,
}
}
}
impl<'a> From<HttpResponse<'a>> for HttpResponseBuilder<'a> {
fn from(response: HttpResponse<'a>) -> Self {
Self {
status_code: Some(response.status_code),
headers: response.headers,
body: response.body,
upgrade: response.upgrade,
}
}
}
impl PartialEq for HttpResponse<'_> {
fn eq(&self, other: &Self) -> bool {
let mut a_headers = self.headers().to_vec();
a_headers.sort();
let mut b_headers = other.headers().to_vec();
b_headers.sort();
self.status_code == other.status_code
&& a_headers == b_headers
&& self.body == other.body
&& self.upgrade == other.upgrade
}
}
impl Debug for HttpResponse<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let max_body_len = 100;
let formatted_body = if self.body.len() > max_body_len {
format!("{:?}...", &self.body[..max_body_len])
} else {
format!("{:?}", &self.body)
};
f.debug_struct("HttpResponse")
.field("status_code", &self.status_code)
.field("headers", &self.headers)
.field("body", &formatted_body)
.field("upgrade", &self.upgrade)
.finish()
}
}
#[derive(Clone, Debug, CandidType, Deserialize, PartialEq, Eq)]
pub struct HttpUpdateResponse<'a> {
status_code: StatusCodeWrapper,
headers: Vec<HeaderField>,
body: Cow<'a, [u8]>,
}
impl<'a> HttpUpdateResponse<'a> {
#[inline]
pub fn status_code(&self) -> StatusCode {
self.status_code.0
}
#[inline]
pub fn headers(&self) -> &[HeaderField] {
&self.headers
}
#[inline]
pub fn headers_mut(&mut self) -> &mut Vec<HeaderField> {
&mut self.headers
}
#[inline]
pub fn add_header(&mut self, header: HeaderField) {
self.headers.push(header);
}
#[inline]
pub fn body(&self) -> &[u8] {
&self.body
}
}
impl<'a> From<HttpResponse<'a>> for HttpUpdateResponse<'a> {
fn from(response: HttpResponse<'a>) -> Self {
Self {
status_code: response.status_code,
headers: response.headers,
body: response.body,
}
}
}