use crate::error::BoxError;
use crate::Context;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
pub trait Service<S, Request>: Sized + Send + Sync + 'static {
type Response: Send + 'static;
type Error: Send + Sync + 'static;
fn serve(
&self,
ctx: Context<S>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_;
fn boxed(self) -> BoxService<S, Request, Self::Response, Self::Error> {
BoxService {
inner: Box::new(self),
}
}
}
impl<S, Request> Service<S, Request> for ()
where
S: Clone + Send + Sync + 'static,
Request: Send + 'static,
{
type Response = Request;
type Error = Infallible;
async fn serve(&self, _ctx: Context<S>, req: Request) -> Result<Self::Response, Self::Error> {
Ok(req)
}
}
impl<S, State, Request> Service<State, Request> for std::sync::Arc<S>
where
S: Service<State, Request>,
{
type Response = S::Response;
type Error = S::Error;
#[inline]
fn serve(
&self,
ctx: Context<State>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
self.as_ref().serve(ctx, req)
}
}
impl<S, State, Request> Service<State, Request> for &'static S
where
S: Service<State, Request>,
{
type Response = S::Response;
type Error = S::Error;
#[inline]
fn serve(
&self,
ctx: Context<State>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
(**self).serve(ctx, req)
}
}
impl<S, State, Request> Service<State, Request> for Box<S>
where
S: Service<State, Request>,
{
type Response = S::Response;
type Error = S::Error;
#[inline]
fn serve(
&self,
ctx: Context<State>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
self.as_ref().serve(ctx, req)
}
}
trait DynService<S, Request> {
type Response;
type Error;
#[allow(clippy::type_complexity)]
fn serve_box(
&self,
ctx: Context<S>,
req: Request,
) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>>;
}
impl<S, Request, T> DynService<S, Request> for T
where
T: Service<S, Request>,
{
type Response = T::Response;
type Error = T::Error;
fn serve_box(
&self,
ctx: Context<S>,
req: Request,
) -> Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + '_>> {
Box::pin(self.serve(ctx, req))
}
}
pub struct BoxService<S, Request, Response, Error> {
inner:
Box<dyn DynService<S, Request, Response = Response, Error = Error> + Send + Sync + 'static>,
}
impl<S, Request, Response, Error> BoxService<S, Request, Response, Error> {
pub fn new<T>(service: T) -> Self
where
T: Service<S, Request, Response = Response, Error = Error>,
{
Self {
inner: Box::new(service),
}
}
}
impl<S, Request, Response, Error> std::fmt::Debug for BoxService<S, Request, Response, Error> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BoxService").finish()
}
}
impl<S, Request, Response, Error> Service<S, Request> for BoxService<S, Request, Response, Error>
where
S: 'static,
Request: 'static,
Response: Send + 'static,
Error: Send + Sync + 'static,
{
type Response = Response;
type Error = Error;
fn serve(
&self,
ctx: Context<S>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
self.inner.serve_box(ctx, req)
}
}
macro_rules! impl_service_either {
($id:ident, $($param:ident),+ $(,)?) => {
impl<$($param),+, State, Request, Response> Service<State, Request> for crate::combinators::$id<$($param),+>
where
$(
$param: Service<State, Request, Response = Response, Error: Into<BoxError>>,
)+
Request: Send + 'static,
State: Clone + Send + Sync + 'static,
Response: Send + 'static,
{
type Response = Response;
type Error = BoxError;
async fn serve(&self, ctx: Context<State>, req: Request) -> Result<Self::Response, Self::Error> {
match self {
$(
crate::combinators::$id::$param(s) => s.serve(ctx, req).await.map_err(Into::into),
)+
}
}
}
};
}
crate::combinators::impl_either!(impl_service_either);
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
#[derive(Debug)]
struct AddSvc(usize);
impl Service<(), usize> for AddSvc {
type Response = usize;
type Error = Infallible;
async fn serve(
&self,
_ctx: Context<()>,
req: usize,
) -> Result<Self::Response, Self::Error> {
Ok(self.0 + req)
}
}
#[derive(Debug)]
struct MulSvc(usize);
impl Service<(), usize> for MulSvc {
type Response = usize;
type Error = Infallible;
async fn serve(
&self,
_ctx: Context<()>,
req: usize,
) -> Result<Self::Response, Self::Error> {
Ok(self.0 * req)
}
}
#[test]
fn assert_send() {
use rama_utils::test_helpers::*;
assert_send::<AddSvc>();
assert_send::<MulSvc>();
assert_send::<BoxService<(), (), (), ()>>();
}
#[test]
fn assert_sync() {
use rama_utils::test_helpers::*;
assert_sync::<AddSvc>();
assert_sync::<MulSvc>();
assert_sync::<BoxService<(), (), (), ()>>();
}
#[tokio::test]
async fn add_svc() {
let svc = AddSvc(1);
let ctx = Context::default();
let response = svc.serve(ctx, 1).await.unwrap();
assert_eq!(response, 2);
}
#[tokio::test]
async fn static_dispatch() {
let services = vec![AddSvc(1), AddSvc(2), AddSvc(3)];
let ctx = Context::default();
for (i, svc) in services.into_iter().enumerate() {
let response = svc.serve(ctx.clone(), i).await.unwrap();
assert_eq!(response, i * 2 + 1);
}
}
#[tokio::test]
async fn dynamic_dispatch() {
let services = vec![
AddSvc(1).boxed(),
AddSvc(2).boxed(),
AddSvc(3).boxed(),
MulSvc(4).boxed(),
MulSvc(5).boxed(),
];
let ctx = Context::default();
for (i, svc) in services.into_iter().enumerate() {
let response = svc.serve(ctx.clone(), i).await.unwrap();
if i < 3 {
assert_eq!(response, i * 2 + 1);
} else {
assert_eq!(response, i * (i + 1));
}
}
}
#[tokio::test]
async fn service_arc() {
let svc = std::sync::Arc::new(AddSvc(1));
let ctx = Context::default();
let response = svc.serve(ctx, 1).await.unwrap();
assert_eq!(response, 2);
}
#[tokio::test]
async fn box_service_arc() {
let svc = std::sync::Arc::new(AddSvc(1)).boxed();
let ctx = Context::default();
let response = svc.serve(ctx, 1).await.unwrap();
assert_eq!(response, 2);
}
}