use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use crate::context::RequestContext;
use crate::error::Error;
use crate::response::{IntoResponse, Response};
use crate::types::Op;
pub trait FromRequest<State>: Sized {
type Rejection: IntoResponse;
fn from_request(
ctx: &mut RequestContext,
state: &Arc<State>,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
}
pub trait Handler<State>: Send + Sync + 'static {
fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
}
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);
pub trait IntoHandler<State, Args>: Send + Sync + 'static {
fn into_handler(self) -> Arc<dyn Handler<State>>;
}
struct HandlerFn0<F> {
f: Arc<F>,
}
struct HandlerFn1<F, T1> {
f: Arc<F>,
_t1: PhantomData<fn(T1)>,
}
struct HandlerFn2<F, T1, T2> {
f: Arc<F>,
_t: PhantomData<fn(T1, T2)>,
}
struct HandlerFn3<F, T1, T2, T3> {
f: Arc<F>,
_t: PhantomData<fn(T1, T2, T3)>,
}
struct HandlerFn4<F, T1, T2, T3, T4> {
f: Arc<F>,
_t: PhantomData<fn(T1, T2, T3, T4)>,
}
struct HandlerFn5<F, T1, T2, T3, T4, T5> {
f: Arc<F>,
_t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
}
struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
f: Arc<F>,
_t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
}
macro_rules! impl_handler {
($name:ident, $( $ty:ident ),* ) => {
#[allow(non_snake_case)]
impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
where
F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
$( $ty: FromRequest<State> + Send + 'static, )*
State: Send + Sync + 'static,
{
fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
let f = Arc::clone(&self.f);
Box::pin(async move {
$(
let $ty = match $ty::from_request(&mut ctx, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*
f($( $ty ),*).await.into_response()
})
}
}
};
}
impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
where
F: Send + Sync + 'static + Fn() -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
State: Send + Sync + 'static,
{
fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
let f = Arc::clone(&self.f);
Box::pin(async move {
let _ = ctx;
f().await.into_response()
})
}
}
impl_handler!(HandlerFn1, T1);
impl_handler!(HandlerFn2, T1, T2);
impl_handler!(HandlerFn3, T1, T2, T3);
impl_handler!(HandlerFn4, T1, T2, T3, T4);
impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);
impl<State, F, Fut, R> IntoHandler<State, ()> for F
where
F: Send + Sync + 'static + Fn() -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn0 { f: Arc::new(self) })
}
}
impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
where
F: Send + Sync + 'static + Fn(T1) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn1 {
f: Arc::new(self),
_t1: PhantomData,
})
}
}
impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
where
F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
T2: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn2 {
f: Arc::new(self),
_t: PhantomData,
})
}
}
impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
where
F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
T2: FromRequest<State> + Send + 'static,
T3: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn3 {
f: Arc::new(self),
_t: PhantomData,
})
}
}
impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
where
F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
T2: FromRequest<State> + Send + 'static,
T3: FromRequest<State> + Send + 'static,
T4: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn4 {
f: Arc::new(self),
_t: PhantomData,
})
}
}
impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
where
F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
T2: FromRequest<State> + Send + 'static,
T3: FromRequest<State> + Send + 'static,
T4: FromRequest<State> + Send + 'static,
T5: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn5 {
f: Arc::new(self),
_t: PhantomData,
})
}
}
impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
where
F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
Fut: Future<Output = R> + Send + 'static,
R: IntoResponse,
T1: FromRequest<State> + Send + 'static,
T2: FromRequest<State> + Send + 'static,
T3: FromRequest<State> + Send + 'static,
T4: FromRequest<State> + Send + 'static,
T5: FromRequest<State> + Send + 'static,
T6: FromRequest<State> + Send + 'static,
State: Send + Sync + 'static,
{
fn into_handler(self) -> Arc<dyn Handler<State>> {
Arc::new(HandlerFn6 {
f: Arc::new(self),
_t: PhantomData,
})
}
}
pub struct Router<State> {
state: Arc<State>,
routes: HashMap<Op, Arc<dyn Handler<State>>>,
fallback: Arc<dyn Handler<State>>,
}
impl<State> Router<State>
where
State: Send + Sync + 'static,
{
pub fn from_state(state: State) -> Self {
Self {
state: Arc::new(state),
routes: HashMap::new(),
fallback: default_fallback(),
}
}
pub fn route<H, Args>(mut self, op: Op, handler: H) -> Self
where
H: IntoHandler<State, Args>,
{
self.routes.insert(op, handler.into_handler());
self
}
pub fn fallback<H, Args>(mut self, handler: H) -> Self
where
H: IntoHandler<State, Args>,
{
self.fallback = handler.into_handler();
self
}
pub fn state(&self) -> Arc<State> {
Arc::clone(&self.state)
}
pub async fn call(&self, ctx: RequestContext) -> Response {
let handler = self.routes.get(&ctx.request.op).unwrap_or(&self.fallback);
handler.call(ctx, Arc::clone(&self.state)).await
}
}
impl<State> Default for Router<State>
where
State: Default + Send + Sync + 'static,
{
fn default() -> Self {
Self::from_state(State::default())
}
}
fn default_fallback<State>() -> Arc<dyn Handler<State>>
where
State: Send + Sync + 'static,
{
Arc::new(FallbackHandler)
}
struct FallbackHandler;
impl<State> Handler<State> for FallbackHandler
where
State: Send + Sync + 'static,
{
fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
Box::pin(async move {
match ctx.request.op {
Op::Noop | Op::MetaNoop | Op::Quit => Response::Noop,
_ => Response::Error(Error::unknown("unknown command")),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{Extensions, RequestContext};
use crate::extract::Key;
use crate::response::Stored;
use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
use bytes::Bytes;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
#[tokio::test]
async fn router_dispatches_handler() {
async fn get(Key(_key): Key) -> Stored {
Stored
}
let router = Router::from_state(()).route(Op::Get, get);
let req = {
let mut req = Request::new(Op::Get);
req.key = Some(Bytes::from_static(b"alpha"));
req
};
let ctx = RequestContext {
request: req,
meta: RequestMeta {
protocol: Protocol::Ascii,
reply: ReplyMode::Always,
opaque: None,
return_key: false,
opcode: 0,
},
peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
client_id: 1,
extensions: Extensions::default(),
};
let response = router.call(ctx).await;
assert!(matches!(response, Response::Stored));
}
#[tokio::test]
async fn router_fallback_unknown() {
let router = Router::from_state(());
let req = Request::new(Op::Unknown);
let ctx = RequestContext {
request: req,
meta: RequestMeta {
protocol: Protocol::Ascii,
reply: ReplyMode::Always,
opaque: None,
return_key: false,
opcode: 0,
},
peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
client_id: 1,
extensions: Extensions::default(),
};
let response = router.call(ctx).await;
assert!(matches!(response, Response::Error(_)));
}
}