rama_http/service/web/endpoint/extract/body/
csv.rs1use super::BytesRejection;
2use crate::Request;
3use crate::dep::http_body_util::BodyExt;
4use crate::service::web::extract::FromRequest;
5use crate::utils::macros::{composite_http_rejection, define_http_rejection};
6use bytes::{Buf, Bytes};
7
8pub use crate::service::web::endpoint::response::Csv;
9
10define_http_rejection! {
11 #[status = UNSUPPORTED_MEDIA_TYPE]
12 #[body = "Csv requests must have `Content-Type: text/csv`"]
13 pub struct InvalidCsvContentType;
17}
18
19define_http_rejection! {
20 #[status = BAD_REQUEST]
21 #[body = "Failed to deserialize csv payload"]
22 pub struct FailedToDeserializeCsv(Error);
25}
26
27composite_http_rejection! {
28 pub enum CsvRejection {
33 InvalidCsvContentType,
34 FailedToDeserializeCsv,
35 BytesRejection,
36 }
37}
38
39impl<T> FromRequest for Csv<Vec<T>>
40where
41 T: serde::de::DeserializeOwned + Send + Sync + 'static,
42{
43 type Rejection = CsvRejection;
44
45 async fn from_request(req: Request) -> Result<Self, Self::Rejection> {
46 async fn req_to_csv_bytes(req: Request) -> Result<Bytes, CsvRejection> {
48 if !crate::service::web::extract::has_any_content_type(
49 req.headers(),
50 &[&mime::TEXT_CSV],
51 ) {
52 return Err(InvalidCsvContentType.into());
53 }
54
55 let body = req.into_body();
56 let bytes = body.collect().await.map_err(BytesRejection::from_err)?;
57
58 Ok(bytes.to_bytes())
59 }
60
61 let b = req_to_csv_bytes(req).await?;
62 let mut rdr = csv::Reader::from_reader(b.reader());
63
64 let out: Result<Vec<T>, _> = rdr
65 .deserialize()
66 .map(|rec| {
67 let record: Result<T, _> = rec;
68 record
69 })
70 .collect();
71
72 match out {
73 Ok(s) => Ok(Self(s)),
74 Err(err) => Err(FailedToDeserializeCsv::from_err(err).into()),
75 }
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use super::*;
82 use crate::StatusCode;
83 use crate::service::web::WebService;
84 use rama_core::{Context, Service};
85
86 #[tokio::test]
87 async fn test_csv() {
88 #[derive(serde::Deserialize)]
89 struct Input {
90 name: String,
91 age: u8,
92 alive: Option<bool>,
93 }
94
95 let service = WebService::default().post("/", async |Csv(body): Csv<Vec<Input>>| {
96 assert_eq!(body.len(), 2);
97
98 assert_eq!(body[0].name, "glen");
99 assert_eq!(body[0].age, 42);
100 assert_eq!(body[0].alive, None);
101
102 assert_eq!(body[1].name, "adr");
103 assert_eq!(body[1].age, 40);
104 assert_eq!(body[1].alive, Some(true));
105 StatusCode::OK
106 });
107
108 let req = rama_http_types::Request::builder()
109 .method(rama_http_types::Method::POST)
110 .header(
111 rama_http_types::header::CONTENT_TYPE,
112 "text/csv; charset=utf-8",
113 )
114 .body("name,age,alive\nglen,42,\nadr,40,true\n".into())
115 .unwrap();
116 let resp = service.serve(Context::default(), req).await.unwrap();
117 println!("debug {:?}", resp);
118 assert_eq!(resp.status(), StatusCode::OK);
119 }
120
121 #[tokio::test]
122 async fn test_csv_missing_content_type() {
123 #[derive(Debug, serde::Deserialize)]
124 struct Input {
125 _name: String,
126 _age: u8,
127 _alive: Option<bool>,
128 }
129
130 let service =
131 WebService::default().post("/", async |Csv(_): Csv<Vec<Input>>| StatusCode::OK);
132
133 let req = rama_http_types::Request::builder()
134 .method(rama_http_types::Method::POST)
135 .header(rama_http_types::header::CONTENT_TYPE, "text/plain")
136 .body(r#"{"name": "glen", "age": 42}"#.into())
137 .unwrap();
138 let resp = service.serve(Context::default(), req).await.unwrap();
139 assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
140 }
141
142 #[tokio::test]
143 async fn test_csv_invalid_body() {
144 #[derive(Debug, serde::Deserialize)]
145 struct Input {
146 _name: String,
147 _age: u8,
148 _alive: Option<bool>,
149 }
150
151 let service =
152 WebService::default().post("/", async |Csv(_): Csv<Vec<Input>>| StatusCode::OK);
153
154 let req = rama_http_types::Request::builder()
155 .method(rama_http_types::Method::POST)
156 .header(
157 rama_http_types::header::CONTENT_TYPE,
158 "text/csv; charset=utf-8",
159 )
160 .body("name,age,alive\nglen,42,\nadr,40\n".into())
162 .unwrap();
163 let resp = service.serve(Context::default(), req).await.unwrap();
164 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
165 }
166}