1#![doc = include_str!("../README.md")]
8#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use small_http::{Method, Request, Response, Status};
14
15type HandlerFn<T> = fn(&Request, &T) -> Response;
19type PreLayerFn<T> = fn(&Request, &mut T) -> Option<Response>;
20type PostLayerFn<T> = fn(&Request, &mut T, Response) -> Response;
21
22struct Handler<T> {
23 handler: HandlerFn<T>,
24 pre_layers: Vec<PreLayerFn<T>>,
25 post_layers: Vec<PostLayerFn<T>>,
26}
27
28impl<T> Handler<T> {
29 fn new(
30 handler: HandlerFn<T>,
31 pre_layers: Vec<PreLayerFn<T>>,
32 post_layers: Vec<PostLayerFn<T>>,
33 ) -> Self {
34 Self {
35 handler,
36 pre_layers,
37 post_layers,
38 }
39 }
40
41 fn call(&self, req: &Request, ctx: &mut T) -> Response {
42 for pre_layer in &self.pre_layers {
43 if let Some(mut res) = pre_layer(req, ctx) {
44 for post_layer in &self.post_layers {
45 res = post_layer(req, ctx, res);
46 }
47 return res;
48 }
49 }
50 let mut res = (self.handler)(req, ctx);
51 for post_layer in &self.post_layers {
52 res = post_layer(req, ctx, res);
53 }
54 res
55 }
56}
57
58enum RoutePart {
60 Static(String),
61 Param(String),
62}
63
64struct Route<T> {
65 methods: Vec<Method>,
66 route: String,
67 parts: Vec<RoutePart>,
68 handler: Handler<T>,
69}
70
71impl<T> Route<T> {
72 fn new(methods: Vec<Method>, route: String, handler: Handler<T>) -> Self {
73 let parts = Self::route_parse_parts(&route);
74 Self {
75 methods,
76 route,
77 parts,
78 handler,
79 }
80 }
81
82 fn route_parse_parts(route: &str) -> Vec<RoutePart> {
83 route
84 .split('/')
85 .filter(|part| !part.is_empty())
86 .map(|part| {
87 if let Some(stripped) = part.strip_prefix(':') {
88 RoutePart::Param(stripped.to_string())
89 } else {
90 RoutePart::Static(part.to_string())
91 }
92 })
93 .collect()
94 }
95
96 fn is_match(&self, path: &str) -> bool {
97 let mut path_parts = path.split('/').filter(|part| !part.is_empty());
98 for part in &self.parts {
99 match part {
100 RoutePart::Static(expected) => {
101 if let Some(actual) = path_parts.next() {
102 if actual != *expected {
103 return false;
104 }
105 } else {
106 return false;
107 }
108 }
109 RoutePart::Param(_) => {
110 if path_parts.next().is_none() {
111 return false;
112 }
113 }
114 }
115 }
116 path_parts.next().is_none()
117 }
118
119 fn match_path(&self, path: &str) -> HashMap<String, String> {
120 let mut path_parts = path.split('/').filter(|part| !part.is_empty());
121 let mut params = HashMap::new();
122 for part in &self.parts {
123 match part {
124 RoutePart::Static(_) => {
125 path_parts.next();
126 }
127 RoutePart::Param(name) => {
128 if let Some(value) = path_parts.next() {
129 params.insert(name.clone(), value.to_string());
130 }
131 }
132 }
133 }
134 params
135 }
136}
137
138pub struct RouterBuilder<T: Clone> {
141 ctx: T,
142 pre_layers: Vec<PreLayerFn<T>>,
143 post_layers: Vec<PostLayerFn<T>>,
144 routes: Vec<Route<T>>,
145 not_allowed_method_handler: Option<Handler<T>>,
146 fallback_handler: Option<Handler<T>>,
147}
148
149impl Default for RouterBuilder<()> {
150 fn default() -> Self {
151 Self::with(())
152 }
153}
154
155impl RouterBuilder<()> {
156 pub fn new() -> Self {
158 Self::default()
159 }
160}
161
162impl<T: Clone> RouterBuilder<T> {
163 pub fn with(ctx: T) -> Self {
165 Self {
166 ctx,
167 pre_layers: Vec::new(),
168 post_layers: Vec::new(),
169 routes: Vec::new(),
170 not_allowed_method_handler: None,
171 fallback_handler: None,
172 }
173 }
174
175 pub fn pre_layer(mut self, layer: PreLayerFn<T>) -> Self {
177 self.pre_layers.push(layer);
178 self
179 }
180
181 pub fn post_layer(mut self, layer: PostLayerFn<T>) -> Self {
183 self.post_layers.push(layer);
184 self
185 }
186
187 pub fn route(mut self, methods: &[Method], route: String, handler: HandlerFn<T>) -> Self {
189 self.routes.push(Route::new(
190 methods.to_vec(),
191 route,
192 Handler::new(handler, self.pre_layers.clone(), self.post_layers.clone()),
193 ));
194 self
195 }
196
197 pub fn any(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
199 self.route(
200 &[
201 Method::Get,
202 Method::Head,
203 Method::Post,
204 Method::Put,
205 Method::Delete,
206 Method::Connect,
207 Method::Options,
208 Method::Trace,
209 Method::Patch,
210 ],
211 route.into(),
212 handler,
213 )
214 }
215 pub fn get(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
217 self.route(&[Method::Get], route.into(), handler)
218 }
219
220 pub fn head(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
222 self.route(&[Method::Head], route.into(), handler)
223 }
224
225 pub fn post(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
227 self.route(&[Method::Post], route.into(), handler)
228 }
229
230 pub fn put(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
232 self.route(&[Method::Put], route.into(), handler)
233 }
234
235 pub fn delete(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
237 self.route(&[Method::Delete], route.into(), handler)
238 }
239
240 pub fn connect(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
242 self.route(&[Method::Connect], route.into(), handler)
243 }
244
245 pub fn options(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
247 self.route(&[Method::Options], route.into(), handler)
248 }
249
250 pub fn trace(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
252 self.route(&[Method::Trace], route.into(), handler)
253 }
254
255 pub fn patch(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
257 self.route(&[Method::Patch], route.into(), handler)
258 }
259
260 pub fn fallback(mut self, handler: HandlerFn<T>) -> Self {
262 self.fallback_handler = Some(Handler::new(
263 handler,
264 self.pre_layers.clone(),
265 self.post_layers.clone(),
266 ));
267 self
268 }
269
270 pub fn build(self) -> Router<T> {
272 Router(Arc::new(InnerRouter {
273 ctx: self.ctx,
274 routes: self.routes,
275 not_allowed_method_handler: self.not_allowed_method_handler.unwrap_or_else(|| {
276 Handler::new(
277 |_, _| {
278 Response::with_status(Status::MethodNotAllowed)
279 .body("405 Method Not Allowed")
280 },
281 self.pre_layers.clone(),
282 self.post_layers.clone(),
283 )
284 }),
285 fallback_handler: self.fallback_handler.unwrap_or_else(|| {
286 Handler::new(
287 |_, _| Response::with_status(Status::NotFound).body("404 Not Found"),
288 self.pre_layers.clone(),
289 self.post_layers.clone(),
290 )
291 }),
292 }))
293 }
294}
295
296struct InnerRouter<T: Clone> {
298 ctx: T,
299 routes: Vec<Route<T>>,
300 not_allowed_method_handler: Handler<T>,
301 fallback_handler: Handler<T>,
302}
303
304impl<T: Clone> InnerRouter<T> {
305 fn handle(&self, req: &Request) -> Response {
306 let mut ctx = self.ctx.clone();
307
308 let path = req.url.path();
310 for route in self.routes.iter().rev() {
311 if route.is_match(path) {
312 let mut req = req.clone();
313 req.params = route.match_path(path);
314
315 for route in self.routes.iter().filter(|r| r.route == route.route) {
317 if route.methods.contains(&req.method) {
318 return route.handler.call(&req, &mut ctx);
319 }
320 }
321
322 return self.not_allowed_method_handler.call(&req, &mut ctx);
324 }
325 }
326
327 self.fallback_handler.call(req, &mut ctx)
329 }
330}
331
332#[derive(Clone)]
335pub struct Router<T: Clone>(Arc<InnerRouter<T>>);
336
337impl<T: Clone> Router<T> {
338 pub fn handle(&self, req: &Request) -> Response {
340 self.0.handle(req)
341 }
342}
343
344#[cfg(test)]
346mod test {
347 use small_http::Status;
348
349 use super::*;
350
351 fn home(_req: &Request, _ctx: &()) -> Response {
352 Response::with_status(Status::Ok).body("Hello, World!")
353 }
354
355 fn hello(req: &Request, _ctx: &()) -> Response {
356 let name = req.params.get("name").unwrap();
357 Response::with_status(Status::Ok).body(format!("Hello, {name}!"))
358 }
359
360 #[test]
361 fn test_routing() {
362 let router = RouterBuilder::new()
363 .get("/", home)
364 .get("/hello/:name", hello)
365 .get("/hello/:name/i/:am/so/:deep", hello)
366 .build();
367
368 let res = router.handle(&Request::get("http://localhost/"));
370 assert_eq!(res.status, Status::Ok);
371 assert_eq!(res.body, b"Hello, World!");
372
373 let res = router.handle(&Request::get("http://localhost/unknown"));
375 assert_eq!(res.status, Status::NotFound);
376 assert_eq!(res.body, b"404 Not Found");
377
378 let res = router.handle(&Request::get("http://localhost/hello/Bassie"));
380 assert_eq!(res.status, Status::Ok);
381 assert_eq!(res.body, b"Hello, Bassie!");
382
383 let res = router.handle(&Request::get(
385 "http://localhost/hello/Bassie/i/handle/so/much",
386 ));
387 assert_eq!(res.status, Status::Ok);
388
389 let res = router.handle(&Request::options("http://localhost/"));
391 assert_eq!(res.status, Status::MethodNotAllowed);
392 assert_eq!(res.body, b"405 Method Not Allowed");
393 }
394}