small_router/
lib.rs

1/*
2 * Copyright (c) 2024-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#![doc = include_str!("../README.md")]
8#![forbid(unsafe_code)]
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use small_http::{Method, Request, Response, Status};
14
15// MARK: Handler
16
17/// Parsed path parameters
18type HandlerFn<T> = fn(&Request, &T) -> Response;
19type PreLayerFn<T> = fn(&Request, &mut T) -> Option<Response>;
20type PostLayerFn<T> = fn(&Request, &mut T, Response) -> Response;
21
22struct Handler<T> {
23    handler: HandlerFn<T>,
24    pre_layers: Vec<PreLayerFn<T>>,
25    post_layers: Vec<PostLayerFn<T>>,
26}
27
28impl<T> Handler<T> {
29    fn new(
30        handler: HandlerFn<T>,
31        pre_layers: Vec<PreLayerFn<T>>,
32        post_layers: Vec<PostLayerFn<T>>,
33    ) -> Self {
34        Self {
35            handler,
36            pre_layers,
37            post_layers,
38        }
39    }
40
41    fn call(&self, req: &Request, ctx: &mut T) -> Response {
42        for pre_layer in &self.pre_layers {
43            if let Some(mut res) = pre_layer(req, ctx) {
44                for post_layer in &self.post_layers {
45                    res = post_layer(req, ctx, res);
46                }
47                return res;
48            }
49        }
50        let mut res = (self.handler)(req, ctx);
51        for post_layer in &self.post_layers {
52            res = post_layer(req, ctx, res);
53        }
54        res
55    }
56}
57
58// MARK: Route
59enum RoutePart {
60    Static(String),
61    Param(String),
62}
63
64struct Route<T> {
65    methods: Vec<Method>,
66    route: String,
67    parts: Vec<RoutePart>,
68    handler: Handler<T>,
69}
70
71impl<T> Route<T> {
72    fn new(methods: Vec<Method>, route: String, handler: Handler<T>) -> Self {
73        let parts = Self::route_parse_parts(&route);
74        Self {
75            methods,
76            route,
77            parts,
78            handler,
79        }
80    }
81
82    fn route_parse_parts(route: &str) -> Vec<RoutePart> {
83        route
84            .split('/')
85            .filter(|part| !part.is_empty())
86            .map(|part| {
87                if let Some(stripped) = part.strip_prefix(':') {
88                    RoutePart::Param(stripped.to_string())
89                } else {
90                    RoutePart::Static(part.to_string())
91                }
92            })
93            .collect()
94    }
95
96    fn is_match(&self, path: &str) -> bool {
97        let mut path_parts = path.split('/').filter(|part| !part.is_empty());
98        for part in &self.parts {
99            match part {
100                RoutePart::Static(expected) => {
101                    if let Some(actual) = path_parts.next() {
102                        if actual != *expected {
103                            return false;
104                        }
105                    } else {
106                        return false;
107                    }
108                }
109                RoutePart::Param(_) => {
110                    if path_parts.next().is_none() {
111                        return false;
112                    }
113                }
114            }
115        }
116        path_parts.next().is_none()
117    }
118
119    fn match_path(&self, path: &str) -> HashMap<String, String> {
120        let mut path_parts = path.split('/').filter(|part| !part.is_empty());
121        let mut params = HashMap::new();
122        for part in &self.parts {
123            match part {
124                RoutePart::Static(_) => {
125                    path_parts.next();
126                }
127                RoutePart::Param(name) => {
128                    if let Some(value) = path_parts.next() {
129                        params.insert(name.clone(), value.to_string());
130                    }
131                }
132            }
133        }
134        params
135    }
136}
137
138// MARK: RouterBuilder
139/// Router builder
140pub struct RouterBuilder<T: Clone> {
141    ctx: T,
142    pre_layers: Vec<PreLayerFn<T>>,
143    post_layers: Vec<PostLayerFn<T>>,
144    routes: Vec<Route<T>>,
145    not_allowed_method_handler: Option<Handler<T>>,
146    fallback_handler: Option<Handler<T>>,
147}
148
149impl Default for RouterBuilder<()> {
150    fn default() -> Self {
151        Self::with(())
152    }
153}
154
155impl RouterBuilder<()> {
156    /// Create new router
157    pub fn new() -> Self {
158        Self::default()
159    }
160}
161
162impl<T: Clone> RouterBuilder<T> {
163    /// Create new router with context
164    pub fn with(ctx: T) -> Self {
165        Self {
166            ctx,
167            pre_layers: Vec::new(),
168            post_layers: Vec::new(),
169            routes: Vec::new(),
170            not_allowed_method_handler: None,
171            fallback_handler: None,
172        }
173    }
174
175    /// Add pre layer
176    pub fn pre_layer(mut self, layer: PreLayerFn<T>) -> Self {
177        self.pre_layers.push(layer);
178        self
179    }
180
181    /// Add post layer
182    pub fn post_layer(mut self, layer: PostLayerFn<T>) -> Self {
183        self.post_layers.push(layer);
184        self
185    }
186
187    /// Add route
188    pub fn route(mut self, methods: &[Method], route: String, handler: HandlerFn<T>) -> Self {
189        self.routes.push(Route::new(
190            methods.to_vec(),
191            route,
192            Handler::new(handler, self.pre_layers.clone(), self.post_layers.clone()),
193        ));
194        self
195    }
196
197    /// Add route for any method
198    pub fn any(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
199        self.route(
200            &[
201                Method::Get,
202                Method::Head,
203                Method::Post,
204                Method::Put,
205                Method::Delete,
206                Method::Connect,
207                Method::Options,
208                Method::Trace,
209                Method::Patch,
210            ],
211            route.into(),
212            handler,
213        )
214    }
215    /// Add route for GET method
216    pub fn get(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
217        self.route(&[Method::Get], route.into(), handler)
218    }
219
220    /// Add route for HEAD method
221    pub fn head(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
222        self.route(&[Method::Head], route.into(), handler)
223    }
224
225    /// Add route for POST method
226    pub fn post(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
227        self.route(&[Method::Post], route.into(), handler)
228    }
229
230    /// Add route for PUT method
231    pub fn put(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
232        self.route(&[Method::Put], route.into(), handler)
233    }
234
235    /// Add route for DELETE method
236    pub fn delete(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
237        self.route(&[Method::Delete], route.into(), handler)
238    }
239
240    /// Add route for CONNECT method
241    pub fn connect(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
242        self.route(&[Method::Connect], route.into(), handler)
243    }
244
245    /// Add route for OPTIONS method
246    pub fn options(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
247        self.route(&[Method::Options], route.into(), handler)
248    }
249
250    /// Add route for TRACE method
251    pub fn trace(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
252        self.route(&[Method::Trace], route.into(), handler)
253    }
254
255    /// Add route for PATCH method
256    pub fn patch(self, route: impl Into<String>, handler: HandlerFn<T>) -> Self {
257        self.route(&[Method::Patch], route.into(), handler)
258    }
259
260    /// Set fallback handler
261    pub fn fallback(mut self, handler: HandlerFn<T>) -> Self {
262        self.fallback_handler = Some(Handler::new(
263            handler,
264            self.pre_layers.clone(),
265            self.post_layers.clone(),
266        ));
267        self
268    }
269
270    /// Build router
271    pub fn build(self) -> Router<T> {
272        Router(Arc::new(InnerRouter {
273            ctx: self.ctx,
274            routes: self.routes,
275            not_allowed_method_handler: self.not_allowed_method_handler.unwrap_or_else(|| {
276                Handler::new(
277                    |_, _| {
278                        Response::with_status(Status::MethodNotAllowed)
279                            .body("405 Method Not Allowed")
280                    },
281                    self.pre_layers.clone(),
282                    self.post_layers.clone(),
283                )
284            }),
285            fallback_handler: self.fallback_handler.unwrap_or_else(|| {
286                Handler::new(
287                    |_, _| Response::with_status(Status::NotFound).body("404 Not Found"),
288                    self.pre_layers.clone(),
289                    self.post_layers.clone(),
290                )
291            }),
292        }))
293    }
294}
295
296// MARK: InnerRouter
297struct InnerRouter<T: Clone> {
298    ctx: T,
299    routes: Vec<Route<T>>,
300    not_allowed_method_handler: Handler<T>,
301    fallback_handler: Handler<T>,
302}
303
304impl<T: Clone> InnerRouter<T> {
305    fn handle(&self, req: &Request) -> Response {
306        let mut ctx = self.ctx.clone();
307
308        // Match routes
309        let path = req.url.path();
310        for route in self.routes.iter().rev() {
311            if route.is_match(path) {
312                let mut req = req.clone();
313                req.params = route.match_path(path);
314
315                // Find matching route by method
316                for route in self.routes.iter().filter(|r| r.route == route.route) {
317                    if route.methods.contains(&req.method) {
318                        return route.handler.call(&req, &mut ctx);
319                    }
320                }
321
322                // Or run not allowed method handler
323                return self.not_allowed_method_handler.call(&req, &mut ctx);
324            }
325        }
326
327        // Or run fallback handler
328        self.fallback_handler.call(req, &mut ctx)
329    }
330}
331
332// MARK: Router
333/// Router
334#[derive(Clone)]
335pub struct Router<T: Clone>(Arc<InnerRouter<T>>);
336
337impl<T: Clone> Router<T> {
338    /// Handle request
339    pub fn handle(&self, req: &Request) -> Response {
340        self.0.handle(req)
341    }
342}
343
344// MARK: Tests
345#[cfg(test)]
346mod test {
347    use small_http::Status;
348
349    use super::*;
350
351    fn home(_req: &Request, _ctx: &()) -> Response {
352        Response::with_status(Status::Ok).body("Hello, World!")
353    }
354
355    fn hello(req: &Request, _ctx: &()) -> Response {
356        let name = req.params.get("name").unwrap();
357        Response::with_status(Status::Ok).body(format!("Hello, {name}!"))
358    }
359
360    #[test]
361    fn test_routing() {
362        let router = RouterBuilder::new()
363            .get("/", home)
364            .get("/hello/:name", hello)
365            .get("/hello/:name/i/:am/so/:deep", hello)
366            .build();
367
368        // Test home route
369        let res = router.handle(&Request::get("http://localhost/"));
370        assert_eq!(res.status, Status::Ok);
371        assert_eq!(res.body, b"Hello, World!");
372
373        // Test fallback route
374        let res = router.handle(&Request::get("http://localhost/unknown"));
375        assert_eq!(res.status, Status::NotFound);
376        assert_eq!(res.body, b"404 Not Found");
377
378        // Test route with params
379        let res = router.handle(&Request::get("http://localhost/hello/Bassie"));
380        assert_eq!(res.status, Status::Ok);
381        assert_eq!(res.body, b"Hello, Bassie!");
382
383        // Test route with multiple params
384        let res = router.handle(&Request::get(
385            "http://localhost/hello/Bassie/i/handle/so/much",
386        ));
387        assert_eq!(res.status, Status::Ok);
388
389        // Test wrong method
390        let res = router.handle(&Request::options("http://localhost/"));
391        assert_eq!(res.status, Status::MethodNotAllowed);
392        assert_eq!(res.body, b"405 Method Not Allowed");
393    }
394}