1use std::{collections::HashMap, convert::Infallible, ffi::OsStr, fs, path::Path, sync::Arc};
2
3use bytes::Bytes;
4use http_body_util::{BodyExt, Full};
5use hyper::{Method, Uri};
6use tokio::sync::{
7 mpsc::{self, Sender},
8 oneshot,
9};
10
11use crate::{
12 errors::{default_error_page, StatusCode},
13 request::{Catch, Endpoint},
14 uri::index,
15};
16
17#[derive(Debug)]
19pub enum Command {
20 Get {
21 method: Method,
22 path: String,
23 response: oneshot::Sender<Option<Route>>,
24 },
25 Error {
26 code: u16,
27 response: oneshot::Sender<Option<ErrorHandler>>,
28 },
29}
30
31#[derive(Debug, Clone)]
32pub struct Route(pub Arc<dyn Endpoint>);
33
34#[derive(Debug, Clone)]
35pub struct ErrorHandler(pub Arc<dyn Catch>);
36
37#[derive(Clone)]
38pub struct Router {
39 channel: Option<Sender<Command>>,
40 router: HashMap<Method, Vec<Route>>,
41 catch: HashMap<u16, ErrorHandler>,
42 assets: String,
43}
44impl Router {
45 pub fn new() -> Self {
46 Router {
47 channel: None,
48 router: HashMap::new(),
49 catch: HashMap::new(),
50 assets: "assets/".to_string(),
51 }
52 }
53
54 pub fn assets(&mut self, path: String) {
55 self.assets = path;
56 }
57
58 pub fn catch(&mut self, catch: Arc<dyn Catch>) {
59 if !self.catch.contains_key(&catch.code()) {
60 self.catch.insert(catch.code(), ErrorHandler(catch));
61 }
62 }
63
64 pub fn route(&mut self, route: Arc<dyn Endpoint>) {
65 for method in route.methods() {
66 if !self.router.contains_key(&method) {
67 self.router.insert(method.clone(), Vec::new());
68 }
69 self.router
70 .get_mut(&method)
71 .unwrap()
72 .push(Route(route.clone()));
73 }
74 }
75
76 pub fn serve_routes(&mut self) {
81 let (tx, mut rx) = mpsc::channel::<Command>(32);
82 let router = self.router.clone();
83 let catch = self.catch.clone();
84
85 tokio::spawn(async move {
86 'watcher: while let Some(cmd) = rx.recv().await {
87 use Command::*;
88
89 match cmd {
90 Get {
91 method,
92 path,
93 response,
94 } => {
95 match router.get(&method) {
96 Some(data) => {
97 match index(
98 &path,
99 &data.iter().map(|r| r.0.path()).collect::<Vec<String>>(),
100 ) {
101 Some(index) => {
102 response.send(Some(data[index].clone())).unwrap();
103 continue 'watcher;
104 }
105 _ => {}
106 }
107 }
108 _ => {}
109 };
110 response.send(None).unwrap();
111 }
112 Error { code, response } => {
113 if catch.contains_key(&code) {
114 response
115 .send(catch.get(&code).map(|eh| eh.clone()))
116 .unwrap()
117 } else if catch.contains_key(&0) {
118 response.send(catch.get(&0).map(|eh| eh.clone())).unwrap()
119 } else {
120 response.send(None).unwrap()
121 }
122 }
123 }
124 }
125 });
126
127 self.channel = Some(tx);
128 }
129
130 async fn error(
131 &self,
132 uri: &Uri,
133 method: &Method,
134 body: &Vec<u8>,
135 code: u16,
136 reason: String,
137 channel: Sender<Command>,
138 ) -> std::result::Result<hyper::Response<Full<Bytes>>, Infallible> {
139 let (error_tx, error_rx) = oneshot::channel();
140 match channel
141 .send(Command::Error {
142 code: code.clone(),
143 response: error_tx,
144 })
145 .await
146 {
147 Ok(_) => {}
148 Err(error) => eprintln!("{:?}", error),
149 };
150
151 match error_rx.await.unwrap() {
152 Some(ErrorHandler(handler)) => {
153 match handler.execute(
154 code.clone(),
155 StatusCode::from(code.clone()).message(),
156 reason.clone(),
157 ) {
158 Ok(response) => {
159 Router::log_request(
160 &uri.path().to_string(),
161 &method.clone(),
162 &response.status().into(),
163 );
164 Ok(response)
165 }
166 Err((code, reason)) => {
167 Router::log_request(&uri.path().to_string(), method, &code);
168 Ok(default_error_page(
169 &code,
170 &reason,
171 method,
172 uri,
173 std::str::from_utf8(body).unwrap_or("").to_string(),
174 ))
175 }
176 }
177 }
178 None => {
179 Router::log_request(&uri.path().to_string(), method, &code);
180 Ok(default_error_page(
181 &code,
182 &reason,
183 method,
184 uri,
185 std::str::from_utf8(body).unwrap_or("").to_string(),
186 ))
187 }
188 }
189 }
190
191 fn log_request(path: &String, method: &Method, status: &u16) {
192 #[cfg(debug_assertions)]
193 eprintln!(
194 " {}(\x1b[3{}m{}\x1b[39m) \x1b[32m{:?}\x1b[0m",
195 method,
196 match status {
197 100..=199 => 6,
198 200..=299 => 2,
199 300..=399 => 5,
200 400..=499 => 1,
201 500..=599 => 3,
202 _ => 7,
203 },
204 status,
205 path
206 );
207 }
208
209 pub async fn parse(
210 &self,
211 request: hyper::Request<hyper::body::Incoming>,
212 ) -> Result<hyper::Response<Full<Bytes>>, Infallible> {
213 let mut uri = request.uri().clone();
215 let method = request.method().clone();
216 let _headers = request.headers().clone();
218 let mut body = request.collect().await.unwrap().to_bytes().to_vec();
219
220 let (endpoint_tx, endpoint_rx) = oneshot::channel();
221 match &self.channel {
222 Some(channel) => {
223 let path = format!("{}{}", self.assets, uri.path());
224 let path = Path::new(&path);
225 if let Some(extension) = path.extension().and_then(OsStr::to_str) {
226 match fs::read_to_string(path) {
227 Ok(text) => {
228 Router::log_request(&uri.path().to_string(), &method, &200);
229 let mut builder = hyper::Response::builder().status(200);
230
231 match mime_guess::from_ext(extension).first() {
232 Some(mime) => {
233 builder = builder.header("Content-Type", mime.to_string())
234 }
235 _ => {}
236 };
237
238 return Ok(builder.body(Full::new(Bytes::from(text))).unwrap());
239 }
240 _ => {
241 Router::log_request(&uri.path().to_string(), &method, &404);
242 return Ok(default_error_page(
243 &404,
244 &"File not found".to_string(),
245 &method,
246 &uri,
247 std::str::from_utf8(body.as_slice())
248 .unwrap_or("")
249 .to_string(),
250 ));
251 }
252 }
253 }
254
255 match channel
256 .send(Command::Get {
257 method: method.clone(),
258 path: uri.path().to_string(),
259 response: endpoint_tx,
260 })
261 .await
262 {
263 Ok(_) => {}
264 Err(error) => eprintln!("{}", error),
265 };
266
267 match endpoint_rx.await.unwrap() {
268 Some(Route(endpoint)) => match endpoint.execute(&method, &mut uri, &mut body) {
269 Ok(response) => {
270 Router::log_request(
271 &uri.path().to_string(),
272 &method,
273 &response.status().into(),
274 );
275 Ok(response)
276 }
277 Err((code, reason)) => {
278 self.error(&uri, &method, &body, code, reason, channel.clone())
279 .await
280 }
281 },
282 None => {
283 self.error(
284 &uri,
285 &method,
286 &body,
287 404,
288 "Page not found in router".to_string(),
289 channel.clone(),
290 )
291 .await
292 }
293 }
294 }
295 _ => panic!("Unable to communicate with router"),
296 }
297 }
298}