1use crate::{
2 http::{self, Header, Method, MimeType},
3 request::Request,
4 response::{IntoResponse, Response},
5 state::{FromRequest, State},
6};
7use async_trait::async_trait;
8use enum_map::{enum_map, EnumMap};
9use std::{
10 collections::HashMap,
11 future::Future,
12 path::{Path, PathBuf},
13 sync::Arc,
14 vec,
15};
16use tokio::sync::RwLock;
17
18#[async_trait]
19pub trait RequestResolver<S>: Send + Sync + 'static {
20 async fn resolve(&self, state: State<S>, request: Request) -> Response;
21}
22
23#[async_trait]
24impl<F, P, O, E, Fut> RequestResolver<P> for F
25where
26 O: IntoResponse,
27 E: IntoResponse,
28 Fut: Future<Output = Result<O, E>> + Send + 'static,
29 F: Fn(P, Request) -> Fut + Send + Sync + 'static,
30 P: FromRequest<P> + Send + Sync + 'static,
31{
32 async fn resolve(&self, state: State<P>, request: Request) -> Response {
33 let result = (self)(P::from_request(state, request.clone()), request).await;
34 match result {
35 Ok(r) => r.into_response(),
36 Err(e) => e.into_response(),
37 }
38 }
39}
40
41pub enum RouteResolver<S> {
42 Static { file_path: String },
43 Redirect(String),
44 Function(Arc<Box<dyn RequestResolver<S>>>),
45 Embed(&'static [u8], MimeType),
46}
47
48pub struct Route<S> {
49 method: Method,
50 path: String,
51 resolver: RouteResolver<S>,
52}
53
54pub type Routes<R> = Arc<RwLock<EnumMap<Method, HashMap<String, Route<R>>>>>;
55
56#[derive(Clone)]
57pub struct Router<S> {
58 routes: Routes<S>,
59 state: State<S>,
60 mime_headers: Vec<(MimeType, Header)>,
61 default_headers: Vec<Header>,
62}
63
64impl<S> Router<S>
65where
66 S: Clone + Send + Sync + 'static,
67{
68 #[tracing::instrument(level = "debug", skip(state))]
69 pub fn new(state: S) -> Self {
70 let map = enum_map! {
71 crate::routes::Method::GET | crate::routes::Method::POST => HashMap::new(),
72 };
73 Router {
74 routes: Arc::new(RwLock::new(map)),
75 state: State(state),
76 mime_headers: vec![],
77 default_headers: Header::new_server(), }
79 }
80
81 #[tracing::instrument(level = "debug", skip(self, route))]
82 pub async fn add_route(&mut self, route: Route<S>) {
83 let mut routes_locked = self.routes.write().await;
84 routes_locked[*route.method()].insert(route.path.clone(), route);
85 }
86
87 #[tracing::instrument(level = "debug", skip(self))]
88 pub fn routes(&self) -> Routes<S> {
89 Arc::clone(&self.routes)
90 }
91
92 #[tracing::instrument(level = "debug", skip(self))]
94 pub fn add_default_header(&mut self, header: Header) {
95 self.default_headers.push(header);
96 }
97
98 #[tracing::instrument(level = "debug", skip(self))]
100 pub fn add_mime_header(&mut self, header: Header, mime: MimeType) {
101 self.mime_headers.push((mime, header));
102 }
103
104 #[tracing::instrument(level = "debug", skip(self))]
106 fn push_headers(&self, response: &mut Response) {
107 for header in &self.default_headers {
110 response.add_header(header);
111 }
112
113 let mime = response.mime();
114 for (key, header) in &self.mime_headers {
115 if key == &mime {
116 response.add_header(header);
117 }
118 }
119 }
120
121 pub fn new_routes() -> Routes<S> {
122 let map = enum_map! {
123 crate::routes::Method::GET | crate::routes::Method::POST => HashMap::new(),
124 };
125 Arc::new(RwLock::new(map))
126 }
127
128 #[tracing::instrument(level = "debug", skip(self, doc_root))]
129 pub async fn route(&self, request: &Request, doc_root: impl AsRef<Path>) -> Response {
130 let routes = self.routes();
131 let routes_locked = &routes.read().await[*request.method()];
132 let mut matching_route = None;
133
134 if let Some(route) = routes_locked.get(request.path()) {
136 matching_route = Some(route);
138 } else {
139 let path = Path::new(request.path());
141 if let Some(parent) = path.parent() {
142 let ancestors = parent.ancestors();
143 for a in ancestors {
144 if let Some(globed) = a.join("*").to_str() {
145 if let Some(route) = routes_locked.get(globed) {
146 matching_route = Some(route);
147 }
148 }
149 }
150 } else {
151 if let Some(route) = routes_locked.get("*") {
153 matching_route = Some(route);
154 }
155 }
156 }
157
158 if let Some(route) = matching_route {
160 tracing::debug!("Found matching route");
161 match route.resolver() {
162 RouteResolver::Static { file_path } => {
163 let path = doc_root.as_ref().join(file_path);
164 let mut res = Self::get_file(path).await;
165 self.push_headers(&mut res);
166 res
167 }
168 RouteResolver::Redirect(redirect_to) => {
169 let mut response = Response::new(
170 http::StatusCode::MOVED_PERMANENTLY,
171 vec![],
172 MimeType::PlainText,
173 );
174 self.push_headers(&mut response);
175 response.add_header(("Location", redirect_to));
176 response
177 }
178 RouteResolver::Function(resolver) => {
179 let resolver = resolver.clone();
180 let mut response = resolver
181 .resolve(self.state.clone(), request.to_owned())
182 .await;
183 self.push_headers(&mut response);
184 response
185 }
186 RouteResolver::Embed(body, mime_type) => {
187 let mut response =
188 Response::new(http::StatusCode::OK, body.to_vec(), *mime_type);
189 self.push_headers(&mut response);
190 response
191 }
192 }
193 } else {
194 tracing::debug!("Trying static file serve");
195 let mut file_path = PathBuf::from(request.path());
196 if file_path.is_absolute() {
197 if let Ok(path) = file_path.strip_prefix("/") {
198 file_path = path.to_path_buf();
199 } else {
200 let mut response =
201 Response::error(http::StatusCode::NOT_FOUND, "File Not Found".into());
202 self.push_headers(&mut response);
203 return response;
204 }
205 }
206 let final_path = doc_root.as_ref().join(file_path);
207 let mut response = Self::get_file(final_path).await;
208 self.push_headers(&mut response);
209 response
210 }
211 }
212
213 #[tracing::instrument(level = "debug")]
214 async fn get_file(path: PathBuf) -> Response {
215 match tokio::fs::read(&path).await {
216 Ok(contents) => {
217 let mime: MimeType = path.into();
218 Response::new(http::StatusCode::OK, contents, mime)
219 }
220 Err(err) => {
221 tracing::warn!("static load error:{}", err.to_string());
222 match err.kind() {
223 std::io::ErrorKind::PermissionDenied => {
224 Response::error(http::StatusCode::FORBIDDEN, "Permission Denied".into())
225 }
226 _ => Response::error(
227 http::StatusCode::NOT_FOUND,
228 "Static File Not Found".into(),
229 ),
230 }
231 }
232 }
233 }
234}
235
236impl<S> Route<S>
237where
238 S: Clone + Send + Sync + 'static,
239{
240 pub fn redirect(path: &str, redirect_url: &str) -> Self {
242 let method = Method::GET;
243 Route {
244 path: path.to_string(),
245 resolver: RouteResolver::Redirect(redirect_url.to_string()),
246 method,
247 }
248 }
249
250 pub fn redirect_all(redirect_url: &str) -> Self {
252 let method = Method::GET;
253 Route {
254 path: "*".to_string(),
255 resolver: RouteResolver::Redirect(redirect_url.to_string()),
256 method,
257 }
258 }
259
260 pub fn get<R>(path: &str, func: R) -> Self
261 where
262 R: RequestResolver<S>,
263 {
264 let method = Method::GET;
265 let resolver = RouteResolver::Function(Arc::new(Box::new(func)));
266 Route {
267 path: path.to_string(),
268 resolver,
269 method,
270 }
271 }
272
273 pub fn post<R>(path: &str, func: R) -> Self
274 where
275 R: RequestResolver<S>,
276 {
277 let method = Method::POST;
278 let resolver = RouteResolver::Function(Arc::new(Box::new(func)));
279 Route {
280 path: path.to_string(),
281 resolver,
282 method,
283 }
284 }
285
286 pub fn embed(path: &str, body: &'static [u8], mime: MimeType) -> Self {
289 let method = Method::GET;
290 let resolver = RouteResolver::Embed(body, mime);
291 Route {
292 method,
293 path: path.into(),
294 resolver,
295 }
296 }
297
298 pub fn get_static(path: &str, file_path: &str) -> Self {
304 let method = Method::GET;
305 let resolver = RouteResolver::Static {
306 file_path: file_path.to_string(),
307 };
308 Route {
309 path: path.to_string(),
310 resolver,
311 method,
312 }
313 }
314
315 pub fn method(&self) -> &Method {
316 &self.method
317 }
318
319 pub fn resolver(&self) -> &RouteResolver<S> {
320 &self.resolver
321 }
322
323 pub fn path(&self) -> &str {
324 &self.path
325 }
326
327 pub fn matches_request(&self, request: &Request) -> bool {
332 if request.method() != self.method() {
333 return false;
334 }
335
336 let request_path = request.path();
338 let route_path = self.path();
339
340 request_path == route_path || route_path == "*"
341 }
342}
343
344#[macro_export]
345macro_rules! embed_route {
346 ($route_path:expr, $file_path:expr) => {
347 Route::embed(
349 $route_path,
350 include_bytes!($file_path),
351 PathBuf::from($file_path).into(),
352 )
353 };
354}
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::virtual_host::VirtualHost;
359
360 #[tokio::test]
361 async fn create_embedded_html_route() {
362 let route: Route<()> = embed_route!("/test", "../index.html");
363 assert_eq!(route.path, "/test", "route path incorrect");
364 assert_eq!(route.method, Method::GET, "route method incorrect");
365 if let RouteResolver::Embed(body, mime) = route.resolver {
366 assert_eq!(
367 include_bytes!("../index.html"),
368 body,
369 "embedded body incorect"
370 );
371 assert_eq!(MimeType::HTML, mime);
372 } else {
373 panic!("wrong route type");
374 }
375 }
376
377 #[tokio::test]
378 async fn route_static_file() {
379 let request =
380 Request::from_string("GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned())
381 .unwrap();
382 let mut router = Router::new(());
383 router.add_route(Route::get_static("/", "index.html")).await;
384
385 let file = tokio::fs::read_to_string("./index.html").await.unwrap();
386 let mut expected = Response::from(file);
387 router.push_headers(&mut expected);
388 assert_eq!(http::StatusCode::OK, expected.status());
389
390 let response = router.route(&request, "./").await;
391 assert_eq!(expected, response);
392
393 let request =
394 Request::from_string("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned()).unwrap();
395 let response = router.route(&request, "./").await;
396 assert_eq!(expected, response);
397 }
398
399 async fn hello(_: (), _: Request) -> Result<String, String> {
400 Ok("hello".to_owned())
401 }
402
403 #[tokio::test]
404 async fn route_basic() {
405 let request =
406 Request::from_string("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned()).unwrap();
407 let mut router = Router::new(());
408 router.add_route(Route::get("/", hello)).await;
409
410 let mut expected = Response::from("hello");
411 router.push_headers(&mut expected);
412 assert_eq!(http::StatusCode::OK, expected.status());
413
414 let response = router.route(&request, "./").await;
415 assert_eq!(expected, response);
416 }
417
418 async fn dynamic(_: (), req: Request) -> Result<String, String> {
419 Ok(format!("Hello {}", req.path()))
420 }
421
422 #[tokio::test]
423 async fn route_dynamic() {
424 let mut router = Router::new(());
425 router.add_route(Route::get("/*", dynamic)).await;
426
427 let mut expected = Response::from("Hello /bob");
428 router.push_headers(&mut expected);
429 assert_eq!(http::StatusCode::OK, expected.status());
430
431 let request =
432 Request::from_string("GET /bob HTTP/1.1\r\nHost: localhost\r\n\r\n".to_owned())
433 .unwrap();
434 let response = router.route(&request, "./").await;
435 assert_eq!(expected, response);
436 }
437}