rama_http/service/web/endpoint/extract/body/
csv.rs

1use super::BytesRejection;
2use crate::Request;
3use crate::dep::http_body_util::BodyExt;
4use crate::service::web::extract::FromRequest;
5use crate::utils::macros::{composite_http_rejection, define_http_rejection};
6use bytes::{Buf, Bytes};
7
8pub use crate::service::web::endpoint::response::Csv;
9
10define_http_rejection! {
11    #[status = UNSUPPORTED_MEDIA_TYPE]
12    #[body = "Csv requests must have `Content-Type: text/csv`"]
13    /// Rejection type for [`Csv`]
14    /// used if the `Content-Type` header is missing
15    /// or its value is not `text/csv`.
16    pub struct InvalidCsvContentType;
17}
18
19define_http_rejection! {
20    #[status = BAD_REQUEST]
21    #[body = "Failed to deserialize csv payload"]
22    /// Rejection type used if the [`Csv`]
23    /// deserialize the payload into the target type.
24    pub struct FailedToDeserializeCsv(Error);
25}
26
27composite_http_rejection! {
28    /// Rejection used for [`Csv`]
29    ///
30    /// Contains one variant for each way the [`Csv`] extractor
31    /// can fail.
32    pub enum CsvRejection {
33        InvalidCsvContentType,
34        FailedToDeserializeCsv,
35        BytesRejection,
36    }
37}
38
39impl<T> FromRequest for Csv<Vec<T>>
40where
41    T: serde::de::DeserializeOwned + Send + Sync + 'static,
42{
43    type Rejection = CsvRejection;
44
45    async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
46        // Extracted into separate fn so it's only compiled once for all T.
47        async fn req_to_csv_bytes(req: Request) -> Result<Bytes, CsvRejection> {
48            if !crate::service::web::extract::has_any_content_type(
49                req.headers(),
50                &[&mime::TEXT_CSV],
51            ) {
52                return Err(InvalidCsvContentType.into());
53            }
54
55            let body = req.into_body();
56            let bytes = body.collect().await.map_err(BytesRejection::from_err)?;
57
58            Ok(bytes.to_bytes())
59        }
60
61        let b = req_to_csv_bytes(req).await?;
62        let mut rdr = csv::Reader::from_reader(b.reader());
63
64        let out: Result<Vec<T>, _> = rdr
65            .deserialize()
66            .map(|rec| {
67                let record: Result<T, _> = rec;
68                record
69            })
70            .collect();
71
72        match out {
73            Ok(s) => Ok(Self(s)),
74            Err(err) => Err(FailedToDeserializeCsv::from_err(err).into()),
75        }
76    }
77}
78
79#[cfg(test)]
80mod test {
81    use super::*;
82    use crate::StatusCode;
83    use crate::service::web::WebService;
84    use rama_core::{Context, Service};
85
86    #[tokio::test]
87    async fn test_csv() {
88        #[derive(serde::Deserialize)]
89        struct Input {
90            name: String,
91            age: u8,
92            alive: Option<bool>,
93        }
94
95        let service = WebService::default().post("/", async |Csv(body): Csv<Vec<Input>>| {
96            assert_eq!(body.len(), 2);
97
98            assert_eq!(body[0].name, "glen");
99            assert_eq!(body[0].age, 42);
100            assert_eq!(body[0].alive, None);
101
102            assert_eq!(body[1].name, "adr");
103            assert_eq!(body[1].age, 40);
104            assert_eq!(body[1].alive, Some(true));
105            StatusCode::OK
106        });
107
108        let req = rama_http_types::Request::builder()
109            .method(rama_http_types::Method::POST)
110            .header(
111                rama_http_types::header::CONTENT_TYPE,
112                "text/csv; charset=utf-8",
113            )
114            .body("name,age,alive\nglen,42,\nadr,40,true\n".into())
115            .unwrap();
116        let resp = service.serve(Context::default(), req).await.unwrap();
117        println!("debug {:?}", resp);
118        assert_eq!(resp.status(), StatusCode::OK);
119    }
120
121    #[tokio::test]
122    async fn test_csv_missing_content_type() {
123        #[derive(Debug, serde::Deserialize)]
124        struct Input {
125            _name: String,
126            _age: u8,
127            _alive: Option<bool>,
128        }
129
130        let service =
131            WebService::default().post("/", async |Csv(_): Csv<Vec<Input>>| StatusCode::OK);
132
133        let req = rama_http_types::Request::builder()
134            .method(rama_http_types::Method::POST)
135            .header(rama_http_types::header::CONTENT_TYPE, "text/plain")
136            .body(r#"{"name": "glen", "age": 42}"#.into())
137            .unwrap();
138        let resp = service.serve(Context::default(), req).await.unwrap();
139        assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
140    }
141
142    #[tokio::test]
143    async fn test_csv_invalid_body() {
144        #[derive(Debug, serde::Deserialize)]
145        struct Input {
146            _name: String,
147            _age: u8,
148            _alive: Option<bool>,
149        }
150
151        let service =
152            WebService::default().post("/", async |Csv(_): Csv<Vec<Input>>| StatusCode::OK);
153
154        let req = rama_http_types::Request::builder()
155            .method(rama_http_types::Method::POST)
156            .header(
157                rama_http_types::header::CONTENT_TYPE,
158                "text/csv; charset=utf-8",
159            )
160            // the missing column last line should trigger an error
161            .body("name,age,alive\nglen,42,\nadr,40\n".into())
162            .unwrap();
163        let resp = service.serve(Context::default(), req).await.unwrap();
164        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
165    }
166}