Skip to main content

resp_async/
router.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::context::{
11    ClientId, Cmd, Command, Extensions, LocalAddr, PeerAddr, PubSubHandle, PushHandle,
12    RequestContext, State as AppState,
13};
14use crate::resp::Value;
15use crate::response::{IntoResponse, RespError, Response};
16
17/// Extract a typed value from a request context.
18pub trait FromRequest<State>: Sized {
19    type Rejection: IntoResponse;
20
21    fn from_request(
22        ctx: &mut RequestContext,
23        state: &Arc<State>,
24    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
25}
26
27/// Handler for a RESP command.
28pub trait Handler<State>: Send + Sync + 'static {
29    fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
30}
31
32type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
33type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
34type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);
35
36pub trait IntoHandler<State, Args>: Send + Sync + 'static {
37    fn into_handler(self) -> Arc<dyn Handler<State>>;
38}
39
40struct HandlerFn0<F> {
41    f: Arc<F>,
42}
43
44struct HandlerFn1<F, T1> {
45    f: Arc<F>,
46    _t1: PhantomData<fn(T1)>,
47}
48
49struct HandlerFn2<F, T1, T2> {
50    f: Arc<F>,
51    _t: PhantomData<fn(T1, T2)>,
52}
53
54struct HandlerFn3<F, T1, T2, T3> {
55    f: Arc<F>,
56    _t: PhantomData<fn(T1, T2, T3)>,
57}
58
59struct HandlerFn4<F, T1, T2, T3, T4> {
60    f: Arc<F>,
61    _t: PhantomData<fn(T1, T2, T3, T4)>,
62}
63
64struct HandlerFn5<F, T1, T2, T3, T4, T5> {
65    f: Arc<F>,
66    _t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
67}
68
69struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
70    f: Arc<F>,
71    _t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
72}
73
74macro_rules! impl_handler {
75    ($name:ident, $( $ty:ident ),* ) => {
76        #[allow(non_snake_case)]
77        impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
78        where
79            F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
80            Fut: Future<Output = R> + Send + 'static,
81            R: IntoResponse,
82            $( $ty: FromRequest<State> + Send + 'static, )*
83            State: Send + Sync + 'static,
84        {
85    fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
86        let f = Arc::clone(&self.f);
87        Box::pin(async move {
88            log_handler_start(&ctx);
89            $(
90                let $ty = match $ty::from_request(&mut ctx, &state).await {
91                    Ok(value) => value,
92                    Err(rejection) => {
93                        let response = rejection.into_response();
94                        log_handler_result(&ctx, &response);
95                        return response;
96                    }
97                };
98            )*
99
100            let response = f($( $ty ),*).await.into_response();
101            log_handler_result(&ctx, &response);
102            response
103        })
104    }
105        }
106    };
107}
108
109impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
110where
111    F: Send + Sync + 'static + Fn() -> Fut,
112    Fut: Future<Output = R> + Send + 'static,
113    R: IntoResponse,
114    State: Send + Sync + 'static,
115{
116    fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
117        let f = Arc::clone(&self.f);
118        Box::pin(async move {
119            log_handler_start(&ctx);
120            let response = f().await.into_response();
121            log_handler_result(&ctx, &response);
122            response
123        })
124    }
125}
126
127impl_handler!(HandlerFn1, T1);
128impl_handler!(HandlerFn2, T1, T2);
129impl_handler!(HandlerFn3, T1, T2, T3);
130impl_handler!(HandlerFn4, T1, T2, T3, T4);
131impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
132impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);
133
134impl<State, F, Fut, R> IntoHandler<State, ()> for F
135where
136    F: Send + Sync + 'static + Fn() -> Fut,
137    Fut: Future<Output = R> + Send + 'static,
138    R: IntoResponse,
139    State: Send + Sync + 'static,
140{
141    fn into_handler(self) -> Arc<dyn Handler<State>> {
142        Arc::new(HandlerFn0 { f: Arc::new(self) })
143    }
144}
145
146impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
147where
148    F: Send + Sync + 'static + Fn(T1) -> Fut,
149    Fut: Future<Output = R> + Send + 'static,
150    R: IntoResponse,
151    T1: FromRequest<State> + Send + 'static,
152    State: Send + Sync + 'static,
153{
154    fn into_handler(self) -> Arc<dyn Handler<State>> {
155        Arc::new(HandlerFn1 {
156            f: Arc::new(self),
157            _t1: PhantomData,
158        })
159    }
160}
161
162impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
163where
164    F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
165    Fut: Future<Output = R> + Send + 'static,
166    R: IntoResponse,
167    T1: FromRequest<State> + Send + 'static,
168    T2: FromRequest<State> + Send + 'static,
169    State: Send + Sync + 'static,
170{
171    fn into_handler(self) -> Arc<dyn Handler<State>> {
172        Arc::new(HandlerFn2 {
173            f: Arc::new(self),
174            _t: PhantomData,
175        })
176    }
177}
178
179impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
180where
181    F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
182    Fut: Future<Output = R> + Send + 'static,
183    R: IntoResponse,
184    T1: FromRequest<State> + Send + 'static,
185    T2: FromRequest<State> + Send + 'static,
186    T3: FromRequest<State> + Send + 'static,
187    State: Send + Sync + 'static,
188{
189    fn into_handler(self) -> Arc<dyn Handler<State>> {
190        Arc::new(HandlerFn3 {
191            f: Arc::new(self),
192            _t: PhantomData,
193        })
194    }
195}
196
197impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
198where
199    F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
200    Fut: Future<Output = R> + Send + 'static,
201    R: IntoResponse,
202    T1: FromRequest<State> + Send + 'static,
203    T2: FromRequest<State> + Send + 'static,
204    T3: FromRequest<State> + Send + 'static,
205    T4: FromRequest<State> + Send + 'static,
206    State: Send + Sync + 'static,
207{
208    fn into_handler(self) -> Arc<dyn Handler<State>> {
209        Arc::new(HandlerFn4 {
210            f: Arc::new(self),
211            _t: PhantomData,
212        })
213    }
214}
215
216impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
217where
218    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
219    Fut: Future<Output = R> + Send + 'static,
220    R: IntoResponse,
221    T1: FromRequest<State> + Send + 'static,
222    T2: FromRequest<State> + Send + 'static,
223    T3: FromRequest<State> + Send + 'static,
224    T4: FromRequest<State> + Send + 'static,
225    T5: FromRequest<State> + Send + 'static,
226    State: Send + Sync + 'static,
227{
228    fn into_handler(self) -> Arc<dyn Handler<State>> {
229        Arc::new(HandlerFn5 {
230            f: Arc::new(self),
231            _t: PhantomData,
232        })
233    }
234}
235
236impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
237where
238    F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
239    Fut: Future<Output = R> + Send + 'static,
240    R: IntoResponse,
241    T1: FromRequest<State> + Send + 'static,
242    T2: FromRequest<State> + Send + 'static,
243    T3: FromRequest<State> + Send + 'static,
244    T4: FromRequest<State> + Send + 'static,
245    T5: FromRequest<State> + Send + 'static,
246    T6: FromRequest<State> + Send + 'static,
247    State: Send + Sync + 'static,
248{
249    fn into_handler(self) -> Arc<dyn Handler<State>> {
250        Arc::new(HandlerFn6 {
251            f: Arc::new(self),
252            _t: PhantomData,
253        })
254    }
255}
256
257/// Router mapping command names to handlers.
258pub struct Router<State = ()> {
259    inner: Arc<RouterInner<State>>,
260}
261
262impl<State> Clone for Router<State> {
263    fn clone(&self) -> Self {
264        Self {
265            inner: Arc::clone(&self.inner),
266        }
267    }
268}
269
270impl<State> Default for Router<State>
271where
272    State: Default + Send + Sync + 'static,
273{
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279struct RouterInner<State> {
280    state: Arc<State>,
281    routes: HashMap<Bytes, Arc<dyn Handler<State>>>,
282}
283
284impl<State> Router<State>
285where
286    State: Default + Send + Sync + 'static,
287{
288    /// Create a new router using `State::default()`.
289    pub fn new() -> Self {
290        Self {
291            inner: Arc::new(RouterInner {
292                state: Arc::new(State::default()),
293                routes: HashMap::new(),
294            }),
295        }
296    }
297}
298
299impl<State> Router<State>
300where
301    State: Send + Sync + 'static,
302{
303    /// Create a new router with the provided state.
304    pub fn from_state(state: State) -> Self {
305        Self {
306            inner: Arc::new(RouterInner {
307                state: Arc::new(state),
308                routes: HashMap::new(),
309            }),
310        }
311    }
312
313    /// Replace the state while keeping existing routes.
314    pub fn with_state(self, state: State) -> Self {
315        let mut inner = self.into_inner();
316        inner.state = Arc::new(state);
317        Self {
318            inner: Arc::new(inner),
319        }
320    }
321
322    /// Register a command handler.
323    ///
324    /// The handler must be an async function or `Fn` that returns a `Send + 'static` future.
325    pub fn route<H, Args>(self, command: &'static str, handler: H) -> Self
326    where
327        H: IntoHandler<State, Args>,
328    {
329        let mut inner = self.into_inner();
330        inner
331            .routes
332            .insert(normalize_command_key(command), handler.into_handler());
333        Self {
334            inner: Arc::new(inner),
335        }
336    }
337
338    pub(crate) fn state(&self) -> Arc<State> {
339        Arc::clone(&self.inner.state)
340    }
341
342    pub(crate) fn call(&self, ctx: RequestContext) -> BoxFuture<Response> {
343        let Some(handler) = self.inner.routes.get(&ctx.command.name_upper).cloned() else {
344            return Box::pin(async move {
345                RespError::invalid_data(format!(
346                    "ERR unknown command '{}'",
347                    display_command_name(&ctx.command.name)
348                ))
349                .into_response()
350            });
351        };
352        handler.call(ctx, self.state())
353    }
354
355    fn into_inner(self) -> RouterInner<State> {
356        Arc::try_unwrap(self.inner).unwrap_or_else(|arc| RouterInner {
357            state: Arc::clone(&arc.state),
358            routes: arc.routes.clone(),
359        })
360    }
361}
362
363fn normalize_command_key(command: &str) -> Bytes {
364    let bytes = command.as_bytes();
365    let mut needs = false;
366    for &b in bytes {
367        if b.is_ascii_lowercase() {
368            needs = true;
369            break;
370        }
371    }
372    if !needs {
373        return Bytes::copy_from_slice(command.as_bytes());
374    }
375    let mut buf = BytesMut::with_capacity(bytes.len());
376    for &b in bytes {
377        buf.put_u8(b.to_ascii_uppercase());
378    }
379    buf.freeze()
380}
381
382fn display_command_name(bytes: &Bytes) -> String {
383    display_bytes(bytes)
384}
385
386fn display_bytes(bytes: &Bytes) -> String {
387    match std::str::from_utf8(bytes) {
388        Ok(s) => s.to_owned(),
389        Err(_) => format!("0x{}", hex_bytes(bytes)),
390    }
391}
392
393fn hex_bytes(bytes: &Bytes) -> String {
394    const HEX: &[u8; 16] = b"0123456789abcdef";
395    let mut out = String::with_capacity(bytes.len() * 2);
396    for &b in bytes.iter() {
397        out.push(HEX[(b >> 4) as usize] as char);
398        out.push(HEX[(b & 0x0f) as usize] as char);
399    }
400    out
401}
402
403fn log_handler_start(ctx: &RequestContext) {
404    if log::log_enabled!(log::Level::Debug) {
405        let name = display_command_name(&ctx.command.name_upper);
406        log::debug!(
407            target: "handler",
408            "start id={} cmd={} args={}",
409            ctx.client_id,
410            name,
411            ctx.command.args.len()
412        );
413    }
414}
415
416fn log_handler_result(ctx: &RequestContext, response: &Response) {
417    if !log::log_enabled!(log::Level::Debug) {
418        return;
419    }
420    if let Value::Error(msg) = response {
421        let name = display_command_name(&ctx.command.name_upper);
422        let detail = display_bytes(msg);
423        log::debug!(
424            target: "handler",
425            "error id={} cmd={} msg={}",
426            ctx.client_id,
427            name,
428            detail
429        );
430    }
431}
432
433impl<State> FromRequest<State> for Cmd
434where
435    State: Send + Sync + 'static,
436{
437    type Rejection = Infallible;
438
439    async fn from_request(
440        ctx: &mut RequestContext,
441        _state: &Arc<State>,
442    ) -> Result<Self, Self::Rejection> {
443        Ok(Cmd(ctx.command.clone()))
444    }
445}
446
447impl<T> FromRequest<T> for AppState<T>
448where
449    T: Send + Sync + 'static,
450{
451    type Rejection = Infallible;
452
453    async fn from_request(
454        _ctx: &mut RequestContext,
455        state: &Arc<T>,
456    ) -> Result<Self, Self::Rejection> {
457        Ok(AppState(Arc::clone(state)))
458    }
459}
460
461impl<State> FromRequest<State> for PeerAddr
462where
463    State: Send + Sync + 'static,
464{
465    type Rejection = Infallible;
466
467    async fn from_request(
468        ctx: &mut RequestContext,
469        _state: &Arc<State>,
470    ) -> Result<Self, Self::Rejection> {
471        Ok(PeerAddr(ctx.peer_addr))
472    }
473}
474
475impl<State> FromRequest<State> for LocalAddr
476where
477    State: Send + Sync + 'static,
478{
479    type Rejection = Infallible;
480
481    async fn from_request(
482        ctx: &mut RequestContext,
483        _state: &Arc<State>,
484    ) -> Result<Self, Self::Rejection> {
485        Ok(LocalAddr(ctx.local_addr))
486    }
487}
488
489impl<State> FromRequest<State> for ClientId
490where
491    State: Send + Sync + 'static,
492{
493    type Rejection = Infallible;
494
495    async fn from_request(
496        ctx: &mut RequestContext,
497        _state: &Arc<State>,
498    ) -> Result<Self, Self::Rejection> {
499        Ok(ClientId(ctx.client_id))
500    }
501}
502
503impl<State> FromRequest<State> for Extensions
504where
505    State: Send + Sync + 'static,
506{
507    type Rejection = Infallible;
508
509    async fn from_request(
510        ctx: &mut RequestContext,
511        _state: &Arc<State>,
512    ) -> Result<Self, Self::Rejection> {
513        Ok(ctx.extensions.clone())
514    }
515}
516
517impl<State> FromRequest<State> for PushHandle
518where
519    State: Send + Sync + 'static,
520{
521    type Rejection = Infallible;
522
523    async fn from_request(
524        ctx: &mut RequestContext,
525        _state: &Arc<State>,
526    ) -> Result<Self, Self::Rejection> {
527        Ok(ctx.push.clone())
528    }
529}
530
531impl<State> FromRequest<State> for PubSubHandle
532where
533    State: Send + Sync + 'static,
534{
535    type Rejection = Infallible;
536
537    async fn from_request(
538        ctx: &mut RequestContext,
539        _state: &Arc<State>,
540    ) -> Result<Self, Self::Rejection> {
541        Ok(ctx.pubsub.clone())
542    }
543}
544
545impl<State> FromRequest<State> for Command
546where
547    State: Send + Sync + 'static,
548{
549    type Rejection = Infallible;
550
551    async fn from_request(
552        ctx: &mut RequestContext,
553        _state: &Arc<State>,
554    ) -> Result<Self, Self::Rejection> {
555        Ok(ctx.command.clone())
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use std::sync::atomic::AtomicUsize;
563
564    use crate::Value;
565    use bytes::Bytes;
566    use tokio::sync::mpsc;
567
568    fn make_ctx(cmd: Command) -> RequestContext {
569        let (push_tx, _push_rx) = mpsc::channel(1);
570        let (close_tx, _close_rx) = mpsc::channel(1);
571        RequestContext {
572            command: cmd,
573            peer_addr: "127.0.0.1:1".parse().unwrap(),
574            local_addr: "127.0.0.1:2".parse().unwrap(),
575            client_id: 1,
576            extensions: Extensions::default(),
577            push: PushHandle::new(push_tx, close_tx),
578            pubsub: PubSubHandle::new(Arc::new(AtomicUsize::new(0))),
579        }
580    }
581
582    async fn ping() -> Value {
583        Value::Simple(Bytes::from_static(b"PONG"))
584    }
585
586    #[tokio::test]
587    async fn route_dispatches() {
588        let app: Router<()> = Router::new().route("PING", ping);
589        let cmd = Command::new(Bytes::from_static(b"PING"), Vec::new());
590        let resp = app.call(make_ctx(cmd)).await;
591        assert_eq!(resp, Value::Simple(Bytes::from_static(b"PONG")));
592    }
593
594    #[tokio::test]
595    async fn unknown_command_returns_error() {
596        let app: Router<()> = Router::new();
597        let cmd = Command::new(Bytes::from_static(b"NOPE"), Vec::new());
598        let resp = app.call(make_ctx(cmd)).await;
599        assert!(matches!(resp, Value::Error(_)));
600    }
601
602    #[tokio::test]
603    async fn route_accepts_capturing_closure() {
604        let payload = Bytes::from_static(b"PONG");
605        let handler = move || {
606            let payload = payload.clone();
607            async move { Value::Simple(payload) }
608        };
609
610        let app: Router<()> = Router::new().route("PING", handler);
611        let cmd = Command::new(Bytes::from_static(b"PING"), Vec::new());
612        let resp = app.call(make_ctx(cmd)).await;
613        assert_eq!(resp, Value::Simple(Bytes::from_static(b"PONG")));
614    }
615
616    #[tokio::test]
617    async fn state_extractor_works() {
618        async fn handler(AppState(state): AppState<u64>) -> Value {
619            Value::Integer(*state as i64)
620        }
621
622        let app = Router::from_state(5u64).route("GET", handler);
623        let cmd = Command::new(Bytes::from_static(b"GET"), Vec::new());
624        let resp = app.call(make_ctx(cmd)).await;
625        assert_eq!(resp, Value::Integer(5));
626    }
627}