1use anyhow::Context;
2use hyper::{Body, Request};
3use serde::de::DeserializeOwned;
4use serde_json::Value;
5use std::collections::HashMap;
6
7const MAX_FORM_PARAM_COUNT: usize = 2048;
8const MAX_MEMORY: u64 = 32 << 20; const MAX_BODY_LEN: usize = 8 << 20; pub async fn read_json<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
14 let bytes = hyper::body::to_bytes(req.body_mut())
15 .await
16 .context("read request body")?;
17 if bytes.len() > MAX_BODY_LEN {
18 anyhow::bail!(
19 "body too large: {} bytes (limit {})",
20 bytes.len(),
21 MAX_BODY_LEN
22 );
23 }
24 let val: T = serde_json::from_slice(&bytes).context("deserialize json body")?;
25 *req.body_mut() = Body::from(bytes);
26 Ok(val)
27}
28
29pub async fn parse_json_body<T: DeserializeOwned>(
32 req: &mut Request<Body>,
33) -> anyhow::Result<Option<T>> {
34 if !with_json_body(req) {
35 return Ok(None);
36 }
37 let bytes = hyper::body::to_bytes(req.body_mut())
38 .await
39 .context("read request body")?;
40 if bytes.len() > MAX_BODY_LEN {
41 anyhow::bail!("request body too large");
42 }
43 let val: T = serde_json::from_slice(&bytes).context("deserialize json body")?;
44 *req.body_mut() = Body::from(bytes);
45 Ok(Some(val))
46}
47
48pub async fn get_form_values(
51 req: &mut Request<Body>,
52) -> anyhow::Result<HashMap<String, Vec<String>>> {
53 let mut params: HashMap<String, Vec<String>> = HashMap::new();
54 let mut count = 0usize;
55
56 if let Some(q) = req.uri().query() {
58 append_pairs(q, &mut params, &mut count)?;
59 }
60
61 if let Some(ct) = req.headers().get(http::header::CONTENT_TYPE)
63 && let Ok(ct) = ct.to_str()
64 && ct.contains("application/x-www-form-urlencoded")
65 {
66 let bytes = hyper::body::to_bytes(req.body_mut())
67 .await
68 .context("read form body")?;
69 if bytes.len() > MAX_MEMORY as usize {
70 anyhow::bail!("form body too large");
71 }
72 append_pairs(std::str::from_utf8(&bytes)?, &mut params, &mut count)?;
73 *req.body_mut() = Body::from(bytes);
75 }
76
77 Ok(params)
78}
79
80pub async fn parse_form<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
83 let params = get_form_values(req).await?;
84 let json = form_map_to_json(params);
85 serde_json::from_value(json).context("deserialize form to struct")
86}
87
88pub fn parse_path<T: DeserializeOwned>(req: &Request<Body>) -> anyhow::Result<T> {
91 let params: HashMap<String, Vec<String>> = req
92 .extensions()
93 .get::<crate::router::params::PathParams>()
94 .map(|p| {
95 p.params
96 .iter()
97 .map(|(k, v)| (k.clone(), vec![v.clone()]))
98 .collect()
99 })
100 .unwrap_or_default();
101 let json = form_map_to_json(params);
102 serde_json::from_value(json).context("deserialize path params to struct")
103}
104
105pub async fn parse<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
107 if matches!(req.method(), &http::Method::GET | &http::Method::HEAD) {
108 let path_map: HashMap<String, Vec<String>> = req
109 .extensions()
110 .get::<crate::router::params::PathParams>()
111 .map(|p| {
112 p.params
113 .iter()
114 .map(|(k, v)| (k.clone(), vec![v.clone()]))
115 .collect()
116 })
117 .unwrap_or_default();
118 let query_map = get_query_values(req)?;
119
120 let mut merged = serde_json::Map::new();
121 merge_map(&mut merged, form_map_to_json(path_map));
122 merge_map(&mut merged, form_map_to_json(query_map));
123 return serde_json::from_value(serde_json::Value::Object(merged))
124 .context("parse merged request (GET/HEAD path+query)");
125 }
126
127 let path_map: HashMap<String, Vec<String>> = req
128 .extensions()
129 .get::<crate::router::params::PathParams>()
130 .map(|p| {
131 p.params
132 .iter()
133 .map(|(k, v)| (k.clone(), vec![v.clone()]))
134 .collect()
135 })
136 .unwrap_or_default();
137 let form_map = get_form_values(req).await?;
138 let json_body: Option<serde_json::Value> = parse_json_body(req).await?;
139
140 let mut merged = serde_json::Map::new();
141 merge_map(&mut merged, form_map_to_json(path_map));
142 merge_map(&mut merged, form_map_to_json(form_map));
143 if let Some(serde_json::Value::Object(obj)) = json_body {
144 for (k, v) in obj {
145 merged.insert(k, v);
146 }
147 }
148
149 serde_json::from_value(serde_json::Value::Object(merged)).context("parse merged request")
150}
151
152fn get_query_values(req: &Request<Body>) -> anyhow::Result<HashMap<String, Vec<String>>> {
153 let mut params: HashMap<String, Vec<String>> = HashMap::new();
154 let mut count = 0usize;
155 if let Some(q) = req.uri().query() {
156 append_pairs(q, &mut params, &mut count)?;
157 }
158 Ok(params)
159}
160
161pub fn parse_header(header_value: &str) -> HashMap<String, String> {
164 let mut ret = HashMap::new();
165 for field in header_value.split(';') {
166 let field = field.trim();
167 if field.is_empty() {
168 continue;
169 }
170 if let Some((k, v)) = field.split_once('=') {
171 ret.insert(k.to_string(), v.to_string());
172 }
173 }
174 ret
175}
176
177pub fn get_remote_addr(req: &Request<Body>) -> String {
180 if let Some(v) = req.headers().get("x-forwarded-for")
181 && let Ok(s) = v.to_str()
182 && !s.is_empty()
183 {
184 return s.to_string();
185 }
186 if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
187 return addr.to_string();
188 }
189 "unknown".to_string()
190}
191
192fn append_pairs(
193 raw: &str,
194 params: &mut HashMap<String, Vec<String>>,
195 count: &mut usize,
196) -> anyhow::Result<()> {
197 let pairs: Vec<(String, String)> =
198 serde_urlencoded::from_bytes(raw.as_bytes()).context("parse urlencoded form")?;
199 for (mut k_owned, v) in pairs {
200 if v.is_empty() {
201 continue;
202 }
203 if *count >= MAX_FORM_PARAM_COUNT {
204 anyhow::bail!("too many form values");
205 }
206 if k_owned.ends_with("[]") {
207 k_owned.truncate(k_owned.len() - 2);
208 }
209 let values: Vec<String> = v
210 .split(',')
211 .filter(|s| !s.is_empty())
212 .map(|s| s.to_string())
213 .collect();
214 if values.is_empty() {
215 continue;
216 }
217 *count += values.len();
218 params.entry(k_owned).or_default().extend(values);
219 }
220 Ok(())
221}
222
223fn form_map_to_json(map: HashMap<String, Vec<String>>) -> Value {
224 let mut obj = serde_json::Map::new();
225 for (k, vals) in map {
226 if vals.len() == 1 {
227 obj.insert(k, str_to_json_value(&vals[0]));
228 } else {
229 obj.insert(
230 k,
231 Value::Array(vals.into_iter().map(|s| str_to_json_value(&s)).collect()),
232 );
233 }
234 }
235 Value::Object(obj)
236}
237
238fn str_to_json_value(s: &str) -> Value {
239 if let Ok(n) = s.parse::<i64>() {
240 return Value::Number(n.into());
241 }
242 Value::String(s.to_string())
243}
244
245fn merge_map(target: &mut serde_json::Map<String, Value>, src: Value) {
246 if let Value::Object(obj) = src {
247 for (k, v) in obj {
248 target.insert(k, v);
249 }
250 }
251}
252
253fn with_json_body(req: &Request<Body>) -> bool {
254 req.headers()
255 .get(http::header::CONTENT_TYPE)
256 .and_then(|v| v.to_str().ok())
257 .map(|ct| ct.contains("application/json"))
258 .unwrap_or(false)
259 && req
260 .headers()
261 .get(http::header::CONTENT_LENGTH)
262 .and_then(|v| v.to_str().ok())
263 .and_then(|s| s.parse::<usize>().ok())
264 .unwrap_or(0)
265 > 0
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use http::Method;
272 use serde::Deserialize;
273 use tokio::runtime::Runtime;
274
275 #[derive(Debug, Deserialize, PartialEq, Eq)]
276 struct Payload {
277 name: String,
278 id: u32,
279 }
280
281 #[derive(Debug, Deserialize, PartialEq, Eq)]
282 struct FormPayload {
283 #[serde(default)]
284 names: Vec<String>,
285 #[serde(default)]
286 ids: Vec<i32>,
287 #[serde(default)]
288 tags: Vec<String>,
289 #[serde(default)]
290 name: String,
291 }
292
293 fn runtime() -> Runtime {
294 Runtime::new().unwrap()
295 }
296
297 #[test]
298 fn read_json_should_parse_and_preserve_body() {
299 runtime().block_on(async {
300 let body = r#"{"name":"alice","id":7}"#;
301 let mut req = Request::builder()
302 .method(Method::POST)
303 .uri("/test")
304 .body(Body::from(body))
305 .unwrap();
306
307 let parsed: Payload = read_json(&mut req).await.unwrap();
308 assert_eq!(
309 parsed,
310 Payload {
311 name: "alice".into(),
312 id: 7
313 }
314 );
315
316 let bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
318 assert_eq!(bytes, body);
319 });
320 }
321
322 #[test]
323 fn get_form_values_should_support_array_notations() {
324 runtime().block_on(async {
325 let mut req = Request::builder()
326 .method(Method::POST)
327 .uri("/api?names=alice&names=bob&names[]=carol&names=dave,erin")
328 .header(
329 http::header::CONTENT_TYPE,
330 "application/x-www-form-urlencoded",
331 )
332 .body(Body::from("ids=1&ids=2&empty=&ids[]=3"))
333 .unwrap();
334 let params = get_form_values(&mut req).await.unwrap();
335 assert_eq!(
336 params.get("names").unwrap(),
337 &vec![
338 "alice".to_string(),
339 "bob".to_string(),
340 "carol".to_string(),
341 "dave".to_string(),
342 "erin".to_string()
343 ]
344 );
345 assert_eq!(
346 params.get("ids").unwrap(),
347 &vec!["1".to_string(), "2".to_string(), "3".to_string()]
348 );
349 });
350 }
351
352 #[test]
353 fn parse_form_should_fill_struct() {
354 runtime().block_on(async {
355 let mut req = Request::builder()
356 .method(Method::POST)
357 .uri("/api?names=alice,bob&ids=1&ids=2")
358 .header(
359 http::header::CONTENT_TYPE,
360 "application/x-www-form-urlencoded",
361 )
362 .body(Body::from("tags=r&tags=go&names=carol&ids=3"))
363 .unwrap();
364 let fp: FormPayload = parse_form(&mut req).await.unwrap();
365 assert_eq!(fp.names, vec!["alice", "bob", "carol"]);
366 assert_eq!(fp.ids, vec![1, 2, 3]);
367 assert_eq!(fp.tags, vec!["r", "go"]);
368 assert_eq!(fp.name, "");
369 });
370 }
371
372 #[test]
373 fn parse_should_use_query_only_for_get() {
374 runtime().block_on(async {
375 #[derive(Deserialize, Debug, PartialEq)]
376 struct Q {
377 a: i64,
378 }
379
380 let mut req = Request::builder()
381 .method(Method::GET)
382 .uri("/path?a=1")
383 .header(http::header::CONTENT_TYPE, "application/json")
384 .body(Body::from(r#"{"a":"2"}"#))
385 .unwrap();
386
387 let parsed: Q = parse(&mut req).await.unwrap();
388 assert_eq!(parsed.a, 1);
389 let bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
391 assert_eq!(bytes, r#"{"a":"2"}"#);
392 });
393 }
394
395 #[test]
396 fn parse_should_merge_path_form_json() {
397 runtime().block_on(async {
398 let mut req = Request::builder()
399 .method(Method::POST)
400 .uri("/users/42?tags=a,b&ids=1&ids=2")
401 .header(
402 http::header::CONTENT_TYPE,
403 "application/x-www-form-urlencoded",
404 )
405 .body(Body::from("tags=c&name=bob"))
406 .unwrap();
407 let mut params = std::collections::HashMap::new();
409 params.insert("id".to_string(), "42".to_string());
410 req.extensions_mut()
411 .insert(crate::router::params::PathParams { params });
412
413 let merged: FormPayload = parse(&mut req).await.unwrap();
414 assert_eq!(merged.ids, vec![1, 2]);
415 assert_eq!(merged.tags, vec!["a", "b", "c"]);
416 assert_eq!(merged.name, "bob");
417 assert_eq!(merged.names, Vec::<String>::new());
418 });
419 }
420
421 #[test]
422 fn parse_header_should_split_pairs() {
423 let m = parse_header("a=1; b=2; c=3");
424 assert_eq!(m.get("a").unwrap(), "1");
425 assert_eq!(m.get("b").unwrap(), "2");
426 assert_eq!(m.get("c").unwrap(), "3");
427 }
428
429 #[test]
430 fn get_remote_addr_should_prefer_header() {
431 let req = Request::builder()
432 .method(Method::GET)
433 .uri("/")
434 .header("x-forwarded-for", "1.1.1.1")
435 .body(Body::empty())
436 .unwrap();
437 assert_eq!(get_remote_addr(&req), "1.1.1.1");
438 }
439}