use bytes::{Bytes, BytesMut};
use cot::error::error_impl::impl_into_cot_error;
use cot::headers::{HTML_CONTENT_TYPE, OCTET_STREAM_CONTENT_TYPE, PLAIN_TEXT_CONTENT_TYPE};
use cot::response::{RESPONSE_BUILD_FAILURE, Response};
use cot::{Body, Error, StatusCode};
use http;
#[cfg(feature = "json")]
use crate::headers::JSON_CONTENT_TYPE;
use crate::html::Html;
use crate::response::Redirect;
pub trait IntoResponse {
fn into_response(self) -> cot::Result<Response>;
fn with_header<K, V>(self, key: K, value: V) -> WithHeader<Self>
where
K: TryInto<http::HeaderName>,
V: TryInto<http::HeaderValue>,
Self: Sized,
{
let key = key.try_into().ok();
let value = value.try_into().ok();
WithHeader {
inner: self,
header: key.zip(value),
}
}
fn with_content_type<V>(self, content_type: V) -> WithContentType<Self>
where
V: TryInto<http::HeaderValue>,
Self: Sized,
{
WithContentType {
inner: self,
content_type: content_type.try_into().ok(),
}
}
fn with_status(self, status: StatusCode) -> WithStatus<Self>
where
Self: Sized,
{
WithStatus {
inner: self,
status,
}
}
fn with_body(self, body: impl Into<Body>) -> WithBody<Self>
where
Self: Sized,
{
WithBody {
inner: self,
body: body.into(),
}
}
fn with_extension<T>(self, extension: T) -> WithExtension<Self, T>
where
T: Clone + Send + Sync + 'static,
Self: Sized,
{
WithExtension {
inner: self,
extension,
}
}
}
#[derive(Debug)]
pub struct WithHeader<T> {
inner: T,
header: Option<(http::HeaderName, http::HeaderValue)>,
}
impl<T: IntoResponse> IntoResponse for WithHeader<T> {
fn into_response(self) -> cot::Result<Response> {
self.inner.into_response().map(|mut resp| {
if let Some((key, value)) = self.header {
resp.headers_mut().append(key, value);
}
resp
})
}
}
#[derive(Debug)]
pub struct WithContentType<T> {
inner: T,
content_type: Option<http::HeaderValue>,
}
impl<T: IntoResponse> IntoResponse for WithContentType<T> {
fn into_response(self) -> cot::Result<Response> {
self.inner.into_response().map(|mut resp| {
if let Some(content_type) = self.content_type {
resp.headers_mut()
.insert(http::header::CONTENT_TYPE, content_type);
}
resp
})
}
}
#[derive(Debug)]
pub struct WithStatus<T> {
inner: T,
status: StatusCode,
}
impl<T: IntoResponse> IntoResponse for WithStatus<T> {
fn into_response(self) -> cot::Result<Response> {
self.inner.into_response().map(|mut resp| {
*resp.status_mut() = self.status;
resp
})
}
}
#[derive(Debug)]
pub struct WithBody<T> {
inner: T,
body: Body,
}
impl<T: IntoResponse> IntoResponse for WithBody<T> {
fn into_response(self) -> cot::Result<Response> {
self.inner.into_response().map(|mut resp| {
*resp.body_mut() = self.body;
resp
})
}
}
#[derive(Debug)]
pub struct WithExtension<T, D> {
inner: T,
extension: D,
}
impl<T, D> IntoResponse for WithExtension<T, D>
where
T: IntoResponse,
D: Clone + Send + Sync + 'static,
{
fn into_response(self) -> cot::Result<Response> {
self.inner.into_response().map(|mut resp| {
resp.extensions_mut().insert(self.extension);
resp
})
}
}
macro_rules! impl_into_response_for_type_and_mime {
($ty:ty, $mime:expr) => {
impl IntoResponse for $ty {
fn into_response(self) -> cot::Result<Response> {
Body::from(self)
.with_header(http::header::CONTENT_TYPE, $mime)
.into_response()
}
}
};
}
impl IntoResponse for () {
fn into_response(self) -> cot::Result<Response> {
Body::empty().into_response()
}
}
impl<R, E> IntoResponse for Result<R, E>
where
R: IntoResponse,
E: Into<Error>,
{
fn into_response(self) -> cot::Result<Response> {
match self {
Ok(value) => value.into_response(),
Err(err) => Err(err.into()),
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> cot::Result<Response> {
Err(self)
}
}
impl IntoResponse for Response {
fn into_response(self) -> cot::Result<Response> {
Ok(self)
}
}
impl_into_response_for_type_and_mime!(&'static str, PLAIN_TEXT_CONTENT_TYPE);
impl_into_response_for_type_and_mime!(String, PLAIN_TEXT_CONTENT_TYPE);
impl IntoResponse for Box<str> {
fn into_response(self) -> cot::Result<Response> {
String::from(self).into_response()
}
}
impl_into_response_for_type_and_mime!(&'static [u8], OCTET_STREAM_CONTENT_TYPE);
impl_into_response_for_type_and_mime!(Vec<u8>, OCTET_STREAM_CONTENT_TYPE);
impl_into_response_for_type_and_mime!(Bytes, OCTET_STREAM_CONTENT_TYPE);
impl<const N: usize> IntoResponse for &'static [u8; N] {
fn into_response(self) -> cot::Result<Response> {
self.as_slice().into_response()
}
}
impl<const N: usize> IntoResponse for [u8; N] {
fn into_response(self) -> cot::Result<Response> {
self.to_vec().into_response()
}
}
impl IntoResponse for Box<[u8]> {
fn into_response(self) -> cot::Result<Response> {
Vec::from(self).into_response()
}
}
impl IntoResponse for BytesMut {
fn into_response(self) -> cot::Result<Response> {
self.freeze().into_response()
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> cot::Result<Response> {
().into_response().with_status(self).into_response()
}
}
impl IntoResponse for http::HeaderMap {
fn into_response(self) -> cot::Result<Response> {
().into_response().map(|mut resp| {
*resp.headers_mut() = self;
resp
})
}
}
impl IntoResponse for http::Extensions {
fn into_response(self) -> cot::Result<Response> {
().into_response().map(|mut resp| {
*resp.extensions_mut() = self;
resp
})
}
}
impl IntoResponse for crate::response::ResponseHead {
fn into_response(self) -> cot::Result<Response> {
Ok(Response::from_parts(self, Body::empty()))
}
}
impl IntoResponse for Html {
fn into_response(self) -> cot::Result<Response> {
self.0
.into_response()
.with_content_type(HTML_CONTENT_TYPE)
.into_response()
}
}
#[cfg(feature = "json")]
impl<D: serde::Serialize> IntoResponse for cot::json::Json<D> {
fn into_response(self) -> cot::Result<Response> {
const DEFAULT_JSON_SIZE: usize = 128;
let mut buf = Vec::with_capacity(DEFAULT_JSON_SIZE);
let mut serializer = serde_json::Serializer::new(&mut buf);
serde_path_to_error::serialize(&self.0, &mut serializer).map_err(JsonSerializeError)?;
let data = String::from_utf8(buf).expect("JSON serialization always returns valid UTF-8");
data.with_content_type(JSON_CONTENT_TYPE).into_response()
}
}
#[cfg(feature = "json")]
#[derive(Debug, thiserror::Error)]
#[error("JSON serialization error: {0}")]
struct JsonSerializeError(serde_path_to_error::Error<serde_json::Error>);
#[cfg(feature = "json")]
impl_into_cot_error!(JsonSerializeError, INTERNAL_SERVER_ERROR);
impl IntoResponse for Body {
fn into_response(self) -> cot::Result<Response> {
Ok(Response::new(self))
}
}
impl IntoResponse for Redirect {
fn into_response(self) -> cot::Result<Response> {
let response = http::Response::builder()
.status(StatusCode::SEE_OTHER)
.header(http::header::LOCATION, self.0)
.body(Body::empty())
.expect(RESPONSE_BUILD_FAILURE);
Ok(response)
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use cot::response::Response;
use cot::{Body, StatusCode};
use http::{self, HeaderMap, HeaderValue};
use super::*;
use crate::error::NotFound;
use crate::html::Html;
#[cot::test]
async fn test_unit_into_response() {
let response = ().into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().is_empty());
assert_eq!(response.into_body().into_bytes().await.unwrap().len(), 0);
}
#[cot::test]
async fn test_result_ok_into_response() {
let res: Result<&'static str, Error> = Ok("hello");
let response = res.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "hello");
}
#[cot::test]
async fn test_result_err_into_response() {
let err = Error::from(NotFound::with_message("test"));
let res: Result<&'static str, Error> = Err(err);
let error_result = res.into_response();
assert!(error_result.is_err());
assert!(error_result.err().unwrap().to_string().contains("test"));
}
#[cot::test]
async fn test_response_into_response() {
let original_response = Response::new(Body::from("test"));
let response = original_response.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "test");
}
#[cot::test]
async fn test_static_str_into_response() {
let response = "hello world".into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"hello world"
);
}
#[cot::test]
async fn test_string_into_response() {
let s = String::from("hello string");
let response = s.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"hello string"
);
}
#[cot::test]
async fn test_box_str_into_response() {
let b: Box<str> = "hello box".into();
let response = b.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"hello box"
);
}
#[cot::test]
async fn test_static_u8_slice_into_response() {
let data: &'static [u8] = b"hello bytes";
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"hello bytes"
);
}
#[cot::test]
async fn test_vec_u8_into_response() {
let data: Vec<u8> = vec![1, 2, 3];
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from(vec![1, 2, 3])
);
}
#[cot::test]
async fn test_bytes_into_response() {
let data = Bytes::from_static(b"hello bytes obj");
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"hello bytes obj"
);
}
#[cot::test]
async fn test_static_u8_array_into_response() {
let data: &'static [u8; 5] = b"array";
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "array");
}
#[cot::test]
async fn test_u8_array_into_response() {
let data: [u8; 3] = [4, 5, 6];
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from(vec![4, 5, 6])
);
}
#[cot::test]
async fn test_box_u8_slice_into_response() {
let data: Box<[u8]> = Box::new([7, 8, 9]);
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from(vec![7, 8, 9])
);
}
#[cot::test]
async fn test_bytes_mut_into_response() {
let mut data = BytesMut::with_capacity(10);
data.extend_from_slice(b"mutable");
let response = data.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "mutable");
}
#[cot::test]
async fn test_status_code_into_response() {
let response = StatusCode::NOT_FOUND.into_response().unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
assert!(response.headers().is_empty());
assert_eq!(response.into_body().into_bytes().await.unwrap().len(), 0);
}
#[cot::test]
async fn test_header_map_into_response() {
let mut headers = HeaderMap::new();
headers.insert("X-Test", HeaderValue::from_static("value"));
let response = headers.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers().get("X-Test").unwrap(), "value");
assert_eq!(response.into_body().into_bytes().await.unwrap().len(), 0);
}
#[cot::test]
async fn test_extensions_into_response() {
let mut extensions = http::Extensions::new();
extensions.insert("My Extension Data");
let response = extensions.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().is_empty());
assert_eq!(
response.extensions().get::<&str>(),
Some(&"My Extension Data")
);
assert_eq!(response.into_body().into_bytes().await.unwrap().len(), 0);
}
#[cot::test]
async fn test_parts_into_response() {
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::ACCEPTED;
response
.headers_mut()
.insert("X-From-Parts", HeaderValue::from_static("yes"));
response.extensions_mut().insert(123usize);
let (head, _) = response.into_parts();
let new_response = head.into_response().unwrap();
assert_eq!(new_response.status(), StatusCode::ACCEPTED);
assert_eq!(new_response.headers().get("X-From-Parts").unwrap(), "yes");
assert_eq!(new_response.extensions().get::<usize>(), Some(&123));
assert_eq!(
new_response.into_body().into_bytes().await.unwrap().len(),
0
);
}
#[cot::test]
async fn test_html_into_response() {
let html = Html::new("<h1>Test</h1>");
let response = html.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/html; charset=utf-8"
);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"<h1>Test</h1>"
);
}
#[cot::test]
async fn test_body_into_response() {
let body = Body::from("body test");
let response = body.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE),
None );
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
"body test"
);
}
#[cot::test]
async fn test_with_header() {
let response = "test"
.with_header("X-Custom", "HeaderValue")
.into_response()
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers().get("X-Custom").unwrap(), "HeaderValue");
assert_eq!(response.into_body().into_bytes().await.unwrap(), "test");
}
#[cot::test]
async fn test_with_content_type() {
let response = "test"
.with_content_type("application/json")
.into_response()
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"application/json"
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "test");
}
#[cot::test]
async fn test_with_status() {
let response = "test"
.with_status(StatusCode::CREATED)
.into_response()
.unwrap();
assert_eq!(response.status(), StatusCode::CREATED);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "test");
}
#[cot::test]
async fn test_with_body() {
let response = StatusCode::ACCEPTED
.with_body("new body")
.into_response()
.unwrap();
assert_eq!(response.status(), StatusCode::ACCEPTED);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "new body");
}
#[cot::test]
async fn test_with_extension() {
#[derive(Clone, Debug, PartialEq)]
struct MyExt(String);
let response = "test"
.with_extension(MyExt("data".to_string()))
.into_response()
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.extensions().get::<MyExt>(),
Some(&MyExt("data".to_string()))
);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "test");
}
#[cfg(feature = "json")]
#[cot::test]
async fn test_json_struct_into_response() {
use serde::Serialize;
#[derive(Serialize, PartialEq, Debug)]
struct TestData {
name: String,
value: i32,
}
let data = TestData {
name: "test".to_string(),
value: 123,
};
let json = cot::json::Json(data);
let response = json.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
JSON_CONTENT_TYPE
);
let body_bytes = response.into_body().into_bytes().await.unwrap();
let expected_json = r#"{"name":"test","value":123}"#;
assert_eq!(body_bytes, expected_json.as_bytes());
}
#[cfg(feature = "json")]
#[cot::test]
async fn test_json_hashmap_into_response() {
use std::collections::HashMap;
let data = HashMap::from([("key", "value")]);
let json = cot::json::Json(data);
let response = json.into_response().unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(http::header::CONTENT_TYPE).unwrap(),
JSON_CONTENT_TYPE
);
let body_bytes = response.into_body().into_bytes().await.unwrap();
let expected_json = r#"{"key":"value"}"#;
assert_eq!(body_bytes, expected_json.as_bytes());
}
}