use std::future::Future;
use http::header::LOCATION;
use http::{StatusCode, Uri};
use crate::{async_trait, throw, Context, Result, Status};
#[async_trait(?Send)]
pub trait Middleware<'a, S = ()>: 'static + Sync + Send {
async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result;
}
#[async_trait(?Send)]
impl<'a, S, T, F> Middleware<'a, S> for T
where
S: 'a,
T: 'static + Send + Sync + Fn(&'a mut Context<S>, Next<'a>) -> F,
F: 'a + Future<Output = Result>,
{
#[inline]
async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
(self)(ctx, next).await
}
}
#[async_trait(?Send)]
pub trait Endpoint<'a, S = ()>: 'static + Sync + Send {
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result;
}
#[async_trait(?Send)]
impl<'a, S, T, F> Endpoint<'a, S> for T
where
S: 'a,
T: 'static + Send + Sync + Fn(&'a mut Context<S>) -> F,
F: 'a + Future<Output = Result>,
{
#[inline]
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
(self)(ctx).await
}
}
#[async_trait(?Send)]
impl<'a, S> Middleware<'a, S> for () {
#[allow(clippy::trivially_copy_pass_by_ref)]
#[inline]
async fn handle(&'a self, _ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
next.await
}
}
#[async_trait(?Send)]
impl<'a, S> Endpoint<'a, S> for () {
#[allow(clippy::trivially_copy_pass_by_ref)]
#[inline]
async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
Ok(())
}
}
#[async_trait(?Send)]
impl<'a, S> Endpoint<'a, S> for Status {
#[inline]
async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
Err(self.clone())
}
}
#[async_trait(?Send)]
impl<'a, S> Endpoint<'a, S> for String {
#[inline]
#[allow(clippy::ptr_arg)]
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
ctx.resp.write(self.clone());
Ok(())
}
}
#[async_trait(?Send)]
impl<'a, S> Endpoint<'a, S> for &'static str {
#[inline]
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
ctx.resp.write(*self);
Ok(())
}
}
#[async_trait(?Send)]
impl<'a, S> Endpoint<'a, S> for Uri {
#[inline]
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
ctx.resp.headers.insert(LOCATION, self.to_string().parse()?);
throw!(StatusCode::PERMANENT_REDIRECT)
}
}
pub type Next<'a> = &'a mut (dyn Unpin + Future<Output = Result>);
#[cfg(test)]
mod tests {
use futures::{AsyncReadExt, TryStreamExt};
use http::header::LOCATION;
use http::{StatusCode, Uri};
use crate::{status, App, Request};
const HELLO: &str = "Hello, world";
#[tokio::test]
async fn status_endpoint() {
let app = App::new().end(status!(StatusCode::BAD_REQUEST));
let service = app.http_service();
let resp = service.serve(Request::default()).await;
assert_eq!(StatusCode::BAD_REQUEST, resp.status);
}
#[tokio::test]
async fn string_endpoint() {
let app = App::new().end(HELLO.to_owned());
let service = app.http_service();
let mut data = String::new();
service
.serve(Request::default())
.await
.body
.into_async_read()
.read_to_string(&mut data)
.await
.unwrap();
assert_eq!(HELLO, data);
}
#[tokio::test]
async fn static_slice_endpoint() {
let app = App::new().end(HELLO);
let service = app.http_service();
let mut data = String::new();
service
.serve(Request::default())
.await
.body
.into_async_read()
.read_to_string(&mut data)
.await
.unwrap();
assert_eq!(HELLO, data);
}
#[tokio::test]
async fn redirect_endpoint() {
let app = App::new().end("/target".parse::<Uri>().unwrap());
let service = app.http_service();
let resp = service.serve(Request::default()).await;
assert_eq!(StatusCode::PERMANENT_REDIRECT, resp.status);
assert_eq!("/target", resp.headers[LOCATION].to_str().unwrap())
}
}