pub use my_threadpool::ThreadPool;
use std::collections::HashMap;
use std::net::TcpListener;
use std::net::TcpStream;
use std::net::Shutdown;
use std::io::prelude::*;
use std::sync::{Arc, Mutex};
pub type ErrorMessage = &'static str;
#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
pub enum RequestType {
Get,
Post,
Put,
Patch,
Delete,
}
impl RequestType {
fn stringify(&self) -> &str {
match self {
RequestType::Get => "GET",
RequestType::Post=> "POST",
RequestType::Put => "PUT",
RequestType::Patch => "PATCH",
RequestType::Delete => "DELETE",
}
}
}
#[derive(PartialEq, Eq, Hash, Clone)]
struct Route {
req_type: RequestType,
route: String,
}
#[derive(PartialEq, Eq, Hash, Clone)]
enum RouteIndex {
Route(Route),
AllRoute(Route)
}
#[derive(PartialEq, Eq, Clone, Hash, Debug, Copy)]
pub struct DebugOptions {
pub show_response_body: bool,
pub show_request_query: bool,
pub show_request_body: bool,
pub show_middleware: bool,
pub show_middleware_request_changes: bool,
}
#[derive(PartialEq, Eq, Clone, Hash, Debug, Copy)]
pub enum Mode {
Default,
Debug(DebugOptions),
}
#[derive(Clone, Debug, PartialEq)]
pub struct Request {
pub route: String,
pub host: String,
pub request_type: RequestType,
pub query: HashMap<String, String>,
pub cookie: HashMap<String, String>,
pub body: String,
}
impl Request {
fn new(req: String) -> Option<Self> {
let req_split = req.split("\n").collect::<Vec<&str>>();
let split = req_split[0].split(" ").collect::<Vec<&str>>();
let request_type = match split[0] {
"GET" => RequestType::Get,
"POST" => RequestType::Post,
"PUT" => RequestType::Put,
"PATCH" => RequestType::Patch,
"DELETE" => RequestType::Delete,
_ => {return None;},
};
let mut query = HashMap::new();
match split[1].split("?").collect::<Vec<&str>>().get(1) {
Some(n) => {
for i in n.split("&").collect::<Vec<&str>>() {
let querys = i.split("=").collect::<Vec<&str>>();
let (name, val) = match querys.get(1) {
Some(n) => (querys[0], n),
_ => continue,
};
query.insert(name.to_string(), val.to_string());
}
},
_ => {
},
}
let mut cookie = HashMap::new();
match req_split
.iter()
.filter(|x| x
.split(" ")
.collect::<Vec<&str>>()[0] == "Cookie:"
).collect::<Vec<&&str>>().get(0) {
Some(n) => {
for i in n.split(": ").collect::<Vec<&str>>()[1].split("; ") {
let i = i.trim();
let i = i.split("=").collect::<Vec<&str>>();
cookie.insert(i[0].to_string(), i[1].to_string());
}
},
None => {}
}
Some(Request {
query,
route: split[1].to_string(),
host: req_split.iter().filter(|x| x.split(" ").collect::<Vec<&str>>()[0] == "Host:").collect::<Vec<&&str>>()[0].split(" ").collect::<Vec<&str>>()[1].to_string(),
request_type,
cookie,
body: req.split("\r\n\r\n").collect::<Vec<&str>>()[1].to_string().split("\0").collect::<Vec<&str>>()[0].to_string(),
})
}
}
#[derive(Clone)]
enum ResponseType {
Static(String),
Function(Arc<dyn Fn(Request) -> (String, ResponseCode) + Send + Sync + 'static>)
}
pub enum ResponseCode {
Rc200,
Rc201,
Rc204,
Rc304,
Rc400,
Rc404,
Rc429,
Rc500,
}
pub struct WebServer {
pool: ThreadPool,
not_found: String,
routes: HashMap<RouteIndex, ResponseType>,
existing_routes: String,
middleware: HashMap<String, Arc<dyn Fn(Request) -> Request + Send + Sync + 'static>>,
existing_middleware: String,
mode: Mode,
}
pub fn HttpRequest(request: Request) -> Result<String, ErrorMessage> {
let mut stream = match TcpStream::connect(request.host.clone()) {
Ok(n) => n,
Err(_) => {return Err("Could not connect")}
};
let mut req = String::new();
let req_type = request.request_type.stringify();
req.push_str(&format!("{} {} HTTP/1.1\r\n", req_type, request.route));
req.push_str(&format!("Host: {}\r\n", request.host.clone()));
req.push_str("Connection: close\r\n");
if req_type == "POST" || req_type == "PUT" || req_type == "PATCH" {
req.push_str(&format!("Content-Length: {}\r\n\r\n{}", request.body.clone().len(), request.body));
} else {
req.push_str("\r\n");
}
if let Err(_) = stream.write_all(req.as_bytes()) {return Err("Could not write the bytes for the request");}
let mut res = String::new();
if let Err(_) = stream.read_to_string(&mut res) {return Err("Could not read the response");};
if let Err(_) = stream.shutdown(Shutdown::Both) {return Err("failed to close stream");};
Ok(res)
}
fn create_response(n: &ResponseType, data: String, req: Option<Request>) -> String {
match n {
ResponseType::Static(n) => format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}", n.len(), n),
ResponseType::Function(n) => {
let resp = n(match req {
Some(n) => n,
_ => {
Request::new(data.to_string()).unwrap()
}
});
let res: String = match resp.1 {
ResponseCode::Rc200 => format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc201 => format!("HTTP/1.1 201 CREATED\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc204 => format!("HTTP/1.1 204 NO CONTENT\r\n\r\n\r\n"),
ResponseCode::Rc304 => format!("HTTP/1.1 304 NOT MODIFIED\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc400 => format!("HTTP/1.1 400 BAD REQUEST\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc404 => format!("HTTP/1.1 404 NOT FOUND\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc429 => format!("HTTP/1.1 429 TOO MANY REQUESTS\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
ResponseCode::Rc500 => format!("HTTP/1.1 500 INTERNAL SERVER ERROR\r\nContent-Length: {}\r\n\r\n{}", resp.0.len(), resp.0),
};
res
}
}
}
impl WebServer {
pub fn new() -> Self {
WebServer {
pool: ThreadPool::new(4),
not_found: String::from("<h1>404 NOT FOUND</h1>"),
routes: HashMap::new(),
middleware: HashMap::new(),
existing_middleware: String::new(),
existing_routes: String::new(),
mode: Mode::Default,
}
}
pub fn from_threads(threads: usize) -> Self {
WebServer {
pool: ThreadPool::new(threads),
not_found: String::from("<h1>404 NOT FOUND</h1>"),
routes: HashMap::new(),
middleware: HashMap::new(),
existing_middleware: String::new(),
existing_routes: String::new(),
mode: Mode::Default,
}
}
pub fn on_static(&mut self, req_type: RequestType, route: &str, res: &str) {
self.routes.insert(RouteIndex::Route(Route {
req_type,
route: route.to_string(),
}), ResponseType::Static(res.to_string()));
self.existing_routes.push_str(&format!(
"\n --> Static response for {} requests on route '{}'",
req_type.stringify(),
route
));
}
pub fn on<F: Fn(Request) -> (String, ResponseCode) + Send + Sync + 'static>(&mut self, req_type: RequestType, route: &str, res: F) {
self.routes.insert(RouteIndex::Route(Route {
req_type,
route: route.to_string(),
}), ResponseType::Function(Arc::new(res)));
self.existing_routes.push_str(&format!(
"\n --> response for {} requests on route '{}'",
req_type.stringify(),
route
));
}
pub fn on_all<F: Fn(Request) -> (String, ResponseCode) + Send + Sync + 'static>(&mut self, req_type: RequestType, route: &str, res: F) {
self.routes.insert(RouteIndex::AllRoute(Route {
req_type,
route: route.to_string(),
}), ResponseType::Function(Arc::new(res)));
self.existing_routes.push_str(&format!(
"\n --> response for All {} requests that start with '{}'",
req_type.stringify(),
route
));
}
pub fn not_found(&mut self, html: &str) {
self.not_found = html.to_string();
self.existing_routes.push_str(&format!(
"\n --> Set default 404 response to '{}'",
html
));
}
pub fn mode(&mut self, mode: Mode) {
self.mode = mode;
}
pub fn listen(self, addr: &str) -> Result<(), ()> {
println!("--------- Route Handlers ---------\n{}\n\n--------- Middleware ---------\n{}\n\n\nListening on http://{}\n\n",
self.existing_routes,
self.existing_middleware,
addr
);
let pool = Arc::new(Mutex::new(self.pool));
let pool2 = pool.clone();
ctrlc::set_handler(move || {
println!("----- Starting Gracefull Shutdown -----\n");
pool2.lock().unwrap().shutdown();
std::process::exit(0);
}).unwrap();
let listener = match TcpListener::bind(addr) {
Ok(n) => n,
Err(_) => {
return Err(());
}
};
for stream in listener.incoming() {
let mut stream = match stream {
Ok(n) => n,
Err(_) => continue,
};
let routes = self.routes.clone();
let middleware = self.middleware.clone();
let not_found = self.not_found.clone();
pool.lock().unwrap().execute(move || {
let mut options = None;
if let Mode::Debug(mode_options) = self.mode {
options = Some(mode_options);
}
let mut debug_str = String::new();
let mut data = [0; 4096];
stream.read(&mut data).unwrap();
let mut data = String::from_utf8_lossy(&data);
let req_raw = data.clone();
let mut request = None;
let response;
let req = data.split("\n").collect::<Vec<&str>>()[0];
let split = req.split(" ").collect::<Vec<&str>>();
debug_str.push_str(split[0]);
debug_str.push(' ');
if let Some(options) = options {
if options.show_request_query {
debug_str.push_str(split[1]);
} else {
debug_str.push_str(split[1].split("?").collect::<Vec<&str>>()[0]);
}
} else {
debug_str.push_str(split[1].split("?").collect::<Vec<&str>>()[0]);
}
if let Some(options) = options {
if options.show_request_body {
debug_str.push_str("\n -> request body '");
if let Some(n) = req_raw.split("\r\n\r\n").collect::<Vec<&str>>().get(1) {
debug_str.push_str(n);
};
debug_str.push_str("'")
}
}
let s = split[1].split("?").collect::<Vec<&str>>()[0].to_string();
let mut s2 = &mut s.split("/").collect::<Vec<&str>>();
let s2 = &s2[1..].iter().map(|x| format!("/{}", x)).collect::<Vec<String>>();
let mut route2 = vec![];
for route in &s2[..] {
route2.push(route.to_string());
if let Some(n) = middleware.get(route) {
let original_req = match request {
Some(n) => n,
_ => {
Request::new(data.to_string()).unwrap()
}
};
let finished_request = n(original_req.clone());
if let Some(options) = options {
if options.show_middleware_request_changes {
debug_str.push_str("\n -> response went through middleware: ");
debug_str.push_str(&route2.join(""));
}
if options.show_middleware {
let request = finished_request.clone();
if request != original_req {
if request.route != original_req.route {
debug_str.push_str(&format!(
"\n -> redirected request from '{}' to '{}'"
,original_req.route,
request.route
));
}
if request.request_type != original_req.request_type {
debug_str.push_str(&format!(
"\n -> modified request type from {} to {}"
,original_req.request_type.stringify(),
request.request_type.stringify()
));
}
if request.body != original_req.body {
debug_str.push_str(&format!(
"\n -> modified request body from '{}' to '{}'"
,original_req.body,
request.body
));
}
if request.host != original_req.host {
debug_str.push_str(&format!(
"\n -> modified request host from '{}' to '{}'"
,original_req.host,
request.host
));
}
if request.query != original_req.query {
debug_str.push_str(&format!(
"\n -> modified request query from {:?} to {:?}"
,original_req.query,
request.query
));
}
if request.cookie != original_req.cookie {
debug_str.push_str(&format!(
"\n -> modified request cookie from {:?} to {:?}"
,original_req.cookie,
request.cookie
));
}
}
}
}
request = Some(finished_request);
}
}
let req_route = match request {
Some(n) => {
request = Some(n.clone());
n.route
},
None => split[1].split("?").collect::<Vec<&str>>()[0].to_string()
};
let req_type = match request {
Some(n) => {
request = Some(n.clone());
n.request_type
},
None => match split[0] {
"GET" => {
RequestType::Get
},
"POST" => {
RequestType::Post
},
"PUT" => {
RequestType::Put
},
"PATCH" => {
RequestType::Patch
},
"DELETE" => {
RequestType::Delete
},
_ => {
let contents = "";
response = format!("HTTP/1.1 400 BAD REQUEST\r\nContent-Length: {}\r\n\r\n{}", contents.len(), contents);
return;
}
}
};
match routes.get(&&RouteIndex::Route(Route { req_type, route: req_route.clone() })) {
Some(n) => {
response = create_response(n, data.into(), request);
debug_str.push_str(&format!("\n -> resoponded with {}", response.split("\r\n").collect::<Vec<&str>>()[0]));
debug_str.push_str("\n -> resoponse successful");
},
None => {
let s = req_route;
let s2 = s.split("/").collect::<Vec<&str>>();
let mut route = String::from("/");
let route2 = Arc::new(Mutex::new(s2.clone()));
for _ in s2 {
let mut route2 = route2.lock().unwrap();
if let Some(_) = routes.get(&&RouteIndex::AllRoute(Route { req_type, route: route2.join("/").to_string() })) {
route = (*route2).clone().join("/");
break;
}
route2.pop();
}
match routes.get(&&RouteIndex::AllRoute(Route { req_type, route: route.clone() })) {
Some(n) => {
debug_str.push_str("\n -> Route was found on an on_all route: ");
debug_str.push_str(&route);
response = create_response(n, data.into(), request);
debug_str.push_str(&format!("\n -> resoponded with {}", response.split("\r\n").collect::<Vec<&str>>()[0].split("HTTP/1.1 ").collect::<Vec<&str>>()[0]));
debug_str.push_str("\n -> resoponse successful");
},
None => {
let contents = not_found;
response = format!("HTTP/1.1 404 NOT FOUND\r\nContent-Length: {}\r\n\r\n{}", contents.len(), contents);
debug_str.push_str("\n -> resoponded with HTTP/1.1 404 NOT FOUND");
debug_str.push_str("\n -> resoponse successfull");
}
}
}
}
if let Some(options) = options {
if options.show_response_body {
debug_str.push_str("\n -> response body: '");
debug_str.push_str(response.split("\r\n\r\n").collect::<Vec<&str>>()[1]);
debug_str.push_str("'");
}
}
stream.write(response.as_bytes()).unwrap();
stream.flush().unwrap();
if let Some(_) = options {
println!("{debug_str}\n\n");
}
return;
});
}
Ok(())
}
pub fn middleware<F: Fn(Request) -> Request + Send + Sync + 'static>(&mut self, route: &str, res: F) {
self.middleware.insert(route.to_string(), Arc::new(res));
self.existing_middleware.push_str(&format!("\n --> middleware for route '{}'", route));
}
}