use std::{collections::HashMap, sync::Arc};
use http::{Method, StatusCode};
use http_body_util::Full;
use hyper::body::Bytes;
use matchit::Match;
use crate::{
Flow::{self, Continue, Exit},
FromRequest, Handler, IntoResponse, Request, Response,
body::Empty,
router::endpoint::{Endpoint, EndpointHandler, MethodHandler},
};
mod endpoint;
#[derive(Clone, Eq, PartialEq, Hash, Default)]
struct RouteId(u32);
#[derive(Default)]
pub struct Router<S> {
inner: matchit::Router<RouteId>,
routes: HashMap<RouteId, Endpoint<S>>,
next_id: RouteId,
path_to_id: HashMap<String, RouteId>,
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
inner: matchit::Router::new(),
routes: HashMap::new(),
next_id: RouteId(0),
path_to_id: HashMap::new(),
}
}
fn get_next_id(&mut self) -> RouteId {
let id = self.next_id.clone();
self.next_id = RouteId(id.0 + 1);
id
}
pub fn post<I, O>(
self,
path: &str,
handler: impl EndpointHandler<I, S, Output = O, Future: Send> + Clone + Send + Sync + 'static,
) -> Self
where
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
self.insert(path, Method::POST, handler)
}
pub fn put<I, O>(
self,
path: &str,
handler: impl EndpointHandler<I, S, Output = O, Future: Send> + Clone + Send + Sync + 'static,
) -> Self
where
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
self.insert(path, Method::PUT, handler)
}
pub fn get<I, O>(
self,
path: &str,
handler: impl EndpointHandler<I, S, Output = O, Future: Send> + Clone + Send + Sync + 'static,
) -> Self
where
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
self.insert(path, Method::GET, handler)
}
pub fn patch<I, O>(
self,
path: &str,
handler: impl EndpointHandler<I, S, Output = O, Future: Send> + Clone + Send + Sync + 'static,
) -> Self
where
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
self.insert(path, Method::PATCH, handler)
}
pub fn delete<I, O>(
self,
path: &str,
handler: impl EndpointHandler<I, S, Output = O, Future: Send> + Clone + Send + Sync + 'static,
) -> Self
where
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
self.insert(path, Method::DELETE, handler)
}
pub fn insert<I, O, H>(mut self, path: &str, method: Method, handler: H) -> Self
where
H: EndpointHandler<I, S, Output = O> + Clone + Send + Sync + 'static,
H::Future: Send,
I: FromRequest<Output = I> + Send + 'static,
<I as FromRequest>::Error: Send,
<I as FromRequest>::from_request(..): Send,
O: IntoResponse + 'static,
{
let id = match self.path_to_id.get(path) {
Some(existing_id) => existing_id.clone(),
None => {
let new_id = self.get_next_id();
self.inner.insert(path, new_id.clone()).unwrap();
self.path_to_id.insert(path.to_string(), new_id.clone());
new_id
}
};
let method_handler: Arc<MethodHandler<S>> = Arc::new(move |request, state| {
let handler = Arc::new(handler.clone());
let future = I::from_request(request);
Box::pin(async move {
match future.await {
Ok(input) => match handler.handle(input, state).await {
Continue(output) => Continue(output.into_response()),
Exit(exception) => Exit(exception.into_response()),
},
Err(_) => {
let mut response = Empty.into_response();
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Exit(response)
}
}
})
});
let endpoint = self.routes.entry(id).or_insert_with(|| Endpoint {
methods: HashMap::new(),
});
if endpoint
.methods
.insert(method.clone(), method_handler)
.is_some()
{
panic!("Route `{path}` already has handler for method `{method}`");
}
self
}
}
impl<S> Handler<Request, S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
type Output = Response;
type Exception = Response;
type Future = impl Future<Output = Flow<Self::Output, Self::Exception>> + Send;
fn handle(&self, mut req: Request, state: S) -> Self::Future {
let (method, path) = (req.method(), req.uri().path().to_string());
let result = {
let mut response = Response::new(Full::new(Bytes::new()));
*response.status_mut() = StatusCode::NOT_FOUND;
if let Ok(Match {
value: route_id,
params,
}) = self.inner.at(&path)
{
if let Some(endpoint) = self.routes.get(route_id) {
if let Some(handler) = endpoint.methods.get(method) {
let params = params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<HashMap<String, String>>();
Ok((handler.clone(), params))
} else {
*response.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
Err(response)
}
} else {
Err(response)
}
} else {
Err(response)
}
};
async move {
match result {
Ok((h, params)) => {
req.extensions_mut().insert(params);
h(req, state).await
}
Err(res) => Exit(res),
}
}
}
}