use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use http_body_util::{combinators::BoxBody as HttpBoxBody, Full};
use hyper::{Method, Request, Response, StatusCode};
use matchit::Router as MatchitRouter;
use tower::Service;
use crate::BoxError;
pub type BodyError = hyper::Error;
pub type BoxBody = HttpBoxBody<Bytes, BodyError>;
pub type Handler<S> = Arc<
dyn Fn(
Request<BoxBody>,
Arc<S>,
) -> Pin<
Box<dyn Future<Output = Result<Response<Full<Bytes>>, BoxError>> + Send + 'static>,
> + Send
+ Sync
+ 'static,
>;
#[derive(Clone, Debug, Default)]
pub struct PathParams(pub HashMap<String, String>);
#[derive(Clone)]
struct Route<S>
where
S: Clone + Send + Sync + 'static,
{
method: Method,
handler: Handler<S>,
}
pub struct Router<S>
where
S: Clone + Send + Sync + 'static,
{
routes: MatchitRouter<Vec<Route<S>>>,
state: Arc<S>,
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(state: S) -> Self {
Self {
routes: MatchitRouter::new(),
state: Arc::new(state),
}
}
pub fn route(mut self, path: &str, method: Method, handler: Handler<S>) -> Self {
println!(
"[router debug] registering route '{}', method={}",
path, method
);
match self.routes.at_mut(path) {
Ok(matched) => {
if let Some(existing) = matched
.value
.iter_mut()
.find(|route| route.method == method)
{
#[cfg(feature = "tracing")]
tracing::warn!(%path, %method, "overwriting existing route handler");
existing.handler = handler;
} else {
matched.value.push(Route { method, handler });
}
}
Err(_) => {
let routes = vec![Route { method, handler }];
if let Err(e) = self.routes.insert(path, routes) {
#[cfg(feature = "tracing")]
tracing::error!(%path, error = ?e, "failed to insert route");
#[cfg(not(feature = "tracing"))]
let _ = e;
}
}
}
self
}
}
impl<S> Clone for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
state: self.state.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let method = req.method().clone();
let path = req.uri().path().to_string();
let state = self.state.clone();
let mut allowed_methods = None;
let mut params_map: Option<PathParams> = None;
let handler = match self.routes.at(path.as_str()) {
Ok(matched) => {
params_map = Some(PathParams(
matched
.params
.iter()
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect(),
));
allowed_methods = Some(
matched
.value
.iter()
.map(|route| route.method.clone())
.collect::<Vec<_>>(),
);
matched
.value
.iter()
.find(|route| route.method == method)
.cloned()
}
Err(_) => None,
};
if let Some(ref p) = params_map {
println!("[router debug] path='{}' params={:?}", path, p);
}
Box::pin(async move {
let response = match handler {
Some(route) => {
let mut req = req;
if let Some(path_params) = params_map {
req.extensions_mut().insert(path_params);
}
req.extensions_mut().insert(state.clone());
match (route.handler)(req, state).await {
Ok(response) => response,
Err(err) => error_response(err),
}
}
None => match allowed_methods {
Some(methods) => method_not_allowed_response(methods),
None => not_found_response(),
},
};
Ok(response)
})
}
}
fn not_found_response() -> Response<Full<Bytes>> {
match serde_json::to_vec(&serde_json::json!({
"success": false,
"errors": [{
"code": "NOT_FOUND",
"message": "Route not found",
}]
})) {
Ok(body) => Response::builder()
.status(StatusCode::NOT_FOUND)
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(Full::from(Bytes::from(body)))
.unwrap_or_else(|_| fallback_text_response(StatusCode::NOT_FOUND)),
Err(_) => fallback_text_response(StatusCode::NOT_FOUND),
}
}
fn method_not_allowed_response(methods: Vec<Method>) -> Response<Full<Bytes>> {
let allow_header = methods
.iter()
.map(Method::as_str)
.collect::<Vec<_>>()
.join(", ");
let body = serde_json::json!({
"success": false,
"errors": [{
"code": "METHOD_NOT_ALLOWED",
"message": "The requested method is not allowed for this path",
}]
});
match serde_json::to_vec(&body) {
Ok(bytes) => Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header(hyper::header::CONTENT_TYPE, "application/json")
.header(hyper::header::ALLOW, allow_header)
.body(Full::from(Bytes::from(bytes)))
.unwrap_or_else(|_| fallback_text_response(StatusCode::METHOD_NOT_ALLOWED)),
Err(_) => fallback_text_response(StatusCode::METHOD_NOT_ALLOWED),
}
}
fn error_response(err: BoxError) -> Response<Full<Bytes>> {
#[cfg(feature = "tracing")]
{
tracing::error!(error = ?err, "handler error");
}
#[cfg(not(feature = "tracing"))]
let _ = &err;
match serde_json::to_vec(&serde_json::json!({
"success": false,
"errors": [{
"code": "INTERNAL_ERROR",
"message": "An internal error occurred",
}]
})) {
Ok(body) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(Full::from(Bytes::from(body)))
.unwrap_or_else(|_| fallback_text_response(StatusCode::INTERNAL_SERVER_ERROR)),
Err(_) => fallback_text_response(StatusCode::INTERNAL_SERVER_ERROR),
}
}
fn fallback_text_response(status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(hyper::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(Full::from(Bytes::from_static(b"internal server error")))
.expect("static response")
}