const-router 0.1.0

Compile-time string-key router with Tower service integration.
Documentation
#![doc = include_str!("../README.md")]

use crate::__internal::{Handler, Route};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower_service::Service;

extern crate self as const_router;

pub use const_router_macros::{handler, router};

/// A static string-key router.
///
/// Build a router with [`router!`]. Requests are matched by the value returned
/// from [`ExtractKey`]. Missing keys are routed to the configured fallback
/// handler.
#[derive(Debug)]
pub struct Router<TRequest, TResponse, TError>
where
    TResponse: 'static,
    TError: 'static,
    TRequest: 'static,
{
    fallback: Handler<TRequest, TResponse, TError>,
    routes: &'static [Route<TRequest, TResponse, TError>],
}

/// Extracts the route key from a request.
///
/// The key is compared against the string literals passed to [`router!`].
pub trait ExtractKey {
    /// Returns the string key used for route matching.
    fn extract_key(&self) -> &str;
}

/// Boxed future returned by router handlers.
pub type BoxFuture<TResponse, TError> =
    Pin<Box<dyn Future<Output = Result<TResponse, TError>> + Send + 'static>>;

impl<TRequest, TResponse, TError> Router<TRequest, TResponse, TError>
where
    TRequest: ExtractKey,
{
    /// Routes a request to the matching handler or to the fallback handler.
    pub fn handle(&self, req: TRequest) -> BoxFuture<TResponse, TError> {
        let key = req.extract_key();

        let handler = match self.routes.binary_search_by(|route| route.key.cmp(key)) {
            Ok(index) => &self.routes[index].handler,
            Err(_) => &self.fallback,
        };

        (handler.0)(req)
    }
}

impl<TRequest, TResponse, TError> Service<TRequest> for Router<TRequest, TResponse, TError>
where
    TResponse: 'static,
    TError: 'static,
    TRequest: ExtractKey + 'static,
{
    type Error = TError;
    type Future = BoxFuture<TResponse, TError>;
    type Response = TResponse;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: TRequest) -> Self::Future {
        self.handle(req)
    }
}

#[cfg(feature = "http")]
impl<T> ExtractKey for http::Request<T> {
    fn extract_key(&self) -> &str {
        self.uri().path()
    }
}

#[doc(hidden)]
pub mod __internal {
    use super::*;

    #[derive(Debug)]
    pub struct Route<TRequest, TResponse, TError> {
        pub(crate) key: &'static str,
        pub(crate) handler: Handler<TRequest, TResponse, TError>,
    }

    #[derive(Debug)]
    pub struct Handler<TRequest, TResponse, TError>(
        pub(crate) HandlerFn<TRequest, TResponse, TError>,
    );

    type HandlerFn<TRequest, TResponse, TError> = fn(TRequest) -> BoxFuture<TResponse, TError>;

    pub const fn new_router<TRequest, TResponse, TError>(
        fallback: Handler<TRequest, TResponse, TError>,
        routes: &'static [Route<TRequest, TResponse, TError>],
    ) -> Router<TRequest, TResponse, TError>
    where
        TResponse: 'static,
        TError: 'static,
        TRequest: 'static,
    {
        Router { fallback, routes }
    }

    pub const fn new_route<TRequest, TResponse, TError>(
        key: &'static str,
        handler: Handler<TRequest, TResponse, TError>,
    ) -> Route<TRequest, TResponse, TError> {
        Route { key, handler }
    }

    pub const fn new_handler<TRequest, TResponse, TError>(
        handler_fn: HandlerFn<TRequest, TResponse, TError>,
    ) -> Handler<TRequest, TResponse, TError> {
        Handler(handler_fn)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::{
        fmt::Debug,
        task::{Context, Poll, Waker},
    };
    use tower_service::Service;

    #[derive(Debug)]
    struct Request {
        key: String,
    }

    impl ExtractKey for Request {
        fn extract_key(&self) -> &str {
            &self.key
        }
    }

    type TestRouter = Router<Request, String, &'static str>;

    #[handler]
    fn alpha_handler(_req: Request) -> Result<String, &'static str> {
        Ok("alpha".to_owned())
    }

