use actus_controller::{Params, Verb};
use actus_reply::WebError;
use bytes::Bytes;
use http::HeaderMap;
use http_body_util::{BodyExt, LengthLimitError, Limited};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
#[derive(Clone, Debug)]
pub struct Request {
pub method: http::Method,
pub path_parts: Vec<String>,
pub query_params: HashMap<String, Vec<String>>,
pub body: Bytes,
pub headers: HeaderMap,
pub rate_limit_class: Option<&'static str>,
}
fn collect_pairs(pairs: Vec<(String, String)>) -> HashMap<String, Vec<String>> {
let mut map: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in pairs {
map.entry(name).or_default().push(value);
}
map
}
async fn collect_body_capped<B>(
body: B,
max_bytes: usize,
budget: Option<&Arc<Semaphore>>,
) -> Result<Bytes, WebError>
where
B: hyper::body::Body<Data = Bytes>,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let _permit = match budget {
Some(s) => {
let n = u32::try_from(max_bytes).unwrap_or(u32::MAX);
match s.clone().try_acquire_many_owned(n) {
Ok(p) => Some(p),
Err(_) => return Err(WebError::Busy(Some(Duration::from_secs(1)))),
}
}
None => None,
};
match Limited::new(body, max_bytes).collect().await {
Ok(collected) => Ok(collected.to_bytes()),
Err(e) if e.downcast_ref::<LengthLimitError>().is_some() => Err(WebError::PayloadTooLarge),
Err(e) => Err(WebError::BadRequest(format!(
"could not read request body: {e}"
))),
}
}
impl Request {
pub fn from_hyper_parts(
req: hyper::Request<hyper::body::Incoming>,
) -> (Self, hyper::body::Incoming) {
let (parts, body) = req.into_parts();
let path_parts: Vec<String> = parts
.uri
.path()
.trim_matches('/')
.split('/')
.map(String::from)
.filter(|s| !s.is_empty())
.collect();
let query_params = parts
.uri
.query()
.map(|q| {
collect_pairs(
serde_urlencoded::from_str::<Vec<(String, String)>>(q).unwrap_or_default(),
)
})
.unwrap_or_default();
let skeleton = Self {
method: parts.method,
path_parts,
query_params,
body: Bytes::new(),
headers: parts.headers,
rate_limit_class: None,
};
(skeleton, body)
}
pub async fn collect_body(
mut self,
body: hyper::body::Incoming,
max_body_bytes: usize,
inflight_budget: Option<&Arc<Semaphore>>,
) -> Result<Self, (Self, WebError)> {
match collect_body_capped(body, max_body_bytes, inflight_budget).await {
Ok(body_bytes) => {
self.body = body_bytes;
Ok(self)
}
Err(e) => Err((self, e)),
}
}
pub async fn from_hyper(
req: hyper::Request<hyper::body::Incoming>,
max_body_bytes: usize,
inflight_budget: Option<&Arc<Semaphore>>,
) -> Result<Self, (Self, WebError)> {
let (skeleton, body) = Self::from_hyper_parts(req);
skeleton
.collect_body(body, max_body_bytes, inflight_budget)
.await
}
pub fn to_params(&self) -> Result<Params, WebError> {
let mut all_params = self.query_params.clone();
let content_type = self
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| {
s.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase()
});
let json_body = if self.body.is_empty() {
None
} else {
match content_type.as_deref() {
Some(ct) if is_json_content_type(ct) => {
Some(serde_json::from_slice(&self.body).map_err(|e| {
WebError::BadRequest(format!("body is not valid JSON: {e}"))
})?)
}
Some("application/x-www-form-urlencoded") => {
let form_pairs: Vec<(String, String)> =
serde_urlencoded::from_bytes(&self.body).map_err(|e| {
WebError::BadRequest(format!("body is not valid form-urlencoded: {e}"))
})?;
for (name, value) in form_pairs {
all_params.entry(name).or_default().push(value);
}
None
}
Some(_) => None,
None => {
return Err(WebError::BadRequest(
"non-empty request body requires a Content-Type header".into(),
));
}
}
};
let mut headers: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in self.headers.iter() {
if let Ok(v) = value.to_str() {
headers
.entry(name.as_str().to_ascii_lowercase())
.or_default()
.push(v.to_string());
}
}
Ok(Params::new(
method_to_verb(&self.method),
all_params,
json_body,
self.body.clone(),
headers,
))
}
}
fn is_json_content_type(ct: &str) -> bool {
ct == "application/json" || ct.ends_with("+json")
}
fn method_to_verb(method: &http::Method) -> Verb {
match method {
m if m == http::Method::GET => Verb::GET,
m if m == http::Method::POST => Verb::POST,
m if m == http::Method::PUT => Verb::PUT,
m if m == http::Method::DELETE => Verb::DELETE,
m if m == http::Method::PATCH => Verb::PATCH,
m if m == http::Method::HEAD => Verb::HEAD,
m if m == http::Method::OPTIONS => Verb::OPTIONS,
_ => Verb::GET,
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::HeaderValue;
fn req(content_type: Option<&str>, body: &[u8]) -> Request {
let mut headers = HeaderMap::new();
if let Some(ct) = content_type {
headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_str(ct).unwrap(),
);
}
Request {
method: http::Method::POST,
path_parts: vec!["whatever".into()],
query_params: HashMap::new(),
body: Bytes::copy_from_slice(body),
headers,
rate_limit_class: None,
}
}
#[test]
fn empty_body_no_content_type_is_ok() {
let params = req(None, b"").to_params().expect("empty + no CT is fine");
assert!(params.body_bytes().is_empty());
}
#[test]
fn json_body_is_parsed() {
let params = req(Some("application/json"), br#"{"x":1}"#)
.to_params()
.expect("valid JSON");
let v = params.json_body().expect("body present");
assert_eq!(v["x"], 1);
assert_eq!(params.body_bytes().as_ref(), br#"{"x":1}"#);
}
#[test]
fn json_body_with_charset_param_is_parsed() {
let params = req(Some("application/json; charset=utf-8"), br#"{"x":1}"#)
.to_params()
.expect("valid JSON with charset");
assert_eq!(params.json_body().unwrap()["x"], 1);
}
#[test]
fn vendor_plus_json_subtype_is_parsed() {
let params = req(Some("application/vnd.example+json"), br#"{"x":1}"#)
.to_params()
.expect("valid +json");
assert_eq!(params.json_body().unwrap()["x"], 1);
}
#[test]
fn malformed_json_is_a_400() {
match req(Some("application/json"), b"not json").to_params() {
Err(WebError::BadRequest(msg)) => assert!(msg.contains("valid JSON"), "msg = {msg:?}"),
Err(other) => panic!("expected BadRequest, got {other:?}"),
Ok(_) => panic!("must reject malformed JSON"),
}
}
#[test]
fn form_body_merges_into_query() {
let params = req(
Some("application/x-www-form-urlencoded"),
b"username=alice&kind=admin",
)
.to_params()
.expect("valid form");
assert!(params.json_body().is_err());
assert_eq!(params.body_bytes().as_ref(), b"username=alice&kind=admin");
}
#[test]
fn binary_body_round_trips() {
let zip_bytes = b"PK\x03\x04\x00\x00fake-zip";
let params = req(Some("application/zip"), zip_bytes)
.to_params()
.expect("opaque CT is fine");
assert!(params.json_body().is_err());
assert_eq!(params.body_bytes().as_ref(), zip_bytes);
}
#[test]
fn non_empty_body_without_content_type_is_a_400() {
match req(None, b"some payload").to_params() {
Err(WebError::BadRequest(msg)) => {
assert!(msg.contains("Content-Type"), "msg = {msg:?}")
}
Err(other) => panic!("expected BadRequest, got {other:?}"),
Ok(_) => panic!("must reject non-empty body without Content-Type"),
}
}
#[test]
fn collect_pairs_keeps_repeated_keys_in_order() {
let m = collect_pairs(vec![
("tag".into(), "a".into()),
("page".into(), "1".into()),
("tag".into(), "b".into()),
]);
assert_eq!(m.get("tag").unwrap(), &["a".to_string(), "b".to_string()]);
assert_eq!(m.get("page").unwrap(), &["1".to_string()]);
}
#[tokio::test]
async fn body_within_limit_round_trips() {
let body = http_body_util::Full::new(Bytes::from_static(b"hello"));
assert_eq!(
collect_body_capped(body, 1024, None).await.unwrap(),
Bytes::from_static(b"hello")
);
}
#[tokio::test]
async fn body_over_limit_is_413() {
let body = http_body_util::Full::new(Bytes::from(vec![0u8; 2048]));
match collect_body_capped(body, 1024, None).await {
Err(WebError::PayloadTooLarge) => {}
other => panic!("expected PayloadTooLarge, got {other:?}"),
}
}
#[tokio::test]
async fn body_refused_when_inflight_budget_exhausted() {
let budget = Arc::new(Semaphore::new(100));
let _hold = budget.clone().try_acquire_many_owned(80).expect("acquire");
let body = http_body_util::Full::new(Bytes::from_static(b"hello"));
match collect_body_capped(body, 80, Some(&budget)).await {
Err(WebError::Busy(Some(d))) => assert_eq!(d, Duration::from_secs(1)),
other => panic!("expected Busy, got {other:?}"),
}
drop(_hold);
let body = http_body_util::Full::new(Bytes::from_static(b"hello"));
assert!(collect_body_capped(body, 80, Some(&budget)).await.is_ok());
}
#[test]
fn duplicate_request_headers_survive_into_params() {
let mut headers = HeaderMap::new();
headers.append(
http::HeaderName::from_static("forwarded"),
HeaderValue::from_static("for=1.2.3.4"),
);
headers.append(
http::HeaderName::from_static("forwarded"),
HeaderValue::from_static("for=10.0.0.1"),
);
let request = Request {
method: http::Method::GET,
path_parts: vec![],
query_params: HashMap::new(),
body: Bytes::new(),
headers,
rate_limit_class: None,
};
let params = request.to_params().expect("ok");
assert_eq!(params.header("Forwarded"), Some("for=1.2.3.4"));
assert_eq!(
params.header_all("Forwarded"),
["for=1.2.3.4", "for=10.0.0.1"]
);
}
#[test]
fn form_body_appends_to_query_multimap_instead_of_overwriting() {
let mut query = HashMap::new();
query.insert("tag".to_string(), vec!["from-query".to_string()]);
let request = Request {
method: http::Method::POST,
path_parts: vec!["whatever".into()],
query_params: query,
body: Bytes::from_static(b"tag=from-body¬e=hi"),
headers: {
let mut h = HeaderMap::new();
h.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/x-www-form-urlencoded"),
);
h
},
rate_limit_class: None,
};
let params = request.to_params().expect("valid form");
assert_eq!(params.get_all("tag").unwrap(), ["from-query", "from-body"]);
assert_eq!(params.require("tag").unwrap(), "from-query");
assert_eq!(params.get_all("note").unwrap(), ["hi"]);
}
}