volo_grpc/server/
router.rs1use std::{
2 fmt,
3 sync::atomic::{AtomicU32, Ordering},
4};
5
6use http_body::Body as HttpBody;
7use motore::{BoxCloneService, Service};
8use rustc_hash::FxHashMap;
9use volo::Unwrap;
10
11use super::NamedService;
12use crate::{Request, Response, Status, body::BoxBody, context::ServerContext};
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
15struct RouteId(u32);
16
17impl RouteId {
18 fn next() -> Self {
19 static ID: AtomicU32 = AtomicU32::new(0);
21 let id = ID.fetch_add(1, Ordering::Relaxed);
22 if id == u32::MAX {
23 panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
24 }
25 Self(id)
26 }
27}
28
29#[derive(Default)]
30pub struct Router<B = BoxBody> {
31 routes:
32 FxHashMap<RouteId, BoxCloneService<ServerContext, Request<B>, Response<BoxBody>, Status>>,
33 node: matchit::Router<RouteId>,
34}
35
36impl<B> Clone for Router<B> {
37 fn clone(&self) -> Self {
38 Self {
39 routes: self.routes.clone(),
40 node: self.node.clone(),
41 }
42 }
43}
44
45impl<B> Router<B>
46where
47 B: HttpBody + 'static,
48{
49 pub fn new() -> Self {
50 Self {
51 routes: Default::default(),
52 node: Default::default(),
53 }
54 }
55
56 pub fn add_service<S>(mut self, service: S) -> Self
57 where
58 S: Service<ServerContext, Request<B>, Response = Response<BoxBody>, Error = Status>
59 + NamedService
60 + Clone
61 + Send
62 + Sync
63 + 'static,
64 {
65 let path = format!("/{}/{{*rest}}", S::NAME);
66
67 if path.is_empty() {
68 panic!("[VOLO] Paths must start with a `/`. Use \"/\" for root routes");
69 } else if !path.starts_with('/') {
70 panic!("[VOLO] Paths must start with a `/`");
71 }
72
73 let id = RouteId::next();
74
75 self.set_node(path, id);
76
77 self.routes.insert(id, BoxCloneService::new(service));
78
79 self
80 }
81
82 #[track_caller]
83 fn set_node(&mut self, path: String, id: RouteId) {
84 if let Err(err) = self.node.insert(path, id) {
85 panic!("[VOLO] Invalid route: {err}");
86 }
87 }
88}
89
90impl<B> Service<ServerContext, Request<B>> for Router<B>
91where
92 B: HttpBody + Send,
93{
94 type Response = Response<BoxBody>;
95 type Error = Status;
96
97 async fn call(
98 &self,
99 cx: &mut ServerContext,
100 req: Request<B>,
101 ) -> Result<Self::Response, Self::Error> {
102 let path = cx.rpc_info.method();
103 match self.node.at(path) {
104 Ok(match_) => {
105 let id = match_.value;
106 let route = self.routes.get(id).volo_unwrap().clone();
107 route.call(cx, req).await
108 }
109 Err(err) => Err(Status::unimplemented(err.to_string())),
110 }
111 }
112}
113
114impl<B> fmt::Debug for Router<B> {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("Router")
117 .field("routes", &self.routes)
118 .finish()
119 }
120}