rama_http/service/web/endpoint/extract/body/
text.rs

1use super::BytesRejection;
2use crate::dep::http_body_util::BodyExt;
3use crate::service::web::extract::FromRequest;
4use crate::utils::macros::{composite_http_rejection, define_http_rejection};
5use crate::Request;
6use rama_utils::macros::impl_deref;
7
8/// Extractor to get the response body, collected as [`String`].
9#[derive(Debug, Clone)]
10pub struct Text(pub String);
11
12impl_deref!(Text: String);
13
14define_http_rejection! {
15    #[status = UNSUPPORTED_MEDIA_TYPE]
16    #[body = "Text requests must have `Content-Type: text/plain`"]
17    /// Rejection type for [`Text`]
18    /// used if the `Content-Type` header is missing
19    /// or its value is not `text/plain`.
20    pub struct InvalidTextContentType;
21}
22
23define_http_rejection! {
24    #[status = BAD_REQUEST]
25    #[body = "Failed to decode text payload"]
26    /// Rejection type used if the [`Text`]
27    /// was not valid UTF-8.
28    pub struct InvalidUtf8Text(Error);
29}
30
31composite_http_rejection! {
32    /// Rejection used for [`Text`]
33    ///
34    /// Contains one variant for each way the [`Text`] extractor
35    /// can fail.
36    pub enum TextRejection {
37        InvalidTextContentType,
38        InvalidUtf8Text,
39        BytesRejection,
40    }
41}
42
43impl FromRequest for Text {
44    type Rejection = TextRejection;
45
46    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
47        if !crate::service::web::extract::has_any_content_type(req.headers(), &[&mime::TEXT_PLAIN])
48        {
49            return Err(InvalidTextContentType.into());
50        }
51
52        let body = req.into_body();
53        match body.collect().await {
54            Ok(c) => match String::from_utf8(c.to_bytes().to_vec()) {
55                Ok(s) => Ok(Self(s)),
56                Err(err) => Err(InvalidUtf8Text::from_err(err).into()),
57            },
58            Err(err) => Err(BytesRejection::from_err(err).into()),
59        }
60    }
61}
62
63#[cfg(test)]
64mod test {
65    use super::*;
66    use crate::service::web::WebService;
67    use crate::{header, Method, Request, StatusCode};
68    use rama_core::{Context, Service};
69
70    #[tokio::test]
71    async fn test_text() {
72        let service = WebService::default().post("/", |Text(body): Text| async move {
73            assert_eq!(body, "test");
74        });
75
76        let req = Request::builder()
77            .method(Method::POST)
78            .header(header::CONTENT_TYPE, "text/plain")
79            .body("test".into())
80            .unwrap();
81        let resp = service.serve(Context::default(), req).await.unwrap();
82        assert_eq!(resp.status(), StatusCode::OK);
83    }
84
85    #[tokio::test]
86    async fn test_text_missing_content_type() {
87        let service =
88            WebService::default().post("/", |Text(_): Text| async move { StatusCode::OK });
89
90        let req = Request::builder()
91            .method(Method::POST)
92            .body("test".into())
93            .unwrap();
94        let resp = service.serve(Context::default(), req).await.unwrap();
95        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
96    }
97
98    #[tokio::test]
99    async fn test_text_incorrect_content_type() {
100        let service =
101            WebService::default().post("/", |Text(_): Text| async move { StatusCode::OK });
102
103        let req = Request::builder()
104            .method(Method::POST)
105            .header(header::CONTENT_TYPE, "application/json")
106            .body("test".into())
107            .unwrap();
108        let resp = service.serve(Context::default(), req).await.unwrap();
109        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
110    }
111
112    #[tokio::test]
113    async fn test_text_invalid_utf8() {
114        let service =
115            WebService::default().post("/", |Text(_): Text| async move { StatusCode::OK });
116
117        let req = Request::builder()
118            .method(Method::POST)
119            .header(header::CONTENT_TYPE, "text/plain")
120            .body(vec![0, 159, 146, 150].into())
121            .unwrap();
122        let resp = service.serve(Context::default(), req).await.unwrap();
123        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
124    }
125}