use std::{
collections::HashMap,
io::{Read, Write},
net::{TcpListener, ToSocketAddrs},
str::{FromStr, Utf8Error, from_utf8},
sync::Arc,
};
use crate::{method::HttpMethod, request::HttpRequest, response::Response};
pub trait HandlerFn: Send + Sync {
fn call(&self, req: HttpRequest) -> Box<dyn Response>;
}
impl<F, T> HandlerFn for F
where
F: Fn(HttpRequest) -> T + Send + Sync,
T: Response + 'static,
{
fn call(&self, req: HttpRequest) -> Box<dyn Response> {
Box::new(self(req))
}
}
pub type MiddleWareFn = fn(HttpRequest) -> HttpRequest;
pub type Handler = Box<dyn HandlerFn + Send + Sync>;
#[derive(Default)]
pub struct HttpServer {
handlers: HashMap<(String, HttpMethod), Handler>,
middle_ware: Option<MiddleWareFn>,
state: Option<Box<dyn Send + Sync>>,
}
impl HttpServer {
#[must_use]
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
middle_ware: None,
state: None,
}
}
#[must_use]
pub fn add_middleware(mut self, f: fn(req: HttpRequest) -> HttpRequest) -> Self {
self.middle_ware.replace(f);
self
}
#[must_use]
pub fn route<F: HandlerFn + 'static>(
mut self,
path: impl Into<String>,
method: HttpMethod,
f: F,
) -> Self {
self.handlers.insert((path.into(), method), Box::new(f));
self
}
#[must_use]
pub fn get<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Get, f)
}
#[must_use]
pub fn post<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Post, f)
}
#[must_use]
pub fn delete<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Delete, f)
}
#[must_use]
pub fn update<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Update, f)
}
#[must_use]
pub fn put<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Put, f)
}
#[must_use]
pub fn patch<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Patch, f)
}
#[must_use]
pub fn head<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Head, f)
}
#[must_use]
pub fn options<F: HandlerFn + 'static>(self, path: impl Into<String>, f: F) -> Self {
self.route(path, HttpMethod::Options, f)
}
#[must_use]
pub fn set_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
self.state.replace(Box::new(state));
self
}
pub fn listen(self, address: impl ToSocketAddrs) -> Result<(), ServerError> {
let listener = TcpListener::bind(address)?;
let handlers = Arc::new(self.handlers);
let middle_ware = Arc::new(self.middle_ware);
for stream in listener.incoming() {
let stream = stream?;
let middle_ware = middle_ware.clone();
let handlers = handlers.clone();
let job = move || _ = handle_connection(&handlers, &middle_ware, stream);
std::thread::spawn(job);
}
Ok(())
}
}
fn handle_connection(
handlers: &Arc<HashMap<(String, HttpMethod), Handler>>,
middle_ware: &Arc<Option<MiddleWareFn>>,
mut stream: std::net::TcpStream,
) -> Result<(), ServerError> {
let mut buf = [0; 4096 * 4];
let n = stream.read(&mut buf)?;
let request = {
let request = HttpRequest::from_str(from_utf8(&buf[..n])?)?;
if let Some(middle_ware) = **middle_ware {
middle_ware(request)
} else {
request
}
};
let path = request.path.clone();
let method = request.method.clone();
if let Some(handler) = handlers.get(&(path, method)) {
let ret = handler.call(request);
stream.write_all(ret.to_response().into_bytes().as_slice())?;
} else {
stream.write_all(&"no method found".to_response().into_bytes())?;
}
Ok(())
}
#[derive(Debug)]
pub enum ServerError {
Utf8Conversion(Utf8Error),
IoError(std::io::Error),
}
impl From<Utf8Error> for ServerError {
fn from(value: Utf8Error) -> Self {
Self::Utf8Conversion(value)
}
}
impl From<std::io::Error> for ServerError {
fn from(value: std::io::Error) -> Self {
Self::IoError(value)
}
}