1use crate::error::Error;
5use bytes::Bytes;
6use http::{HeaderValue, StatusCode, header};
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use serde::Serialize;
9
10#[derive(Debug)]
14pub struct BodyError(String);
15
16impl BodyError {
17 pub(crate) fn new(message: impl Into<String>) -> Self {
18 Self(message.into())
19 }
20}
21
22impl std::fmt::Display for BodyError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.write_str(&self.0)
25 }
26}
27
28impl std::error::Error for BodyError {}
29
30impl From<std::convert::Infallible> for BodyError {
31 fn from(e: std::convert::Infallible) -> Self {
32 match e {}
33 }
34}
35
36pub struct JcBody(BoxBody<Bytes, BodyError>);
41
42impl JcBody {
43 pub fn full(bytes: impl Into<Bytes>) -> Self {
45 Self(Full::new(bytes.into()).map_err(BodyError::from).boxed())
46 }
47
48 pub fn empty() -> Self {
50 Self::full(Bytes::new())
51 }
52
53 pub fn stream<B>(body: B) -> Self
57 where
58 B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
59 B::Error: Into<BodyError>,
60 {
61 Self(body.map_err(Into::into).boxed())
62 }
63}
64
65impl http_body::Body for JcBody {
66 type Data = Bytes;
67 type Error = BodyError;
68
69 fn poll_frame(
70 mut self: std::pin::Pin<&mut Self>,
71 cx: &mut std::task::Context<'_>,
72 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
73 std::pin::Pin::new(&mut self.0).poll_frame(cx)
74 }
75
76 fn is_end_stream(&self) -> bool {
77 self.0.is_end_stream()
78 }
79
80 fn size_hint(&self) -> http_body::SizeHint {
81 self.0.size_hint()
82 }
83}
84
85pub type Response = http::Response<JcBody>;
88
89pub trait IntoResponse {
91 fn into_response(self) -> Response;
92}
93
94pub struct Json<T>(pub T);
96
97pub struct Created<T>(pub T);
99
100pub struct NoContent;
102
103pub struct Redirect {
107 status: StatusCode,
108 location: String,
109}
110
111impl Redirect {
112 pub fn to(location: impl Into<String>) -> Self {
115 Self {
116 status: StatusCode::FOUND,
117 location: location.into(),
118 }
119 }
120
121 pub fn see_other(location: impl Into<String>) -> Self {
123 Self {
124 status: StatusCode::SEE_OTHER,
125 location: location.into(),
126 }
127 }
128
129 pub fn temporary(location: impl Into<String>) -> Self {
131 Self {
132 status: StatusCode::TEMPORARY_REDIRECT,
133 location: location.into(),
134 }
135 }
136
137 pub fn permanent(location: impl Into<String>) -> Self {
139 Self {
140 status: StatusCode::PERMANENT_REDIRECT,
141 location: location.into(),
142 }
143 }
144}
145
146impl IntoResponse for Redirect {
147 fn into_response(self) -> Response {
148 let value = match HeaderValue::from_str(&self.location) {
152 Ok(v) => v,
153 Err(_) => {
154 return Error::internal("redirect location is not a valid header value")
155 .into_response();
156 }
157 };
158 let mut r = http::Response::new(JcBody::empty());
159 *r.status_mut() = self.status;
160 r.headers_mut().insert(header::LOCATION, value);
161 r
162 }
163}
164
165fn full(status: StatusCode, content_type: &'static str, body: impl Into<Bytes>) -> Response {
166 let mut r = http::Response::new(JcBody::full(body));
167 *r.status_mut() = status;
168 r.headers_mut()
169 .insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type));
170 r
171}
172
173fn json_body<T: Serialize>(status: StatusCode, value: &T) -> Response {
174 match serde_json::to_vec(value) {
175 Ok(bytes) => full(status, "application/json", bytes),
176 Err(e) => Error::internal(format!("response serialization failed: {e}")).into_response(),
177 }
178}
179
180impl IntoResponse for Response {
181 fn into_response(self) -> Response {
182 self
183 }
184}
185
186impl IntoResponse for &'static str {
187 fn into_response(self) -> Response {
188 full(
189 StatusCode::OK,
190 "text/plain; charset=utf-8",
191 self.as_bytes().to_vec(),
192 )
193 }
194}
195
196impl IntoResponse for String {
197 fn into_response(self) -> Response {
198 full(
199 StatusCode::OK,
200 "text/plain; charset=utf-8",
201 self.into_bytes(),
202 )
203 }
204}
205
206impl IntoResponse for StatusCode {
207 fn into_response(self) -> Response {
208 let mut r = http::Response::new(JcBody::empty());
209 *r.status_mut() = self;
210 r
211 }
212}
213
214impl<T: Serialize> IntoResponse for Json<T> {
215 fn into_response(self) -> Response {
216 json_body(StatusCode::OK, &self.0)
217 }
218}
219
220impl<T: Serialize> IntoResponse for Created<T> {
221 fn into_response(self) -> Response {
222 json_body(StatusCode::CREATED, &self.0)
223 }
224}
225
226impl IntoResponse for NoContent {
227 fn into_response(self) -> Response {
228 let mut r = http::Response::new(JcBody::empty());
229 *r.status_mut() = StatusCode::NO_CONTENT;
230 r
231 }
232}
233
234impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
239 fn into_response(self) -> Response {
240 let (status, inner) = self;
241 let mut r = inner.into_response();
242 *r.status_mut() = status;
243 r
244 }
245}
246
247#[derive(Serialize)]
248struct ErrorBody<'a> {
249 code: &'a str,
250 message: &'a str,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 details: Option<&'a serde_json::Value>,
253}
254
255impl IntoResponse for Error {
256 fn into_response(self) -> Response {
257 json_body(
258 self.status(),
259 &ErrorBody {
260 code: self.code(),
261 message: self.message(),
262 details: self.details(),
263 },
264 )
265 }
266}
267
268impl<T: IntoResponse> IntoResponse for crate::Result<T> {
269 fn into_response(self) -> Response {
270 match self {
271 Ok(v) => v.into_response(),
272 Err(e) => e.into_response(),
273 }
274 }
275}
276
277pub struct StreamBody {
281 stream: std::pin::Pin<
282 Box<dyn futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static>,
283 >,
284 content_type: HeaderValue,
285 attachment: Option<HeaderValue>,
286 frame_timeout: std::time::Duration,
287}
288
289impl StreamBody {
290 pub const DEFAULT_FRAME_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
292
293 pub fn new<S>(stream: S) -> Self
296 where
297 S: futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static,
298 {
299 Self {
300 stream: Box::pin(stream),
301 content_type: HeaderValue::from_static("application/octet-stream"),
302 attachment: None,
303 frame_timeout: Self::DEFAULT_FRAME_TIMEOUT,
304 }
305 }
306
307 pub fn channel() -> (Self, BodySender) {
311 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, Error>>(16);
313 (Self::new(ReceiverStream(rx)), BodySender(tx))
314 }
315
316 pub fn content_type(mut self, value: &str) -> Self {
319 self.content_type =
320 HeaderValue::from_str(value).expect("content_type must be a valid header value");
321 self
322 }
323
324 pub fn attachment(mut self, filename: &str) -> Self {
330 let safe: String = filename
331 .chars()
332 .filter(|c| *c != '"' && *c != '\\' && !c.is_control())
333 .collect();
334 self.attachment = Some(
335 HeaderValue::from_str(&format!("attachment; filename=\"{safe}\""))
336 .expect("sanitized filename is a valid header value"),
337 );
338 self
339 }
340
341 pub fn frame_timeout(mut self, timeout: std::time::Duration) -> Self {
344 self.frame_timeout = timeout;
345 self
346 }
347}
348
349pub struct BodySender(tokio::sync::mpsc::Sender<Result<Bytes, Error>>);
351
352impl BodySender {
353 pub async fn send(&self, chunk: impl Into<Bytes>) -> bool {
355 self.0.send(Ok(chunk.into())).await.is_ok()
356 }
357 pub async fn fail(self, error: Error) -> bool {
360 self.0.send(Err(error)).await.is_ok()
361 }
362}
363
364struct ReceiverStream(tokio::sync::mpsc::Receiver<Result<Bytes, Error>>);
366
367impl futures_core::Stream for ReceiverStream {
368 type Item = Result<Bytes, Error>;
369 fn poll_next(
370 mut self: std::pin::Pin<&mut Self>,
371 cx: &mut std::task::Context<'_>,
372 ) -> std::task::Poll<Option<Self::Item>> {
373 self.0.poll_recv(cx)
374 }
375}
376
377struct TimedFrames {
381 stream: std::pin::Pin<
382 Box<dyn futures_core::Stream<Item = Result<Bytes, Error>> + Send + Sync + 'static>,
383 >,
384 timeout: std::time::Duration,
385 sleep: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
386}
387
388impl http_body::Body for TimedFrames {
389 type Data = Bytes;
390 type Error = BodyError;
391
392 fn poll_frame(
393 mut self: std::pin::Pin<&mut Self>,
394 cx: &mut std::task::Context<'_>,
395 ) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, BodyError>>> {
396 use std::future::Future;
397 use std::task::Poll;
398 match self.stream.as_mut().poll_next(cx) {
399 Poll::Ready(Some(Ok(chunk))) => {
400 self.sleep = None;
401 Poll::Ready(Some(Ok(http_body::Frame::data(chunk))))
402 }
403 Poll::Ready(Some(Err(e))) => {
404 self.sleep = None;
405 Poll::Ready(Some(Err(BodyError::new(format!(
406 "response stream failed: {e}"
407 )))))
408 }
409 Poll::Ready(None) => Poll::Ready(None),
410 Poll::Pending => {
411 let timeout = self.timeout;
412 let sleep = self
413 .sleep
414 .get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)));
415 match sleep.as_mut().poll(cx) {
416 Poll::Ready(()) => {
417 self.sleep = None;
418 Poll::Ready(Some(Err(BodyError::new(
419 "response stream timed out producing the next chunk",
420 ))))
421 }
422 Poll::Pending => Poll::Pending,
423 }
424 }
425 }
426 }
427}
428
429impl IntoResponse for StreamBody {
430 fn into_response(self) -> Response {
431 let body = JcBody::stream(TimedFrames {
432 stream: self.stream,
433 timeout: self.frame_timeout,
434 sleep: None,
435 });
436 let mut r = http::Response::new(body);
437 r.headers_mut()
438 .insert(header::CONTENT_TYPE, self.content_type);
439 if let Some(disposition) = self.attachment {
440 r.headers_mut()
441 .insert(header::CONTENT_DISPOSITION, disposition);
442 }
443 r
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 fn body_of(r: Response) -> String {
452 let collected = futures_executor_lite(r.into_body());
453 String::from_utf8(collected.to_vec()).unwrap()
454 }
455
456 fn futures_executor_lite(body: JcBody) -> Bytes {
460 let fut = body.collect();
461 let mut fut = Box::pin(fut);
462 let waker = std::task::Waker::noop();
463 let mut cx = std::task::Context::from_waker(waker);
464 match fut.as_mut().poll(&mut cx) {
465 std::task::Poll::Ready(Ok(c)) => c.to_bytes(),
466 _ => panic!("buffered body was not immediately ready"),
467 }
468 }
469
470 #[test]
471 fn str_becomes_200_text() {
472 let r = "hello".into_response();
473 assert_eq!(r.status(), StatusCode::OK);
474 assert_eq!(
475 r.headers()[header::CONTENT_TYPE],
476 "text/plain; charset=utf-8"
477 );
478 assert_eq!(body_of(r), "hello");
479 }
480
481 #[test]
482 fn json_wrapper_sets_content_type() {
483 #[derive(Serialize)]
484 struct Todo {
485 id: u32,
486 }
487 let r = Json(Todo { id: 7 }).into_response();
488 assert_eq!(r.status(), StatusCode::OK);
489 assert_eq!(r.headers()[header::CONTENT_TYPE], "application/json");
490 assert_eq!(body_of(r), r#"{"id":7}"#);
491 }
492
493 #[test]
494 fn created_is_201_and_no_content_is_204() {
495 #[derive(Serialize)]
496 struct T {
497 ok: bool,
498 }
499 assert_eq!(
500 Created(T { ok: true }).into_response().status(),
501 StatusCode::CREATED
502 );
503 let r = NoContent.into_response();
504 assert_eq!(r.status(), StatusCode::NO_CONTENT);
505 assert_eq!(body_of(r), "");
506 }
507
508 #[test]
509 fn errors_render_code_and_message_json() {
510 let r = Error::not_found().into_response();
511 assert_eq!(r.status(), StatusCode::NOT_FOUND);
512 assert_eq!(body_of(r), r#"{"code":"JC0404","message":"not found"}"#);
513 }
514
515 #[test]
516 fn error_details_appear_in_the_body_only_when_present() {
517 let r = Error::not_found().into_response();
518 assert_eq!(body_of(r), r#"{"code":"JC0404","message":"not found"}"#);
519 let r = Error::unprocessable("validation failed")
520 .with_details(serde_json::json!([{ "field": "t" }]))
521 .into_response();
522 assert_eq!(
523 body_of(r),
524 r#"{"code":"JC0422","message":"validation failed","details":[{"field":"t"}]}"#
525 );
526 }
527
528 #[test]
529 fn result_renders_ok_or_err() {
530 let ok: crate::Result<&'static str> = Ok("fine");
531 assert_eq!(ok.into_response().status(), StatusCode::OK);
532 let err: crate::Result<&'static str> = Err(Error::bad_request("x"));
533 assert_eq!(err.into_response().status(), StatusCode::BAD_REQUEST);
534 }
535
536 #[test]
537 fn redirect_to_is_302_with_location_and_empty_body() {
538 let r = Redirect::to("/x").into_response();
539 assert_eq!(r.status(), StatusCode::FOUND);
540 assert_eq!(r.headers()[header::LOCATION], "/x");
541 assert_eq!(body_of(r), "");
542 }
543
544 #[test]
545 fn redirect_constructors_set_their_status_and_location() {
546 for (build, status) in [
549 (Redirect::see_other("/a") as Redirect, StatusCode::SEE_OTHER),
550 (Redirect::temporary("/b"), StatusCode::TEMPORARY_REDIRECT),
551 (Redirect::permanent("/c"), StatusCode::PERMANENT_REDIRECT),
552 ] {
553 let r = build.into_response();
554 assert_eq!(r.status(), status);
555 assert!(r.headers().contains_key(header::LOCATION));
556 }
557 }
558
559 #[test]
560 fn redirect_with_invalid_location_is_a_non_panicking_500() {
561 let r = Redirect::to("/bad\nlocation").into_response();
564 assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR);
565 assert!(r.headers().get(header::LOCATION).is_none());
566 }
567
568 #[test]
569 fn status_tuple_overrides_status_keeping_the_json_body() {
570 #[derive(Serialize)]
573 struct Summary {
574 queued: u32,
575 }
576 let r = (StatusCode::ACCEPTED, Json(Summary { queued: 3 })).into_response();
577 assert_eq!(r.status(), StatusCode::ACCEPTED);
578 assert_eq!(r.headers()[header::CONTENT_TYPE], "application/json");
579 assert_eq!(body_of(r), r#"{"queued":3}"#);
580 }
581
582 #[test]
583 fn status_tuple_overrides_status_keeping_the_text_body() {
584 let r = (StatusCode::ACCEPTED, "queued").into_response();
585 assert_eq!(r.status(), StatusCode::ACCEPTED);
586 assert_eq!(
587 r.headers()[header::CONTENT_TYPE],
588 "text/plain; charset=utf-8"
589 );
590 assert_eq!(body_of(r), "queued");
591 }
592
593 #[tokio::test]
594 async fn boxed_bodies_stream_and_collect() {
595 struct Chunks(std::collections::VecDeque<Bytes>);
597 impl http_body::Body for Chunks {
598 type Data = Bytes;
599 type Error = std::convert::Infallible;
600 fn poll_frame(
601 mut self: std::pin::Pin<&mut Self>,
602 _cx: &mut std::task::Context<'_>,
603 ) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, Self::Error>>> {
604 std::task::Poll::Ready(self.0.pop_front().map(|b| Ok(http_body::Frame::data(b))))
605 }
606 }
607 let body = JcBody::stream(Chunks(
608 [Bytes::from("ab"), Bytes::from("cd")].into_iter().collect(),
609 ));
610 use http_body_util::BodyExt;
611 let collected = body.collect().await.unwrap().to_bytes();
612 assert_eq!(collected, Bytes::from("abcd"));
613 }
614
615 #[tokio::test]
616 async fn stream_body_streams_with_content_type_and_disposition() {
617 let (body, tx) = StreamBody::channel();
618 let send = async move {
619 assert!(tx.send("a,b\n").await);
620 assert!(tx.send("1,2\n").await);
621 };
622 let r = body
623 .content_type("text/csv")
624 .attachment("export.csv")
625 .into_response();
626 assert_eq!(r.status(), StatusCode::OK);
627 assert_eq!(r.headers()[header::CONTENT_TYPE], "text/csv");
628 assert_eq!(
629 r.headers()[header::CONTENT_DISPOSITION],
630 "attachment; filename=\"export.csv\""
631 );
632 let (_, collected) = tokio::join!(send, r.into_body().collect());
633 assert_eq!(collected.unwrap().to_bytes(), Bytes::from("a,b\n1,2\n"));
634 }
635
636 #[tokio::test(start_paused = true)]
637 async fn stream_body_frame_timeout_errors_the_body() {
638 struct Never;
639 impl futures_core::Stream for Never {
640 type Item = Result<Bytes, Error>;
641 fn poll_next(
642 self: std::pin::Pin<&mut Self>,
643 _cx: &mut std::task::Context<'_>,
644 ) -> std::task::Poll<Option<Self::Item>> {
645 std::task::Poll::Pending
646 }
647 }
648 let body = StreamBody::new(Never)
649 .frame_timeout(std::time::Duration::from_millis(100))
650 .into_response()
651 .into_body();
652 use http_body_util::BodyExt;
653 let err = body
654 .collect()
655 .await
656 .expect_err("stall must error, not end cleanly");
657 assert!(err.to_string().contains("timed out"), "{err}");
658 }
659
660 #[tokio::test]
661 async fn channel_fail_surfaces_as_a_body_error_carrying_the_message() {
662 let (body, tx) = StreamBody::channel();
667 let produce = async move {
668 assert!(tx.send("first chunk").await, "client present");
669 assert!(tx.fail(Error::internal("boom")).await, "fail delivered");
670 };
671 let response = body.into_response();
672 use http_body_util::BodyExt;
673 let (_, collected) = tokio::join!(produce, response.into_body().collect());
674 let err = collected.expect_err("a failed producer must error the body, not end cleanly");
675 assert!(
676 err.to_string().contains("boom"),
677 "the propagated message must survive to the body error: {err}"
678 );
679 }
680
681 #[tokio::test]
682 async fn stream_body_composes_through_a_real_handler_dispatch() {
683 use crate::prelude::*;
684 async fn export() -> Result<StreamBody> {
685 let (body, tx) = StreamBody::channel();
686 tokio::spawn(async move {
687 tx.send("id,name\n").await;
688 tx.send("1,ada\n").await;
689 });
690 Ok(body.content_type("text/csv"))
691 }
692 let t = App::new().route("/export", get(export)).into_test();
693 let r = t.get("/export").await;
694 assert_eq!(r.status(), StatusCode::OK);
695 assert_eq!(r.headers()[header::CONTENT_TYPE], "text/csv");
696 assert_eq!(r.text(), "id,name\n1,ada\n");
697 }
698}