use bytes::Bytes;
use http::request::Parts;
use http::StatusCode;
use serde::de::DeserializeOwned;
use typeway_core::{ExtractPath, PathSpec};
use crate::response::IntoResponse;
#[diagnostic::on_unimplemented(
message = "`{Self}` cannot be extracted from request metadata",
label = "does not implement `FromRequestParts`",
note = "valid extractors: `Path<P>`, `State<T>`, `Query<T>`, `HeaderMap`"
)]
pub trait FromRequestParts: Sized + Send {
type Error: IntoResponse;
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error>;
}
#[diagnostic::on_unimplemented(
message = "`{Self}` cannot be extracted from the request body",
label = "does not implement `FromRequest`",
note = "valid body extractors: `Json<T>`, `Bytes`, `String`, `()`"
)]
pub trait FromRequest: Sized + Send {
type Error: IntoResponse;
fn from_request(
parts: &Parts,
body: bytes::Bytes,
) -> impl std::future::Future<Output = Result<Self, Self::Error>> + Send;
}
pub struct Path<P: PathSpec>(pub P::Captures);
#[derive(Copy, Clone)]
pub struct PathPrefixOffset(pub usize);
impl<P> FromRequestParts for Path<P>
where
P: PathSpec + ExtractPath + Send,
P::Captures: Send,
{
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
let full_path = parts.uri.path();
let offset = parts
.extensions
.get::<PathPrefixOffset>()
.map_or(0, |o| o.0);
let path = if offset <= full_path.len() {
&full_path[offset..]
} else {
""
};
let segs: smallvec::SmallVec<[&str; 8]> =
path.split('/').filter(|s| !s.is_empty()).collect();
P::extract(&segs).map(Path).ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
format!(
"failed to parse path segments for pattern: {}",
P::pattern()
),
)
})
}
}
pub struct State<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
parts
.extensions
.get::<T>()
.cloned()
.map(State)
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"state of type `{}` not found — did you call .with_state()?",
std::any::type_name::<T>()
),
)
})
}
}
pub struct Query<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequestParts for Query<T> {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
let query = parts.uri.query().unwrap_or("");
serde_urlencoded::from_str::<T>(query)
.map(Query)
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("failed to parse query string: {e}"),
)
})
}
}
impl FromRequestParts for http::HeaderMap {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
Ok(parts.headers.clone())
}
}
pub struct Extension<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
parts
.extensions
.get::<T>()
.cloned()
.map(Extension)
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"extension of type `{}` not found in request",
std::any::type_name::<T>()
),
)
})
}
}
pub trait NamedCookie: Sized + Send {
const COOKIE_NAME: &'static str;
fn from_value(value: &str) -> Result<Self, String>;
}
pub struct Cookie<T>(pub T);
impl<T: NamedCookie + 'static> FromRequestParts for Cookie<T> {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
let cookies = parts
.headers
.get(http::header::COOKIE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
for pair in cookies.split(';') {
let pair = pair.trim();
if let Some(value) = pair
.strip_prefix(T::COOKIE_NAME)
.and_then(|s| s.strip_prefix('='))
{
return T::from_value(value)
.map(Cookie)
.map_err(|e| (StatusCode::BAD_REQUEST, e));
}
}
Err((
StatusCode::BAD_REQUEST,
format!("missing cookie: {}", T::COOKIE_NAME),
))
}
}
pub struct CookieJar(pub std::collections::HashMap<String, String>);
impl CookieJar {
pub fn get(&self, name: &str) -> Option<&str> {
self.0.get(name).map(|s| s.as_str())
}
}
impl FromRequestParts for CookieJar {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
let cookies = parts
.headers
.get(http::header::COOKIE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let map = cookies
.split(';')
.filter_map(|pair| {
let pair = pair.trim();
let (name, value) = pair.split_once('=')?;
Some((name.to_string(), value.to_string()))
})
.collect();
Ok(CookieJar(map))
}
}
impl FromRequestParts for http::Method {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
Ok(parts.method.clone())
}
}
impl FromRequestParts for http::Uri {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
Ok(parts.uri.clone())
}
}
pub struct Header<T>(pub T);
pub trait NamedHeader: Sized + Send {
const HEADER_NAME: &'static str;
fn from_value(value: &str) -> Result<Self, String>;
}
impl<T: NamedHeader + 'static> FromRequestParts for Header<T> {
type Error = (StatusCode, String);
fn from_request_parts(parts: &Parts) -> Result<Self, Self::Error> {
let value = parts
.headers
.get(T::HEADER_NAME)
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
format!("missing required header: {}", T::HEADER_NAME),
)
})?
.to_str()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
format!("invalid header value for: {}", T::HEADER_NAME),
)
})?;
T::from_value(value)
.map(Header)
.map_err(|e| (StatusCode::BAD_REQUEST, e))
}
}
impl<T: DeserializeOwned + Send> FromRequest for crate::response::Json<T> {
type Error = (StatusCode, String);
async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
serde_json::from_slice(&body)
.map(crate::response::Json)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("invalid JSON: {e}")))
}
}
impl FromRequest for Bytes {
type Error = (StatusCode, String);
async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
Ok(body)
}
}
impl FromRequest for String {
type Error = (StatusCode, String);
async fn from_request(_parts: &Parts, body: bytes::Bytes) -> Result<Self, Self::Error> {
String::from_utf8(body.to_vec()).map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("request body is not valid UTF-8: {e}"),
)
})
}
}
impl FromRequest for () {
type Error = (StatusCode, String);
async fn from_request(_parts: &Parts, _body: bytes::Bytes) -> Result<Self, Self::Error> {
Ok(())
}
}