Skip to main content

memcached_async/
router.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::context::RequestContext;
8use crate::error::Error;
9use crate::response::{IntoResponse, Response};
10use crate::types::Op;
11
12/// Extract a typed value from a request context.
13pub trait FromRequest<State>: Sized {
14    type Rejection: IntoResponse;
15
16    fn from_request(
17        ctx: &mut RequestContext,
18        state: &Arc<State>,
19    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
20}
21
22/// Handler for a memcached operation.
23pub trait Handler<State>: Send + Sync + 'static {
24    fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
25}
26
27type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
28type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
29type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);
30
31pub trait IntoHandler<State, Args>: Send + Sync + 'static {
32    fn into_handler(self) -> Arc<dyn Handler<State>>;
33}
34
35struct HandlerFn0<F> {
36    f: Arc<F>,
37}
38
39struct HandlerFn1<F, T1> {
40    f: Arc<F>,
41    _t1: PhantomData<fn(T1)>,
42}
43
44struct HandlerFn2<F, T1, T2> {
45    f: Arc<F>,
46    _t: PhantomData<fn(T1, T2)>,
47}
48
49struct HandlerFn3<F, T1, T2, T3> {
50    f: Arc<F>,
51    _t: PhantomData<fn(T1, T2, T3)>,
52}
53
54struct HandlerFn4<F, T1, T2, T3, T4> {
55    f: Arc<F>,
56    _t: PhantomData<fn(T1, T2, T3, T4)>,
57}
58
59struct HandlerFn5<F, T1, T2, T3, T4, T5> {
60    f: Arc<F>,
61    _t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
62}
63
64struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
65    f: Arc<F>,
66    _t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
67}
68
69macro_rules! impl_handler {
70    ($name:ident, $( $ty:ident ),* ) => {
71        #[allow(non_snake_case)]
72        impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
73        where
74            F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
75            Fut: Future<Output = R> + Send + 'static,
76            R: IntoResponse,
77            $( $ty: FromRequest<State> + Send + 'static, )*
78            State: Send + Sync + 'static,
79        {
80            fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
81                let f = Arc::clone(&self.f);
82                Box::pin(async move {
83                    $(
84                        let $ty = match $ty::from_request(&mut ctx, &state).await {
85                            Ok(value) => value,
86                            Err(rejection) => return rejection.into_response(),
87                        };
88                    )*
89                    f($( $ty ),*).await.into_response()
90                })
91            }
92        }
93    };
94}
95
96impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
97where
98    F: Send + Sync + 'static + Fn() -> Fut,
99    Fut: Future<Output = R> + Send + 'static,
100    R: IntoResponse,
101    State: Send + Sync + 'static,
102{
103    fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
104        let f = Arc::clone(&self.f);
105        Box::pin(async move {
106            let _ = ctx;
107            f().await.into_response()
108        })
109    }
110}
111
112impl_handler!(HandlerFn1, T1);
113impl_handler!(HandlerFn2, T1, T2);
114impl_handler!(HandlerFn3, T1, T2, T3);
115impl_handler!(HandlerFn4, T1, T2, T3, T4);
116impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
117impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);
118
119impl<State, F, Fut, R> IntoHandler<State, ()> for F
120where
121    F: Send + Sync + 'static + Fn() -> Fut,
122    Fut: Future<Output = R> + Send + 'static,
123    R: IntoResponse,
124    State: Send + Sync + 'static,
125{
126    fn into_handler(self) -> Arc<dyn Handler<State>> {
127        Arc::new(HandlerFn0 { f: Arc::new(self) })
128    }
129}
130
131impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
132where
133    F: Send + Sync + 'static + Fn(T1) -> Fut,
134    Fut: Future<Output = R> + Send + 'static,
135    R: IntoResponse,
136    T1: FromRequest<State> + Send + 'static,
137    State: Send + Sync + 'static,
138{
139    fn into_handler(self) -> Arc<dyn Handler<State>> {
140        Arc::new(HandlerFn1 {
141            f: Arc::new(self),
142            _t1: PhantomData,
143        })
144    }
145}
146
147impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
148where
149    F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
150    Fut: Future<Output = R> + Send + 'static,
151    R: IntoResponse,
152    T1: FromRequest<State> + Send + 'static,
153    T2: FromRequest<State> + Send + 'static,
154    State: Send + Sync + 'static,
155{
156    fn into_handler(self) -> Arc<dyn Handler<State>> {
157        Arc::new(HandlerFn2 {
158            f: Arc::new(self),
159            _t: PhantomData,
160        })
161    }
162}
163
164impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
165where
166    F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
167    Fut: Future<Output = R> + Send + 'static,
168    R: IntoResponse,
169    T1: FromRequest<State> + Send + 'static,
170    T2: FromRequest<State> + Send + 'static,
171    T3: FromRequest<State> + Send + 'static,
172    State: Send + Sync + 'static,
173{
174    fn into_handler(self) -> Arc<dyn Handler<State>> {
175        Arc::new(HandlerFn3 {
176            f: Arc::new(self),
177            _t: PhantomData,
178        })
179    }
180}
181
182impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
183where
184    F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
185    Fut: Future<Output = R> + Send + 'static,
186    R: IntoResponse,
187    T1: FromRequest<State> + Send + 'static,
188    T2: FromRequest<State> + Send + 'static,
189    T3: FromRequest<State> + Send + 'static,
190    T4: FromRequest<State> + Send + 'static,
191    State: Send + Sync + 'static,
192{
193    fn into_handler(self) -> Arc<dyn Handler<State>> {
194        Arc::new(HandlerFn4 {
195            f: Arc::new(self),
196            _t: PhantomData,
197        })
198    }
199}
200
201impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
202where
203    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
204    Fut: Future<Output = R> + Send + 'static,
205    R: IntoResponse,
206    T1: FromRequest<State> + Send + 'static,
207    T2: FromRequest<State> + Send + 'static,
208    T3: FromRequest<State> + Send + 'static,
209    T4: FromRequest<State> + Send + 'static,
210    T5: FromRequest<State> + Send + 'static,
211    State: Send + Sync + 'static,
212{
213    fn into_handler(self) -> Arc<dyn Handler<State>> {
214        Arc::new(HandlerFn5 {
215            f: Arc::new(self),
216            _t: PhantomData,
217        })
218    }
219}
220
221impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
222where
223    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
224    Fut: Future<Output = R> + Send + 'static,
225    R: IntoResponse,
226    T1: FromRequest<State> + Send + 'static,
227    T2: FromRequest<State> + Send + 'static,
228    T3: FromRequest<State> + Send + 'static,
229    T4: FromRequest<State> + Send + 'static,
230    T5: FromRequest<State> + Send + 'static,
231    T6: FromRequest<State> + Send + 'static,
232    State: Send + Sync + 'static,
233{
234    fn into_handler(self) -> Arc<dyn Handler<State>> {
235        Arc::new(HandlerFn6 {
236            f: Arc::new(self),
237            _t: PhantomData,
238        })
239    }
240}
241
242/// Router mapping operations to handlers.
243pub struct Router<State> {
244    state: Arc<State>,
245    routes: HashMap<Op, Arc<dyn Handler<State>>>,
246    fallback: Arc<dyn Handler<State>>,
247}
248
249impl<State> Router<State>
250where
251    State: Send + Sync + 'static,
252{
253    pub fn from_state(state: State) -> Self {
254        Self {
255            state: Arc::new(state),
256            routes: HashMap::new(),
257            fallback: default_fallback(),
258        }
259    }
260
261    pub fn route<H, Args>(mut self, op: Op, handler: H) -> Self
262    where
263        H: IntoHandler<State, Args>,
264    {
265        self.routes.insert(op, handler.into_handler());
266        self
267    }
268
269    pub fn fallback<H, Args>(mut self, handler: H) -> Self
270    where
271        H: IntoHandler<State, Args>,
272    {
273        self.fallback = handler.into_handler();
274        self
275    }
276
277    pub fn state(&self) -> Arc<State> {
278        Arc::clone(&self.state)
279    }
280
281    pub async fn call(&self, ctx: RequestContext) -> Response {
282        let handler = self.routes.get(&ctx.request.op).unwrap_or(&self.fallback);
283        handler.call(ctx, Arc::clone(&self.state)).await
284    }
285}
286
287impl<State> Default for Router<State>
288where
289    State: Default + Send + Sync + 'static,
290{
291    fn default() -> Self {
292        Self::from_state(State::default())
293    }
294}
295
296fn default_fallback<State>() -> Arc<dyn Handler<State>>
297where
298    State: Send + Sync + 'static,
299{
300    Arc::new(FallbackHandler)
301}
302
303struct FallbackHandler;
304
305impl<State> Handler<State> for FallbackHandler
306where
307    State: Send + Sync + 'static,
308{
309    fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
310        Box::pin(async move {
311            match ctx.request.op {
312                Op::Noop | Op::MetaNoop | Op::Quit => Response::Noop,
313                _ => Response::Error(Error::unknown("unknown command")),
314            }
315        })
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::context::{Extensions, RequestContext};
323    use crate::extract::Key;
324    use crate::response::Stored;
325    use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
326    use bytes::Bytes;
327    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
328
329    #[tokio::test]
330    async fn router_dispatches_handler() {
331        async fn get(Key(_key): Key) -> Stored {
332            Stored
333        }
334
335        let router = Router::from_state(()).route(Op::Get, get);
336
337        let req = {
338            let mut req = Request::new(Op::Get);
339            req.key = Some(Bytes::from_static(b"alpha"));
340            req
341        };
342        let ctx = RequestContext {
343            request: req,
344            meta: RequestMeta {
345                protocol: Protocol::Ascii,
346                reply: ReplyMode::Always,
347                opaque: None,
348                return_key: false,
349                opcode: 0,
350            },
351            peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
352            local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
353            client_id: 1,
354            extensions: Extensions::default(),
355        };
356
357        let response = router.call(ctx).await;
358        assert!(matches!(response, Response::Stored));
359    }
360
361    #[tokio::test]
362    async fn router_fallback_unknown() {
363        let router = Router::from_state(());
364        let req = Request::new(Op::Unknown);
365        let ctx = RequestContext {
366            request: req,
367            meta: RequestMeta {
368                protocol: Protocol::Ascii,
369                reply: ReplyMode::Always,
370                opaque: None,
371                return_key: false,
372                opcode: 0,
373            },
374            peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
375            local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
376            client_id: 1,
377            extensions: Extensions::default(),
378        };
379        let response = router.call(ctx).await;
380        assert!(matches!(response, Response::Error(_)));
381    }
382}