rama_http/service/web/endpoint/extract/body/
json.rs1use bytes::Bytes;
2use rama_http_types::{HeaderMap, header};
3
4use super::BytesRejection;
5use crate::Request;
6use crate::dep::http_body_util::BodyExt;
7use crate::service::web::extract::{FromRequest, OptionalFromRequest};
8use crate::utils::macros::{composite_http_rejection, define_http_rejection};
9
10pub use crate::service::web::endpoint::response::Json;
11
12define_http_rejection! {
13 #[status = UNSUPPORTED_MEDIA_TYPE]
14 #[body = "Json requests must have `Content-Type: application/json`"]
15 pub struct InvalidJsonContentType;
19}
20
21define_http_rejection! {
22 #[status = BAD_REQUEST]
23 #[body = "Failed to deserialize json payload"]
24 pub struct FailedToDeserializeJson(Error);
27}
28
29composite_http_rejection! {
30 pub enum JsonRejection {
35 InvalidJsonContentType,
36 FailedToDeserializeJson,
37 BytesRejection,
38 }
39}
40
41impl<T> FromRequest for Json<T>
42where
43 T: serde::de::DeserializeOwned + Send + Sync + 'static,
44{
45 type Rejection = JsonRejection;
46
47 async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
48 async fn extract_json_bytes(req: Request) -> Result<Bytes, JsonRejection> {
50 if !json_content_type(req.headers()) {
51 return Err(InvalidJsonContentType.into());
52 }
53
54 let body = req.into_body();
55
56 match body.collect().await {
57 Ok(c) => Ok(c.to_bytes()),
58 Err(err) => Err(BytesRejection::from_err(err).into()),
59 }
60 }
61
62 let b = extract_json_bytes(req).await?;
63 match serde_json::from_slice(&b) {
64 Ok(s) => Ok(Self(s)),
65 Err(err) => Err(FailedToDeserializeJson::from_err(err).into()),
66 }
67 }
68}
69
70impl<T> OptionalFromRequest for Json<T>
71where
72 T: serde::de::DeserializeOwned + Send + Sync + 'static,
73{
74 type Rejection = JsonRejection;
75
76 async fn from_request(req: Request) -> Result<Option<Self>, Self::Rejection> {
77 if req.headers().get(header::CONTENT_TYPE).is_some() {
78 let v = <Self as FromRequest>::from_request(req).await?;
79 Ok(Some(v))
80 } else {
81 Ok(None)
82 }
83 }
84}
85
86fn json_content_type(headers: &HeaderMap) -> bool {
87 headers
88 .get(header::CONTENT_TYPE)
89 .and_then(|content_type| content_type.to_str().ok())
90 .and_then(|content_type| content_type.parse::<mime::Mime>().ok())
91 .is_some_and(|mime| {
92 mime.type_() == "application"
93 && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"))
94 })
95}
96
97#[cfg(test)]
98mod test {
99 use super::*;
100 use crate::StatusCode;
101 use crate::service::web::WebService;
102 use rama_core::{Context, Service};
103
104 #[tokio::test]
105 async fn test_json() {
106 #[derive(Debug, serde::Deserialize)]
107 struct Input {
108 name: String,
109 age: u8,
110 alive: Option<bool>,
111 }
112
113 let service = WebService::default().post("/", async |Json(body): Json<Input>| {
114 assert_eq!(body.name, "glen");
115 assert_eq!(body.age, 42);
116 assert_eq!(body.alive, None);
117 });
118
119 let req = rama_http_types::Request::builder()
120 .method(rama_http_types::Method::POST)
121 .header(
122 rama_http_types::header::CONTENT_TYPE,
123 "application/json; charset=utf-8",
124 )
125 .body(r#"{"name": "glen", "age": 42}"#.into())
126 .unwrap();
127 let resp = service.serve(Context::default(), req).await.unwrap();
128 assert_eq!(resp.status(), StatusCode::OK);
129 }
130
131 #[tokio::test]
132 async fn test_json_missing_content_type() {
133 #[derive(Debug, serde::Deserialize)]
134 struct Input {
135 _name: String,
136 _age: u8,
137 _alive: Option<bool>,
138 }
139
140 let service = WebService::default().post("/", async |Json(_): Json<Input>| StatusCode::OK);
141
142 let req = rama_http_types::Request::builder()
143 .method(rama_http_types::Method::POST)
144 .header(rama_http_types::header::CONTENT_TYPE, "text/plain")
145 .body(r#"{"name": "glen", "age": 42}"#.into())
146 .unwrap();
147 let resp = service.serve(Context::default(), req).await.unwrap();
148 assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
149 }
150
151 #[tokio::test]
152 async fn test_json_invalid_body_encoding() {
153 #[derive(Debug, serde::Deserialize)]
154 struct Input {
155 _name: String,
156 _age: u8,
157 _alive: Option<bool>,
158 }
159
160 let service = WebService::default().post("/", async |Json(_): Json<Input>| StatusCode::OK);
161
162 let req = rama_http_types::Request::builder()
163 .method(rama_http_types::Method::POST)
164 .header(
165 rama_http_types::header::CONTENT_TYPE,
166 "application/json; charset=utf-8",
167 )
168 .body(r#"deal with it, or not?!"#.into())
169 .unwrap();
170 let resp = service.serve(Context::default(), req).await.unwrap();
171 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
172 }
173}