use std::{
collections::{HashMap, hash_map::Drain},
error::Error,
fmt,
future::Future,
str::FromStr,
};
use http::uri::Uri;
use motore::{layer::Layer, service::Service};
use crate::{context::ServerContext, request::Request, response::Response};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(super) struct RouteId(u32);
impl RouteId {
fn next() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
static ID: AtomicU32 = AtomicU32::new(0);
let id = ID.fetch_add(1, Ordering::Relaxed);
if id == u32::MAX {
panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
}
Self(id)
}
}
#[derive(Default)]
pub(super) struct Matcher {
matches: HashMap<String, RouteId>,
router: matchit::Router<RouteId>,
}
impl Matcher {
pub fn insert<R>(&mut self, uri: R) -> Result<RouteId, MatcherError>
where
R: Into<String>,
{
let route_id = RouteId::next();
self.insert_with_id(uri, route_id)?;
Ok(route_id)
}
pub fn insert_with_id<R>(&mut self, uri: R, route_id: RouteId) -> Result<(), MatcherError>
where
R: Into<String>,
{
let uri = uri.into();
if self.matches.insert(uri.clone(), route_id).is_some() {
return Err(MatcherError::UriConflict(uri));
}
self.router
.insert(uri, route_id)
.map_err(MatcherError::RouterInsertError)?;
Ok(())
}
pub fn at<'a>(
&'a self,
path: &'a str,
) -> Result<matchit::Match<'a, 'a, &'a RouteId>, MatcherError> {
self.router.at(path).map_err(MatcherError::RouterMatchError)
}
pub fn drain(&mut self) -> Drain<'_, String, RouteId> {
self.matches.drain()
}
}
#[derive(Debug)]
pub(super) enum MatcherError {
UriConflict(String),
RouterInsertError(matchit::InsertError),
RouterMatchError(matchit::MatchError),
}
impl fmt::Display for MatcherError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UriConflict(uri) => write!(f, "URI conflict: {uri}"),
Self::RouterInsertError(err) => write!(f, "router insert error: {err}"),
Self::RouterMatchError(err) => write!(f, "router match error: {err}"),
}
}
}
impl Error for MatcherError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::UriConflict(_) => None,
Self::RouterInsertError(e) => Some(e),
Self::RouterMatchError(e) => Some(e),
}
}
}
pub(super) struct StripPrefixLayer;
impl<S> Layer<S> for StripPrefixLayer {
type Service = StripPrefix<S>;
fn layer(self, inner: S) -> Self::Service {
StripPrefix { inner }
}
}
pub(super) const NEST_CATCH_PARAM: &str = "{*__priv_nest_catch_param}";
pub(super) const NEST_CATCH_PARAM_NAME: &str = "__priv_nest_catch_param";
pub(super) struct StripPrefix<S> {
inner: S,
}
impl<S, B, E> Service<ServerContext, Request<B>> for StripPrefix<S>
where
S: Service<ServerContext, Request<B>, Response = Response, Error = E>,
{
type Response = Response;
type Error = E;
fn call(
&self,
cx: &mut ServerContext,
mut req: Request<B>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
let mut uri = String::from("/");
if cx
.params()
.last()
.is_some_and(|(k, _)| k == NEST_CATCH_PARAM_NAME)
{
uri += cx.params_mut().pop().unwrap().1.as_str();
};
if let Some(query) = req.uri().query() {
uri.push('?');
uri.push_str(query);
}
*req.uri_mut() = Uri::from_str(&uri).expect("infallible: stripped uri is invalid");
self.inner.call(cx, req)
}
}