rest/http/
request.rs

1use anyhow::Context;
2use hyper::{Body, Request};
3use serde::de::DeserializeOwned;
4use serde_json::Value;
5use std::collections::HashMap;
6
7const MAX_FORM_PARAM_COUNT: usize = 2048;
8const MAX_MEMORY: u64 = 32 << 20; // 32MB cap kept for reference
9const MAX_BODY_LEN: usize = 8 << 20; // 8MB
10
11/// Read JSON body into `T` and keep body available by re-inserting bytes.
12/// Read JSON body as `T`, and put bytes back so the body stays readable.
13pub async fn read_json<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
14    let bytes = hyper::body::to_bytes(req.body_mut())
15        .await
16        .context("read request body")?;
17    if bytes.len() > MAX_BODY_LEN {
18        anyhow::bail!(
19            "body too large: {} bytes (limit {})",
20            bytes.len(),
21            MAX_BODY_LEN
22        );
23    }
24    let val: T = serde_json::from_slice(&bytes).context("deserialize json body")?;
25    *req.body_mut() = Body::from(bytes);
26    Ok(val)
27}
28
29/// Parse `application/json` body with size limit, return `None` if not json.
30/// Parse JSON body with size limit; return None if not JSON.
31pub async fn parse_json_body<T: DeserializeOwned>(
32    req: &mut Request<Body>,
33) -> anyhow::Result<Option<T>> {
34    if !with_json_body(req) {
35        return Ok(None);
36    }
37    let bytes = hyper::body::to_bytes(req.body_mut())
38        .await
39        .context("read request body")?;
40    if bytes.len() > MAX_BODY_LEN {
41        anyhow::bail!("request body too large");
42    }
43    let val: T = serde_json::from_slice(&bytes).context("deserialize json body")?;
44    *req.body_mut() = Body::from(bytes);
45    Ok(Some(val))
46}
47
48/// Parse URL query + x-www-form-urlencoded body into map, supporting [] / comma array syntax.
49/// Parse query and form (application/x-www-form-urlencoded), supports []/comma arrays.
50pub async fn get_form_values(
51    req: &mut Request<Body>,
52) -> anyhow::Result<HashMap<String, Vec<String>>> {
53    let mut params: HashMap<String, Vec<String>> = HashMap::new();
54    let mut count = 0usize;
55
56    // query part
57    if let Some(q) = req.uri().query() {
58        append_pairs(q, &mut params, &mut count)?;
59    }
60
61    // body part (only for form content type)
62    if let Some(ct) = req.headers().get(http::header::CONTENT_TYPE)
63        && let Ok(ct) = ct.to_str()
64        && ct.contains("application/x-www-form-urlencoded")
65    {
66        let bytes = hyper::body::to_bytes(req.body_mut())
67            .await
68            .context("read form body")?;
69        if bytes.len() > MAX_MEMORY as usize {
70            anyhow::bail!("form body too large");
71        }
72        append_pairs(std::str::from_utf8(&bytes)?, &mut params, &mut count)?;
73        // reinsert body for downstream use
74        *req.body_mut() = Body::from(bytes);
75    }
76
77    Ok(params)
78}
79
80/// Parse form (query + urlencoded body) into `T`.
81/// Parse form (query + urlencoded body) into `T`.
82pub async fn parse_form<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
83    let params = get_form_values(req).await?;
84    let json = form_map_to_json(params);
85    serde_json::from_value(json).context("deserialize form to struct")
86}
87
88/// Parse path params (from extensions) into `T`.
89/// Parse path params (from extensions) into `T`.
90pub fn parse_path<T: DeserializeOwned>(req: &Request<Body>) -> anyhow::Result<T> {
91    let params: HashMap<String, Vec<String>> = req
92        .extensions()
93        .get::<crate::router::params::PathParams>()
94        .map(|p| {
95            p.params
96                .iter()
97                .map(|(k, v)| (k.clone(), vec![v.clone()]))
98                .collect()
99        })
100        .unwrap_or_default();
101    let json = form_map_to_json(params);
102    serde_json::from_value(json).context("deserialize path params to struct")
103}
104
105/// Merge path + form + json into one struct.
106pub async fn parse<T: DeserializeOwned>(req: &mut Request<Body>) -> anyhow::Result<T> {
107    if matches!(req.method(), &http::Method::GET | &http::Method::HEAD) {
108        let path_map: HashMap<String, Vec<String>> = req
109            .extensions()
110            .get::<crate::router::params::PathParams>()
111            .map(|p| {
112                p.params
113                    .iter()
114                    .map(|(k, v)| (k.clone(), vec![v.clone()]))
115                    .collect()
116            })
117            .unwrap_or_default();
118        let query_map = get_query_values(req)?;
119
120        let mut merged = serde_json::Map::new();
121        merge_map(&mut merged, form_map_to_json(path_map));
122        merge_map(&mut merged, form_map_to_json(query_map));
123        return serde_json::from_value(serde_json::Value::Object(merged))
124            .context("parse merged request (GET/HEAD path+query)");
125    }
126
127    let path_map: HashMap<String, Vec<String>> = req
128        .extensions()
129        .get::<crate::router::params::PathParams>()
130        .map(|p| {
131            p.params
132                .iter()
133                .map(|(k, v)| (k.clone(), vec![v.clone()]))
134                .collect()
135        })
136        .unwrap_or_default();
137    let form_map = get_form_values(req).await?;
138    let json_body: Option<serde_json::Value> = parse_json_body(req).await?;
139
140    let mut merged = serde_json::Map::new();
141    merge_map(&mut merged, form_map_to_json(path_map));
142    merge_map(&mut merged, form_map_to_json(form_map));
143    if let Some(serde_json::Value::Object(obj)) = json_body {
144        for (k, v) in obj {
145            merged.insert(k, v);
146        }
147    }
148
149    serde_json::from_value(serde_json::Value::Object(merged)).context("parse merged request")
150}
151
152fn get_query_values(req: &Request<Body>) -> anyhow::Result<HashMap<String, Vec<String>>> {
153    let mut params: HashMap<String, Vec<String>> = HashMap::new();
154    let mut count = 0usize;
155    if let Some(q) = req.uri().query() {
156        append_pairs(q, &mut params, &mut count)?;
157    }
158    Ok(params)
159}
160
161/// Parse header attributes like `a=1; b=2` into map.
162/// Parse header attributes like `a=1; b=2`.
163pub fn parse_header(header_value: &str) -> HashMap<String, String> {
164    let mut ret = HashMap::new();
165    for field in header_value.split(';') {
166        let field = field.trim();
167        if field.is_empty() {
168            continue;
169        }
170        if let Some((k, v)) = field.split_once('=') {
171            ret.insert(k.to_string(), v.to_string());
172        }
173    }
174    ret
175}
176
177/// Get remote addr, prefer X-Forwarded-For then extension-provided peer addr.
178/// Get remote addr: prefer X-Forwarded-For, then peer addr in extensions.
179pub fn get_remote_addr(req: &Request<Body>) -> String {
180    if let Some(v) = req.headers().get("x-forwarded-for")
181        && let Ok(s) = v.to_str()
182        && !s.is_empty()
183    {
184        return s.to_string();
185    }
186    if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
187        return addr.to_string();
188    }
189    "unknown".to_string()
190}
191
192fn append_pairs(
193    raw: &str,
194    params: &mut HashMap<String, Vec<String>>,
195    count: &mut usize,
196) -> anyhow::Result<()> {
197    let pairs: Vec<(String, String)> =
198        serde_urlencoded::from_bytes(raw.as_bytes()).context("parse urlencoded form")?;
199    for (mut k_owned, v) in pairs {
200        if v.is_empty() {
201            continue;
202        }
203        if *count >= MAX_FORM_PARAM_COUNT {
204            anyhow::bail!("too many form values");
205        }
206        if k_owned.ends_with("[]") {
207            k_owned.truncate(k_owned.len() - 2);
208        }
209        let values: Vec<String> = v
210            .split(',')
211            .filter(|s| !s.is_empty())
212            .map(|s| s.to_string())
213            .collect();
214        if values.is_empty() {
215            continue;
216        }
217        *count += values.len();
218        params.entry(k_owned).or_default().extend(values);
219    }
220    Ok(())
221}
222
223fn form_map_to_json(map: HashMap<String, Vec<String>>) -> Value {
224    let mut obj = serde_json::Map::new();
225    for (k, vals) in map {
226        if vals.len() == 1 {
227            obj.insert(k, str_to_json_value(&vals[0]));
228        } else {
229            obj.insert(
230                k,
231                Value::Array(vals.into_iter().map(|s| str_to_json_value(&s)).collect()),
232            );
233        }
234    }
235    Value::Object(obj)
236}
237
238fn str_to_json_value(s: &str) -> Value {
239    if let Ok(n) = s.parse::<i64>() {
240        return Value::Number(n.into());
241    }
242    Value::String(s.to_string())
243}
244
245fn merge_map(target: &mut serde_json::Map<String, Value>, src: Value) {
246    if let Value::Object(obj) = src {
247        for (k, v) in obj {
248            target.insert(k, v);
249        }
250    }
251}
252
253fn with_json_body(req: &Request<Body>) -> bool {
254    req.headers()
255        .get(http::header::CONTENT_TYPE)
256        .and_then(|v| v.to_str().ok())
257        .map(|ct| ct.contains("application/json"))
258        .unwrap_or(false)
259        && req
260            .headers()
261            .get(http::header::CONTENT_LENGTH)
262            .and_then(|v| v.to_str().ok())
263            .and_then(|s| s.parse::<usize>().ok())
264            .unwrap_or(0)
265            > 0
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use http::Method;
272    use serde::Deserialize;
273    use tokio::runtime::Runtime;
274
275    #[derive(Debug, Deserialize, PartialEq, Eq)]
276    struct Payload {
277        name: String,
278        id: u32,
279    }
280
281    #[derive(Debug, Deserialize, PartialEq, Eq)]
282    struct FormPayload {
283        #[serde(default)]
284        names: Vec<String>,
285        #[serde(default)]
286        ids: Vec<i32>,
287        #[serde(default)]
288        tags: Vec<String>,
289        #[serde(default)]
290        name: String,
291    }
292
293    fn runtime() -> Runtime {
294        Runtime::new().unwrap()
295    }
296
297    #[test]
298    fn read_json_should_parse_and_preserve_body() {
299        runtime().block_on(async {
300            let body = r#"{"name":"alice","id":7}"#;
301            let mut req = Request::builder()
302                .method(Method::POST)
303                .uri("/test")
304                .body(Body::from(body))
305                .unwrap();
306
307            let parsed: Payload = read_json(&mut req).await.unwrap();
308            assert_eq!(
309                parsed,
310                Payload {
311                    name: "alice".into(),
312                    id: 7
313                }
314            );
315
316            // body should still be readable
317            let bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
318            assert_eq!(bytes, body);
319        });
320    }
321
322    #[test]
323    fn get_form_values_should_support_array_notations() {
324        runtime().block_on(async {
325            let mut req = Request::builder()
326                .method(Method::POST)
327                .uri("/api?names=alice&names=bob&names[]=carol&names=dave,erin")
328                .header(
329                    http::header::CONTENT_TYPE,
330                    "application/x-www-form-urlencoded",
331                )
332                .body(Body::from("ids=1&ids=2&empty=&ids[]=3"))
333                .unwrap();
334            let params = get_form_values(&mut req).await.unwrap();
335            assert_eq!(
336                params.get("names").unwrap(),
337                &vec![
338                    "alice".to_string(),
339                    "bob".to_string(),
340                    "carol".to_string(),
341                    "dave".to_string(),
342                    "erin".to_string()
343                ]
344            );
345            assert_eq!(
346                params.get("ids").unwrap(),
347                &vec!["1".to_string(), "2".to_string(), "3".to_string()]
348            );
349        });
350    }
351
352    #[test]
353    fn parse_form_should_fill_struct() {
354        runtime().block_on(async {
355            let mut req = Request::builder()
356                .method(Method::POST)
357                .uri("/api?names=alice,bob&ids=1&ids=2")
358                .header(
359                    http::header::CONTENT_TYPE,
360                    "application/x-www-form-urlencoded",
361                )
362                .body(Body::from("tags=r&tags=go&names=carol&ids=3"))
363                .unwrap();
364            let fp: FormPayload = parse_form(&mut req).await.unwrap();
365            assert_eq!(fp.names, vec!["alice", "bob", "carol"]);
366            assert_eq!(fp.ids, vec![1, 2, 3]);
367            assert_eq!(fp.tags, vec!["r", "go"]);
368            assert_eq!(fp.name, "");
369        });
370    }
371
372    #[test]
373    fn parse_should_use_query_only_for_get() {
374        runtime().block_on(async {
375            #[derive(Deserialize, Debug, PartialEq)]
376            struct Q {
377                a: i64,
378            }
379
380            let mut req = Request::builder()
381                .method(Method::GET)
382                .uri("/path?a=1")
383                .header(http::header::CONTENT_TYPE, "application/json")
384                .body(Body::from(r#"{"a":"2"}"#))
385                .unwrap();
386
387            let parsed: Q = parse(&mut req).await.unwrap();
388            assert_eq!(parsed.a, 1);
389            // body should remain available (unchanged)
390            let bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
391            assert_eq!(bytes, r#"{"a":"2"}"#);
392        });
393    }
394
395    #[test]
396    fn parse_should_merge_path_form_json() {
397        runtime().block_on(async {
398            let mut req = Request::builder()
399                .method(Method::POST)
400                .uri("/users/42?tags=a,b&ids=1&ids=2")
401                .header(
402                    http::header::CONTENT_TYPE,
403                    "application/x-www-form-urlencoded",
404                )
405                .body(Body::from("tags=c&name=bob"))
406                .unwrap();
407            // simulate path params injected by router
408            let mut params = std::collections::HashMap::new();
409            params.insert("id".to_string(), "42".to_string());
410            req.extensions_mut()
411                .insert(crate::router::params::PathParams { params });
412
413            let merged: FormPayload = parse(&mut req).await.unwrap();
414            assert_eq!(merged.ids, vec![1, 2]);
415            assert_eq!(merged.tags, vec!["a", "b", "c"]);
416            assert_eq!(merged.name, "bob");
417            assert_eq!(merged.names, Vec::<String>::new());
418        });
419    }
420
421    #[test]
422    fn parse_header_should_split_pairs() {
423        let m = parse_header("a=1; b=2; c=3");
424        assert_eq!(m.get("a").unwrap(), "1");
425        assert_eq!(m.get("b").unwrap(), "2");
426        assert_eq!(m.get("c").unwrap(), "3");
427    }
428
429    #[test]
430    fn get_remote_addr_should_prefer_header() {
431        let req = Request::builder()
432            .method(Method::GET)
433            .uri("/")
434            .header("x-forwarded-for", "1.1.1.1")
435            .body(Body::empty())
436            .unwrap();
437        assert_eq!(get_remote_addr(&req), "1.1.1.1");
438    }
439}