#![doc = include_str!("../README.md")]
use crate::__internal::{Handler, Route};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_service::Service;
extern crate self as const_router;
pub use const_router_macros::{handler, router};
#[derive(Debug)]
pub struct Router<TRequest, TResponse, TError>
where
TResponse: 'static,
TError: 'static,
TRequest: 'static,
{
fallback: Handler<TRequest, TResponse, TError>,
routes: &'static [Route<TRequest, TResponse, TError>],
}
pub trait ExtractKey {
fn extract_key(&self) -> &str;
}
pub type BoxFuture<TResponse, TError> =
Pin<Box<dyn Future<Output = Result<TResponse, TError>> + Send + 'static>>;
impl<TRequest, TResponse, TError> Router<TRequest, TResponse, TError>
where
TRequest: ExtractKey,
{
pub fn handle(&self, req: TRequest) -> BoxFuture<TResponse, TError> {
let key = req.extract_key();
let handler = match self.routes.binary_search_by(|route| route.key.cmp(key)) {
Ok(index) => &self.routes[index].handler,
Err(_) => &self.fallback,
};
(handler.0)(req)
}
}
impl<TRequest, TResponse, TError> Service<TRequest> for Router<TRequest, TResponse, TError>
where
TResponse: 'static,
TError: 'static,
TRequest: ExtractKey + 'static,
{
type Error = TError;
type Future = BoxFuture<TResponse, TError>;
type Response = TResponse;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: TRequest) -> Self::Future {
self.handle(req)
}
}
#[cfg(feature = "http")]
impl<T> ExtractKey for http::Request<T> {
fn extract_key(&self) -> &str {
self.uri().path()
}
}
#[doc(hidden)]
pub mod __internal {
use super::*;
#[derive(Debug)]
pub struct Route<TRequest, TResponse, TError> {
pub(crate) key: &'static str,
pub(crate) handler: Handler<TRequest, TResponse, TError>,
}
#[derive(Debug)]
pub struct Handler<TRequest, TResponse, TError>(
pub(crate) HandlerFn<TRequest, TResponse, TError>,
);
type HandlerFn<TRequest, TResponse, TError> = fn(TRequest) -> BoxFuture<TResponse, TError>;
pub const fn new_router<TRequest, TResponse, TError>(
fallback: Handler<TRequest, TResponse, TError>,
routes: &'static [Route<TRequest, TResponse, TError>],
) -> Router<TRequest, TResponse, TError>
where
TResponse: 'static,
TError: 'static,
TRequest: 'static,
{
Router { fallback, routes }
}
pub const fn new_route<TRequest, TResponse, TError>(
key: &'static str,
handler: Handler<TRequest, TResponse, TError>,
) -> Route<TRequest, TResponse, TError> {
Route { key, handler }
}
pub const fn new_handler<TRequest, TResponse, TError>(
handler_fn: HandlerFn<TRequest, TResponse, TError>,
) -> Handler<TRequest, TResponse, TError> {
Handler(handler_fn)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
fmt::Debug,
task::{Context, Poll, Waker},
};
use tower_service::Service;
#[derive(Debug)]
struct Request {
key: String,
}
impl ExtractKey for Request {
fn extract_key(&self) -> &str {
&self.key
}
}
type TestRouter = Router<Request, String, &'static str>;
#[handler]
fn alpha_handler(_req: Request) -> Result<String, &'static str> {
Ok("alpha".to_owned())
}
#[handler]
async fn async_handler(_req: Request) -> Result<String, &'static str> {
Ok("async".to_owned())
}
#[handler]
fn echo_handler(req: Request) -> Result<String, &'static str> {
Ok(req.key)
}
#[handler]
fn generic_handler<T, const N: usize>(_req: Request) -> Result<String, &'static str>
where
T: Default,
{
let _ = T::default();
Ok(format!("generic-{N}"))
}
#[handler]
fn error_handler(_req: Request) -> Result<String, &'static str> {
Err("route failed")
}
#[handler]
fn fallback_handler() -> Result<String, &'static str> {
Ok("fallback".to_owned())
}
static ROUTER: TestRouter = router! {
fallback_handler,
"/generic" => generic_handler::<usize, 7>,
"/error" => error_handler,
"/async" => async_handler,
"/echo" => echo_handler,
"/alpha" => alpha_handler,
};
static FALLBACK_ONLY_ROUTER: TestRouter = router! {
fallback_handler,
};
fn request(key: impl Into<String>) -> Request {
Request { key: key.into() }
}
fn ready<T, E>(mut future: BoxFuture<T, E>) -> Result<T, E> {
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => result,
Poll::Pending => panic!("handler future did not complete"),
}
}
fn route_result(router: &TestRouter, key: &str) -> Result<String, &'static str> {
ready(router.handle(request(key)))
}
fn assert_route(router: &TestRouter, key: &str, expected: Result<&'static str, &'static str>) {
assert_eq!(route_result(router, key), expected.map(str::to_owned));
}
fn assert_ready<E>(poll: Poll<Result<(), E>>)
where
E: Debug + PartialEq,
{
assert_eq!(poll, Poll::Ready(Ok(())));
}
#[test]
fn router_macro_matches_sorted_routes_and_fallback() {
for (key, expected) in [
("/alpha", Ok("alpha")),
("/async", Ok("async")),
("/echo", Ok("/echo")),
("/generic", Ok("generic-7")),
("/error", Err("route failed")),
("/missing", Ok("fallback")),
("", Ok("fallback")),
] {
assert_route(&ROUTER, key, expected);
}
}
#[test]
fn fallback_only_router_routes_every_request_to_fallback() {
for key in ["", "/", "/alpha", "/missing"] {
assert_route(&FALLBACK_ONLY_ROUTER, key, Ok("fallback"));
}
}
#[test]
fn router_implements_tower_service() {
static ROUTES: [Route<Request, String, &'static str>; 1] =
[__internal::new_route("/alpha", alpha_handler())];
let mut router = __internal::new_router(fallback_handler(), &ROUTES);
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
assert_ready(Service::poll_ready(&mut router, &mut cx));
assert_eq!(
ready(Service::call(&mut router, request("/alpha"))),
Ok("alpha".to_owned())
);
assert_eq!(
ready(Service::call(&mut router, request("/unknown"))),
Ok("fallback".to_owned())
);
}
#[cfg(feature = "http")]
#[test]
fn http_request_extracts_uri_path() {
for (uri, path) in [
("https://example.com/static?ignored=true", "/static"),
("/nested/path#fragment", "/nested/path"),
("*", "*"),
] {
let request = http::Request::builder()
.uri(uri)
.body(())
.expect("request should build");
assert_eq!(request.extract_key(), path);
}
}
}