rama_http/service/web/endpoint/extract/body/
json.rs

1use bytes::Bytes;
2use rama_http_types::{HeaderMap, header};
3
4use super::BytesRejection;
5use crate::Request;
6use crate::dep::http_body_util::BodyExt;
7use crate::service::web::extract::{FromRequest, OptionalFromRequest};
8use crate::utils::macros::{composite_http_rejection, define_http_rejection};
9
10pub use crate::service::web::endpoint::response::Json;
11
12define_http_rejection! {
13    #[status = UNSUPPORTED_MEDIA_TYPE]
14    #[body = "Json requests must have `Content-Type: application/json`"]
15    /// Rejection type for [`Json`]
16    /// used if the `Content-Type` header is missing
17    /// or its value is not `application/json`.
18    pub struct InvalidJsonContentType;
19}
20
21define_http_rejection! {
22    #[status = BAD_REQUEST]
23    #[body = "Failed to deserialize json payload"]
24    /// Rejection type used if the [`Json`]
25    /// deserialize the payload into the target type.
26    pub struct FailedToDeserializeJson(Error);
27}
28
29composite_http_rejection! {
30    /// Rejection used for [`Json`]
31    ///
32    /// Contains one variant for each way the [`Json`] extractor
33    /// can fail.
34    pub enum JsonRejection {
35        InvalidJsonContentType,
36        FailedToDeserializeJson,
37        BytesRejection,
38    }
39}
40
41impl<T> FromRequest for Json<T>
42where
43    T: serde::de::DeserializeOwned + Send + Sync + 'static,
44{
45    type Rejection = JsonRejection;
46
47    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
48        // Extracted into separate fn so it's only compiled once for all T.
49        async fn extract_json_bytes(req: Request) -> Result<Bytes, JsonRejection> {
50            if !json_content_type(req.headers()) {
51                return Err(InvalidJsonContentType.into());
52            }
53
54            let body = req.into_body();
55
56            match body.collect().await {
57                Ok(c) => Ok(c.to_bytes()),
58                Err(err) => Err(BytesRejection::from_err(err).into()),
59            }
60        }
61
62        let b = extract_json_bytes(req).await?;
63        match serde_json::from_slice(&b) {
64            Ok(s) => Ok(Self(s)),
65            Err(err) => Err(FailedToDeserializeJson::from_err(err).into()),
66        }
67    }
68}
69
70impl<T> OptionalFromRequest for Json<T>
71where
72    T: serde::de::DeserializeOwned + Send + Sync + 'static,
73{
74    type Rejection = JsonRejection;
75
76    async fn from_request(req: Request) -> Result<Option<Self>, Self::Rejection> {
77        if req.headers().get(header::CONTENT_TYPE).is_some() {
78            let v = <Self as FromRequest>::from_request(req).await?;
79            Ok(Some(v))
80        } else {
81            Ok(None)
82        }
83    }
84}
85
86fn json_content_type(headers: &HeaderMap) -> bool {
87    headers
88        .get(header::CONTENT_TYPE)
89        .and_then(|content_type| content_type.to_str().ok())
90        .and_then(|content_type| content_type.parse::<mime::Mime>().ok())
91        .is_some_and(|mime| {
92            mime.type_() == "application"
93                && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"))
94        })
95}
96
97#[cfg(test)]
98mod test {
99    use super::*;
100    use crate::StatusCode;
101    use crate::service::web::WebService;
102    use rama_core::{Context, Service};
103
104    #[tokio::test]
105    async fn test_json() {
106        #[derive(Debug, serde::Deserialize)]
107        struct Input {
108            name: String,
109            age: u8,
110            alive: Option<bool>,
111        }
112
113        let service = WebService::default().post("/", async |Json(body): Json<Input>| {
114            assert_eq!(body.name, "glen");
115            assert_eq!(body.age, 42);
116            assert_eq!(body.alive, None);
117        });
118
119        let req = rama_http_types::Request::builder()
120            .method(rama_http_types::Method::POST)
121            .header(
122                rama_http_types::header::CONTENT_TYPE,
123                "application/json; charset=utf-8",
124            )
125            .body(r#"{"name": "glen", "age": 42}"#.into())
126            .unwrap();
127        let resp = service.serve(Context::default(), req).await.unwrap();
128        assert_eq!(resp.status(), StatusCode::OK);
129    }
130
131    #[tokio::test]
132    async fn test_json_missing_content_type() {
133        #[derive(Debug, serde::Deserialize)]
134        struct Input {
135            _name: String,
136            _age: u8,
137            _alive: Option<bool>,
138        }
139
140        let service = WebService::default().post("/", async |Json(_): Json<Input>| StatusCode::OK);
141
142        let req = rama_http_types::Request::builder()
143            .method(rama_http_types::Method::POST)
144            .header(rama_http_types::header::CONTENT_TYPE, "text/plain")
145            .body(r#"{"name": "glen", "age": 42}"#.into())
146            .unwrap();
147        let resp = service.serve(Context::default(), req).await.unwrap();
148        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
149    }
150
151    #[tokio::test]
152    async fn test_json_invalid_body_encoding() {
153        #[derive(Debug, serde::Deserialize)]
154        struct Input {
155            _name: String,
156            _age: u8,
157            _alive: Option<bool>,
158        }
159
160        let service = WebService::default().post("/", async |Json(_): Json<Input>| StatusCode::OK);
161
162        let req = rama_http_types::Request::builder()
163            .method(rama_http_types::Method::POST)
164            .header(
165                rama_http_types::header::CONTENT_TYPE,
166                "application/json; charset=utf-8",
167            )
168            .body(r#"deal with it, or not?!"#.into())
169            .unwrap();
170        let resp = service.serve(Context::default(), req).await.unwrap();
171        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
172    }
173}