rama_http/service/web/endpoint/extract/body/
form.rs

1use super::BytesRejection;
2use crate::dep::http_body_util::BodyExt;
3use crate::service::web::extract::FromRequest;
4use crate::utils::macros::{composite_http_rejection, define_http_rejection};
5use crate::{Method, Request};
6
7pub use crate::response::Form;
8
9define_http_rejection! {
10    #[status = UNSUPPORTED_MEDIA_TYPE]
11    #[body = "Form requests must have `Content-Type: application/x-www-form-urlencoded`"]
12    /// Rejection type for [`Form`]
13    /// used if the `Content-Type` header is missing
14    /// or its value is not `application/x-www-form-urlencoded`.
15    pub struct InvalidFormContentType;
16}
17
18define_http_rejection! {
19    #[status = BAD_REQUEST]
20    #[body = "Failed to deserialize form"]
21    /// Rejection type used if the [`Form`]
22    /// deserialize the form into the target type.
23    pub struct FailedToDeserializeForm(Error);
24}
25
26composite_http_rejection! {
27    /// Rejection used for [`Form`]
28    ///
29    /// Contains one variant for each way the [`Form`] extractor
30    /// can fail.
31    pub enum FormRejection {
32        InvalidFormContentType,
33        FailedToDeserializeForm,
34        BytesRejection,
35    }
36}
37
38impl<T> FromRequest for Form<T>
39where
40    T: serde::de::DeserializeOwned + Send + Sync + 'static,
41{
42    type Rejection = FormRejection;
43
44    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
45        if req.method() == Method::GET {
46            let query = req.uri().query().unwrap_or_default();
47            let value = match serde_html_form::from_bytes(query.as_bytes()) {
48                Ok(value) => value,
49                Err(err) => return Err(FailedToDeserializeForm::from_err(err).into()),
50            };
51            Ok(Form(value))
52        } else {
53            if !crate::service::web::extract::has_any_content_type(
54                req.headers(),
55                &[&mime::APPLICATION_WWW_FORM_URLENCODED],
56            ) {
57                return Err(InvalidFormContentType.into());
58            }
59
60            let body = req.into_body();
61            match body.collect().await {
62                Ok(c) => {
63                    let value = match serde_html_form::from_bytes(&c.to_bytes()) {
64                        Ok(value) => value,
65                        Err(err) => return Err(FailedToDeserializeForm::from_err(err).into()),
66                    };
67                    Ok(Form(value))
68                }
69                Err(err) => Err(BytesRejection::from_err(err).into()),
70            }
71        }
72    }
73}
74
75#[cfg(test)]
76mod test {
77    use super::*;
78    use crate::service::web::WebService;
79    use crate::{Body, Method, Request, StatusCode};
80    use rama_core::{Context, Service};
81
82    #[tokio::test]
83    async fn test_form_post_form_urlencoded() {
84        #[derive(Debug, serde::Deserialize)]
85        struct Input {
86            name: String,
87            age: u8,
88        }
89
90        let service = WebService::default().post("/", |Form(body): Form<Input>| async move {
91            assert_eq!(body.name, "Devan");
92            assert_eq!(body.age, 29);
93        });
94
95        let req = Request::builder()
96            .uri("/")
97            .method(Method::POST)
98            .header("content-type", "application/x-www-form-urlencoded")
99            .body(r#"name=Devan&age=29"#.into())
100            .unwrap();
101        let resp = service.serve(Context::default(), req).await.unwrap();
102        assert_eq!(resp.status(), StatusCode::OK);
103    }
104
105    #[tokio::test]
106    async fn test_form_post_form_urlencoded_missing_data_fail() {
107        #[derive(Debug, serde::Deserialize)]
108        #[allow(dead_code)]
109        struct Input {
110            name: String,
111            age: u8,
112        }
113
114        let service =
115            WebService::default().post("/", |Form(_): Form<Input>| async move { StatusCode::OK });
116
117        let req = Request::builder()
118            .uri("/")
119            .method(Method::POST)
120            .header("content-type", "application/x-www-form-urlencoded")
121            .body(r#"age=29"#.into())
122            .unwrap();
123        let resp = service.serve(Context::default(), req).await.unwrap();
124        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
125    }
126
127    #[tokio::test]
128    async fn test_form_get_form_urlencoded_fail() {
129        #[derive(Debug, serde::Deserialize)]
130        #[allow(dead_code)]
131        struct Input {
132            name: String,
133            age: u8,
134        }
135
136        let service =
137            WebService::default().get("/", |Form(_): Form<Input>| async move { StatusCode::OK });
138
139        let req = Request::builder()
140            .uri("/")
141            .method(Method::GET)
142            .header("content-type", "application/x-www-form-urlencoded")
143            .body(r#"name=Devan&age=29"#.into())
144            .unwrap();
145        let resp = service.serve(Context::default(), req).await.unwrap();
146        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
147    }
148
149    #[tokio::test]
150    async fn test_form_get() {
151        #[derive(Debug, serde::Deserialize)]
152        struct Input {
153            name: String,
154            age: u8,
155        }
156
157        let service = WebService::default().get("/", |Form(body): Form<Input>| async move {
158            assert_eq!(body.name, "Devan");
159            assert_eq!(body.age, 29);
160        });
161
162        let req = Request::builder()
163            .uri("/?name=Devan&age=29")
164            .method(Method::GET)
165            .body(Body::empty())
166            .unwrap();
167        let resp = service.serve(Context::default(), req).await.unwrap();
168        assert_eq!(resp.status(), StatusCode::OK);
169    }
170
171    #[tokio::test]
172    async fn test_form_get_fail_missing_data() {
173        #[derive(Debug, serde::Deserialize)]
174        #[allow(dead_code)]
175        struct Input {
176            name: String,
177            age: u8,
178        }
179
180        let service =
181            WebService::default().get("/", |Form(_): Form<Input>| async move { StatusCode::OK });
182
183        let req = Request::builder()
184            .uri("/?name=Devan")
185            .method(Method::GET)
186            .body(Body::empty())
187            .unwrap();
188        let resp = service.serve(Context::default(), req).await.unwrap();
189        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
190    }
191}