memcached-async 0.0.1

Asynchronous memcached protocol parser
Documentation
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;

use crate::context::RequestContext;
use crate::error::Error;
use crate::response::{IntoResponse, Response};
use crate::types::Op;

/// Extract a typed value from a request context.
pub trait FromRequest<State>: Sized {
    type Rejection: IntoResponse;

    fn from_request(
        ctx: &mut RequestContext,
        state: &Arc<State>,
    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
}

/// Handler for a memcached operation.
pub trait Handler<State>: Send + Sync + 'static {
    fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
}

type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);

pub trait IntoHandler<State, Args>: Send + Sync + 'static {
    fn into_handler(self) -> Arc<dyn Handler<State>>;
}

struct HandlerFn0<F> {
    f: Arc<F>,
}

struct HandlerFn1<F, T1> {
    f: Arc<F>,
    _t1: PhantomData<fn(T1)>,
}

struct HandlerFn2<F, T1, T2> {
    f: Arc<F>,
    _t: PhantomData<fn(T1, T2)>,
}

struct HandlerFn3<F, T1, T2, T3> {
    f: Arc<F>,
    _t: PhantomData<fn(T1, T2, T3)>,
}

struct HandlerFn4<F, T1, T2, T3, T4> {
    f: Arc<F>,
    _t: PhantomData<fn(T1, T2, T3, T4)>,
}

struct HandlerFn5<F, T1, T2, T3, T4, T5> {
    f: Arc<F>,
    _t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
}

struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
    f: Arc<F>,
    _t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
}

macro_rules! impl_handler {
    ($name:ident, $( $ty:ident ),* ) => {
        #[allow(non_snake_case)]
        impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
        where
            F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
            Fut: Future<Output = R> + Send + 'static,
            R: IntoResponse,
            $( $ty: FromRequest<State> + Send + 'static, )*
            State: Send + Sync + 'static,
        {
            fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
                let f = Arc::clone(&self.f);
                Box::pin(async move {
                    $(
                        let $ty = match $ty::from_request(&mut ctx, &state).await {
                            Ok(value) => value,
                            Err(rejection) => return rejection.into_response(),
                        };
                    )*
                    f($( $ty ),*).await.into_response()
                })
            }
        }
    };
}

impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
where
    F: Send + Sync + 'static + Fn() -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    State: Send + Sync + 'static,
{
    fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
        let f = Arc::clone(&self.f);
        Box::pin(async move {
            let _ = ctx;
            f().await.into_response()
        })
    }
}

impl_handler!(HandlerFn1, T1);
impl_handler!(HandlerFn2, T1, T2);
impl_handler!(HandlerFn3, T1, T2, T3);
impl_handler!(HandlerFn4, T1, T2, T3, T4);
impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);

impl<State, F, Fut, R> IntoHandler<State, ()> for F
where
    F: Send + Sync + 'static + Fn() -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn0 { f: Arc::new(self) })
    }
}

impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
where
    F: Send + Sync + 'static + Fn(T1) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn1 {
            f: Arc::new(self),
            _t1: PhantomData,
        })
    }
}

impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
where
    F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    T2: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn2 {
            f: Arc::new(self),
            _t: PhantomData,
        })
    }
}

impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
where
    F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    T2: FromRequest<State> + Send + 'static,
    T3: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn3 {
            f: Arc::new(self),
            _t: PhantomData,
        })
    }
}

impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
where
    F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    T2: FromRequest<State> + Send + 'static,
    T3: FromRequest<State> + Send + 'static,
    T4: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn4 {
            f: Arc::new(self),
            _t: PhantomData,
        })
    }
}

impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
where
    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    T2: FromRequest<State> + Send + 'static,
    T3: FromRequest<State> + Send + 'static,
    T4: FromRequest<State> + Send + 'static,
    T5: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn5 {
            f: Arc::new(self),
            _t: PhantomData,
        })
    }
}

impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
where
    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
    Fut: Future<Output = R> + Send + 'static,
    R: IntoResponse,
    T1: FromRequest<State> + Send + 'static,
    T2: FromRequest<State> + Send + 'static,
    T3: FromRequest<State> + Send + 'static,
    T4: FromRequest<State> + Send + 'static,
    T5: FromRequest<State> + Send + 'static,
    T6: FromRequest<State> + Send + 'static,
    State: Send + Sync + 'static,
{
    fn into_handler(self) -> Arc<dyn Handler<State>> {
        Arc::new(HandlerFn6 {
            f: Arc::new(self),
            _t: PhantomData,
        })
    }
}

/// Router mapping operations to handlers.
pub struct Router<State> {
    state: Arc<State>,
    routes: HashMap<Op, Arc<dyn Handler<State>>>,
    fallback: Arc<dyn Handler<State>>,
}

impl<State> Router<State>
where
    State: Send + Sync + 'static,
{
    pub fn from_state(state: State) -> Self {
        Self {
            state: Arc::new(state),
            routes: HashMap::new(),
            fallback: default_fallback(),
        }
    }

    pub fn route<H, Args>(mut self, op: Op, handler: H) -> Self
    where
        H: IntoHandler<State, Args>,
    {
        self.routes.insert(op, handler.into_handler());
        self
    }

    pub fn fallback<H, Args>(mut self, handler: H) -> Self
    where
        H: IntoHandler<State, Args>,
    {
        self.fallback = handler.into_handler();
        self
    }

    pub fn state(&self) -> Arc<State> {
        Arc::clone(&self.state)
    }

    pub async fn call(&self, ctx: RequestContext) -> Response {
        let handler = self.routes.get(&ctx.request.op).unwrap_or(&self.fallback);
        handler.call(ctx, Arc::clone(&self.state)).await
    }
}

impl<State> Default for Router<State>
where
    State: Default + Send + Sync + 'static,
{
    fn default() -> Self {
        Self::from_state(State::default())
    }
}

fn default_fallback<State>() -> Arc<dyn Handler<State>>
where
    State: Send + Sync + 'static,
{
    Arc::new(FallbackHandler)
}

struct FallbackHandler;

impl<State> Handler<State> for FallbackHandler
where
    State: Send + Sync + 'static,
{
    fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
        Box::pin(async move {
            match ctx.request.op {
                Op::Noop | Op::MetaNoop | Op::Quit => Response::Noop,
                _ => Response::Error(Error::unknown("unknown command")),
            }
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::context::{Extensions, RequestContext};
    use crate::extract::Key;
    use crate::response::Stored;
    use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
    use bytes::Bytes;
    use std::net::{IpAddr, Ipv4Addr, SocketAddr};

    #[tokio::test]
    async fn router_dispatches_handler() {
        async fn get(Key(_key): Key) -> Stored {
            Stored
        }

        let router = Router::from_state(()).route(Op::Get, get);

        let req = {
            let mut req = Request::new(Op::Get);
            req.key = Some(Bytes::from_static(b"alpha"));
            req
        };
        let ctx = RequestContext {
            request: req,
            meta: RequestMeta {
                protocol: Protocol::Ascii,
                reply: ReplyMode::Always,
                opaque: None,
                return_key: false,
                opcode: 0,
            },
            peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
            local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
            client_id: 1,
            extensions: Extensions::default(),
        };

        let response = router.call(ctx).await;
        assert!(matches!(response, Response::Stored));
    }

    #[tokio::test]
    async fn router_fallback_unknown() {
        let router = Router::from_state(());
        let req = Request::new(Op::Unknown);
        let ctx = RequestContext {
            request: req,
            meta: RequestMeta {
                protocol: Protocol::Ascii,
                reply: ReplyMode::Always,
                opaque: None,
                return_key: false,
                opcode: 0,
            },
            peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
            local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
            client_id: 1,
            extensions: Extensions::default(),
        };
        let response = router.call(ctx).await;
        assert!(matches!(response, Response::Error(_)));
    }
}