nucleus_http/
routes.rs

1use crate::{
2    http::{self, Header, Method, MimeType},
3    request::Request,
4    response::{IntoResponse, Response},
5    state::{FromRequest, State},
6};
7use async_trait::async_trait;
8use enum_map::{enum_map, EnumMap};
9use std::{
10    collections::HashMap,
11    future::Future,
12    path::{Path, PathBuf},
13    sync::Arc,
14    vec,
15};
16use tokio::sync::RwLock;
17
18#[async_trait]
19pub trait RequestResolver<S>: Send + Sync + 'static {
20    async fn resolve(&self, state: State<S>, request: Request) -> Response;
21}
22
23#[async_trait]
24impl<F, P, O, E, Fut> RequestResolver<P> for F
25where
26    O: IntoResponse,
27    E: IntoResponse,
28    Fut: Future<Output = Result<O, E>> + Send + 'static,
29    F: Fn(P, Request) -> Fut + Send + Sync + 'static,
30    P: FromRequest<P> + Send + Sync + 'static,
31{
32    async fn resolve(&self, state: State<P>, request: Request) -> Response {
33        let result = (self)(P::from_request(state, request.clone()), request).await;
34        match result {
35            Ok(r) => r.into_response(),
36            Err(e) => e.into_response(),
37        }
38    }
39}
40
41pub enum RouteResolver<S> {
42    Static { file_path: String },
43    Redirect(String),
44    Function(Arc<Box<dyn RequestResolver<S>>>),
45    Embed(&'static [u8], MimeType),
46}
47
48pub struct Route<S> {
49    method: Method,
50    path: String,
51    resolver: RouteResolver<S>,
52}
53
54pub type Routes<R> = Arc<RwLock<EnumMap<Method, HashMap<String, Route<R>>>>>;
55
56#[derive(Clone)]
57pub struct Router<S> {
58    routes: Routes<S>,
59    state: State<S>,
60    mime_headers: Vec<(MimeType, Header)>,
61    default_headers: Vec<Header>,
62}
63
64impl<S> Router<S>
65where
66    S: Clone + Send + Sync + 'static,
67{
68    #[tracing::instrument(level = "debug", skip(state))]
69    pub fn new(state: S) -> Self {
70        let map = enum_map! {
71            crate::routes::Method::GET | crate::routes::Method::POST => HashMap::new(),
72        };
73        Router {
74            routes: Arc::new(RwLock::new(map)),
75            state: State(state),
76            mime_headers: vec![],
77            default_headers: Header::new_server(), // default server headers. server sw name
78        }
79    }
80
81    #[tracing::instrument(level = "debug", skip(self, route))]
82    pub async fn add_route(&mut self, route: Route<S>) {
83        let mut routes_locked = self.routes.write().await;
84        routes_locked[*route.method()].insert(route.path.clone(), route);
85    }
86
87    #[tracing::instrument(level = "debug", skip(self))]
88    pub fn routes(&self) -> Routes<S> {
89        Arc::clone(&self.routes)
90    }
91
92    /// Add header to all responses
93    #[tracing::instrument(level = "debug", skip(self))]
94    pub fn add_default_header(&mut self, header: Header) {
95        self.default_headers.push(header);
96    }
97
98    /// Add a header that will be added to every response of this mime type
99    #[tracing::instrument(level = "debug", skip(self))]
100    pub fn add_mime_header(&mut self, header: Header, mime: MimeType) {
101        self.mime_headers.push((mime, header));
102    }
103
104    /// Add default and mime headers to req
105    #[tracing::instrument(level = "debug", skip(self))]
106    fn push_headers(&self, response: &mut Response) {
107        //FIXME: Do we need to worry about duplicates ?
108        //add default headers first then mime specific ones
109        for header in &self.default_headers {
110            response.add_header(header);
111        }
112
113        let mime = response.mime();
114        for (key, header) in &self.mime_headers {
115            if key == &mime {
116                response.add_header(header);
117            }
118        }
119    }
120
121    pub fn new_routes() -> Routes<S> {
122        let map = enum_map! {
123            crate::routes::Method::GET | crate::routes::Method::POST => HashMap::new(),
124        };
125        Arc::new(RwLock::new(map))
126    }
127
128    #[tracing::instrument(level = "debug", skip(self, doc_root))]
129    pub async fn route(&self, request: &Request, doc_root: impl AsRef<Path>) -> Response {
130        let routes = self.routes();
131        let routes_locked = &routes.read().await[*request.method()];
132        let mut matching_route = None;
133
134        //look for route mathcing requested URL
135        if let Some(route) = routes_locked.get(request.path()) {
136            //found exact route match
137            matching_route = Some(route);
138        } else {
139            // go through ancestors appending * on the end and see if we have any matches
140            let path = Path::new(request.path());
141            if let Some(parent) = path.parent() {
142                let ancestors = parent.ancestors();
143                for a in ancestors {
144                    if let Some(globed) = a.join("*").to_str() {
145                        if let Some(route) = routes_locked.get(globed) {
146                            matching_route = Some(route);
147                        }
148                    }
149                }
150            } else {
151                //no parent so its root, check for catch all bare *
152                if let Some(route) = routes_locked.get("*") {
153                    matching_route = Some(route);
154                }
155            }
156        }
157
158        //serve specific route if we match
159        if let Some(route) = matching_route {
160            tracing::debug!("Found matching route");
161            match route.resolver() {
162                RouteResolver::Static { file_path } => {
163                    let path = doc_root.as_ref().join(file_path);
164                    let mut res = Self::get_file(path).await;
165                    self.push_headers(&mut res);
166                    res
167                }
168                RouteResolver::Redirect(redirect_to) => {
169                    let mut response = Response::new(
170                        http::StatusCode::MOVED_PERMANENTLY,
171                        vec![],
172                        MimeType::PlainText,
173                    );
174                    self.push_headers(&mut response);
175                    response.add_header(("Location", redirect_to));
176                    response
177                }
178                RouteResolver::Function(resolver) => {
179                    let resolver = resolver.clone();
180                    let mut response = resolver
181                        .resolve(self.state.clone(), request.to_owned())
182                        .await;
183                    self.push_headers(&mut response);
184                    response
185                }
186                RouteResolver::Embed(body, mime_type) => {
187                    let mut response =
188                        Response::new(http::StatusCode::OK, body.to_vec(), *mime_type);
189                    self.push_headers(&mut response);
190                    response
191                }
192            }
193        } else {
194            tracing::debug!("Trying static file serve");
195            let mut file_path = PathBuf::from(request.path());
196            if file_path.is_absolute() {
197                if let Ok(path) = file_path.strip_prefix("/") {
198                    file_path = path.to_path_buf();
199                } else {
200                    let mut response =
201                        Response::error(http::StatusCode::NOT_FOUND, "File Not Found".into());
202                    self.push_headers(&mut response);
203                    return response;
204                }
205            }
206            let final_path = doc_root.as_ref().join(file_path);
207            let mut response = Self::get_file(final_path).await;
208            self.push_headers(&mut response);
209            response
210        }
211    }
212
213    #[tracing::instrument(level = "debug")]
214    async fn get_file(path: PathBuf) -> Response {
215        match tokio::fs::read(&path).await {
216            Ok(contents) => {
217                let mime: MimeType = path.into();
218                Response::new(http::StatusCode::OK, contents, mime)
219            }
220            Err(err) => {
221                tracing::warn!("static load error:{}", err.to_string());
222                match err.kind() {
223                    std::io::ErrorKind::PermissionDenied => {
224                        Response::error(http::StatusCode::FORBIDDEN, "Permission Denied".into())
225                    }
226                    _ => Response::error(
227                        http::StatusCode::NOT_FOUND,
228                        "Static File Not Found".into(),
229                    ),
230                }
231            }
232        }
233    }
234}
235
236impl<S> Route<S>
237where
238    S: Clone + Send + Sync + 'static,
239{
240    /// Route that redirects to another URL
241    pub fn redirect(path: &str, redirect_url: &str) -> Self {
242        let method = Method::GET;
243        Route {
244            path: path.to_string(),
245            resolver: RouteResolver::Redirect(redirect_url.to_string()),
246            method,
247        }
248    }
249
250    /// Reroutes all traffic to url
251    pub fn redirect_all(redirect_url: &str) -> Self {
252        let method = Method::GET;
253        Route {
254            path: "*".to_string(),
255            resolver: RouteResolver::Redirect(redirect_url.to_string()),
256            method,
257        }
258    }
259
260    pub fn get<R>(path: &str, func: R) -> Self
261    where
262        R: RequestResolver<S>,
263    {
264        let method = Method::GET;
265        let resolver = RouteResolver::Function(Arc::new(Box::new(func)));
266        Route {
267            path: path.to_string(),
268            resolver,
269            method,
270        }
271    }
272
273    pub fn post<R>(path: &str, func: R) -> Self
274    where
275        R: RequestResolver<S>,
276    {
277        let method = Method::POST;
278        let resolver = RouteResolver::Function(Arc::new(Box::new(func)));
279        Route {
280            path: path.to_string(),
281            resolver,
282            method,
283        }
284    }
285
286    /// use include_bytes! to load a file as static
287    /// when this route is requested the static data is return with the passes mime type
288    pub fn embed(path: &str, body: &'static [u8], mime: MimeType) -> Self {
289        let method = Method::GET;
290        let resolver = RouteResolver::Embed(body, mime);
291        Route {
292            method,
293            path: path.into(),
294            resolver,
295        }
296    }
297
298    /// Static file map
299    /// Allows remapping a route to a file
300    /// {file_path} is a is relative path to static file (without leading /) that will be joined
301    /// with vhost root dir to serve
302    /// eg. path = / file_path = index.html will remap all "/" requests to index.html
303    pub fn get_static(path: &str, file_path: &str) -> Self {
304        let method = Method::GET;
305        let resolver = RouteResolver::Static {
306            file_path: file_path.to_string(),
307        };
308        Route {
309            path: path.to_string(),
310            resolver,
311            method,
312        }
313    }
314
315    pub fn method(&self) -> &Method {
316        &self.method
317    }
318
319    pub fn resolver(&self) -> &RouteResolver<S> {
320        &self.resolver
321    }
322
323    pub fn path(&self) -> &str {
324        &self.path
325    }
326
327    ///FIXME: this should be a little more robust and look for wild cards only if not route is
328    ///defined.
329    ///as well as look for redirect all paths first and default to them
330    ///for now if you want redirect all that should be the only route on the server
331    pub fn matches_request(&self, request: &Request) -> bool {
332        if request.method() != self.method() {
333            return false;
334        }
335
336        // Check for exact match or if route is wild card
337        let request_path = request.path();
338        let route_path = self.path();
339
340        request_path == route_path || route_path == "*"
341    }
342}
343
344#[macro_export]
345macro_rules! embed_route {
346    ($route_path:expr, $file_path:expr) => {
347        //embed file
348        Route::embed(
349            $route_path,
350            include_bytes!($file_path),
351            PathBuf::from($file_path).into(),
352        )
353    };
354}
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::virtual_host::VirtualHost;
359
360    #[tokio::test]
361    async fn create_embedded_html_route() {
362        let route: Route<()> = embed_route!("/test", "../index.html");
363        assert_eq!(route.path, "/test", "route path incorrect");
364        assert_eq!(route.method, Method::GET, "route method incorrect");
365        if let RouteResolver::Embed(body, mime) = route.resolver {
366            assert_eq!(
367                include_bytes!("../index.html"),
368                body,
369                "embedded body incorect"
370            );
371            assert_eq!(MimeType::HTML, mime);
372        } else {
373            panic!("wrong route type");
374        }
375    }
376
377    #[tokio::test]
378    async fn route_static_file() {
379        let request =
380            Request::from_string("GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned())
381                .unwrap();
382        let mut router = Router::new(());
383        router.add_route(Route::get_static("/", "index.html")).await;
384
385        let file = tokio::fs::read_to_string("./index.html").await.unwrap();
386        let mut expected = Response::from(file);
387        router.push_headers(&mut expected);
388        assert_eq!(http::StatusCode::OK, expected.status());
389
390        let response = router.route(&request, "./").await;
391        assert_eq!(expected, response);
392
393        let request =
394            Request::from_string("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned()).unwrap();
395        let response = router.route(&request, "./").await;
396        assert_eq!(expected, response);
397    }
398
399    async fn hello(_: (), _: Request) -> Result<String, String> {
400        Ok("hello".to_owned())
401    }
402
403    #[tokio::test]
404    async fn route_basic() {
405        let request =
406            Request::from_string("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned()).unwrap();
407        let mut router = Router::new(());
408        router.add_route(Route::get("/", hello)).await;
409
410        let mut expected = Response::from("hello");
411        router.push_headers(&mut expected);
412        assert_eq!(http::StatusCode::OK, expected.status());
413
414        let response = router.route(&request, "./").await;
415        assert_eq!(expected, response);
416    }
417
418    async fn dynamic(_: (), req: Request) -> Result<String, String> {
419        Ok(format!("Hello {}", req.path()))
420    }
421
422    #[tokio::test]
423    async fn route_dynamic() {
424        let mut router = Router::new(());
425        router.add_route(Route::get("/*", dynamic)).await;
426
427        let mut expected = Response::from("Hello /bob");
428        router.push_headers(&mut expected);
429        assert_eq!(http::StatusCode::OK, expected.status());
430
431        let request =
432            Request::from_string("GET /bob HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned())
433                .unwrap();
434        let response = router.route(&request, "./").await;
435        assert_eq!(expected, response);
436    }
437}