use std::{future::Future, marker::PhantomData, sync::Arc};
use futures_util::{FutureExt, future::BoxFuture};
use super::{
After, AndThen, Around, Before, CatchAllError, CatchError, InspectAllError, InspectError, Map,
MapToResponse, ToResponse,
};
use crate::{
Error, IntoResponse, Middleware, Request, Response, Result,
error::IntoResult,
middleware::{AddData, AddDataEndpoint},
};
pub trait Endpoint: Send + Sync {
type Output: IntoResponse;
fn call(&self, req: Request) -> impl Future<Output = Result<Self::Output>> + Send;
fn get_response(&self, req: Request) -> impl Future<Output = Response> + Send {
async move {
self.call(req)
.await
.map(IntoResponse::into_response)
.unwrap_or_else(|err| err.into_response())
}
}
}
struct SyncFnEndpoint<T, F> {
_mark: PhantomData<T>,
f: F,
}
impl<F, T, R> Endpoint for SyncFnEndpoint<T, F>
where
F: Fn(Request) -> R + Send + Sync,
T: IntoResponse + Sync,
R: IntoResult<T>,
{
type Output = T;
async fn call(&self, req: Request) -> Result<Self::Output> {
(self.f)(req).into_result()
}
}
struct AsyncFnEndpoint<T, F> {
_mark: PhantomData<T>,
f: F,
}
impl<F, Fut, T, R> Endpoint for AsyncFnEndpoint<T, F>
where
F: Fn(Request) -> Fut + Sync + Send,
Fut: Future<Output = R> + Send,
T: IntoResponse + Sync,
R: IntoResult<T>,
{
type Output = T;
async fn call(&self, req: Request) -> Result<Self::Output> {
(self.f)(req).await.into_result()
}
}
pub enum EitherEndpoint<A, B> {
A(A),
B(B),
}
impl<A, B> Endpoint for EitherEndpoint<A, B>
where
A: Endpoint,
B: Endpoint,
{
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
match self {
EitherEndpoint::A(a) => a.call(req).await.map(IntoResponse::into_response),
EitherEndpoint::B(b) => b.call(req).await.map(IntoResponse::into_response),
}
}
}
pub fn make_sync<F, T, R>(f: F) -> impl Endpoint<Output = T>
where
F: Fn(Request) -> R + Send + Sync,
T: IntoResponse + Sync,
R: IntoResult<T>,
{
SyncFnEndpoint {
_mark: PhantomData,
f,
}
}
pub fn make<F, Fut, T, R>(f: F) -> impl Endpoint<Output = T>
where
F: Fn(Request) -> Fut + Send + Sync,
Fut: Future<Output = R> + Send,
T: IntoResponse + Sync,
R: IntoResult<T>,
{
AsyncFnEndpoint {
_mark: PhantomData,
f,
}
}
impl<T: Endpoint + ?Sized> Endpoint for &T {
type Output = T::Output;
async fn call(&self, req: Request) -> Result<Self::Output> {
T::call(self, req).await
}
}
impl<T: Endpoint + ?Sized> Endpoint for Box<T> {
type Output = T::Output;
async fn call(&self, req: Request) -> Result<Self::Output> {
self.as_ref().call(req).await
}
}
impl<T: Endpoint + ?Sized> Endpoint for Arc<T> {
type Output = T::Output;
async fn call(&self, req: Request) -> Result<Self::Output> {
self.as_ref().call(req).await
}
}
pub trait DynEndpoint: Send + Sync {
type Output: IntoResponse;
fn call(&self, req: Request) -> BoxFuture<'_, Result<Self::Output>>;
}
pub struct ToDynEndpoint<E>(pub E);
impl<E> DynEndpoint for ToDynEndpoint<E>
where
E: Endpoint,
{
type Output = E::Output;
#[inline]
fn call(&self, req: Request) -> BoxFuture<'_, Result<Self::Output>> {
self.0.call(req).boxed()
}
}
impl<T> Endpoint for dyn DynEndpoint<Output = T> + '_
where
T: IntoResponse,
{
type Output = T;
#[inline]
async fn call(&self, req: Request) -> Result<Self::Output> {
DynEndpoint::call(self, req).await
}
}
pub type BoxEndpoint<'a, T = Response> = Box<dyn DynEndpoint<Output = T> + 'a>;
pub trait EndpointExt: IntoEndpoint {
fn boxed<'a>(self) -> BoxEndpoint<'a, <Self::Endpoint as Endpoint>::Output>
where
Self: Sized + 'a,
{
Box::new(ToDynEndpoint(self.into_endpoint()))
}
fn with<T>(self, middleware: T) -> T::Output
where
T: Middleware<Self::Endpoint>,
Self: Sized,
{
middleware.transform(self.into_endpoint())
}
fn with_if<T>(self, enable: bool, middleware: T) -> EitherEndpoint<Self, T::Output>
where
T: Middleware<Self::Endpoint>,
Self: Sized,
{
if !enable {
EitherEndpoint::A(self)
} else {
EitherEndpoint::B(middleware.transform(self.into_endpoint()))
}
}
fn data<T>(self, data: T) -> AddDataEndpoint<Self::Endpoint, T>
where
T: Clone + Send + Sync + 'static,
Self: Sized,
{
self.with(AddData::new(data))
}
fn data_opt<T>(
self,
data: Option<T>,
) -> EitherEndpoint<AddDataEndpoint<Self::Endpoint, T>, Self>
where
T: Clone + Send + Sync + 'static,
Self: Sized,
{
match data {
Some(data) => EitherEndpoint::A(AddData::new(data).transform(self.into_endpoint())),
None => EitherEndpoint::B(self),
}
}
fn before<F, Fut>(self, f: F) -> Before<Self, F>
where
F: Fn(Request) -> Fut + Send + Sync,
Fut: Future<Output = Result<Request>> + Send,
Self: Sized,
{
Before::new(self, f)
}
fn after<F, Fut, T>(self, f: F) -> After<Self::Endpoint, F>
where
F: Fn(Result<<Self::Endpoint as Endpoint>::Output>) -> Fut + Send + Sync,
Fut: Future<Output = Result<T>> + Send,
T: IntoResponse,
Self: Sized,
{
After::new(self.into_endpoint(), f)
}
fn around<F, Fut, R>(self, f: F) -> Around<Self::Endpoint, F>
where
F: Fn(Arc<Self::Endpoint>, Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R>> + Send + 'static,
R: IntoResponse,
Self: Sized,
{
Around::new(self.into_endpoint(), f)
}
fn map_to_response(self) -> MapToResponse<Self::Endpoint>
where
Self: Sized,
{
MapToResponse::new(self.into_endpoint())
}
fn to_response(self) -> ToResponse<Self::Endpoint>
where
Self: Sized,
{
ToResponse::new(self.into_endpoint())
}
fn map<F, Fut, R, R2>(self, f: F) -> Map<Self::Endpoint, F>
where
F: Fn(R) -> Fut + Send + Sync,
Fut: Future<Output = R2> + Send,
R: IntoResponse,
R2: IntoResponse,
Self: Sized,
Self::Endpoint: Endpoint<Output = R> + Sized,
{
Map::new(self.into_endpoint(), f)
}
fn and_then<F, Fut, R, R2>(self, f: F) -> AndThen<Self::Endpoint, F>
where
F: Fn(R) -> Fut + Send + Sync,
Fut: Future<Output = Result<R2>> + Send,
R: IntoResponse,
R2: IntoResponse,
Self: Sized,
Self::Endpoint: Endpoint<Output = R> + Sized,
{
AndThen::new(self.into_endpoint(), f)
}
fn catch_all_error<F, Fut, R>(self, f: F) -> CatchAllError<Self, F, R>
where
F: Fn(Error) -> Fut + Send + Sync,
Fut: Future<Output = R> + Send,
R: IntoResponse + Send,
Self: Sized + Sync,
{
CatchAllError::new(self, f)
}
fn catch_error<F, Fut, R, ErrType>(self, f: F) -> CatchError<Self, F, R, ErrType>
where
F: Fn(ErrType) -> Fut + Send + Sync,
Fut: Future<Output = R> + Send,
R: IntoResponse + Send + Sync,
ErrType: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
CatchError::new(self, f)
}
fn inspect_all_err<F>(self, f: F) -> InspectAllError<Self, F>
where
F: Fn(&Error) + Send + Sync,
Self: Sized,
{
InspectAllError::new(self, f)
}
fn inspect_err<F, ErrType>(self, f: F) -> InspectError<Self, F, ErrType>
where
F: Fn(&ErrType) + Send + Sync,
ErrType: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
InspectError::new(self, f)
}
}
impl<T: IntoEndpoint> EndpointExt for T {}
pub trait IntoEndpoint {
type Endpoint: Endpoint;
fn into_endpoint(self) -> Self::Endpoint;
}
impl<T: Endpoint> IntoEndpoint for T {
type Endpoint = T;
fn into_endpoint(self) -> Self::Endpoint {
self
}
}
#[cfg(test)]
mod test {
use http::{HeaderValue, Uri};
use crate::{
Endpoint, EndpointExt, Error, IntoEndpoint, Request, Route,
endpoint::{make, make_sync},
get, handler,
http::{Method, StatusCode},
middleware::SetHeader,
test::TestClient,
web::Data,
};
#[tokio::test]
async fn test_make() {
let ep = make(|req| async move { format!("method={}", req.method()) }).map_to_response();
let mut resp = ep
.call(Request::builder().method(Method::DELETE).finish())
.await
.unwrap();
assert_eq!(
resp.take_body().into_string().await.unwrap(),
"method=DELETE"
);
}
#[tokio::test]
async fn test_before() {
assert_eq!(
make_sync(|req| req.method().to_string())
.before(|mut req| async move {
req.set_method(Method::POST);
Ok(req)
})
.call(Request::default())
.await
.unwrap(),
"POST"
);
}
#[tokio::test]
async fn test_after() {
assert_eq!(
make_sync(|_| "abc")
.after(|_| async { Ok::<_, Error>("def") })
.call(Request::default())
.await
.unwrap(),
"def"
);
}
#[tokio::test]
async fn test_map_to_response() {
assert_eq!(
make_sync(|_| "abc")
.map_to_response()
.call(Request::default())
.await
.unwrap()
.take_body()
.into_string()
.await
.unwrap(),
"abc"
);
}
#[tokio::test]
async fn test_and_then() {
assert_eq!(
make_sync(|_| "abc")
.and_then(|resp| async move { Ok(resp.to_string() + "def") })
.call(Request::default())
.await
.unwrap(),
"abcdef"
);
let resp = make_sync(|_| Err::<String, _>(Error::from_status(StatusCode::BAD_REQUEST)))
.and_then(|resp| async move { Ok(resp + "def") })
.get_response(Request::default())
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_map() {
assert_eq!(
make_sync(|_| "abc")
.map(|resp| async move { resp.to_string() + "def" })
.call(Request::default())
.await
.unwrap(),
"abcdef"
);
let resp = make_sync(|_| Err::<String, _>(Error::from_status(StatusCode::BAD_REQUEST)))
.map(|resp| async move { resp.to_string() + "def" })
.get_response(Request::default())
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_around() {
let ep = make(|req| async move { req.into_body().into_string().await.unwrap() + "b" });
assert_eq!(
ep.around(|ep, mut req| async move {
req.set_body("a");
let resp = ep.call(req).await?;
Ok(resp + "c")
})
.call(Request::default())
.await
.unwrap(),
"abc"
);
}
#[tokio::test]
async fn test_with_if() {
let resp = make_sync(|_| ())
.with_if(true, SetHeader::new().appending("a", 1))
.call(Request::default())
.await
.unwrap();
assert_eq!(
resp.headers().get("a"),
Some(&HeaderValue::from_static("1"))
);
let resp = make_sync(|_| ())
.with_if(false, SetHeader::new().appending("a", 1))
.call(Request::default())
.await
.unwrap();
assert_eq!(resp.headers().get("a"), None);
}
#[tokio::test]
async fn test_into_endpoint() {
struct MyEndpointFactory;
impl IntoEndpoint for MyEndpointFactory {
type Endpoint = Route;
fn into_endpoint(self) -> Self::Endpoint {
Route::new()
.at("/a", get(make_sync(|_| "a")))
.at("/b", get(make_sync(|_| "b")))
}
}
let app = Route::new().nest("/api", MyEndpointFactory);
assert_eq!(
app.call(Request::builder().uri(Uri::from_static("/api/a")).finish())
.await
.unwrap()
.take_body()
.into_string()
.await
.unwrap(),
"a"
);
assert_eq!(
app.call(Request::builder().uri(Uri::from_static("/api/b")).finish())
.await
.unwrap()
.take_body()
.into_string()
.await
.unwrap(),
"b"
);
}
#[tokio::test]
async fn test_data_opt() {
#[handler(internal)]
async fn index(data: Option<Data<&i32>>) -> String {
match data.as_deref() {
Some(value) => format!("{value}"),
None => "none".to_string(),
}
}
let cli = TestClient::new(index.data_opt(Some(100)));
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_text("100").await;
let cli = TestClient::new(index.data_opt(None::<i32>));
let resp = cli.get("/").send().await;
resp.assert_status_is_ok();
resp.assert_text("none").await;
}
}