1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::future::Future;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::context::{
11 ClientId, Cmd, Command, Extensions, LocalAddr, PeerAddr, PubSubHandle, PushHandle,
12 RequestContext, State as AppState,
13};
14use crate::resp::Value;
15use crate::response::{IntoResponse, RespError, Response};
16
17pub trait FromRequest<State>: Sized {
19 type Rejection: IntoResponse;
20
21 fn from_request(
22 ctx: &mut RequestContext,
23 state: &Arc<State>,
24 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
25}
26
27pub trait Handler<State>: Send + Sync + 'static {
29 fn call(&self, ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response>;
30}
31
32type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
33type HandlerMarker5<T1, T2, T3, T4, T5> = fn(T1, T2, T3, T4, T5);
34type HandlerMarker6<T1, T2, T3, T4, T5, T6> = fn(T1, T2, T3, T4, T5, T6);
35
36pub trait IntoHandler<State, Args>: Send + Sync + 'static {
37 fn into_handler(self) -> Arc<dyn Handler<State>>;
38}
39
40struct HandlerFn0<F> {
41 f: Arc<F>,
42}
43
44struct HandlerFn1<F, T1> {
45 f: Arc<F>,
46 _t1: PhantomData<fn(T1)>,
47}
48
49struct HandlerFn2<F, T1, T2> {
50 f: Arc<F>,
51 _t: PhantomData<fn(T1, T2)>,
52}
53
54struct HandlerFn3<F, T1, T2, T3> {
55 f: Arc<F>,
56 _t: PhantomData<fn(T1, T2, T3)>,
57}
58
59struct HandlerFn4<F, T1, T2, T3, T4> {
60 f: Arc<F>,
61 _t: PhantomData<fn(T1, T2, T3, T4)>,
62}
63
64struct HandlerFn5<F, T1, T2, T3, T4, T5> {
65 f: Arc<F>,
66 _t: PhantomData<HandlerMarker5<T1, T2, T3, T4, T5>>,
67}
68
69struct HandlerFn6<F, T1, T2, T3, T4, T5, T6> {
70 f: Arc<F>,
71 _t: PhantomData<HandlerMarker6<T1, T2, T3, T4, T5, T6>>,
72}
73
74macro_rules! impl_handler {
75 ($name:ident, $( $ty:ident ),* ) => {
76 #[allow(non_snake_case)]
77 impl<State, F, Fut, R, $( $ty ),*> Handler<State> for $name<F, $( $ty ),*>
78 where
79 F: Send + Sync + 'static + Fn($( $ty ),*) -> Fut,
80 Fut: Future<Output = R> + Send + 'static,
81 R: IntoResponse,
82 $( $ty: FromRequest<State> + Send + 'static, )*
83 State: Send + Sync + 'static,
84 {
85 fn call(&self, mut ctx: RequestContext, state: Arc<State>) -> BoxFuture<Response> {
86 let f = Arc::clone(&self.f);
87 Box::pin(async move {
88 log_handler_start(&ctx);
89 $(
90 let $ty = match $ty::from_request(&mut ctx, &state).await {
91 Ok(value) => value,
92 Err(rejection) => {
93 let response = rejection.into_response();
94 log_handler_result(&ctx, &response);
95 return response;
96 }
97 };
98 )*
99
100 let response = f($( $ty ),*).await.into_response();
101 log_handler_result(&ctx, &response);
102 response
103 })
104 }
105 }
106 };
107}
108
109impl<State, F, Fut, R> Handler<State> for HandlerFn0<F>
110where
111 F: Send + Sync + 'static + Fn() -> Fut,
112 Fut: Future<Output = R> + Send + 'static,
113 R: IntoResponse,
114 State: Send + Sync + 'static,
115{
116 fn call(&self, ctx: RequestContext, _state: Arc<State>) -> BoxFuture<Response> {
117 let f = Arc::clone(&self.f);
118 Box::pin(async move {
119 log_handler_start(&ctx);
120 let response = f().await.into_response();
121 log_handler_result(&ctx, &response);
122 response
123 })
124 }
125}
126
127impl_handler!(HandlerFn1, T1);
128impl_handler!(HandlerFn2, T1, T2);
129impl_handler!(HandlerFn3, T1, T2, T3);
130impl_handler!(HandlerFn4, T1, T2, T3, T4);
131impl_handler!(HandlerFn5, T1, T2, T3, T4, T5);
132impl_handler!(HandlerFn6, T1, T2, T3, T4, T5, T6);
133
134impl<State, F, Fut, R> IntoHandler<State, ()> for F
135where
136 F: Send + Sync + 'static + Fn() -> Fut,
137 Fut: Future<Output = R> + Send + 'static,
138 R: IntoResponse,
139 State: Send + Sync + 'static,
140{
141 fn into_handler(self) -> Arc<dyn Handler<State>> {
142 Arc::new(HandlerFn0 { f: Arc::new(self) })
143 }
144}
145
146impl<State, F, Fut, R, T1> IntoHandler<State, (T1,)> for F
147where
148 F: Send + Sync + 'static + Fn(T1) -> Fut,
149 Fut: Future<Output = R> + Send + 'static,
150 R: IntoResponse,
151 T1: FromRequest<State> + Send + 'static,
152 State: Send + Sync + 'static,
153{
154 fn into_handler(self) -> Arc<dyn Handler<State>> {
155 Arc::new(HandlerFn1 {
156 f: Arc::new(self),
157 _t1: PhantomData,
158 })
159 }
160}
161
162impl<State, F, Fut, R, T1, T2> IntoHandler<State, (T1, T2)> for F
163where
164 F: Send + Sync + 'static + Fn(T1, T2) -> Fut,
165 Fut: Future<Output = R> + Send + 'static,
166 R: IntoResponse,
167 T1: FromRequest<State> + Send + 'static,
168 T2: FromRequest<State> + Send + 'static,
169 State: Send + Sync + 'static,
170{
171 fn into_handler(self) -> Arc<dyn Handler<State>> {
172 Arc::new(HandlerFn2 {
173 f: Arc::new(self),
174 _t: PhantomData,
175 })
176 }
177}
178
179impl<State, F, Fut, R, T1, T2, T3> IntoHandler<State, (T1, T2, T3)> for F
180where
181 F: Send + Sync + 'static + Fn(T1, T2, T3) -> Fut,
182 Fut: Future<Output = R> + Send + 'static,
183 R: IntoResponse,
184 T1: FromRequest<State> + Send + 'static,
185 T2: FromRequest<State> + Send + 'static,
186 T3: FromRequest<State> + Send + 'static,
187 State: Send + Sync + 'static,
188{
189 fn into_handler(self) -> Arc<dyn Handler<State>> {
190 Arc::new(HandlerFn3 {
191 f: Arc::new(self),
192 _t: PhantomData,
193 })
194 }
195}
196
197impl<State, F, Fut, R, T1, T2, T3, T4> IntoHandler<State, (T1, T2, T3, T4)> for F
198where
199 F: Send + Sync + 'static + Fn(T1, T2, T3, T4) -> Fut,
200 Fut: Future<Output = R> + Send + 'static,
201 R: IntoResponse,
202 T1: FromRequest<State> + Send + 'static,
203 T2: FromRequest<State> + Send + 'static,
204 T3: FromRequest<State> + Send + 'static,
205 T4: FromRequest<State> + Send + 'static,
206 State: Send + Sync + 'static,
207{
208 fn into_handler(self) -> Arc<dyn Handler<State>> {
209 Arc::new(HandlerFn4 {
210 f: Arc::new(self),
211 _t: PhantomData,
212 })
213 }
214}
215
216impl<State, F, Fut, R, T1, T2, T3, T4, T5> IntoHandler<State, (T1, T2, T3, T4, T5)> for F
217where
218 F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5) -> Fut,
219 Fut: Future<Output = R> + Send + 'static,
220 R: IntoResponse,
221 T1: FromRequest<State> + Send + 'static,
222 T2: FromRequest<State> + Send + 'static,
223 T3: FromRequest<State> + Send + 'static,
224 T4: FromRequest<State> + Send + 'static,
225 T5: FromRequest<State> + Send + 'static,
226 State: Send + Sync + 'static,
227{
228 fn into_handler(self) -> Arc<dyn Handler<State>> {
229 Arc::new(HandlerFn5 {
230 f: Arc::new(self),
231 _t: PhantomData,
232 })
233 }
234}
235
236impl<State, F, Fut, R, T1, T2, T3, T4, T5, T6> IntoHandler<State, (T1, T2, T3, T4, T5, T6)> for F
237where
238 F: Send + Sync + 'static + Fn(T1, T2, T3, T4, T5, T6) -> Fut,
239 Fut: Future<Output = R> + Send + 'static,
240 R: IntoResponse,
241 T1: FromRequest<State> + Send + 'static,
242 T2: FromRequest<State> + Send + 'static,
243 T3: FromRequest<State> + Send + 'static,
244 T4: FromRequest<State> + Send + 'static,
245 T5: FromRequest<State> + Send + 'static,
246 T6: FromRequest<State> + Send + 'static,
247 State: Send + Sync + 'static,
248{
249 fn into_handler(self) -> Arc<dyn Handler<State>> {
250 Arc::new(HandlerFn6 {
251 f: Arc::new(self),
252 _t: PhantomData,
253 })
254 }
255}
256
257pub struct Router<State = ()> {
259 inner: Arc<RouterInner<State>>,
260}
261
262impl<State> Clone for Router<State> {
263 fn clone(&self) -> Self {
264 Self {
265 inner: Arc::clone(&self.inner),
266 }
267 }
268}
269
270impl<State> Default for Router<State>
271where
272 State: Default + Send + Sync + 'static,
273{
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279struct RouterInner<State> {
280 state: Arc<State>,
281 routes: HashMap<Bytes, Arc<dyn Handler<State>>>,
282}
283
284impl<State> Router<State>
285where
286 State: Default + Send + Sync + 'static,
287{
288 pub fn new() -> Self {
290 Self {
291 inner: Arc::new(RouterInner {
292 state: Arc::new(State::default()),
293 routes: HashMap::new(),
294 }),
295 }
296 }
297}
298
299impl<State> Router<State>
300where
301 State: Send + Sync + 'static,
302{
303 pub fn from_state(state: State) -> Self {
305 Self {
306 inner: Arc::new(RouterInner {
307 state: Arc::new(state),
308 routes: HashMap::new(),
309 }),
310 }
311 }
312
313 pub fn with_state(self, state: State) -> Self {
315 let mut inner = self.into_inner();
316 inner.state = Arc::new(state);
317 Self {
318 inner: Arc::new(inner),
319 }
320 }
321
322 pub fn route<H, Args>(self, command: &'static str, handler: H) -> Self
326 where
327 H: IntoHandler<State, Args>,
328 {
329 let mut inner = self.into_inner();
330 inner
331 .routes
332 .insert(normalize_command_key(command), handler.into_handler());
333 Self {
334 inner: Arc::new(inner),
335 }
336 }
337
338 pub(crate) fn state(&self) -> Arc<State> {
339 Arc::clone(&self.inner.state)
340 }
341
342 pub(crate) fn call(&self, ctx: RequestContext) -> BoxFuture<Response> {
343 let Some(handler) = self.inner.routes.get(&ctx.command.name_upper).cloned() else {
344 return Box::pin(async move {
345 RespError::invalid_data(format!(
346 "ERR unknown command '{}'",
347 display_command_name(&ctx.command.name)
348 ))
349 .into_response()
350 });
351 };
352 handler.call(ctx, self.state())
353 }
354
355 fn into_inner(self) -> RouterInner<State> {
356 Arc::try_unwrap(self.inner).unwrap_or_else(|arc| RouterInner {
357 state: Arc::clone(&arc.state),
358 routes: arc.routes.clone(),
359 })
360 }
361}
362
363fn normalize_command_key(command: &str) -> Bytes {
364 let bytes = command.as_bytes();
365 let mut needs = false;
366 for &b in bytes {
367 if b.is_ascii_lowercase() {
368 needs = true;
369 break;
370 }
371 }
372 if !needs {
373 return Bytes::copy_from_slice(command.as_bytes());
374 }
375 let mut buf = BytesMut::with_capacity(bytes.len());
376 for &b in bytes {
377 buf.put_u8(b.to_ascii_uppercase());
378 }
379 buf.freeze()
380}
381
382fn display_command_name(bytes: &Bytes) -> String {
383 display_bytes(bytes)
384}
385
386fn display_bytes(bytes: &Bytes) -> String {
387 match std::str::from_utf8(bytes) {
388 Ok(s) => s.to_owned(),
389 Err(_) => format!("0x{}", hex_bytes(bytes)),
390 }
391}
392
393fn hex_bytes(bytes: &Bytes) -> String {
394 const HEX: &[u8; 16] = b"0123456789abcdef";
395 let mut out = String::with_capacity(bytes.len() * 2);
396 for &b in bytes.iter() {
397 out.push(HEX[(b >> 4) as usize] as char);
398 out.push(HEX[(b & 0x0f) as usize] as char);
399 }
400 out
401}
402
403fn log_handler_start(ctx: &RequestContext) {
404 if log::log_enabled!(log::Level::Debug) {
405 let name = display_command_name(&ctx.command.name_upper);
406 log::debug!(
407 target: "handler",
408 "start id={} cmd={} args={}",
409 ctx.client_id,
410 name,
411 ctx.command.args.len()
412 );
413 }
414}
415
416fn log_handler_result(ctx: &RequestContext, response: &Response) {
417 if !log::log_enabled!(log::Level::Debug) {
418 return;
419 }
420 if let Value::Error(msg) = response {
421 let name = display_command_name(&ctx.command.name_upper);
422 let detail = display_bytes(msg);
423 log::debug!(
424 target: "handler",
425 "error id={} cmd={} msg={}",
426 ctx.client_id,
427 name,
428 detail
429 );
430 }
431}
432
433impl<State> FromRequest<State> for Cmd
434where
435 State: Send + Sync + 'static,
436{
437 type Rejection = Infallible;
438
439 async fn from_request(
440 ctx: &mut RequestContext,
441 _state: &Arc<State>,
442 ) -> Result<Self, Self::Rejection> {
443 Ok(Cmd(ctx.command.clone()))
444 }
445}
446
447impl<T> FromRequest<T> for AppState<T>
448where
449 T: Send + Sync + 'static,
450{
451 type Rejection = Infallible;
452
453 async fn from_request(
454 _ctx: &mut RequestContext,
455 state: &Arc<T>,
456 ) -> Result<Self, Self::Rejection> {
457 Ok(AppState(Arc::clone(state)))
458 }
459}
460
461impl<State> FromRequest<State> for PeerAddr
462where
463 State: Send + Sync + 'static,
464{
465 type Rejection = Infallible;
466
467 async fn from_request(
468 ctx: &mut RequestContext,
469 _state: &Arc<State>,
470 ) -> Result<Self, Self::Rejection> {
471 Ok(PeerAddr(ctx.peer_addr))
472 }
473}
474
475impl<State> FromRequest<State> for LocalAddr
476where
477 State: Send + Sync + 'static,
478{
479 type Rejection = Infallible;
480
481 async fn from_request(
482 ctx: &mut RequestContext,
483 _state: &Arc<State>,
484 ) -> Result<Self, Self::Rejection> {
485 Ok(LocalAddr(ctx.local_addr))
486 }
487}
488
489impl<State> FromRequest<State> for ClientId
490where
491 State: Send + Sync + 'static,
492{
493 type Rejection = Infallible;
494
495 async fn from_request(
496 ctx: &mut RequestContext,
497 _state: &Arc<State>,
498 ) -> Result<Self, Self::Rejection> {
499 Ok(ClientId(ctx.client_id))
500 }
501}
502
503impl<State> FromRequest<State> for Extensions
504where
505 State: Send + Sync + 'static,
506{
507 type Rejection = Infallible;
508
509 async fn from_request(
510 ctx: &mut RequestContext,
511 _state: &Arc<State>,
512 ) -> Result<Self, Self::Rejection> {
513 Ok(ctx.extensions.clone())
514 }
515}
516
517impl<State> FromRequest<State> for PushHandle
518where
519 State: Send + Sync + 'static,
520{
521 type Rejection = Infallible;
522
523 async fn from_request(
524 ctx: &mut RequestContext,
525 _state: &Arc<State>,
526 ) -> Result<Self, Self::Rejection> {
527 Ok(ctx.push.clone())
528 }
529}
530
531impl<State> FromRequest<State> for PubSubHandle
532where
533 State: Send + Sync + 'static,
534{
535 type Rejection = Infallible;
536
537 async fn from_request(
538 ctx: &mut RequestContext,
539 _state: &Arc<State>,
540 ) -> Result<Self, Self::Rejection> {
541 Ok(ctx.pubsub.clone())
542 }
543}
544
545impl<State> FromRequest<State> for Command
546where
547 State: Send + Sync + 'static,
548{
549 type Rejection = Infallible;
550
551 async fn from_request(
552 ctx: &mut RequestContext,
553 _state: &Arc<State>,
554 ) -> Result<Self, Self::Rejection> {
555 Ok(ctx.command.clone())
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use std::sync::atomic::AtomicUsize;
563
564 use crate::Value;
565 use bytes::Bytes;
566 use tokio::sync::mpsc;
567
568 fn make_ctx(cmd: Command) -> RequestContext {
569 let (push_tx, _push_rx) = mpsc::channel(1);
570 let (close_tx, _close_rx) = mpsc::channel(1);
571 RequestContext {
572 command: cmd,
573 peer_addr: "127.0.0.1:1".parse().unwrap(),
574 local_addr: "127.0.0.1:2".parse().unwrap(),
575 client_id: 1,
576 extensions: Extensions::default(),
577 push: PushHandle::new(push_tx, close_tx),
578 pubsub: PubSubHandle::new(Arc::new(AtomicUsize::new(0))),
579 }
580 }
581
582 async fn ping() -> Value {
583 Value::Simple(Bytes::from_static(b"PONG"))
584 }
585
586 #[tokio::test]
587 async fn route_dispatches() {
588 let app: Router<()> = Router::new().route("PING", ping);
589 let cmd = Command::new(Bytes::from_static(b"PING"), Vec::new());
590 let resp = app.call(make_ctx(cmd)).await;
591 assert_eq!(resp, Value::Simple(Bytes::from_static(b"PONG")));
592 }
593
594 #[tokio::test]
595 async fn unknown_command_returns_error() {
596 let app: Router<()> = Router::new();
597 let cmd = Command::new(Bytes::from_static(b"NOPE"), Vec::new());
598 let resp = app.call(make_ctx(cmd)).await;
599 assert!(matches!(resp, Value::Error(_)));
600 }
601
602 #[tokio::test]
603 async fn route_accepts_capturing_closure() {
604 let payload = Bytes::from_static(b"PONG");
605 let handler = move || {
606 let payload = payload.clone();
607 async move { Value::Simple(payload) }
608 };
609
610 let app: Router<()> = Router::new().route("PING", handler);
611 let cmd = Command::new(Bytes::from_static(b"PING"), Vec::new());
612 let resp = app.call(make_ctx(cmd)).await;
613 assert_eq!(resp, Value::Simple(Bytes::from_static(b"PONG")));
614 }
615
616 #[tokio::test]
617 async fn state_extractor_works() {
618 async fn handler(AppState(state): AppState<u64>) -> Value {
619 Value::Integer(*state as i64)
620 }
621
622 let app = Router::from_state(5u64).route("GET", handler);
623 let cmd = Command::new(Bytes::from_static(b"GET"), Vec::new());
624 let resp = app.call(make_ctx(cmd)).await;
625 assert_eq!(resp, Value::Integer(5));
626 }
627}