#[cfg(feature = "openapi")]
pub mod openapi;
use std::fmt::{Debug, Formatter};
use cot_core::handler::{BoxRequestHandler, into_box_request_handler};
use crate::error::MethodNotAllowed;
use crate::request::Request;
use crate::response::Response;
use crate::{Method, RequestHandler};
#[derive(Debug)]
#[must_use]
pub struct MethodRouter {
inner: InnerMethodRouter<InnerHandler>,
}
macro_rules! define_method {
($name:ident => $method:ident) => {
#[doc = concat!("Set a handler for the [`",
stringify!($method),
"`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
stringify!($method),
") HTTP method.")]
#[doc = concat!(
"let method_router = MethodRouter::new().",
stringify!($name),
"(test_handler);"
)]
#[doc = concat!(
"# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
stringify!($method),
")"
)]
pub fn $name<HandlerParams, H>(mut self, handler: H) -> Self
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
self.inner.$name = Some(InnerHandler::new(handler));
self
}
};
}
impl Default for MethodRouter {
fn default() -> Self {
Self::new()
}
}
impl MethodRouter {
pub fn new() -> Self {
Self {
inner: InnerMethodRouter::new(),
}
}
define_method!(get => GET);
define_method!(head => HEAD);
define_method!(delete => DELETE);
define_method!(options => OPTIONS);
define_method!(patch => PATCH);
define_method!(post => POST);
define_method!(put => PUT);
define_method!(trace => TRACE);
define_method!(connect => CONNECT);
pub fn fallback<HandlerParams, H>(mut self, handler: H) -> Self
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
self.inner.fallback = InnerHandler::new(handler);
self
}
}
impl RequestHandler for MethodRouter {
fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
self.inner.handle(request)
}
}
#[derive(Debug)]
#[must_use]
struct InnerMethodRouter<T> {
pub(self) get: Option<T>,
pub(self) head: Option<T>,
pub(self) delete: Option<T>,
pub(self) options: Option<T>,
pub(self) patch: Option<T>,
pub(self) post: Option<T>,
pub(self) put: Option<T>,
pub(self) trace: Option<T>,
pub(self) connect: Option<InnerHandler>,
pub(self) fallback: InnerHandler,
}
impl<T> InnerMethodRouter<T> {
pub(crate) fn new() -> Self {
Self {
get: None,
head: None,
delete: None,
options: None,
patch: None,
post: None,
put: None,
trace: None,
connect: None,
fallback: InnerHandler::new(default_fallback),
}
}
}
impl<T: RequestHandler + Send + Sync> RequestHandler for InnerMethodRouter<T> {
async fn handle(&self, request: Request) -> cot::Result<Response> {
macro_rules! handle_method {
($name:ident => $method:ident) => {
if request.method() == Method::$method {
if let Some(handler) = &self.$name {
return handler.handle(request).await;
}
}
};
}
handle_method!(get => GET);
handle_method!(head => HEAD);
handle_method!(delete => DELETE);
handle_method!(options => OPTIONS);
handle_method!(patch => PATCH);
handle_method!(post => POST);
handle_method!(put => PUT);
handle_method!(trace => TRACE);
handle_method!(connect => CONNECT);
if request.method() == Method::HEAD {
if let Some(handler) = &self.get {
return handler.handle(request).await;
}
}
self.fallback.handle(request).await
}
}
struct InnerHandler(Box<dyn BoxRequestHandler + Send + Sync>);
impl InnerHandler {
fn new<HandlerParams, H>(handler: H) -> Self
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
Self(Box::new(into_box_request_handler(handler)))
}
}
impl Debug for InnerHandler {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InnerHandler").finish_non_exhaustive()
}
}
impl RequestHandler for InnerHandler {
fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
self.0.handle(request)
}
}
macro_rules! define_method_router {
($name:ident => $method:ident) => {
#[doc = concat!(
"Create a new [`MethodRouter`] with a [`",
stringify!($method),
"`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
stringify!($method),
") handler."
)]
#[doc = concat!(
"This is a shorthand to calling [`MethodRouter::new`] and then [`MethodRouter::",
stringify!($name),
"`]."
)]
#[doc = concat!("use cot::router::method::", stringify!($name), ";")]
#[doc = concat!("let method_router = ", stringify!($name), "(test_handler);")]
#[doc = concat!(
"# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
stringify!($method),
")"
)]
pub fn $name<HandlerParams, H>(handler: H) -> MethodRouter
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
MethodRouter::new().$name(handler)
}
};
}
define_method_router!(get => GET);
define_method_router!(head => HEAD);
define_method_router!(delete => DELETE);
define_method_router!(options => OPTIONS);
define_method_router!(patch => PATCH);
define_method_router!(post => POST);
define_method_router!(put => PUT);
define_method_router!(trace => TRACE);
define_method_router!(connect => CONNECT);
async fn default_fallback(method: Method) -> crate::Error {
MethodNotAllowed::new(method).into()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::StatusCode;
use crate::html::Html;
use crate::test::TestRequestBuilder;
async fn test_handler(method: Method) -> Html {
Html::new(method.as_str())
}
#[test]
fn inner_handler_debug() {
let handler = InnerHandler::new(test_handler);
let debug_str = format!("{handler:?}");
assert_eq!(debug_str, "InnerHandler(..)");
}
#[cot::test]
async fn method_router_fallback() {
let router = MethodRouter::new();
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
#[cot::test]
async fn method_router_default_fallback() {
let router = MethodRouter::default();
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
#[cot::test]
async fn method_router_custom_fallback() {
let router = MethodRouter::new().fallback(test_handler);
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "GET");
}
#[cot::test]
async fn method_router_get() {
let router = get(test_handler);
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let methods = [
Method::DELETE,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
Method::TRACE,
Method::CONNECT,
];
for method in methods {
let request = TestRequestBuilder::with_method("/", method).build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
}
macro_rules! test_method_router {
($test_name:ident, $constructor_name:ident, $method_name:ident) => {
#[cot::test]
async fn $test_name() {
let router = $constructor_name(test_handler);
let request = TestRequestBuilder::with_method("/", Method::$method_name).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
};
}
test_method_router!(method_router_head, head, HEAD);
test_method_router!(method_router_delete, delete, DELETE);
test_method_router!(method_router_options, options, OPTIONS);
test_method_router!(method_router_patch, patch, PATCH);
test_method_router!(method_router_post, post, POST);
test_method_router!(method_router_put, put, PUT);
test_method_router!(method_router_trace, trace, TRACE);
test_method_router!(method_router_connect, connect, CONNECT);
#[cot::test]
async fn method_router_default_head() {
let router = MethodRouter::new();
let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
let router = get(test_handler);
let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[cot::test]
async fn method_router_multiple() {
let router = MethodRouter::new()
.get(test_handler)
.head(test_handler)
.delete(test_handler)
.options(test_handler)
.patch(test_handler)
.post(test_handler)
.put(test_handler)
.trace(test_handler)
.connect(test_handler);
for (method, expected_string) in [
(Method::GET, "GET"),
(Method::HEAD, "HEAD"),
(Method::DELETE, "DELETE"),
(Method::OPTIONS, "OPTIONS"),
(Method::PATCH, "PATCH"),
(Method::POST, "POST"),
(Method::PUT, "PUT"),
(Method::TRACE, "TRACE"),
(Method::CONNECT, "CONNECT"),
] {
let request = TestRequestBuilder::with_method("/", method).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
expected_string
);
}
}
}