rama_http/service/web/endpoint/extract/body/
form.rs

1use bytes::Bytes;
2
3use super::BytesRejection;
4use crate::dep::http_body_util::BodyExt;
5use crate::service::web::extract::FromRequest;
6use crate::utils::macros::{composite_http_rejection, define_http_rejection};
7use crate::{Method, Request};
8
9pub use crate::service::web::endpoint::response::Form;
10
11define_http_rejection! {
12    #[status = UNSUPPORTED_MEDIA_TYPE]
13    #[body = "Form requests must have `Content-Type: application/x-www-form-urlencoded`"]
14    /// Rejection type for [`Form`]
15    /// used if the `Content-Type` header is missing
16    /// or its value is not `application/x-www-form-urlencoded`.
17    pub struct InvalidFormContentType;
18}
19
20define_http_rejection! {
21    #[status = BAD_REQUEST]
22    #[body = "Failed to deserialize form"]
23    /// Rejection type used if the [`Form`]
24    /// deserialize the form into the target type.
25    pub struct FailedToDeserializeForm(Error);
26}
27
28composite_http_rejection! {
29    /// Rejection used for [`Form`]
30    ///
31    /// Contains one variant for each way the [`Form`] extractor
32    /// can fail.
33    pub enum FormRejection {
34        InvalidFormContentType,
35        FailedToDeserializeForm,
36        BytesRejection,
37    }
38}
39
40impl<T> FromRequest for Form<T>
41where
42    T: serde::de::DeserializeOwned + Send + Sync + 'static,
43{
44    type Rejection = FormRejection;
45
46    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
47        // Extracted into separate fn so it's only compiled once for all T.
48        async fn extract_form_body_bytes(req: Request) -> Result<Bytes, FormRejection> {
49            if !crate::service::web::extract::has_any_content_type(
50                req.headers(),
51                &[&mime::APPLICATION_WWW_FORM_URLENCODED],
52            ) {
53                return Err(InvalidFormContentType.into());
54            }
55
56            let body = req.into_body();
57            let bytes = body.collect().await.map_err(BytesRejection::from_err)?;
58
59            Ok(bytes.to_bytes())
60        }
61
62        if req.method() == Method::GET {
63            let query = req.uri().query().unwrap_or_default();
64            let value = match serde_html_form::from_bytes(query.as_bytes()) {
65                Ok(value) => value,
66                Err(err) => return Err(FailedToDeserializeForm::from_err(err).into()),
67            };
68            Ok(Form(value))
69        } else {
70            let b = extract_form_body_bytes(req).await?;
71            Ok(Form(match serde_html_form::from_bytes(&b) {
72                Ok(value) => value,
73                Err(err) => return Err(FailedToDeserializeForm::from_err(err).into()),
74            }))
75        }
76    }
77}
78
79#[cfg(test)]
80mod test {
81    use super::*;
82    use crate::service::web::WebService;
83    use crate::{Body, Method, Request, StatusCode};
84    use rama_core::{Context, Service};
85
86    #[tokio::test]
87    async fn test_form_post_form_urlencoded() {
88        #[derive(Debug, serde::Deserialize)]
89        struct Input {
90            name: String,
91            age: u8,
92        }
93
94        let service = WebService::default().post("/", async |Form(body): Form<Input>| {
95            assert_eq!(body.name, "Devan");
96            assert_eq!(body.age, 29);
97        });
98
99        let req = Request::builder()
100            .uri("/")
101            .method(Method::POST)
102            .header("content-type", "application/x-www-form-urlencoded")
103            .body(r#"name=Devan&age=29"#.into())
104            .unwrap();
105        let resp = service.serve(Context::default(), req).await.unwrap();
106        assert_eq!(resp.status(), StatusCode::OK);
107    }
108
109    #[tokio::test]
110    async fn test_form_post_form_urlencoded_missing_data_fail() {
111        #[derive(Debug, serde::Deserialize)]
112        #[allow(dead_code)]
113        struct Input {
114            name: String,
115            age: u8,
116        }
117
118        let service = WebService::default().post("/", async |Form(_): Form<Input>| StatusCode::OK);
119
120        let req = Request::builder()
121            .uri("/")
122            .method(Method::POST)
123            .header("content-type", "application/x-www-form-urlencoded")
124            .body(r#"age=29"#.into())
125            .unwrap();
126        let resp = service.serve(Context::default(), req).await.unwrap();
127        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
128    }
129
130    #[tokio::test]
131    async fn test_form_get_form_urlencoded_fail() {
132        #[derive(Debug, serde::Deserialize)]
133        #[allow(dead_code)]
134        struct Input {
135            name: String,
136            age: u8,
137        }
138
139        let service = WebService::default().get("/", async |Form(_): Form<Input>| StatusCode::OK);
140
141        let req = Request::builder()
142            .uri("/")
143            .method(Method::GET)
144            .header("content-type", "application/x-www-form-urlencoded")
145            .body(r#"name=Devan&age=29"#.into())
146            .unwrap();
147        let resp = service.serve(Context::default(), req).await.unwrap();
148        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
149    }
150
151    #[tokio::test]
152    async fn test_form_get() {
153        #[derive(Debug, serde::Deserialize)]
154        struct Input {
155            name: String,
156            age: u8,
157        }
158
159        let service = WebService::default().get("/", async |Form(body): Form<Input>| {
160            assert_eq!(body.name, "Devan");
161            assert_eq!(body.age, 29);
162        });
163
164        let req = Request::builder()
165            .uri("/?name=Devan&age=29")
166            .method(Method::GET)
167            .body(Body::empty())
168            .unwrap();
169        let resp = service.serve(Context::default(), req).await.unwrap();
170        assert_eq!(resp.status(), StatusCode::OK);
171    }
172
173    #[tokio::test]
174    async fn test_form_get_fail_missing_data() {
175        #[derive(Debug, serde::Deserialize)]
176        #[allow(dead_code)]
177        struct Input {
178            name: String,
179            age: u8,
180        }
181
182        let service = WebService::default().get("/", async |Form(_): Form<Input>| StatusCode::OK);
183
184        let req = Request::builder()
185            .uri("/?name=Devan")
186            .method(Method::GET)
187            .body(Body::empty())
188            .unwrap();
189        let resp = service.serve(Context::default(), req).await.unwrap();
190        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
191    }
192}