Skip to main content

const_router/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use crate::__internal::{Handler, Route};
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use tower_service::Service;
10
11extern crate self as const_router;
12
13pub use const_router_macros::{handler, router};
14
15/// A static string-key router.
16///
17/// Build a router with [`router!`]. Requests are matched by the value returned
18/// from [`ExtractKey`]. Missing keys are routed to the configured fallback
19/// handler.
20#[derive(Debug)]
21pub struct Router<TRequest, TResponse, TError>
22where
23    TResponse: 'static,
24    TError: 'static,
25    TRequest: 'static,
26{
27    fallback: Handler<TRequest, TResponse, TError>,
28    routes: &'static [Route<TRequest, TResponse, TError>],
29}
30
31/// Extracts the route key from a request.
32///
33/// The key is compared against the string literals passed to [`router!`].
34pub trait ExtractKey {
35    /// Returns the string key used for route matching.
36    fn extract_key(&self) -> &str;
37}
38
39/// Boxed future returned by router handlers.
40pub type BoxFuture<TResponse, TError> =
41    Pin<Box<dyn Future<Output = Result<TResponse, TError>> + Send + 'static>>;
42
43impl<TRequest, TResponse, TError> Router<TRequest, TResponse, TError>
44where
45    TRequest: ExtractKey,
46{
47    /// Routes a request to the matching handler or to the fallback handler.
48    pub fn handle(&self, req: TRequest) -> BoxFuture<TResponse, TError> {
49        let key = req.extract_key();
50
51        let handler = match self.routes.binary_search_by(|route| route.key.cmp(key)) {
52            Ok(index) => &self.routes[index].handler,
53            Err(_) => &self.fallback,
54        };
55
56        (handler.0)(req)
57    }
58}
59
60impl<TRequest, TResponse, TError> Service<TRequest> for Router<TRequest, TResponse, TError>
61where
62    TResponse: 'static,
63    TError: 'static,
64    TRequest: ExtractKey + 'static,
65{
66    type Error = TError;
67    type Future = BoxFuture<TResponse, TError>;
68    type Response = TResponse;
69
70    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71        Poll::Ready(Ok(()))
72    }
73
74    fn call(&mut self, req: TRequest) -> Self::Future {
75        self.handle(req)
76    }
77}
78
79#[cfg(feature = "http")]
80impl<T> ExtractKey for http::Request<T> {
81    fn extract_key(&self) -> &str {
82        self.uri().path()
83    }
84}
85
86#[doc(hidden)]
87pub mod __internal {
88    use super::*;
89
90    #[derive(Debug)]
91    pub struct Route<TRequest, TResponse, TError> {
92        pub(crate) key: &'static str,
93        pub(crate) handler: Handler<TRequest, TResponse, TError>,
94    }
95
96    #[derive(Debug)]
97    pub struct Handler<TRequest, TResponse, TError>(
98        pub(crate) HandlerFn<TRequest, TResponse, TError>,
99    );
100
101    type HandlerFn<TRequest, TResponse, TError> = fn(TRequest) -> BoxFuture<TResponse, TError>;
102
103    pub const fn new_router<TRequest, TResponse, TError>(
104        fallback: Handler<TRequest, TResponse, TError>,
105        routes: &'static [Route<TRequest, TResponse, TError>],
106    ) -> Router<TRequest, TResponse, TError>
107    where
108        TResponse: 'static,
109        TError: 'static,
110        TRequest: 'static,
111    {
112        Router { fallback, routes }
113    }
114
115    pub const fn new_route<TRequest, TResponse, TError>(
116        key: &'static str,
117        handler: Handler<TRequest, TResponse, TError>,
118    ) -> Route<TRequest, TResponse, TError> {
119        Route { key, handler }
120    }
121
122    pub const fn new_handler<TRequest, TResponse, TError>(
123        handler_fn: HandlerFn<TRequest, TResponse, TError>,
124    ) -> Handler<TRequest, TResponse, TError> {
125        Handler(handler_fn)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use std::{
133        fmt::Debug,
134        task::{Context, Poll, Waker},
135    };
136    use tower_service::Service;
137
138    #[derive(Debug)]
139    struct Request {
140        key: String,
141    }
142
143    impl ExtractKey for Request {
144        fn extract_key(&self) -> &str {
145            &self.key
146        }
147    }
148
149    type TestRouter = Router<Request, String, &'static str>;
150
151    #[handler]
152    fn alpha_handler(_req: Request) -> Result<String, &'static str> {
153        Ok("alpha".to_owned())
154    }
155
156    #[handler]
157    async fn async_handler(_req: Request) -> Result<String, &'static str> {
158        Ok("async".to_owned())
159    }
160
161    #[handler]
162    fn echo_handler(req: Request) -> Result<String, &'static str> {
163        Ok(req.key)
164    }
165
166    #[handler]
167    fn generic_handler<T, const N: usize>(_req: Request) -> Result<String, &'static str>
168    where
169        T: Default,
170    {
171        let _ = T::default();
172
173        Ok(format!("generic-{N}"))
174    }
175
176    #[handler]
177    fn error_handler(_req: Request) -> Result<String, &'static str> {
178        Err("route failed")
179    }
180
181    #[handler]
182    fn fallback_handler() -> Result<String, &'static str> {
183        Ok("fallback".to_owned())
184    }
185
186    static ROUTER: TestRouter = router! {
187        fallback_handler,
188        "/generic" => generic_handler::<usize, 7>,
189        "/error" => error_handler,
190        "/async" => async_handler,
191        "/echo" => echo_handler,
192        "/alpha" => alpha_handler,
193    };
194
195    static FALLBACK_ONLY_ROUTER: TestRouter = router! {
196        fallback_handler,
197    };
198
199    fn request(key: impl Into<String>) -> Request {
200        Request { key: key.into() }
201    }
202
203    fn ready<T, E>(mut future: BoxFuture<T, E>) -> Result<T, E> {
204        let waker = Waker::noop();
205        let mut cx = Context::from_waker(waker);
206
207        match future.as_mut().poll(&mut cx) {
208            Poll::Ready(result) => result,
209            Poll::Pending => panic!("handler future did not complete"),
210        }
211    }
212
213    fn route_result(router: &TestRouter, key: &str) -> Result<String, &'static str> {
214        ready(router.handle(request(key)))
215    }
216
217    fn assert_route(router: &TestRouter, key: &str, expected: Result<&'static str, &'static str>) {
218        assert_eq!(route_result(router, key), expected.map(str::to_owned));
219    }
220
221    fn assert_ready<E>(poll: Poll<Result<(), E>>)
222    where
223        E: Debug + PartialEq,
224    {
225        assert_eq!(poll, Poll::Ready(Ok(())));
226    }
227
228    #[test]
229    fn router_macro_matches_sorted_routes_and_fallback() {
230        for (key, expected) in [
231            ("/alpha", Ok("alpha")),
232            ("/async", Ok("async")),
233            ("/echo", Ok("/echo")),
234            ("/generic", Ok("generic-7")),
235            ("/error", Err("route failed")),
236            ("/missing", Ok("fallback")),
237            ("", Ok("fallback")),
238        ] {
239            assert_route(&ROUTER, key, expected);
240        }
241    }
242
243    #[test]
244    fn fallback_only_router_routes_every_request_to_fallback() {
245        for key in ["", "/", "/alpha", "/missing"] {
246            assert_route(&FALLBACK_ONLY_ROUTER, key, Ok("fallback"));
247        }
248    }
249
250    #[test]
251    fn router_implements_tower_service() {
252        static ROUTES: [Route<Request, String, &'static str>; 1] =
253            [__internal::new_route("/alpha", alpha_handler())];
254        let mut router = __internal::new_router(fallback_handler(), &ROUTES);
255        let waker = Waker::noop();
256        let mut cx = Context::from_waker(waker);
257
258        assert_ready(Service::poll_ready(&mut router, &mut cx));
259        assert_eq!(
260            ready(Service::call(&mut router, request("/alpha"))),
261            Ok("alpha".to_owned())
262        );
263        assert_eq!(
264            ready(Service::call(&mut router, request("/unknown"))),
265            Ok("fallback".to_owned())
266        );
267    }
268
269    #[cfg(feature = "http")]
270    #[test]
271    fn http_request_extracts_uri_path() {
272        for (uri, path) in [
273            ("https://example.com/static?ignored=true", "/static"),
274            ("/nested/path#fragment", "/nested/path"),
275            ("*", "*"),
276        ] {
277            let request = http::Request::builder()
278                .uri(uri)
279                .body(())
280                .expect("request should build");
281
282            assert_eq!(request.extract_key(), path);
283        }
284    }
285}