use futures_util::future::BoxFuture;
use std::sync::Arc;
use crate::{Request, Response};
pub use self::from_request::FromRequest;
pub use self::types::*;
mod from_request;
mod types;
pub fn handler_from_extractor<Args, F, Fut, T>(
f: F,
) -> impl Fn(crate::Request) -> BoxFuture<'static, crate::Result<Response>> + Send + Sync + 'static
where
Args: FromRequest + Send + 'static,
<Args as FromRequest>::Rejection: Into<Response> + Send + 'static,
F: Fn(Args) -> Fut + Send + Sync + 'static,
Fut: core::future::Future<Output = crate::Result<T>> + Send + 'static,
T: crate::IntoResponse + Send + 'static,
{
let f = Arc::new(f);
move |mut req: Request| {
let f = f.clone();
Box::pin(async move {
match <Args as FromRequest>::from_request(&mut req).await {
Ok(args) => {
let res = f(args).await?;
Ok(res.into_response())
}
Err(rej) => Ok(rej.into()),
}
})
}
}
pub fn handler_from_extractor_with_request<Args, F, Fut, T>(
f: F,
) -> impl Fn(crate::Request) -> BoxFuture<'static, crate::Result<Response>> + Send + Sync + 'static
where
Args: FromRequest + Send + 'static,
<Args as FromRequest>::Rejection: Into<Response> + Send + 'static,
F: Fn(crate::Request, Args) -> Fut + Send + Sync + 'static,
Fut: core::future::Future<Output = crate::Result<T>> + Send + 'static,
T: crate::IntoResponse + Send + 'static,
{
let f = Arc::new(f);
move |mut req: Request| {
let f = f.clone();
Box::pin(async move {
match <Args as FromRequest>::from_request(&mut req).await {
Ok(args) => {
let res = f(req, args).await?;
Ok(res.into_response())
}
Err(rej) => Ok(rej.into()),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Request, Response};
use headers::UserAgent;
use serde::Deserialize;
#[derive(Deserialize)]
struct Page {
#[serde(default)]
page: u32,
#[serde(default)]
size: u32,
}
#[tokio::test]
async fn test_path_single_and_struct() {
let mut req = Request::empty();
req.set_path_params(
"id".to_owned(),
crate::core::path_param::PathParam::Int64(42),
);
let Path(id): Path<i64> = Path::from_request(&mut req).await.unwrap();
assert_eq!(id, 42);
let mut req = Request::empty();
req.set_path_params(
"id".to_owned(),
crate::core::path_param::PathParam::Int64(7),
);
req.set_path_params(
"name".to_owned(),
crate::core::path_param::PathParam::from("bob".to_string()),
);
#[derive(Deserialize)]
struct U {
id: i64,
name: String,
}
let Path(u): Path<U> = Path::from_request(&mut req).await.unwrap();
assert_eq!(u.id, 7);
assert_eq!(u.name, "bob");
}
#[tokio::test]
async fn test_query_and_json_and_form() {
let mut req = Request::empty();
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=1&size=20");
let Query(p): Query<Page> = Query::from_request(&mut req).await.unwrap();
assert_eq!(p.page, 1);
assert_eq!(p.size, 20);
#[derive(Deserialize, serde::Serialize)]
struct U {
name: String,
}
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/json"),
);
req.replace_body(crate::core::req_body::ReqBody::Once(
serde_json::to_vec(&U {
name: "alice".into(),
})
.unwrap()
.into(),
));
let Json(u): Json<U> = Json::from_request(&mut req).await.unwrap();
assert_eq!(u.name, "alice");
}
#[tokio::test]
async fn test_tuple_and_option_result() {
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(1));
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=2&size=3");
let (_a, _b): (Path<i32>, Query<Page>) =
<(Path<i32>, Query<Page>) as FromRequest>::from_request(&mut req)
.await
.unwrap();
let mut req = Request::empty();
let o: Option<Path<i32>> = Option::<Path<i32>>::from_request(&mut req).await.unwrap();
assert!(o.is_none());
let mut req = Request::empty();
let r: Result<Path<i32>, Response> =
<Result<Path<i32>, Response> as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert!(r.is_err());
}
#[tokio::test]
async fn test_form_and_typed_header_and_method_uri_version_remote() {
#[derive(Deserialize, serde::Serialize)]
struct U {
name: String,
age: u32,
}
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/x-www-form-urlencoded"),
);
req.replace_body(crate::core::req_body::ReqBody::Once(
"name=Alice&age=25".as_bytes().to_vec().into(),
));
let Form(u): Form<U> = Form::from_request(&mut req).await.unwrap();
assert_eq!(u.name, "Alice");
assert_eq!(u.age, 25);
let mut req = Request::empty();
req.headers_mut()
.insert("user-agent", http::HeaderValue::from_static("curl/8.0"));
let TypedHeader(ua): TypedHeader<UserAgent> =
TypedHeader::from_request(&mut req).await.unwrap();
assert_eq!(ua.as_str(), "curl/8.0");
let mut req = Request::empty();
*req.method_mut() = http::Method::POST;
*req.uri_mut() = http::Uri::from_static("http://localhost:8080/path?q=1");
let Method(m): Method = Method::from_request(&mut req).await.unwrap();
let Uri(u): Uri = Uri::from_request(&mut req).await.unwrap();
let Version(v): Version = Version::from_request(&mut req).await.unwrap();
assert_eq!(m, http::Method::POST);
assert_eq!(u.path(), "/path");
assert!(matches!(
v,
http::Version::HTTP_11 | http::Version::HTTP_10 | http::Version::HTTP_2
));
let mut req = Request::empty();
req.set_remote(
"127.0.0.1:9090"
.parse::<crate::core::remote_addr::RemoteAddr>()
.unwrap(),
);
let RemoteAddr(addr): RemoteAddr = RemoteAddr::from_request(&mut req).await.unwrap();
assert_eq!(addr.to_string(), "127.0.0.1:9090");
}
#[tokio::test]
async fn test_state_and_extension_and_request_ext() {
#[derive(Clone)]
struct CfgData(u32);
let mut req = Request::empty();
req.state_mut().insert(CfgData(9));
let State(CfgData(v)): State<CfgData> = State::from_request(&mut req).await.unwrap();
assert_eq!(v, 9);
#[derive(Clone)]
struct Ext(&'static str);
let mut req = Request::empty();
req.extensions_mut().insert(Ext("hello"));
let Extension(Ext(s)): Extension<Ext> = Extension::from_request(&mut req).await.unwrap();
assert_eq!(s, "hello");
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(5));
let Path(id): Path<i32> = RequestExt::extract(&mut req).await.unwrap();
assert_eq!(id, 5);
}
#[tokio::test]
async fn test_tuple_triple_quad_and_result_ok() {
#[derive(Deserialize)]
struct Q {
page: u32,
}
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(1));
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=3");
req.headers_mut()
.insert("user-agent", http::HeaderValue::from_static("ua"));
let (_a, _b, _c): (Path<i32>, Query<Q>, TypedHeader<UserAgent>) =
<(Path<i32>, Query<Q>, TypedHeader<UserAgent>) as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert_eq!(_b.0.page, 3);
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(1));
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=3");
req.headers_mut()
.insert("user-agent", http::HeaderValue::from_static("ua"));
let (_a, _b, _c, _d): (Path<i32>, Query<Q>, TypedHeader<UserAgent>, Method) =
<(Path<i32>, Query<Q>, TypedHeader<UserAgent>, Method) as FromRequest>::from_request(
&mut req,
)
.await
.unwrap();
#[derive(Deserialize, serde::Serialize)]
struct U {
name: String,
}
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/json"),
);
req.replace_body(crate::core::req_body::ReqBody::Once(
serde_json::to_vec(&U { name: "ok".into() }).unwrap().into(),
));
let r: Result<Json<U>, Response> =
<Result<Json<U>, Response> as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert!(matches!(r, Ok(Json(U { name })) if name == "ok"));
}
#[tokio::test]
async fn test_path_param_edge_cases() {
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::Int(-42),
);
let Path(val): Path<i32> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, -42);
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::UInt32(123),
);
let Path(val): Path<u32> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, 123);
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::Int64(i64::MIN),
);
let Path(val): Path<i64> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, i64::MIN);
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::UInt64(u64::MAX),
);
let Path(val): Path<u64> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, u64::MAX);
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::from("test-string".to_string()),
);
let Path(val): Path<String> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, "test-string");
let mut req = Request::empty();
req.set_path_params(
"val".to_owned(),
crate::core::path_param::PathParam::Path(crate::core::path_param::PathString::Owned(
"path/to/file".to_string(),
)),
);
let Path(val): Path<String> = Path::from_request(&mut req).await.unwrap();
assert_eq!(val, "path/to/file");
}
#[tokio::test]
async fn test_query_param_edge_cases() {
let mut req = Request::empty();
*req.uri_mut() = http::Uri::from_static("http://localhost/test");
let Query(params): Query<Page> = Query::from_request(&mut req).await.unwrap();
assert_eq!(params.page, 0);
assert_eq!(params.size, 0);
*req.uri_mut() =
http::Uri::from_static("http://localhost/test?name=hello%20world&value=123");
#[derive(serde::Deserialize)]
struct SpecialParams {
name: String,
value: i32,
}
let Query(params): Query<SpecialParams> = Query::from_request(&mut req).await.unwrap();
assert_eq!(params.name, "hello world");
assert_eq!(params.value, 123);
*req.uri_mut() = http::Uri::from_static("http://localhost/test?status=active");
#[derive(serde::Deserialize)]
enum Status {
#[serde(rename = "active")]
Active,
#[serde(rename = "inactive")]
Inactive,
}
#[derive(serde::Deserialize)]
struct EnumParam {
status: Status,
}
let Query(params): Query<EnumParam> = Query::from_request(&mut req).await.unwrap();
assert!(matches!(params.status, Status::Active));
}
#[tokio::test]
async fn test_json_and_form_error_cases() {
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/json"),
);
req.replace_body(crate::core::req_body::ReqBody::Once(
b"{invalid json}".to_vec().into(),
));
let result = Json::<serde_json::Value>::from_request(&mut req).await;
assert!(result.is_err());
let mut req = Request::empty();
req.replace_body(crate::core::req_body::ReqBody::Once(b"{}".to_vec().into()));
let result = Json::<serde_json::Value>::from_request(&mut req).await;
assert!(result.is_err());
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/x-www-form-urlencoded"),
);
req.replace_body(crate::core::req_body::ReqBody::Once(
b"{invalid}".to_vec().into(),
));
#[derive(serde::Deserialize, serde::Serialize)]
struct FormData {
key: String,
}
let result = Form::<FormData>::from_request(&mut req).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_complex_struct_parsing() {
let mut req = Request::empty();
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/json"),
);
let nested_data = serde_json::json!({
"user": {
"name": "Alice",
"age": 30
},
"settings": {
"theme": "dark",
"notifications": true
}
});
req.replace_body(crate::core::req_body::ReqBody::Once(
serde_json::to_vec(&nested_data).unwrap().into(),
));
#[derive(serde::Deserialize)]
struct User {
name: String,
age: u32,
}
#[derive(serde::Deserialize)]
struct Settings {
theme: String,
notifications: bool,
}
#[derive(serde::Deserialize)]
struct ComplexData {
user: User,
settings: Settings,
}
let Json(data): Json<ComplexData> = Json::from_request(&mut req).await.unwrap();
assert_eq!(data.user.name, "Alice");
assert_eq!(data.user.age, 30);
assert_eq!(data.settings.theme, "dark");
assert!(data.settings.notifications);
}
#[tokio::test]
async fn test_option_extractor_variations() {
let mut req = Request::empty();
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=1");
#[derive(serde::Deserialize)]
struct OptionalParams {
page: Option<u32>,
size: Option<u32>,
}
let Query(params): Query<OptionalParams> = Query::from_request(&mut req).await.unwrap();
assert_eq!(params.page, Some(1));
assert_eq!(params.size, None);
let mut req = Request::empty();
let result = Option::<Path<i32>>::from_request(&mut req).await.unwrap();
assert!(result.is_none());
let mut req = Request::empty();
let result = Option::<Json<serde_json::Value>>::from_request(&mut req)
.await
.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_result_extractor_variations() {
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(42));
let result: Result<Path<i32>, Response> =
<Result<Path<i32>, Response> as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert!(matches!(result, Ok(Path(42))));
let mut req = Request::empty();
let result: Result<Path<i32>, Response> =
<Result<Path<i32>, Response> as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_multiple_typed_headers() {
let mut req = Request::empty();
req.headers_mut()
.insert("user-agent", http::HeaderValue::from_static("Mozilla/5.0"));
req.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/json"),
);
let (TypedHeader(ua), TypedHeader(content_type)): (
TypedHeader<UserAgent>,
TypedHeader<headers::ContentType>,
) = <(
TypedHeader<UserAgent>,
TypedHeader<headers::ContentType>,
) as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert!(ua.as_str().contains("Mozilla"));
assert!(content_type.to_string().starts_with("application/json"));
}
#[tokio::test]
async fn test_deeply_nested_tuples() {
let mut req = Request::empty();
req.set_path_params("id".to_owned(), crate::core::path_param::PathParam::Int(1));
*req.uri_mut() = http::Uri::from_static("http://localhost/test?page=3");
req.headers_mut()
.insert("user-agent", http::HeaderValue::from_static("ua"));
req.headers_mut()
.insert("content-type", http::HeaderValue::from_static("text/html"));
req.state_mut().insert(StateData(99));
#[derive(serde::Deserialize)]
struct Q {
page: u32,
}
#[derive(Clone)]
struct StateData(u32);
type FourTupleResult = Result<
(
Path<i32>,
Query<Q>,
TypedHeader<UserAgent>,
State<StateData>,
),
Response,
>;
let result: FourTupleResult = <(
Path<i32>,
Query<Q>,
TypedHeader<UserAgent>,
State<StateData>,
) as FromRequest>::from_request(&mut req)
.await;
assert!(result.is_ok());
let (Path(id), Query(q), TypedHeader(ua), State(st)) = result.unwrap();
assert_eq!(id, 1);
assert_eq!(q.page, 3);
assert!(ua.as_str().contains("ua"));
assert_eq!(st.0, 99);
}
#[tokio::test]
async fn test_mixed_extractors_with_request() {
let mut req = Request::empty();
req.set_path_params(
"id".to_owned(),
crate::core::path_param::PathParam::Int(123),
);
*req.uri_mut() = http::Uri::from_static("http://localhost:8080/test?page=1&size=10");
req.headers_mut()
.insert("x-test", http::HeaderValue::from_static("value"));
let (Path(path_id), Query(query_params), Method(method)): (Path<i32>, Query<Page>, Method) =
<(Path<i32>, Query<Page>, Method) as FromRequest>::from_request(&mut req)
.await
.unwrap();
assert_eq!(path_id, 123);
assert_eq!(query_params.page, 1);
assert_eq!(query_params.size, 10);
assert_eq!(method, http::Method::GET);
}
#[tokio::test]
async fn test_extension_with_different_types() {
#[derive(Clone)]
struct UserId(String);
#[derive(Clone)]
struct Permission(u32);
let mut req = Request::empty();
req.extensions_mut().insert(UserId("user-123".to_string()));
req.extensions_mut().insert(Permission(777));
let Extension(user_id): Extension<UserId> =
Extension::from_request(&mut req).await.unwrap();
let Extension(permission): Extension<Permission> =
Extension::from_request(&mut req).await.unwrap();
assert_eq!(user_id.0, "user-123");
assert_eq!(permission.0, 777);
#[derive(Clone)]
struct NonExistent;
let result = Extension::<NonExistent>::from_request(&mut req).await;
assert!(result.is_err());
}
}