use std::future::Future;
use futures_util::{FutureExt, future::Either};
use crate::{
Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, endpoint::BoxEndpoint,
error::MethodNotAllowedError, http::Method,
};
#[derive(Default)]
pub struct RouteMethod {
methods: Vec<(Method, BoxEndpoint<'static>)>,
}
impl RouteMethod {
pub fn new() -> Self {
Default::default()
}
#[must_use]
pub fn method<E>(mut self, method: Method, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.methods
.push((method, ep.into_endpoint().map_to_response().boxed()));
self
}
#[must_use]
pub fn get<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::GET, ep)
}
#[must_use]
pub fn post<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::POST, ep)
}
#[must_use]
pub fn put<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::PUT, ep)
}
#[must_use]
pub fn delete<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::DELETE, ep)
}
#[must_use]
pub fn head<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::HEAD, ep)
}
#[must_use]
pub fn options<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::OPTIONS, ep)
}
#[must_use]
pub fn connect<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::CONNECT, ep)
}
#[must_use]
pub fn patch<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::PATCH, ep)
}
#[must_use]
pub fn trace<E>(self, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.method(Method::TRACE, ep)
}
}
impl Endpoint for RouteMethod {
type Output = Response;
fn call(&self, mut req: Request) -> impl Future<Output = Result<Self::Output>> + Send {
match self
.methods
.iter()
.find(|(method, _)| method == req.method())
.map(|(_, ep)| ep)
{
Some(ep) => Either::Left(ep.call(req)),
None => {
if req.method() == Method::HEAD {
Either::Right(Either::Left(
async move {
req.set_method(Method::GET);
let mut resp = self.call(req).await?;
resp.set_body(());
Ok(resp)
}
.boxed(),
))
} else {
Either::Right(Either::Right(async { Err(MethodNotAllowedError.into()) }))
}
}
}
}
}
pub fn get<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().get(ep)
}
pub fn post<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().post(ep)
}
pub fn put<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().put(ep)
}
pub fn delete<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().delete(ep)
}
pub fn head<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().head(ep)
}
pub fn options<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().options(ep)
}
pub fn connect<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().connect(ep)
}
pub fn patch<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().patch(ep)
}
pub fn trace<E>(ep: E) -> RouteMethod
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
RouteMethod::new().trace(ep)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler, http::StatusCode, test::TestClient};
#[tokio::test]
async fn method_not_allowed() {
let resp = TestClient::new(RouteMethod::new()).get("/").send().await;
resp.assert_status(StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn route_method() {
#[handler(internal)]
fn index() -> &'static str {
"hello"
}
for method in &[
Method::GET,
Method::POST,
Method::DELETE,
Method::PUT,
Method::HEAD,
Method::OPTIONS,
Method::CONNECT,
Method::PATCH,
Method::TRACE,
] {
let route = RouteMethod::new().method(method.clone(), index).post(index);
let resp = TestClient::new(route)
.request(method.clone(), "/")
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("hello").await;
}
macro_rules! test_method {
($(($id:ident, $method:ident)),*) => {
$(
let route = RouteMethod::new().$id(index).post(index);
let resp = TestClient::new(route).request(Method::$method, "/").send().await;
resp.assert_status_is_ok();
resp.assert_text("hello").await;
)*
};
}
test_method!(
(get, GET),
(post, POST),
(delete, DELETE),
(put, PUT),
(head, HEAD),
(options, OPTIONS),
(connect, CONNECT),
(patch, PATCH),
(trace, TRACE)
);
}
#[tokio::test]
async fn head_method() {
#[handler(internal)]
fn index() -> &'static str {
"hello"
}
let route = RouteMethod::new().get(index);
let resp = TestClient::new(route).head("/").send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
}
}