use crate::error::Error;
use bytes::Bytes;
use http::{HeaderValue, StatusCode, header};
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use serde::Serialize;
#[derive(Debug)]
pub struct BodyError(String);
impl BodyError {
pub(crate) fn new(message: impl Into<String>) -> Self {
Self(message.into())
}
}
impl std::fmt::Display for BodyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for BodyError {}
impl From<std::convert::Infallible> for BodyError {
fn from(e: std::convert::Infallible) -> Self {
match e {}
}
}
pub struct JcBody(BoxBody<Bytes, BodyError>);
impl JcBody {
pub fn full(bytes: impl Into<Bytes>) -> Self {
Self(Full::new(bytes.into()).map_err(BodyError::from).boxed())
}
pub fn empty() -> Self {
Self::full(Bytes::new())
}
pub fn stream<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BodyError>,
{
Self(body.map_err(Into::into).boxed())
}
}
impl http_body::Body for JcBody {
type Data = Bytes;
type Error = BodyError;
fn poll_frame(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
std::pin::Pin::new(&mut self.0).poll_frame(cx)
}
fn is_end_stream(&self) -> bool {
self.0.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.0.size_hint()
}
}
pub type Response = http::Response<JcBody>;
pub trait IntoResponse {
fn into_response(self) -> Response;
}
pub struct Json<T>(pub T);
pub struct Created<T>(pub T);
pub struct NoContent;
pub struct Redirect {
status: StatusCode,
location: String,
}
impl Redirect {
pub fn to(location: impl Into<String>) -> Self {
Self {
status: StatusCode::FOUND,
location: location.into(),
}
}
pub fn see_other(location: impl Into<String>) -> Self {
Self {
status: StatusCode::SEE_OTHER,
location: location.into(),
}
}
pub fn temporary(location: impl Into<String>) -> Self {
Self {
status: StatusCode::TEMPORARY_REDIRECT,
location: location.into(),
}
}
pub fn permanent(location: impl Into<String>) -> Self {
Self {
status: StatusCode::PERMANENT_REDIRECT,
location: location.into(),
}
}
}
impl IntoResponse for Redirect {
fn into_response(self) -> Response {
let value = match HeaderValue::from_str(&self.location) {
Ok(v) => v,
Err(_) => {
return Error::internal("redirect location is not a valid header value")
.into_response();
}
};
let mut r = http::Response::new(JcBody::empty());
*r.status_mut() = self.status;
r.headers_mut().insert(header::LOCATION, value);
r
}
}
fn full(status: StatusCode, content_type: &'static str, body: impl Into<Bytes>) -> Response {
let mut r = http::Response::new(JcBody::full(body));
*r.status_mut() = status;
r.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type));
r
}
fn json_body<T: Serialize>(status: StatusCode, value: &T) -> Response {
match serde_json::to_vec(value) {
Ok(bytes) => full(status, "application/json", bytes),
Err(e) => Error::internal(format!("response serialization failed: {e}")).into_response(),
}
}
impl IntoResponse for Response {
fn into_response(self) -> Response {
self
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response {
full(
StatusCode::OK,
"text/plain; charset=utf-8",
self.as_bytes().to_vec(),
)
}
}
impl IntoResponse for String {
fn into_response(self) -> Response {
full(
StatusCode::OK,
"text/plain; charset=utf-8",
self.into_bytes(),
)
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> Response {
let mut r = http::Response::new(JcBody::empty());
*r.status_mut() = self;
r
}
}
impl<T: Serialize> IntoResponse for Json<T> {
fn into_response(self) -> Response {
json_body(StatusCode::OK, &self.0)
}
}
impl<T: Serialize> IntoResponse for Created<T> {
fn into_response(self) -> Response {
json_body(StatusCode::CREATED, &self.0)
}
}
impl IntoResponse for NoContent {
fn into_response(self) -> Response {
let mut r = http::Response::new(JcBody::empty());
*r.status_mut() = StatusCode::NO_CONTENT;
r
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
fn into_response(self) -> Response {
let (status, inner) = self;
let mut r = inner.into_response();
*r.status_mut() = status;
r
}
}
#[derive(Serialize)]
struct ErrorBody<'a> {
code: &'a str,
message: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<&'a serde_json::Value>,
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
json_body(
self.status(),
&ErrorBody {
code: self.code(),
message: self.message(),
details: self.details(),
},
)
}
}
impl<T: IntoResponse> IntoResponse for crate::Result<T> {
fn into_response(self) -> Response {
match self {
Ok(v) => v.into_response(),
Err(e) => e.into_response(),
}
}
}
pub struct StreamBody {
stream: std::pin::Pin<
Box<dyn futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static>,
>,
content_type: HeaderValue,
attachment: Option<HeaderValue>,
frame_timeout: std::time::Duration,
}
impl StreamBody {
pub const DEFAULT_FRAME_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
pub fn new<S>(stream: S) -> Self
where
S: futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static,
{
Self {
stream: Box::pin(stream),
content_type: HeaderValue::from_static("application/octet-stream"),
attachment: None,
frame_timeout: Self::DEFAULT_FRAME_TIMEOUT,
}
}
pub fn channel() -> (Self, BodySender) {
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, Error>>(16);
(Self::new(ReceiverStream(rx)), BodySender(tx))
}
pub fn content_type(mut self, value: &str) -> Self {
self.content_type =
HeaderValue::from_str(value).expect("content_type must be a valid header value");
self
}
pub fn attachment(mut self, filename: &str) -> Self {
let safe: String = filename
.chars()
.filter(|c| *c != '"' && *c != '\\' && !c.is_control())
.collect();
self.attachment = Some(
HeaderValue::from_str(&format!("attachment; filename=\"{safe}\""))
.expect("sanitized filename is a valid header value"),
);
self
}
pub fn frame_timeout(mut self, timeout: std::time::Duration) -> Self {
self.frame_timeout = timeout;
self
}
}
pub struct BodySender(tokio::sync::mpsc::Sender<Result<Bytes, Error>>);
impl BodySender {
pub async fn send(&self, chunk: impl Into<Bytes>) -> bool {
self.0.send(Ok(chunk.into())).await.is_ok()
}
pub async fn fail(self, error: Error) -> bool {
self.0.send(Err(error)).await.is_ok()
}
}
struct ReceiverStream(tokio::sync::mpsc::Receiver<Result<Bytes, Error>>);
impl futures_core::Stream for ReceiverStream {
type Item = Result<Bytes, Error>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
struct TimedFrames {
stream: std::pin::Pin<
Box<dyn futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static>,
>,
timeout: std::time::Duration,
sleep: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
}
impl http_body::Body for TimedFrames {
type Data = Bytes;
type Error = BodyError;
fn poll_frame(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, BodyError>>> {
use std::future::Future;
use std::task::Poll;
match self.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
self.sleep = None;
Poll::Ready(Some(Ok(http_body::Frame::data(chunk))))
}
Poll::Ready(Some(Err(e))) => {
self.sleep = None;
Poll::Ready(Some(Err(BodyError::new(format!(
"response stream failed: {e}"
)))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
let timeout = self.timeout;
let sleep = self
.sleep
.get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)));
match sleep.as_mut().poll(cx) {
Poll::Ready(()) => {
self.sleep = None;
Poll::Ready(Some(Err(BodyError::new(
"response stream timed out producing the next chunk",
))))
}
Poll::Pending => Poll::Pending,
}
}
}
}
}
impl IntoResponse for StreamBody {
fn into_response(self) -> Response {
let body = JcBody::stream(TimedFrames {
stream: self.stream,
timeout: self.frame_timeout,
sleep: None,
});
let mut r = http::Response::new(body);
r.headers_mut()
.insert(header::CONTENT_TYPE, self.content_type);
if let Some(disposition) = self.attachment {
r.headers_mut()
.insert(header::CONTENT_DISPOSITION, disposition);
}
r
}
}
#[cfg(test)]
mod tests {
use super::*;
fn body_of(r: Response) -> String {
let collected = futures_executor_lite(r.into_body());
String::from_utf8(collected.to_vec()).unwrap()
}
fn futures_executor_lite(body: JcBody) -> Bytes {
let fut = body.collect();
let mut fut = Box::pin(fut);
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(waker);
match fut.as_mut().poll(&mut cx) {
std::task::Poll::Ready(Ok(c)) => c.to_bytes(),
_ => panic!("buffered body was not immediately ready"),
}
}
#[test]
fn str_becomes_200_text() {
let r = "hello".into_response();
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(
r.headers()[header::CONTENT_TYPE],
"text/plain; charset=utf-8"
);
assert_eq!(body_of(r), "hello");
}
#[test]
fn json_wrapper_sets_content_type() {
#[derive(Serialize)]
struct Todo {
id: u32,
}
let r = Json(Todo { id: 7 }).into_response();
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(r.headers()[header::CONTENT_TYPE], "application/json");
assert_eq!(body_of(r), r#"{"id":7}"#);
}
#[test]
fn created_is_201_and_no_content_is_204() {
#[derive(Serialize)]
struct T {
ok: bool,
}
assert_eq!(
Created(T { ok: true }).into_response().status(),
StatusCode::CREATED
);
let r = NoContent.into_response();
assert_eq!(r.status(), StatusCode::NO_CONTENT);
assert_eq!(body_of(r), "");
}
#[test]
fn errors_render_code_and_message_json() {
let r = Error::not_found().into_response();
assert_eq!(r.status(), StatusCode::NOT_FOUND);
assert_eq!(body_of(r), r#"{"code":"JC0404","message":"not found"}"#);
}
#[test]
fn error_details_appear_in_the_body_only_when_present() {
let r = Error::not_found().into_response();
assert_eq!(body_of(r), r#"{"code":"JC0404","message":"not found"}"#);
let r = Error::unprocessable("validation failed")
.with_details(serde_json::json!([{ "field": "t" }]))
.into_response();
assert_eq!(
body_of(r),
r#"{"code":"JC0422","message":"validation failed","details":[{"field":"t"}]}"#
);
}
#[test]
fn result_renders_ok_or_err() {
let ok: crate::Result<&'static str> = Ok("fine");
assert_eq!(ok.into_response().status(), StatusCode::OK);
let err: crate::Result<&'static str> = Err(Error::bad_request("x"));
assert_eq!(err.into_response().status(), StatusCode::BAD_REQUEST);
}
#[test]
fn redirect_to_is_302_with_location_and_empty_body() {
let r = Redirect::to("/x").into_response();
assert_eq!(r.status(), StatusCode::FOUND);
assert_eq!(r.headers()[header::LOCATION], "/x");
assert_eq!(body_of(r), "");
}
#[test]
fn redirect_constructors_set_their_status_and_location() {
for (build, status) in [
(Redirect::see_other("/a") as Redirect, StatusCode::SEE_OTHER),
(Redirect::temporary("/b"), StatusCode::TEMPORARY_REDIRECT),
(Redirect::permanent("/c"), StatusCode::PERMANENT_REDIRECT),
] {
let r = build.into_response();
assert_eq!(r.status(), status);
assert!(r.headers().contains_key(header::LOCATION));
}
}
#[test]
fn redirect_with_invalid_location_is_a_non_panicking_500() {
let r = Redirect::to("/bad\nlocation").into_response();
assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(r.headers().get(header::LOCATION).is_none());
}
#[test]
fn status_tuple_overrides_status_keeping_the_json_body() {
#[derive(Serialize)]
struct Summary {
queued: u32,
}
let r = (StatusCode::ACCEPTED, Json(Summary { queued: 3 })).into_response();
assert_eq!(r.status(), StatusCode::ACCEPTED);
assert_eq!(r.headers()[header::CONTENT_TYPE], "application/json");
assert_eq!(body_of(r), r#"{"queued":3}"#);
}
#[test]
fn status_tuple_overrides_status_keeping_the_text_body() {
let r = (StatusCode::ACCEPTED, "queued").into_response();
assert_eq!(r.status(), StatusCode::ACCEPTED);
assert_eq!(
r.headers()[header::CONTENT_TYPE],
"text/plain; charset=utf-8"
);
assert_eq!(body_of(r), "queued");
}
#[tokio::test]
async fn boxed_bodies_stream_and_collect() {
struct Chunks(std::collections::VecDeque<Bytes>);
impl http_body::Body for Chunks {
type Data = Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, Self::Error>>> {
std::task::Poll::Ready(self.0.pop_front().map(|b| Ok(http_body::Frame::data(b))))
}
}
let body = JcBody::stream(Chunks(
[Bytes::from("ab"), Bytes::from("cd")].into_iter().collect(),
));
use http_body_util::BodyExt;
let collected = body.collect().await.unwrap().to_bytes();
assert_eq!(collected, Bytes::from("abcd"));
}
#[tokio::test]
async fn stream_body_streams_with_content_type_and_disposition() {
let (body, tx) = StreamBody::channel();
let send = async move {
assert!(tx.send("a,b\n").await);
assert!(tx.send("1,2\n").await);
};
let r = body
.content_type("text/csv")
.attachment("export.csv")
.into_response();
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(r.headers()[header::CONTENT_TYPE], "text/csv");
assert_eq!(
r.headers()[header::CONTENT_DISPOSITION],
"attachment; filename=\"export.csv\""
);
let (_, collected) = tokio::join!(send, r.into_body().collect());
assert_eq!(collected.unwrap().to_bytes(), Bytes::from("a,b\n1,2\n"));
}
#[tokio::test(start_paused = true)]
async fn stream_body_frame_timeout_errors_the_body() {
struct Never;
impl futures_core::Stream for Never {
type Item = Result<Bytes, Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::task::Poll::Pending
}
}
let body = StreamBody::new(Never)
.frame_timeout(std::time::Duration::from_millis(100))
.into_response()
.into_body();
use http_body_util::BodyExt;
let err = body
.collect()
.await
.expect_err("stall must error, not end cleanly");
assert!(err.to_string().contains("timed out"), "{err}");
}
#[tokio::test]
async fn channel_fail_surfaces_as_a_body_error_carrying_the_message() {
let (body, tx) = StreamBody::channel();
let produce = async move {
assert!(tx.send("first chunk").await, "client present");
assert!(tx.fail(Error::internal("boom")).await, "fail delivered");
};
let response = body.into_response();
use http_body_util::BodyExt;
let (_, collected) = tokio::join!(produce, response.into_body().collect());
let err = collected.expect_err("a failed producer must error the body, not end cleanly");
assert!(
err.to_string().contains("boom"),
"the propagated message must survive to the body error: {err}"
);
}
#[tokio::test]
async fn stream_body_composes_through_a_real_handler_dispatch() {
use crate::prelude::*;
async fn export() -> Result<StreamBody> {
let (body, tx) = StreamBody::channel();
tokio::spawn(async move {
tx.send("id,name\n").await;
tx.send("1,ada\n").await;
});
Ok(body.content_type("text/csv"))
}
let t = App::new().route("/export", get(export)).into_test();
let r = t.get("/export").await;
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(r.headers()[header::CONTENT_TYPE], "text/csv");
assert_eq!(r.text(), "id,name\n1,ada\n");
}
}