use axum::{
body::Bytes,
extract::FromRequest,
http::{HeaderMap, Method, Request, Uri},
};
use serde::de::DeserializeOwned;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Req {
pub headers: HeaderMap,
pub cookies: Cookies,
pub method: Method,
pub uri: Uri,
pub body: Body,
}
#[derive(Debug, Clone, Default)]
pub struct Body {
bytes: Bytes,
}
#[derive(Debug)]
pub struct BodyParseError {
pub message: String,
}
impl std::fmt::Display for BodyParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for BodyParseError {}
impl Body {
pub fn new(bytes: Bytes) -> Self {
Self { bytes }
}
pub fn is_empty(&self) -> bool {
self.bytes.is_empty()
}
pub fn as_bytes(&self) -> &Bytes {
&self.bytes
}
pub fn as_text(&self) -> Result<String, BodyParseError> {
String::from_utf8(self.bytes.to_vec()).map_err(|e| BodyParseError {
message: format!("Invalid UTF-8: {}", e),
})
}
pub fn as_json<T: DeserializeOwned>(&self) -> Result<T, BodyParseError> {
serde_json::from_slice(&self.bytes).map_err(|e| BodyParseError {
message: format!("Invalid JSON: {}", e),
})
}
pub fn as_form<T: DeserializeOwned>(&self) -> Result<T, BodyParseError> {
serde_urlencoded::from_bytes(&self.bytes).map_err(|e| BodyParseError {
message: format!("Invalid form data: {}", e),
})
}
}
#[derive(Debug, Clone, Default)]
pub struct Cookies {
cookies: HashMap<String, String>,
}
impl Cookies {
pub fn from_header(header: Option<&str>) -> Self {
let mut cookies = HashMap::new();
if let Some(header) = header {
for part in header.split(';') {
let part = part.trim();
if let Some((name, value)) = part.split_once('=') {
cookies.insert(name.trim().to_string(), value.trim().to_string());
}
}
}
Self { cookies }
}
pub fn get(&self, name: &str) -> Option<&str> {
self.cookies.get(name).map(|s| s.as_str())
}
pub fn has(&self, name: &str) -> bool {
self.cookies.contains_key(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.cookies.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
}
impl<S> FromRequest<S> for Req
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request(
req: Request<axum::body::Body>,
_state: &S,
) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let headers = parts.headers;
let cookies = Cookies::from_header(
headers
.get(axum::http::header::COOKIE)
.and_then(|v| v.to_str().ok()),
);
let method = parts.method;
let uri = parts.uri;
let bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
Ok(Req {
headers,
cookies,
method,
uri,
body: Body::new(bytes),
})
}
}