    #[handler]
    async fn async_handler(_req: Request) -> Result<String, &'static str> {
        Ok("async".to_owned())
    }

    #[handler]
    fn echo_handler(req: Request) -> Result<String, &'static str> {
        Ok(req.key)
    }

    #[handler]
    fn generic_handler<T, const N: usize>(_req: Request) -> Result<String, &'static str>
    where
        T: Default,
    {
        let _ = T::default();

        Ok(format!("generic-{N}"))
    }

    #[handler]
    fn error_handler(_req: Request) -> Result<String, &'static str> {
        Err("route failed")
    }

    #[handler]
    fn fallback_handler() -> Result<String, &'static str> {
        Ok("fallback".to_owned())
    }

    static ROUTER: TestRouter = router! {
        fallback_handler,
        "/generic" => generic_handler::<usize, 7>,
        "/error" => error_handler,
        "/async" => async_handler,
        "/echo" => echo_handler,
        "/alpha" => alpha_handler,
    };

    static FALLBACK_ONLY_ROUTER: TestRouter = router! {
        fallback_handler,
    };

    fn request(key: impl Into<String>) -> Request {
        Request { key: key.into() }
    }

    fn ready<T, E>(mut future: BoxFuture<T, E>) -> Result<T, E> {
        let waker = Waker::noop();
        let mut cx = Context::from_waker(waker);

        match future.as_mut().poll(&mut cx) {
            Poll::Ready(result) => result,
            Poll::Pending => panic!("handler future did not complete"),
        }
    }

    fn route_result(router: &TestRouter, key: &str) -> Result<String, &'static str> {
        ready(router.handle(request(key)))
    }

    fn assert_route(router: &TestRouter, key: &str, expected: Result<&'static str, &'static str>) {
        assert_eq!(route_result(router, key), expected.map(str::to_owned));
    }

    fn assert_ready<E>(poll: Poll<Result<(), E>>)
    where
        E: Debug + PartialEq,
    {
        assert_eq!(poll, Poll::Ready(Ok(())));
    }

    #[test]
    fn router_macro_matches_sorted_routes_and_fallback() {
        for (key, expected) in [
            ("/alpha", Ok("alpha")),
            ("/async", Ok("async")),
            ("/echo", Ok("/echo")),
            ("/generic", Ok("generic-7")),
            ("/error", Err("route failed")),
            ("/missing", Ok("fallback")),
            ("", Ok("fallback")),
        ] {
            assert_route(&ROUTER, key, expected);
        }
    }

    #[test]
    fn fallback_only_router_routes_every_request_to_fallback() {
        for key in ["", "/", "/alpha", "/missing"] {
            assert_route(&FALLBACK_ONLY_ROUTER, key, Ok("fallback"));
        }
    }

    #[test]
    fn router_implements_tower_service() {
        static ROUTES: [Route<Request, String, &'static str>; 1] =
            [__internal::new_route("/alpha", alpha_handler())];
        let mut router = __internal::new_router(fallback_handler(), &ROUTES);
        let waker = Waker::noop();
        let mut cx = Context::from_waker(waker);

        assert_ready(Service::poll_ready(&mut router, &mut cx));
        assert_eq!(
            ready(Service::call(&mut router, request("/alpha"))),
            Ok("alpha".to_owned())
        );
        assert_eq!(
            ready(Service::call(&mut router, request("/unknown"))),
            Ok("fallback".to_owned())
        );
    }

    #[cfg(feature = "http")]
    #[test]
    fn http_request_extracts_uri_path() {
        for (uri, path) in [
            ("https://example.com/static?ignored=true", "/static"),
            ("/nested/path#fragment", "/nested/path"),
            ("*", "*"),
        ] {
            let request = http::Request::builder()
                .uri(uri)
                .body(())
                .expect("request should build");

            assert_eq!(request.extract_key(), path);
        }
    }
}