1use std::sync::Arc;
2
3use rivet_http::{Method, Request, Response};
4
5pub type RouteHandler = Arc<dyn Fn(Request) -> Response + Send + Sync + 'static>;
6
7fn default_handler(_req: Request) -> Response {
8 Response::ok("ok")
9}
10
11#[derive(Clone)]
12pub struct Entry {
13 pub method: Method,
14 pub path: String,
15 handler: RouteHandler,
16}
17
18impl core::fmt::Debug for Entry {
19 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
20 f.debug_struct("Entry")
21 .field("method", &self.method)
22 .field("path", &self.path)
23 .finish()
24 }
25}
26
27impl Entry {
28 pub fn invoke(&self, req: Request) -> Response {
29 (self.handler)(req)
30 }
31}
32
33#[derive(Debug, Default, Clone)]
34pub struct Registry {
35 routes: Vec<Entry>,
36}
37
38#[derive(Debug, Clone)]
39pub enum Match<'a> {
40 Matched {
41 route: &'a Entry,
42 head_fallback: bool,
43 },
44 MethodNotAllowed {
45 allow: Vec<Method>,
46 },
47 NotFound,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
51pub enum RouteRegistryError {
52 #[error("duplicate route: {method:?} {path}")]
53 Duplicate { method: Method, path: String },
54}
55
56impl Registry {
57 pub fn new() -> Self {
58 Self { routes: Vec::new() }
59 }
60
61 pub fn add_route(
62 &mut self,
63 method: Method,
64 path: impl Into<String>,
65 ) -> Result<(), RouteRegistryError> {
66 self.add_route_with_handler(method, path, default_handler)
67 }
68
69 pub fn add_route_with_handler<F>(
70 &mut self,
71 method: Method,
72 path: impl Into<String>,
73 handler: F,
74 ) -> Result<(), RouteRegistryError>
75 where
76 F: Fn(Request) -> Response + Send + Sync + 'static,
77 {
78 let path = path.into();
79
80 if self
81 .routes
82 .iter()
83 .any(|r| r.method == method && r.path == path)
84 {
85 return Err(RouteRegistryError::Duplicate { method, path });
86 }
87
88 self.routes.push(Entry {
89 method,
90 path,
91 handler: Arc::new(handler),
92 });
93 Ok(())
94 }
95
96 pub fn get<F>(&mut self, path: impl Into<String>, handler: F) -> Result<(), RouteRegistryError>
97 where
98 F: Fn(Request) -> Response + Send + Sync + 'static,
99 {
100 self.add_route_with_handler(Method::Get, path, handler)
101 }
102
103 pub fn post<F>(&mut self, path: impl Into<String>, handler: F) -> Result<(), RouteRegistryError>
104 where
105 F: Fn(Request) -> Response + Send + Sync + 'static,
106 {
107 self.add_route_with_handler(Method::Post, path, handler)
108 }
109
110 pub fn put<F>(&mut self, path: impl Into<String>, handler: F) -> Result<(), RouteRegistryError>
111 where
112 F: Fn(Request) -> Response + Send + Sync + 'static,
113 {
114 self.add_route_with_handler(Method::Put, path, handler)
115 }
116
117 pub fn patch<F>(
118 &mut self,
119 path: impl Into<String>,
120 handler: F,
121 ) -> Result<(), RouteRegistryError>
122 where
123 F: Fn(Request) -> Response + Send + Sync + 'static,
124 {
125 self.add_route_with_handler(Method::Patch, path, handler)
126 }
127
128 pub fn delete<F>(
129 &mut self,
130 path: impl Into<String>,
131 handler: F,
132 ) -> Result<(), RouteRegistryError>
133 where
134 F: Fn(Request) -> Response + Send + Sync + 'static,
135 {
136 self.add_route_with_handler(Method::Delete, path, handler)
137 }
138
139 pub fn head<F>(&mut self, path: impl Into<String>, handler: F) -> Result<(), RouteRegistryError>
140 where
141 F: Fn(Request) -> Response + Send + Sync + 'static,
142 {
143 self.add_route_with_handler(Method::Head, path, handler)
144 }
145
146 pub fn options<F>(
147 &mut self,
148 path: impl Into<String>,
149 handler: F,
150 ) -> Result<(), RouteRegistryError>
151 where
152 F: Fn(Request) -> Response + Send + Sync + 'static,
153 {
154 self.add_route_with_handler(Method::Options, path, handler)
155 }
156
157 pub fn routes(&self) -> &[Entry] {
158 &self.routes
159 }
160
161 pub fn match_request(&self, method: &Method, path: &str) -> Match<'_> {
162 let path_matches: Vec<&Entry> = self.routes.iter().filter(|r| r.path == path).collect();
163
164 if path_matches.is_empty() {
165 return Match::NotFound;
166 }
167
168 if let Some(route) = path_matches.iter().find(|r| r.method == *method) {
169 return Match::Matched {
170 route,
171 head_fallback: false,
172 };
173 }
174
175 if *method == Method::Head {
176 if let Some(route) = path_matches.iter().find(|r| r.method == Method::Get) {
177 return Match::Matched {
178 route,
179 head_fallback: true,
180 };
181 }
182 }
183
184 let mut allow = Vec::new();
185 for route in path_matches {
186 if !allow.contains(&route.method) {
187 allow.push(route.method.clone());
188 }
189 }
190
191 Match::MethodNotAllowed { allow }
192 }
193
194 pub fn len(&self) -> usize {
195 self.routes.len()
196 }
197
198 pub fn is_empty(&self) -> bool {
199 self.routes.is_empty()
200 }
201}