1use glob::glob;
2use lazy_static::lazy_static;
3use quote::format_ident;
4use regex::Regex;
5use std::cmp::Ordering;
6
7mod utils {
8 pub fn get_segments(p: &str) -> Vec<&str> {
9 let stripped = p.strip_prefix('/').unwrap_or(p);
10 stripped.split('/').collect::<Vec<&str>>()
11 }
12}
13
14use utils::get_segments;
15
16lazy_static! {
17 static ref DYNAMIC_ROUTE_REGEX: Regex = Regex::new(r"\[[^/\.]+\]").unwrap();
19 static ref DYNAMIC_CATCH_ALL_REGEX: Regex = Regex::new(r"\[\.{3}\S+\]").unwrap();
21 static ref DYNAMIC_OPTIONAL_CATCH_ALL_REGEX: Regex = Regex::new(r"\[{2}\.{3}\S+\]{2}").unwrap();
23}
24
25#[derive(Debug, PartialEq, PartialOrd)]
26pub enum RouteKind {
27 Static,
28 Dynamic,
29 CatchAll,
30 OptionalCatchAll,
31}
32
33#[derive(Debug)]
34pub struct Route {
35 pub kind: RouteKind,
36 pub module_file: String,
37 pub module_name: syn::Ident,
38 pub path: String,
39 pub segments: Option<Vec<String>>,
40}
41
42impl Ord for Route {
43 fn cmp(&self, other: &Self) -> Ordering {
44 match self.kind {
45 RouteKind::Static => match other.kind {
47 RouteKind::Static => other.path.len().cmp(&self.path.len()),
48 _ => Ordering::Less,
49 },
50 RouteKind::Dynamic => match other.kind {
52 RouteKind::Static => Ordering::Greater,
53 RouteKind::Dynamic => match self.segments {
54 Some(ref s) => match other.segments {
55 Some(ref o) => {
56 if s.len() == o.len() {
59 let s_pos = s
60 .iter()
61 .rev()
62 .position(|ss| ss.starts_with('[') && ss.ends_with(']'));
63
64 let o_pos = o
65 .iter()
66 .rev()
67 .position(|os| os.starts_with('[') && os.ends_with(']'));
68
69 return o_pos.cmp(&s_pos);
70 }
71
72 o.len().cmp(&s.len())
73 }
74 None => Ordering::Greater,
75 },
76 None => Ordering::Equal,
77 },
78 RouteKind::CatchAll | RouteKind::OptionalCatchAll => Ordering::Less,
79 },
80 RouteKind::CatchAll | RouteKind::OptionalCatchAll => match other.kind {
82 RouteKind::Static => Ordering::Greater,
83 RouteKind::Dynamic => Ordering::Greater,
84 RouteKind::CatchAll | RouteKind::OptionalCatchAll => match self.segments {
85 Some(ref s) => match other.segments {
86 Some(ref o) => o.len().cmp(&s.len()),
87 None => Ordering::Greater,
88 },
89 None => Ordering::Equal,
90 },
91 },
92 }
93 }
94}
95
96impl PartialOrd for Route {
97 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
98 Some(self.cmp(other))
99 }
100}
101
102impl Eq for Route {}
103
104impl PartialEq for Route {
105 fn eq(&self, other: &Self) -> bool {
106 self.kind == other.kind
107 }
108}
109
110impl From<&str> for Route {
111 fn from(file_path: &str) -> Self {
112 let file_path = file_path.to_string();
113 let route = file_path.strip_suffix(".rs").unwrap_or(&file_path);
114
115 let module_name = file_path.strip_prefix('/').unwrap_or(&file_path);
116 let module_name = module_name.replace('/', "_");
117
118 let module_name = module_name.replace('[', "_");
119 let module_name = module_name.replace(']', "_");
120 let module_name = module_name.replace("...", "___");
121
122 let module_name = module_name.replace('-', "_");
123 let module_name = module_name.strip_suffix(".rs").unwrap_or(&module_name);
124
125 let get_route_kind = |r: &str| -> RouteKind {
127 if DYNAMIC_ROUTE_REGEX.is_match(r) {
128 match DYNAMIC_OPTIONAL_CATCH_ALL_REGEX.is_match(r) {
129 true => return RouteKind::OptionalCatchAll,
130 false => match DYNAMIC_CATCH_ALL_REGEX.is_match(r) {
132 true => return RouteKind::CatchAll,
133 false => return RouteKind::Dynamic,
134 },
135 }
136 }
137
138 if DYNAMIC_OPTIONAL_CATCH_ALL_REGEX.is_match(r) {
139 return RouteKind::OptionalCatchAll;
140 }
141
142 if DYNAMIC_CATCH_ALL_REGEX.is_match(r) {
143 return RouteKind::CatchAll;
144 }
145 RouteKind::Static
146 };
147
148 let route_kind = get_route_kind(route);
149
150 let segments = match route_kind {
151 RouteKind::Static => None,
152 RouteKind::Dynamic => Some(get_segments(route)),
153 RouteKind::CatchAll => Some(get_segments(route)),
154 RouteKind::OptionalCatchAll => Some(get_segments(route)),
155 };
156
157 let segments = segments.map(|s| s.iter().map(|s| s.to_string()).collect::<Vec<_>>());
158
159 Route {
160 kind: route_kind,
161 module_file: file_path.to_owned(),
163 module_name: format_ident!("{}", module_name.to_owned()),
164 path: route.to_owned(),
165 segments,
166 }
167 }
168}
169
170pub struct Router {
171 pub routes: Vec<Route>,
172}
173
174impl Default for Router {
175 fn default() -> Self {
176 Self::new("api/**/*.rs")
177 }
178}
179
180impl From<Vec<&str>> for Router {
181 fn from(raw_paths: Vec<&str>) -> Self {
182 let mut routes: Vec<Route> = raw_paths.into_iter().map(Route::from).collect();
183 routes.sort();
184 Router { routes }
185 }
186}
187
188impl Router {
189 pub fn new(file_pattern: &str) -> Self {
190 let mut routes = glob(file_pattern)
191 .expect("Failed to read glob pattern")
192 .filter_map(|e| e.ok())
193 .map(|raw_path| {
194 let path = raw_path.to_str().unwrap();
195 Route::from(path)
196 })
197 .collect::<Vec<_>>();
198
199 routes.sort();
200 Router { routes }
201 }
202
203 pub fn call(&self, req_path: &str) -> Option<&Route> {
204 if let Some(optional_catch_all) = self.routes.iter().find(|r| {
206 let dynamic_optional_catch_all_exp = Regex::new(r"\[{2}\.{3}\S+\]{2}").unwrap();
207 let optional_catchall_route =
208 dynamic_optional_catch_all_exp.replace_all(r.path.as_str(), "");
209 let optional_catchall_route = optional_catchall_route.trim_end_matches('/');
210
211 r.kind == RouteKind::OptionalCatchAll && req_path == optional_catchall_route
212 }) {
213 return Some(optional_catch_all);
214 };
215
216 let result = self.routes.iter().find(|route| {
217 match route.kind {
218 RouteKind::Static => route.path == req_path,
219 RouteKind::Dynamic => {
220 let path_segements = get_segments(req_path);
221 match route.segments {
223 None => false,
224 Some(ref route_segments) => {
225 if route_segments.len() != path_segements.len() {
226 return false;
227 }
228
229 route_segments.iter().enumerate().all(|(i, rs)| {
230 (rs.contains('[') && rs.contains(']')) || rs == path_segements[i]
231 })
232 }
233 }
234 }
235 RouteKind::OptionalCatchAll => {
236 let optional_catchall_prefix =
238 DYNAMIC_OPTIONAL_CATCH_ALL_REGEX.replace_all(route.path.as_str(), "");
239 req_path.starts_with(optional_catchall_prefix.as_ref())
240 }
241 RouteKind::CatchAll => {
242 let catchall_prefix =
244 DYNAMIC_CATCH_ALL_REGEX.replace_all(route.path.as_str(), "");
245 req_path.starts_with(catchall_prefix.as_ref())
246 }
247 }
248 });
249
250 result
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::Router;
257
258 #[test]
259 fn dynamic_routing() {
260 let router = Router::from(vec![
261 "api/posts.rs",
262 "api/[id].rs",
263 "api/posts/[id].rs",
264 "api/[...id].rs",
265 "api/nested/posts.rs",
266 "api/nested/[id].rs",
267 "api/nested/posts/[id].rs",
268 "api/nested/[...id].rs",
269 "api/optional/posts.rs",
270 "api/optional/[id].rs",
271 "api/optional/posts/[id].rs",
272 "api/optional.rs",
273 "api/optional/[[...id]].rs",
274 "api/deep/nested/[id]/comments/[cid].rs",
275 "api/other/[ab]/[cd]/ef.rs",
276 "api/foo/[d]/bar/baz/[f].rs",
277 "api/github/[owner]/[release]/baz/[f].rs",
278 "api/github/[owner]/[name]/releases/[release].rs",
279 "api/github/[owner]/[name]/releases/all.rs",
280 "api/github/[owner]/[name]/releases/latest.rs",
281 "api/github/[owner]/[name]/tags/[...all].rs",
282 "api/github/[owner]/[name]/tags/latest.rs",
283 ]);
284
285 insta::assert_debug_snapshot!(router.call("api/posts"));
287 insta::assert_debug_snapshot!(router.call("api/[id]"));
288 insta::assert_debug_snapshot!(router.call("api/posts/[id]"));
289 insta::assert_debug_snapshot!(router.call("api"));
290 insta::assert_debug_snapshot!(router.call("api/root/catch/all/route"));
291 insta::assert_debug_snapshot!(router.call("api/nested/posts"));
293 insta::assert_debug_snapshot!(router.call("api/nested/[id]"));
294 insta::assert_debug_snapshot!(router.call("api/nested/posts/[id]"));
295 insta::assert_debug_snapshot!(router.call("api/nested"));
296 insta::assert_debug_snapshot!(router.call("api/nested/catch/all/route"));
297 insta::assert_debug_snapshot!(router.call("api/optional/posts"));
299 insta::assert_debug_snapshot!(router.call("api/optional/[id]"));
300 insta::assert_debug_snapshot!(router.call("api/optional/posts/[id]"));
301 insta::assert_debug_snapshot!(router.call("api/optional"));
302 insta::assert_debug_snapshot!(router.call("api/optional/catch/all/route"));
303 insta::assert_debug_snapshot!(router.call("api/deep/nested/[id]/comments/[cid]"));
305 insta::assert_debug_snapshot!(router.call("api/should/be/caught/by/root/catch/all"));
306 insta::assert_debug_snapshot!(router.call("api/other/[ab]/[cd]/ef"));
307 insta::assert_debug_snapshot!(router.call("api/foo/[d]/bar/baz/[f]"));
308 insta::assert_debug_snapshot!(router.call("api/github/ecklf/rust-at-home/releases/foo"));
310 insta::assert_debug_snapshot!(router.call("api/github/ecklf/rust-at-home/releases/latest"));
311 insta::assert_debug_snapshot!(router.call("api/github/ecklf/rust-at-home/releases/all"));
312 insta::assert_debug_snapshot!(router.call("api/github/ecklf/rust-at-home/tags/v0.1.0"));
313 insta::assert_debug_snapshot!(router.call("api/github/ecklf/rust-at-home/tags/latest"));
314 }
315}
316
317#[cfg(test)]
318mod route_tests {
319 use super::{Route, RouteKind};
320
321 #[test]
322 fn it_creates_static_route() {
323 let path = "api/handler";
324 let route = Route::from(path);
325 assert!(matches!(route.kind, RouteKind::Static));
326 assert_eq!(route.path, path);
327 assert!(route.segments.is_none());
328 }
329
330 #[test]
331 fn it_creates_dynamic_route() {
332 let path = "api/[dyn]";
333 let route = Route::from(path);
334 assert!(matches!(route.kind, RouteKind::Dynamic));
335 assert_eq!(route.path, path);
336 assert!(route.segments.is_some());
337 assert_eq!(route.segments.unwrap(), vec!["api", "[dyn]"]);
338 }
339
340 #[test]
341 fn it_creates_complex_dynamic_route() {
342 let path = "api/[dyn]/handler/[dyn2]";
343 let route = Route::from(path);
344 assert!(matches!(route.kind, RouteKind::Dynamic));
345 assert_eq!(route.path, path);
346 assert!(route.segments.is_some());
347 assert_eq!(
348 route.segments.unwrap(),
349 vec!["api", "[dyn]", "handler", "[dyn2]"]
350 );
351 }
352
353 #[test]
354 fn it_creates_catch_all_route() {
355 let path = "api/[...all]";
356 let route = Route::from(path);
357 assert!(matches!(route.kind, RouteKind::CatchAll));
358 assert_eq!(route.path, path);
359 assert!(route.segments.is_some());
360 assert_eq!(route.segments.unwrap(), vec!["api", "[...all]"]);
361 }
362
363 #[test]
364 fn it_creates_complex_catch_all_route() {
365 let path = "api/[dyn]/handler/[...all]";
366 let route = Route::from(path);
367 assert!(matches!(route.kind, RouteKind::CatchAll));
368 assert_eq!(route.path, path);
369 assert!(route.segments.is_some());
370 assert_eq!(
371 route.segments.unwrap(),
372 vec!["api", "[dyn]", "handler", "[...all]"]
373 );
374 }
375
376 #[test]
377 fn it_creates_optional_catch_all_route() {
378 let path = "api/[[...all]]";
379 let route = Route::from(path);
380 assert!(matches!(route.kind, RouteKind::OptionalCatchAll));
381 assert_eq!(route.path, path);
382 assert!(route.segments.is_some());
383 assert_eq!(route.segments.unwrap(), vec!["api", "[[...all]]"]);
384 }
385
386 #[test]
387 fn it_creates_complex_optional_catch_all_route() {
388 let path = "api/[dyn]/handler/[[...all]]";
389 let route = Route::from(path);
390 assert!(matches!(route.kind, RouteKind::OptionalCatchAll));
391 assert_eq!(route.path, path);
392 assert!(route.segments.is_some());
393 assert_eq!(
394 route.segments.unwrap(),
395 vec!["api", "[dyn]", "handler", "[[...all]]"]
396 );
397 }
398}