1#![doc = include_str!("../README.md")]
2
3use crate::__internal::{Handler, Route};
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8};
9use tower_service::Service;
10
11extern crate self as const_router;
12
13pub use const_router_macros::{handler, router};
14
15#[derive(Debug)]
21pub struct Router<TRequest, TResponse, TError>
22where
23 TResponse: 'static,
24 TError: 'static,
25 TRequest: 'static,
26{
27 fallback: Handler<TRequest, TResponse, TError>,
28 routes: &'static [Route<TRequest, TResponse, TError>],
29}
30
31pub trait ExtractKey {
35 fn extract_key(&self) -> &str;
37}
38
39pub type BoxFuture<TResponse, TError> =
41 Pin<Box<dyn Future<Output = Result<TResponse, TError>> + Send + 'static>>;
42
43impl<TRequest, TResponse, TError> Router<TRequest, TResponse, TError>
44where
45 TRequest: ExtractKey,
46{
47 pub fn handle(&self, req: TRequest) -> BoxFuture<TResponse, TError> {
49 let key = req.extract_key();
50
51 let handler = match self.routes.binary_search_by(|route| route.key.cmp(key)) {
52 Ok(index) => &self.routes[index].handler,
53 Err(_) => &self.fallback,
54 };
55
56 (handler.0)(req)
57 }
58}
59
60impl<TRequest, TResponse, TError> Service<TRequest> for Router<TRequest, TResponse, TError>
61where
62 TResponse: 'static,
63 TError: 'static,
64 TRequest: ExtractKey + 'static,
65{
66 type Error = TError;
67 type Future = BoxFuture<TResponse, TError>;
68 type Response = TResponse;
69
70 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71 Poll::Ready(Ok(()))
72 }
73
74 fn call(&mut self, req: TRequest) -> Self::Future {
75 self.handle(req)
76 }
77}
78
79#[cfg(feature = "http")]
80impl<T> ExtractKey for http::Request<T> {
81 fn extract_key(&self) -> &str {
82 self.uri().path()
83 }
84}
85
86#[doc(hidden)]
87pub mod __internal {
88 use super::*;
89
90 #[derive(Debug)]
91 pub struct Route<TRequest, TResponse, TError> {
92 pub(crate) key: &'static str,
93 pub(crate) handler: Handler<TRequest, TResponse, TError>,
94 }
95
96 #[derive(Debug)]
97 pub struct Handler<TRequest, TResponse, TError>(
98 pub(crate) HandlerFn<TRequest, TResponse, TError>,
99 );
100
101 type HandlerFn<TRequest, TResponse, TError> = fn(TRequest) -> BoxFuture<TResponse, TError>;
102
103 pub const fn new_router<TRequest, TResponse, TError>(
104 fallback: Handler<TRequest, TResponse, TError>,
105 routes: &'static [Route<TRequest, TResponse, TError>],
106 ) -> Router<TRequest, TResponse, TError>
107 where
108 TResponse: 'static,
109 TError: 'static,
110 TRequest: 'static,
111 {
112 Router { fallback, routes }
113 }
114
115 pub const fn new_route<TRequest, TResponse, TError>(
116 key: &'static str,
117 handler: Handler<TRequest, TResponse, TError>,
118 ) -> Route<TRequest, TResponse, TError> {
119 Route { key, handler }
120 }
121
122 pub const fn new_handler<TRequest, TResponse, TError>(
123 handler_fn: HandlerFn<TRequest, TResponse, TError>,
124 ) -> Handler<TRequest, TResponse, TError> {
125 Handler(handler_fn)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use std::{
133 fmt::Debug,
134 task::{Context, Poll, Waker},
135 };
136 use tower_service::Service;
137
138 #[derive(Debug)]
139 struct Request {
140 key: String,
141 }
142
143 impl ExtractKey for Request {
144 fn extract_key(&self) -> &str {
145 &self.key
146 }
147 }
148
149 type TestRouter = Router<Request, String, &'static str>;
150
151 #[handler]
152 fn alpha_handler(_req: Request) -> Result<String, &'static str> {
153 Ok("alpha".to_owned())
154 }
155
156 #[handler]
157 async fn async_handler(_req: Request) -> Result<String, &'static str> {
158 Ok("async".to_owned())
159 }
160
161 #[handler]
162 fn echo_handler(req: Request) -> Result<String, &'static str> {
163 Ok(req.key)
164 }
165
166 #[handler]
167 fn generic_handler<T, const N: usize>(_req: Request) -> Result<String, &'static str>
168 where
169 T: Default,
170 {
171 let _ = T::default();
172
173 Ok(format!("generic-{N}"))
174 }
175
176 #[handler]
177 fn error_handler(_req: Request) -> Result<String, &'static str> {
178 Err("route failed")
179 }
180
181 #[handler]
182 fn fallback_handler() -> Result<String, &'static str> {
183 Ok("fallback".to_owned())
184 }
185
186 static ROUTER: TestRouter = router! {
187 fallback_handler,
188 "/generic" => generic_handler::<usize, 7>,
189 "/error" => error_handler,
190 "/async" => async_handler,
191 "/echo" => echo_handler,
192 "/alpha" => alpha_handler,
193 };
194
195 static FALLBACK_ONLY_ROUTER: TestRouter = router! {
196 fallback_handler,
197 };
198
199 fn request(key: impl Into<String>) -> Request {
200 Request { key: key.into() }
201 }
202
203 fn ready<T, E>(mut future: BoxFuture<T, E>) -> Result<T, E> {
204 let waker = Waker::noop();
205 let mut cx = Context::from_waker(waker);
206
207 match future.as_mut().poll(&mut cx) {
208 Poll::Ready(result) => result,
209 Poll::Pending => panic!("handler future did not complete"),
210 }
211 }
212
213 fn route_result(router: &TestRouter, key: &str) -> Result<String, &'static str> {
214 ready(router.handle(request(key)))
215 }
216
217 fn assert_route(router: &TestRouter, key: &str, expected: Result<&'static str, &'static str>) {
218 assert_eq!(route_result(router, key), expected.map(str::to_owned));
219 }
220
221 fn assert_ready<E>(poll: Poll<Result<(), E>>)
222 where
223 E: Debug + PartialEq,
224 {
225 assert_eq!(poll, Poll::Ready(Ok(())));
226 }
227
228 #[test]
229 fn router_macro_matches_sorted_routes_and_fallback() {
230 for (key, expected) in [
231 ("/alpha", Ok("alpha")),
232 ("/async", Ok("async")),
233 ("/echo", Ok("/echo")),
234 ("/generic", Ok("generic-7")),
235 ("/error", Err("route failed")),
236 ("/missing", Ok("fallback")),
237 ("", Ok("fallback")),
238 ] {
239 assert_route(&ROUTER, key, expected);
240 }
241 }
242
243 #[test]
244 fn fallback_only_router_routes_every_request_to_fallback() {
245 for key in ["", "/", "/alpha", "/missing"] {
246 assert_route(&FALLBACK_ONLY_ROUTER, key, Ok("fallback"));
247 }
248 }
249
250 #[test]
251 fn router_implements_tower_service() {
252 static ROUTES: [Route<Request, String, &'static str>; 1] =
253 [__internal::new_route("/alpha", alpha_handler())];
254 let mut router = __internal::new_router(fallback_handler(), &ROUTES);
255 let waker = Waker::noop();
256 let mut cx = Context::from_waker(waker);
257
258 assert_ready(Service::poll_ready(&mut router, &mut cx));
259 assert_eq!(
260 ready(Service::call(&mut router, request("/alpha"))),
261 Ok("alpha".to_owned())
262 );
263 assert_eq!(
264 ready(Service::call(&mut router, request("/unknown"))),
265 Ok("fallback".to_owned())
266 );
267 }
268
269 #[cfg(feature = "http")]
270 #[test]
271 fn http_request_extracts_uri_path() {
272 for (uri, path) in [
273 ("https://example.com/static?ignored=true", "/static"),
274 ("/nested/path#fragment", "/nested/path"),
275 ("*", "*"),
276 ] {
277 let request = http::Request::builder()
278 .uri(uri)
279 .body(())
280 .expect("request should build");
281
282 assert_eq!(request.extract_key(), path);
283 }
284 }
285}