use std::{
future::Future,
io,
net::{SocketAddr, ToSocketAddrs},
pin::Pin,
};
use futures::FutureExt;
use hooch::{
net::{HoochTcpListener, HoochTcpStream},
spawner::Spawner,
};
use crate::{
request::HttpRequest, response::HttpResponse, HttpMethod, HttpResponseBuilder, Params, Uri,
};
type MiddlewareFuture = Pin<Box<dyn Future<Output = Middleware> + Send>>;
type MiddlewareFn = Box<dyn Fn(HttpRequest<'static>, SocketAddr) -> MiddlewareFuture + Send + Sync>;
type RouterFuture = Pin<Box<dyn Future<Output = HttpResponse> + Send>>;
type RouterFn = Box<dyn Fn(HttpRequest<'static>, Params<'static>) -> RouterFuture + Send + Sync>;
#[derive(Debug)]
pub enum Middleware {
Continue(HttpRequest<'static>),
ShortCircuit(HttpResponse),
}
pub struct Route {
fut: RouterFn,
method: HttpMethod,
path: &'static str,
}
pub struct HoochAppBuilder {
addr: SocketAddr,
middleware: Vec<MiddlewareFn>,
router: Vec<Route>,
}
impl HoochAppBuilder {
pub fn new(addr: impl ToSocketAddrs) -> io::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address resolved"))?;
Ok(Self {
addr,
middleware: Vec::new(),
router: Vec::new(),
})
}
pub fn add_middleware<Fut, F>(&mut self, middleware: F)
where
Fut: Future<Output = Middleware> + Send + 'static,
F: Fn(HttpRequest<'static>, SocketAddr) -> Fut + Send + Sync + 'static,
{
self.middleware.push(Box::new(move |req, socket| {
Box::pin(middleware(req, socket))
}));
}
pub fn add_route<FutRoute, FnRoute>(
&mut self,
path: &'static str,
method: HttpMethod,
route: FnRoute,
) where
FnRoute: Fn(HttpRequest<'static>, Params<'static>) -> FutRoute + Sync + Send + 'static,
FutRoute: Future<Output = HttpResponse> + Send + 'static,
{
let route = Route {
fut: Box::new(move |req, params| route(req, params).boxed()),
method,
path,
};
self.router.push(route);
}
pub fn build(self) -> HoochApp {
let middleware_ptr: &'static Vec<MiddlewareFn> = Box::leak(Box::new(self.middleware));
let route_ptr: &'static Vec<Route> = Box::leak(Box::new(self.router));
HoochApp {
addr: self.addr,
middleware: middleware_ptr,
routes: route_ptr,
}
}
}
pub struct HoochApp {
addr: SocketAddr,
middleware: &'static Vec<MiddlewareFn>,
routes: &'static Vec<Route>,
}
impl HoochApp {
pub async fn serve(&self) {
let listener = HoochTcpListener::bind(self.addr).await.unwrap();
let middleware_ptr: &'static Vec<MiddlewareFn> = self.middleware;
let route_ptr: &'static Vec<Route> = self.routes;
while let Ok((stream, socket)) = listener.accept().await {
println!("Received connection from {:?}", socket);
Spawner::spawn(async move {
Self::handle_stream(stream, socket, middleware_ptr, route_ptr).await;
});
}
}
async fn handle_stream(
mut stream: HoochTcpStream,
socket_addr: SocketAddr,
middleware_fns: &'static [MiddlewareFn],
routes: &'static [Route],
) {
let mut buffer = [0; 1024 * 100];
let bytes_read = stream.read(&mut buffer).await.unwrap();
let http_request = HttpRequest::from_bytes(&buffer[..bytes_read]);
let mut http_request: HttpRequest<'static> = unsafe { std::mem::transmute(http_request) };
for mid in middleware_fns.iter() {
let middleware = mid(http_request, socket_addr).await;
match middleware {
Middleware::Continue(req) => {
http_request = req;
}
Middleware::ShortCircuit(response) => {
return Self::handle_http_response(response, stream).await;
}
}
}
for route in routes.iter() {
let uri: &Uri<'static> = unsafe { std::mem::transmute(http_request.uri()) };
if let Some(param) = uri.is_match(route.path) {
if route.method == http_request.method() {
let response = (route.fut)(http_request, param).await;
return Self::handle_http_response(response, stream).await;
}
}
}
Self::handle_http_response(HttpResponseBuilder::not_found().build(), stream).await;
}
async fn handle_http_response(http_response: HttpResponse, mut stream: HoochTcpStream) {
let mut buffer = Vec::with_capacity(std::mem::size_of_val(&http_response));
buffer = http_response.serialize(buffer);
stream.write(&buffer).await.unwrap();
}
}