use std::collections::HashMap;
use std::io::{BufRead, BufReader, Write};
use std::net::{Shutdown, TcpListener, TcpStream};
use std::str::FromStr;
use std::thread;
pub enum HttpMethod {
GET,
POST,
PUT,
PATCH,
DELETE,
OPTIONS,
}
impl FromStr for HttpMethod {
type Err = ();
fn from_str(input: &str) -> Result<HttpMethod, ()> {
match input {
"GET" => Ok(HttpMethod::GET),
"POST" => Ok(HttpMethod::POST),
"PUT" => Ok(HttpMethod::PUT),
"PATCH" => Ok(HttpMethod::PATCH),
"DELETE" => Ok(HttpMethod::DELETE),
"OPTIONS" => Ok(HttpMethod::OPTIONS),
_ => Err(())
}
}
}
pub struct HttpRequest {
pub method: HttpMethod,
pub path: String,
}
pub struct HttpResponse {
pub status_code: i32,
pub status_code_message: String,
pub headers: HashMap<String, String>,
pub body: String,
}
pub fn start(addr: &str, callback: fn(HttpRequest) -> HttpResponse) -> Result<(), &str> {
let listener = TcpListener::bind(addr);
if listener.is_err() {
return Err("Could not bind listener to address");
}
let listener = listener.unwrap();
for stream in listener.incoming() {
let stream = stream;
if stream.is_err() {
return Err("Error when trying to get the stream instance.");
}
let stream = stream.unwrap();
process_request(stream, callback);
}
return Ok(());
}
pub fn start_multithreaded(addr: &str, callback: fn(HttpRequest) -> HttpResponse) -> Result<(), &str> {
let listener = TcpListener::bind(addr);
if listener.is_err() {
return Err("Could not bind listener to address");
}
let listener = listener.unwrap();
for stream in listener.incoming() {
let stream = stream;
if stream.is_err() {
return Err("Error when trying to get the stream instance.");
}
let stream = stream.unwrap();
thread::spawn(move || process_request(stream, callback));
}
return Ok(());
}
fn process_request(mut stream: TcpStream, callback: fn(HttpRequest) -> HttpResponse) {
let mut buffer = String::new();
let mut reader = BufReader::new(&mut stream);
reader.read_line(&mut buffer).expect("Can't read line");
let http_request = analyze_request_data(buffer.clone());
let http_response: HttpResponse;
if http_request.is_err() {
println!("{}", buffer);
http_response = HttpResponse
{
status_code: 400,
status_code_message: "Bad Request".to_string(),
headers: Default::default(),
body: "".to_string(),
};
} else {
let http_request = http_request.unwrap();
http_response = callback(http_request);
}
let generated_response = generate_response(http_response);
let _ = stream.write(generated_response.as_bytes());
let _ = stream.shutdown(Shutdown::Both);
drop(stream);
}
fn analyze_request_data(request_payload: String) -> Result<HttpRequest, &'static str> {
let splitted_request_payload: Vec<String> = request_payload
.split("\r\n")
.map(|x| x.strip_suffix("\r\n").unwrap_or(&*x))
.map(|x| x.to_string())
.collect::<Vec<String>>();
if splitted_request_payload.len() == 0 {
return Err("Request is empty.");
}
let splitted_protocol_line: Vec<String> = splitted_request_payload[0]
.split(" ")
.map(|x| x.to_string())
.collect::<Vec<String>>();
if splitted_protocol_line.len() < 3 {
return Err("Protocol line is invalid! It should contain minimum 3 elements");
}
let method = HttpMethod::from_str(splitted_protocol_line[0].as_ref());
if method.is_err() {
return Err("Protocol line contains invalid method.");
}
let method = method.unwrap();
if splitted_protocol_line[2] != "HTTP/1.1" {
return Err("Protocol line contains invalid protocol version");
}
let http_request = HttpRequest {
method,
path: splitted_protocol_line[1].to_string(),
};
return Ok(http_request);
}
fn generate_response(mut http_response: HttpResponse) -> String {
let mut raw: String = format!("HTTP/1.1 {0} {1}\r\n", http_response.status_code.to_string(), http_response.status_code_message);
if http_response.body != "" {
http_response.headers.insert("Content-Length".parse().unwrap(), http_response.body.as_bytes().len().to_string());
}
for header in http_response.headers {
raw.push_str(&*header.0);
raw.push_str(": ");
raw.push_str(&*header.1);
raw.push_str("\r\n");
}
raw.push_str("\r\n");
raw.push_str(&*http_response.body);
raw.push_str("\r\n");
return raw;
}