use axum::body::Body;
use axum::http::{HeaderName, HeaderValue, Method as AxumMethod, Request as AxumRequest};
use rivet_http::{Method, Request, Response};
#[derive(Debug, thiserror::Error)]
pub enum MappingError {
#[error("unsupported method: {0}")]
UnsupportedMethod(String),
#[error("request body exceeded configured limit")]
BodyLimit,
}
pub async fn from_axum_request(
req: AxumRequest<Body>,
body_limit: usize,
) -> Result<Request, MappingError> {
let (parts, body) = req.into_parts();
let method = method_from_axum(&parts.method)
.ok_or_else(|| MappingError::UnsupportedMethod(parts.method.to_string()))?;
let mut mapped = Request::new(method, parts.uri.path().to_string());
if let Some(raw_query) = parts.uri.query() {
mapped.query = parse_query(raw_query);
}
mapped.headers = map_headers(parts.headers);
let bytes = axum::body::to_bytes(body, body_limit)
.await
.map_err(|_| MappingError::BodyLimit)?;
mapped.body = bytes.to_vec();
Ok(mapped)
}
pub fn to_axum_response(res: Response) -> axum::response::Response {
let mut response = axum::response::Response::new(Body::from(res.body));
*response.status_mut() = axum::http::StatusCode::from_u16(res.status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
for (name, value) in res.headers {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(&value),
) {
response.headers_mut().insert(name, value);
}
}
response
}
fn method_from_axum(method: &AxumMethod) -> Option<Method> {
match *method {
AxumMethod::GET => Some(Method::Get),
AxumMethod::POST => Some(Method::Post),
AxumMethod::PUT => Some(Method::Put),
AxumMethod::PATCH => Some(Method::Patch),
AxumMethod::DELETE => Some(Method::Delete),
AxumMethod::HEAD => Some(Method::Head),
AxumMethod::OPTIONS => Some(Method::Options),
_ => None,
}
}
fn map_headers(headers: axum::http::HeaderMap) -> rivet_http::Headers {
let mut mapped = rivet_http::Headers::new();
for (name, value) in headers {
let Some(name) = name else {
continue;
};
let key = name.as_str().to_ascii_lowercase();
let value = value
.to_str()
.map(|v| v.to_string())
.unwrap_or_else(|_| String::from_utf8_lossy(value.as_bytes()).into_owned());
mapped
.entry(key)
.and_modify(|existing| {
existing.push(',');
existing.push_str(&value);
})
.or_insert(value);
}
mapped
}
fn parse_query(raw: &str) -> Vec<(String, String)> {
if raw.is_empty() {
return Vec::new();
}
raw.split('&')
.filter(|pair| !pair.is_empty())
.map(|pair| {
let mut parts = pair.splitn(2, '=');
let key = parts.next().unwrap_or_default().to_string();
let value = parts.next().unwrap_or_default().to_string();
(key, value)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::parse_query;
#[test]
fn query_parser_preserves_duplicate_keys_and_order() {
let parsed = parse_query("a=1&a=2&b=3");
assert_eq!(
parsed,
vec![
("a".to_string(), "1".to_string()),
("a".to_string(), "2".to_string()),
("b".to_string(), "3".to_string())
]
);
}
}