under/router/
mod.rs

1mod pattern;
2mod route;
3mod service;
4
5pub(crate) use self::pattern::Pattern;
6pub use self::route::Path;
7pub(crate) use self::route::Route;
8use crate::endpoint::Endpoint;
9use crate::middleware::Middleware;
10use crate::{Request, Response};
11use std::pin::Pin;
12use std::sync::Arc;
13use tokio::sync::watch;
14
15/// An HTTP router.
16///
17/// This contains a set of paths, and the [`Endpoint`]s they point
18/// to.  This expects a root, `/`, and all paths placed in this router are
19/// expected to be based off of this root.  Ultimately, this is an array of
20/// routes, where each route is a path, a method, and an endpoint.  If
21/// the incoming request matches on the path and method, then the last route
22/// inserted that matches will have its endpoint run.  So, assuming that you
23/// have the following routes defined:
24///
25/// ```text
26/// // ...
27/// POST /user/{id} -> endpoint_user_id
28/// POST /user/@me -> endpoint_user_me
29/// // ...
30/// ```
31///
32/// Even though the former route can match `/user/@me`, the latter route will
33/// always be picked instead, as it is defined _after_ the former.
34///
35/// # Internals
36///
37/// Internally, the router uses a regular expression matcher to convert the
38/// given paths (e.g. `/user/{id}`) into a regular expression
39/// (`^/user/(?P<id>[^/]+)`).  It does this segment-by-segment in the path, and
40/// is rather strict about what the names of a placeholder component can be
41/// (only alphabetical).  This is compiled into a `RegexSet`, which, when run
42/// against a given path, will return a list of routes that the path matches.
43/// Because we don't have to fool around with matching every route, the timing
44/// is `O(n)`, with `n` being the length of the input path.  After the
45/// `RegexSet` match, we again match against the route to collect the pattern
46/// matchers (e.g. `{some}` and `{value:path}`), before returning both.  This
47/// information is included as a part of the request.
48pub struct Router {
49    regex: regex::RegexSet,
50    routes: Vec<Arc<Route>>,
51    middleware: Vec<Pin<Box<dyn Middleware>>>,
52    fallback: Option<Pin<Box<dyn Endpoint>>>,
53    terminate: Option<watch::Receiver<bool>>,
54}
55
56impl Default for Router {
57    fn default() -> Self {
58        Router {
59            regex: regex::RegexSet::empty(),
60            middleware: vec![],
61            routes: vec![],
62            fallback: None,
63            terminate: None,
64        }
65    }
66}
67
68impl Router {
69    /// Prepares the router, constructing the routes.
70    ///
71    /// This is automatically called when listening using [`Router::listen`].
72    /// However, you may want to use the router before that point for e.g.
73    /// testing, and so this must be called before any requests are routed
74    /// (or, if any routes are changed).  If this is not called, you will
75    /// receive only 500 errors.
76    #[allow(clippy::missing_panics_doc)]
77    pub fn prepare(&mut self) {
78        let patterns = self
79            .routes
80            .iter()
81            .map(|route| route.pattern.regex().as_str());
82        // This shouldn't panic, because the patterns were already validated
83        // (e.g. if any of them were invalid, we would have already panicked).
84        let set = regex::RegexSet::new(patterns).unwrap();
85        self.regex = set;
86    }
87
88    pub(crate) fn routes(&self) -> &[Arc<Route>] {
89        &self.routes[..]
90    }
91
92    /// Creates a [`Path`] at the provided prefix.  See [`Path::at`] for more.
93    pub fn at<P: AsRef<str>>(&mut self, prefix: P) -> Path<'_> {
94        Path::new(join_paths("", prefix.as_ref()), &mut self.routes)
95    }
96
97    /// Creates a [`Path`] at the provided prefix, and executes the provided
98    /// closure with it.  See [`Path::under`] for more.
99    pub fn under<P: AsRef<str>, F: FnOnce(&mut Path<'_>)>(
100        &mut self,
101        prefix: P,
102        build: F,
103    ) -> &mut Self {
104        let mut path = Path::new(join_paths("", prefix.as_ref()), &mut self.routes);
105        build(&mut path);
106        self
107    }
108
109    /// Appends middleware to the router.  Each middleware is executed in the
110    /// order that it is appended to the router (i.e., the first middleware
111    /// inserted executes first).
112    ///
113    /// # Examples
114    /// ```rust
115    /// let mut http = under::http();
116    /// http.with(under::middleware::TraceMiddleware::new())
117    ///     .with(under::middleware::StateMiddleware::new(123u32));
118    /// ```
119    pub fn with<M: Middleware>(&mut self, middleware: M) -> &mut Self {
120        self.middleware.push(Box::pin(middleware));
121        self
122    }
123
124    /// Sets a fallback endpoint.  If there exists no other endpoint in the
125    /// router that could potentially respond to the request, it will first
126    /// attempt to execute this fallback endpoint, before instead returning
127    /// an empty 500 error.
128    ///
129    /// # Examples
130    /// ```rust
131    /// # use under::*;
132    /// # #[tokio::main] async fn main() -> Result<(), anyhow::Error> {
133    /// let mut http = under::http();
134    /// http.at("/foo").get(under::endpoints::simple(Response::empty_204));
135    /// http.fallback(under::endpoints::simple(Response::empty_404));
136    /// http.prepare();
137    /// let response = http.handle(Request::get("/foo")?).await?;
138    /// assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
139    /// let response = http.handle(Request::get("/bar")?).await?;
140    /// assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
141    /// # Ok(())
142    /// # }
143    pub fn fallback<E: Endpoint>(&mut self, endpoint: E) -> &mut Self {
144        self.fallback = Some(Box::pin(endpoint));
145        self
146    }
147
148    /// A channel to handle the termination singal.  By default, the router does
149    /// not terminate, at least not gracefully, even in the face of
150    /// SIGINT/SIGTERM.  This allows you to signal to the router when it should
151    /// terminate, and it will gracefully shut down, letting all current
152    /// requests finish before exiting.  Note that the return type is not
153    /// `Clone`, and dropping the sender will not terminate the router.
154    ///
155    /// Note this only applies to the router when listening, and not when
156    /// handling a single request.
157    pub fn termination_signal(&mut self) -> watch::Sender<bool> {
158        let (tx, rx) = watch::channel(false);
159        self.terminate = Some(rx);
160        tx
161    }
162
163    /// Handles a one-off request to the router.  This is equivalent to pinning
164    /// the router (with [`Pin::new`], since the Router is `Unpin`), before
165    /// calling [`crate::Endpoint::apply`].
166    ///
167    /// # Errors
168    /// This will error if any middleware or endpoint errors.  Note that this
169    /// does not error if the router was not prepared before calling this
170    /// method.
171    pub async fn handle(&self, request: Request) -> Result<Response, anyhow::Error> {
172        Pin::new(self).apply(request).await
173    }
174
175    pub(crate) fn lookup(&self, path: &str, method: &http::Method) -> Option<Arc<Route>> {
176        self.regex
177            .matches(path)
178            .into_iter()
179            .map(|i| &self.routes[i])
180            .filter(|r| r.matches(method))
181            .next_back()
182            .cloned()
183    }
184
185    fn fallback_endpoint(&self) -> Option<Pin<&dyn Endpoint>> {
186        self.fallback.as_ref().map(Pin::as_ref)
187    }
188}
189
190#[async_trait]
191impl crate::Endpoint for Router {
192    async fn apply(self: Pin<&Self>, mut request: Request) -> Result<Response, anyhow::Error> {
193        let route = self.lookup(request.uri().path(), request.method());
194        if let Some(route) = route.clone() {
195            // This should most always be a `Some`, because the route's path
196            // would 100% match the uri's path.
197            if let Some(fragment) =
198                crate::request::fragment::Fragment::new(request.uri().path(), &route)
199            {
200                request.extensions_mut().insert(fragment);
201            }
202            request.extensions_mut().insert(route);
203        }
204
205        let endpoint = {
206            let route_endpoint = || route.as_ref().map(|e| e.endpoint().as_ref());
207            let fallback_endpoint = || self.fallback_endpoint();
208            route_endpoint()
209                .or_else(fallback_endpoint)
210                .unwrap_or_else(default_endpoint)
211        };
212        log::trace!("{} {} --> {:?}", request.method(), request.uri(), endpoint);
213        let next = crate::middleware::Next::new(&self.middleware[..], endpoint);
214        next.apply(request).await
215    }
216}
217
218impl std::fmt::Debug for Router {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        f.debug_struct("Router")
221            .field("regex", &self.regex)
222            .field("routes", &self.routes)
223            .finish()
224    }
225}
226
227lazy_static::lazy_static! {
228    static ref DEFAULT_ENDPOINT: crate::endpoints::SyncEndpoint<fn(Request) -> Response> = crate::endpoints::SyncEndpoint::new(|_| Response::empty_500());
229    static ref DEFAULT_ENDPOINT_PIN: Pin<&'static (dyn Endpoint + Unpin + 'static)> = Pin::new(&*DEFAULT_ENDPOINT);
230}
231
232// 'r can be anything _up to and including_ 'static, and this makes it play
233// nice with unwrap_or_else.
234pub(crate) fn default_endpoint<'r>() -> Pin<&'r dyn Endpoint> {
235    *DEFAULT_ENDPOINT_PIN
236}
237
238// Base *MUST* be either `""` or start with `"/"`.
239fn join_paths(base: &str, extend: &str) -> String {
240    let mut buffer = String::with_capacity(base.len() + extend.len());
241    buffer.push_str(base);
242
243    match (base.ends_with('/'), extend.starts_with('/')) {
244        (true, true) => {
245            buffer.push_str(&extend[1..]);
246        }
247        (false, true) | (true, false) => {
248            buffer.push_str(extend);
249        }
250        (false, false) => {
251            buffer.push('/');
252            buffer.push_str(extend);
253        }
254    }
255
256    buffer.shrink_to_fit();
257    buffer
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263    use crate::request::Request;
264    use crate::response::Response;
265    use crate::UnderError;
266
267    #[allow(clippy::unused_async)]
268    async fn simple_endpoint(_: Request) -> Result<Response, UnderError> {
269        unimplemented!()
270    }
271
272    fn simple_router() -> Router {
273        let mut router = Router::default();
274        router.at("/").get(simple_endpoint);
275        router.at("/alpha").get(simple_endpoint);
276        router.at("/beta/{id}").get(simple_endpoint);
277        router.at("/gamma/{all:path}").get(simple_endpoint);
278        router.prepare();
279        router
280    }
281
282    #[test]
283    fn test_join_paths() {
284        assert_eq!(join_paths("", "/id"), "/id");
285        assert_eq!(join_paths("", "id"), "/id");
286        assert_eq!(join_paths("/user", "/id"), "/user/id");
287        assert_eq!(join_paths("/user/", "/id"), "/user/id");
288        assert_eq!(join_paths("/user/", "id"), "/user/id");
289    }
290
291    #[test]
292    fn test_build() {
293        simple_router();
294    }
295
296    #[test]
297    fn test_basic_match() {
298        let router = simple_router();
299        dbg!(&router);
300        let result = router.lookup("/", &http::Method::GET);
301        assert!(result.is_some());
302        let result = result.unwrap();
303        assert_eq!("/", &result.path);
304    }
305
306    #[test]
307    fn test_simple_match() {
308        let router = simple_router();
309        let result = router.lookup("/beta/4444", &http::Method::GET);
310        assert!(result.is_some());
311        let result = result.unwrap();
312        assert_eq!("/beta/{id}", &result.path);
313    }
314
315    #[test]
316    fn test_multi_match() {
317        let router = simple_router();
318        let result = router.lookup("/gamma/a/b/c", &http::Method::GET);
319        assert!(result.is_some());
320        let result = result.unwrap();
321        assert_eq!("/gamma/{all:path}", &result.path);
322    }
323
324    #[test]
325    fn test_missing_match() {
326        let router = simple_router();
327        let result = router.lookup("/omega/aaa", &http::Method::GET);
328        assert!(result.is_none());
329    }
330
331    #[test]
332    fn test_correct_method() {
333        let router = simple_router();
334        let result = router.lookup("/alpha", &http::Method::POST);
335        assert!(result.is_none());
336    }
337}