1use std::collections::HashMap;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::context::RequestContext;
8use crate::error::Error;
9use crate::response::{IntoResponse, Response};
10use crate::types::Op;
11
12pub trait FromRequest<State>: Sized {
14 type Rejection: IntoResponse;
15
16 fn from_request(
17 ctx: &mut RequestContext,
18 state: &Arc<State>,
19 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
20}
21
22pub trait Handler<State>: Send + Sync + 'static {
24 fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
25}
26
27type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
28type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
29type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);
30
31pub trait IntoHandler<State, Args>: Send + Sync + 'static {
32 fn into_handler(self) -> Arc<dyn Handler<State>>;
33}
34
35struct HandlerFn0<F> {
36 f: Arc<F>,
37}
38
39struct HandlerFn1<F, T1> {
40 f: Arc<F>,
41 _t1: PhantomData<fn(T1)>,
42}
43
44struct HandlerFn2<F, T1, T2> {
45 f: Arc<F>,
46 _t: PhantomData<fn(T1, T2)>,
47}
48
49struct HandlerFn3<F, T1, T2, T3> {
50 f: Arc<F>,
51 _t: PhantomData<fn(T1, T2, T3)>,
52}
53
54struct HandlerFn4<F, T1, T2, T3, T4> {
55 f: Arc<F>,
56 _t: PhantomData<fn(T1, T2, T3, T4)>,
57}
58
59struct HandlerFn5<F, T1, T2, T3, T4, T5> {
60 f: Arc<F>,
61 _t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
62}
63
64struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
65 f: Arc<F>,
66 _t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
67}
68
69macro_rules! impl_handler {
70 ($name:ident, $( $ty:ident ),* ) => {
71 #[allow(non_snake_case)]
72 impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
73 where
74 F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
75 Fut: Future<Output = R> + Send + 'static,
76 R: IntoResponse,
77 $( $ty: FromRequest<State> + Send + 'static, )*
78 State: Send + Sync + 'static,
79 {
80 fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
81 let f = Arc::clone(&self.f);
82 Box::pin(async move {
83 $(
84 let $ty = match $ty::from_request(&mut ctx, &state).await {
85 Ok(value) => value,
86 Err(rejection) => return rejection.into_response(),
87 };
88 )*
89 f($( $ty ),*).await.into_response()
90 })
91 }
92 }
93 };
94}
95
96impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
97where
98 F: Send + Sync + 'static + Fn() -> Fut,
99 Fut: Future<Output = R> + Send + 'static,
100 R: IntoResponse,
101 State: Send + Sync + 'static,
102{
103 fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
104 let f = Arc::clone(&self.f);
105 Box::pin(async move {
106 let _ = ctx;
107 f().await.into_response()
108 })
109 }
110}
111
112impl_handler!(HandlerFn1, T1);
113impl_handler!(HandlerFn2, T1, T2);
114impl_handler!(HandlerFn3, T1, T2, T3);
115impl_handler!(HandlerFn4, T1, T2, T3, T4);
116impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
117impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);
118
119impl<State, F, Fut, R> IntoHandler<State, ()> for F
120where
121 F: Send + Sync + 'static + Fn() -> Fut,
122 Fut: Future<Output = R> + Send + 'static,
123 R: IntoResponse,
124 State: Send + Sync + 'static,
125{
126 fn into_handler(self) -> Arc<dyn Handler<State>> {
127 Arc::new(HandlerFn0 { f: Arc::new(self) })
128 }
129}
130
131impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
132where
133 F: Send + Sync + 'static + Fn(T1) -> Fut,
134 Fut: Future<Output = R> + Send + 'static,
135 R: IntoResponse,
136 T1: FromRequest<State> + Send + 'static,
137 State: Send + Sync + 'static,
138{
139 fn into_handler(self) -> Arc<dyn Handler<State>> {
140 Arc::new(HandlerFn1 {
141 f: Arc::new(self),
142 _t1: PhantomData,
143 })
144 }
145}
146
147impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
148where
149 F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
150 Fut: Future<Output = R> + Send + 'static,
151 R: IntoResponse,
152 T1: FromRequest<State> + Send + 'static,
153 T2: FromRequest<State> + Send + 'static,
154 State: Send + Sync + 'static,
155{
156 fn into_handler(self) -> Arc<dyn Handler<State>> {
157 Arc::new(HandlerFn2 {
158 f: Arc::new(self),
159 _t: PhantomData,
160 })
161 }
162}
163
164impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
165where
166 F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
167 Fut: Future<Output = R> + Send + 'static,
168 R: IntoResponse,
169 T1: FromRequest<State> + Send + 'static,
170 T2: FromRequest<State> + Send + 'static,
171 T3: FromRequest<State> + Send + 'static,
172 State: Send + Sync + 'static,
173{
174 fn into_handler(self) -> Arc<dyn Handler<State>> {
175 Arc::new(HandlerFn3 {
176 f: Arc::new(self),
177 _t: PhantomData,
178 })
179 }
180}
181
182impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
183where
184 F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
185 Fut: Future<Output = R> + Send + 'static,
186 R: IntoResponse,
187 T1: FromRequest<State> + Send + 'static,
188 T2: FromRequest<State> + Send + 'static,
189 T3: FromRequest<State> + Send + 'static,
190 T4: FromRequest<State> + Send + 'static,
191 State: Send + Sync + 'static,
192{
193 fn into_handler(self) -> Arc<dyn Handler<State>> {
194 Arc::new(HandlerFn4 {
195 f: Arc::new(self),
196 _t: PhantomData,
197 })
198 }
199}
200
201impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
202where
203 F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
204 Fut: Future<Output = R> + Send + 'static,
205 R: IntoResponse,
206 T1: FromRequest<State> + Send + 'static,
207 T2: FromRequest<State> + Send + 'static,
208 T3: FromRequest<State> + Send + 'static,
209 T4: FromRequest<State> + Send + 'static,
210 T5: FromRequest<State> + Send + 'static,
211 State: Send + Sync + 'static,
212{
213 fn into_handler(self) -> Arc<dyn Handler<State>> {
214 Arc::new(HandlerFn5 {
215 f: Arc::new(self),
216 _t: PhantomData,
217 })
218 }
219}
220
221impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
222where
223 F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
224 Fut: Future<Output = R> + Send + 'static,
225 R: IntoResponse,
226 T1: FromRequest<State> + Send + 'static,
227 T2: FromRequest<State> + Send + 'static,
228 T3: FromRequest<State> + Send + 'static,
229 T4: FromRequest<State> + Send + 'static,
230 T5: FromRequest<State> + Send + 'static,
231 T6: FromRequest<State> + Send + 'static,
232 State: Send + Sync + 'static,
233{
234 fn into_handler(self) -> Arc<dyn Handler<State>> {
235 Arc::new(HandlerFn6 {
236 f: Arc::new(self),
237 _t: PhantomData,
238 })
239 }
240}
241
242pub struct Router<State> {
244 state: Arc<State>,
245 routes: HashMap<Op, Arc<dyn Handler<State>>>,
246 fallback: Arc<dyn Handler<State>>,
247}
248
249impl<State> Router<State>
250where
251 State: Send + Sync + 'static,
252{
253 pub fn from_state(state: State) -> Self {
254 Self {
255 state: Arc::new(state),
256 routes: HashMap::new(),
257 fallback: default_fallback(),
258 }
259 }
260
261 pub fn route<H, Args>(mut self, op: Op, handler: H) -> Self
262 where
263 H: IntoHandler<State, Args>,
264 {
265 self.routes.insert(op, handler.into_handler());
266 self
267 }
268
269 pub fn fallback<H, Args>(mut self, handler: H) -> Self
270 where
271 H: IntoHandler<State, Args>,
272 {
273 self.fallback = handler.into_handler();
274 self
275 }
276
277 pub fn state(&self) -> Arc<State> {
278 Arc::clone(&self.state)
279 }
280
281 pub async fn call(&self, ctx: RequestContext) -> Response {
282 let handler = self.routes.get(&ctx.request.op).unwrap_or(&self.fallback);
283 handler.call(ctx, Arc::clone(&self.state)).await
284 }
285}
286
287impl<State> Default for Router<State>
288where
289 State: Default + Send + Sync + 'static,
290{
291 fn default() -> Self {
292 Self::from_state(State::default())
293 }
294}
295
296fn default_fallback<State>() -> Arc<dyn Handler<State>>
297where
298 State: Send + Sync + 'static,
299{
300 Arc::new(FallbackHandler)
301}
302
303struct FallbackHandler;
304
305impl<State> Handler<State> for FallbackHandler
306where
307 State: Send + Sync + 'static,
308{
309 fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
310 Box::pin(async move {
311 match ctx.request.op {
312 Op::Noop | Op::MetaNoop | Op::Quit => Response::Noop,
313 _ => Response::Error(Error::unknown("unknown command")),
314 }
315 })
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::context::{Extensions, RequestContext};
323 use crate::extract::Key;
324 use crate::response::Stored;
325 use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
326 use bytes::Bytes;
327 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
328
329 #[tokio::test]
330 async fn router_dispatches_handler() {
331 async fn get(Key(_key): Key) -> Stored {
332 Stored
333 }
334
335 let router = Router::from_state(()).route(Op::Get, get);
336
337 let req = {
338 let mut req = Request::new(Op::Get);
339 req.key = Some(Bytes::from_static(b"alpha"));
340 req
341 };
342 let ctx = RequestContext {
343 request: req,
344 meta: RequestMeta {
345 protocol: Protocol::Ascii,
346 reply: ReplyMode::Always,
347 opaque: None,
348 return_key: false,
349 opcode: 0,
350 },
351 peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
352 local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
353 client_id: 1,
354 extensions: Extensions::default(),
355 };
356
357 let response = router.call(ctx).await;
358 assert!(matches!(response, Response::Stored));
359 }
360
361 #[tokio::test]
362 async fn router_fallback_unknown() {
363 let router = Router::from_state(());
364 let req = Request::new(Op::Unknown);
365 let ctx = RequestContext {
366 request: req,
367 meta: RequestMeta {
368 protocol: Protocol::Ascii,
369 reply: ReplyMode::Always,
370 opaque: None,
371 return_key: false,
372 opcode: 0,
373 },
374 peer_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
375 local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 11211),
376 client_id: 1,
377 extensions: Extensions::default(),
378 };
379 let response = router.call(ctx).await;
380 assert!(matches!(response, Response::Error(_)));
381 }
382}