rama_http/service/web/endpoint/extract/body/
text.rs

1use super::BytesRejection;
2use crate::Request;
3use crate::dep::http_body_util::BodyExt;
4use crate::service::web::extract::FromRequest;
5use crate::utils::macros::{composite_http_rejection, define_http_rejection};
6use rama_utils::macros::impl_deref;
7
8/// Extractor to get the response body, collected as [`String`].
9#[derive(Debug, Clone)]
10pub struct Text(pub String);
11
12impl_deref!(Text: String);
13
14define_http_rejection! {
15    #[status = UNSUPPORTED_MEDIA_TYPE]
16    #[body = "Text requests must have `Content-Type: text/plain`"]
17    /// Rejection type for [`Text`]
18    /// used if the `Content-Type` header is missing
19    /// or its value is not `text/plain`.
20    pub struct InvalidTextContentType;
21}
22
23define_http_rejection! {
24    #[status = BAD_REQUEST]
25    #[body = "Failed to decode text payload"]
26    /// Rejection type used if the [`Text`]
27    /// was not valid UTF-8.
28    pub struct InvalidUtf8Text(Error);
29}
30
31composite_http_rejection! {
32    /// Rejection used for [`Text`]
33    ///
34    /// Contains one variant for each way the [`Text`] extractor
35    /// can fail.
36    pub enum TextRejection {
37        InvalidTextContentType,
38        InvalidUtf8Text,
39        BytesRejection,
40    }
41}
42
43impl FromRequest for Text {
44    type Rejection = TextRejection;
45
46    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
47        if !crate::service::web::extract::has_any_content_type(req.headers(), &[&mime::TEXT_PLAIN])
48        {
49            return Err(InvalidTextContentType.into());
50        }
51
52        let body = req.into_body();
53        match body.collect().await {
54            Ok(c) => match String::from_utf8(c.to_bytes().to_vec()) {
55                Ok(s) => Ok(Self(s)),
56                Err(err) => Err(InvalidUtf8Text::from_err(err).into()),
57            },
58            Err(err) => Err(BytesRejection::from_err(err).into()),
59        }
60    }
61}
62
63#[cfg(test)]
64mod test {
65    use super::*;
66    use crate::service::web::WebService;
67    use crate::{Method, Request, StatusCode, header};
68    use rama_core::{Context, Service};
69
70    #[tokio::test]
71    async fn test_text() {
72        let service = WebService::default().post("/", async |Text(body): Text| {
73            assert_eq!(body, "test");
74        });
75
76        let req = Request::builder()
77            .method(Method::POST)
78            .header(header::CONTENT_TYPE, "text/plain")
79            .body("test".into())
80            .unwrap();
81        let resp = service.serve(Context::default(), req).await.unwrap();
82        assert_eq!(resp.status(), StatusCode::OK);
83    }
84
85    #[tokio::test]
86    async fn test_text_missing_content_type() {
87        let service = WebService::default().post("/", async |Text(_): Text| StatusCode::OK);
88
89        let req = Request::builder()
90            .method(Method::POST)
91            .body("test".into())
92            .unwrap();
93        let resp = service.serve(Context::default(), req).await.unwrap();
94        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
95    }
96
97    #[tokio::test]
98    async fn test_text_incorrect_content_type() {
99        let service = WebService::default().post("/", async |Text(_): Text| StatusCode::OK);
100
101        let req = Request::builder()
102            .method(Method::POST)
103            .header(header::CONTENT_TYPE, "application/json")
104            .body("test".into())
105            .unwrap();
106        let resp = service.serve(Context::default(), req).await.unwrap();
107        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
108    }
109
110    #[tokio::test]
111    async fn test_text_invalid_utf8() {
112        let service = WebService::default().post("/", async |Text(_): Text| StatusCode::OK);
113
114        let req = Request::builder()
115            .method(Method::POST)
116            .header(header::CONTENT_TYPE, "text/plain")
117            .body(vec![0, 159, 146, 150].into())
118            .unwrap();
119        let resp = service.serve(Context::default(), req).await.unwrap();
120        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
121    }
122}