fnroute 0.1.0

A small function router with axum-style handler extraction.
Documentation
use alloc::{
    boxed::Box,
    collections::BTreeMap,
    string::{String, ToString},
    sync::Arc,
    vec::Vec,
};
use core::marker::PhantomData;

use crate::{
    BoxFuture, RouteContext, RouteError, RouteResult,
    handler::{Handler, SyncHandler},
};

type BoxedHandler<S, R> = Arc<dyn Fn(RouteContext, S) -> BoxFuture<RouteResult<R>>>;

#[derive(Clone)]
pub struct Router<R, S = ()> {
    state: S,
    routes: Vec<Route<S, R>>,
    return_type: PhantomData<R>,
}

#[derive(Clone)]
struct Route<S, R> {
    pattern: Pattern,
    handler: BoxedHandler<S, R>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
struct Pattern {
    segments: Vec<Segment>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum Segment {
    Literal(String),
    Param(String),
}

impl<R> Router<R, ()> {
    pub fn new() -> Self {
        Self {
            state: (),
            routes: Vec::new(),
            return_type: PhantomData,
        }
    }
}

impl<R> Default for Router<R, ()> {
    fn default() -> Self {
        Self::new()
    }
}

impl<R, S> Router<R, S>
where
    S: Clone + 'static,
    R: 'static,
{
    pub fn with_state<T>(self, state: T) -> Router<R, T>
    where
        T: Clone + 'static,
    {
        Router {
            state,
            routes: Vec::new(),
            return_type: PhantomData,
        }
    }

    pub fn route<T, H>(mut self, pattern: &str, handler: H) -> Self
    where
        H: Handler<T, S, R>,
    {
        let handler = Arc::new(move |context, state| handler.clone().call(context, state));

        self.routes.push(Route {
            pattern: Pattern::parse(pattern),
            handler,
        });

        self
    }

    pub fn route_sync<T, H>(mut self, pattern: &str, handler: H) -> Self
    where
        H: SyncHandler<T, S, R>,
    {
        let handler = Arc::new(move |context, state| {
            let result = handler.clone().call_sync(context, state);
            Box::pin(async move { result }) as BoxFuture<RouteResult<R>>
        });

        self.routes.push(Route {
            pattern: Pattern::parse(pattern),
            handler,
        });

        self
    }

    pub async fn call(&self, route: impl Into<String>) -> RouteResult<R> {
        self.call_with(RouteContext::new(route)).await
    }

    pub async fn call_with(&self, mut context: RouteContext) -> RouteResult<R> {
        for route in &self.routes {
            if let Some(params) = route.pattern.matches(&context.route) {
                context.params = params;
                return (route.handler)(context, self.state.clone()).await;
            }
        }

        Err(RouteError::NotFound {
            route: context.route,
        })
    }
}

impl Pattern {
    fn parse(path: &str) -> Self {
        Self {
            segments: split_path(path)
                .map(|segment| {
                    if let Some(name) = segment.strip_prefix(':') {
                        Segment::Param(name.to_string())
                    } else {
                        Segment::Literal(segment.to_string())
                    }
                })
                .collect(),
        }
    }

    fn matches(&self, path: &str) -> Option<BTreeMap<String, String>> {
        let mut params = BTreeMap::new();
        let mut pattern_segments = self.segments.iter();
        let mut path_segments = split_path(path);

        loop {
            match (pattern_segments.next(), path_segments.next()) {
                (None, None) => return Some(params),
                (None, Some(_)) | (Some(_), None) => return None,
                (Some(Segment::Literal(expected)), Some(actual)) if expected == actual => {}
                (Some(Segment::Literal(_)), Some(_)) => return None,
                (Some(Segment::Param(name)), Some(actual)) => {
                    params.insert(name.clone(), actual.to_string());
                }
            }
        }
    }
}

fn split_path(path: &str) -> impl Iterator<Item = &str> {
    path.trim_matches('/').split('/').filter(|s| !s.is_empty())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{Input, Param, Params, State};
    use std::{
        future::Future,
        pin::Pin,
        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
    };

    #[test]
    fn pattern_matches_named_params() {
        let pattern = Pattern::parse("users/:id/posts/:post_id");
        let params = pattern.matches("users/42/posts/7").unwrap();

        assert_eq!(params["id"], "42");
        assert_eq!(params["post_id"], "7");
        assert!(pattern.matches("users/42").is_none());
    }

    #[test]
    fn routes_zero_argument_handlers() {
        block_on(async {
            async fn hello() -> &'static str {
                "hello"
            }

            let app = Router::<String>::new().route("hello", hello);
            let output = app.call("hello").await.unwrap();

            assert_eq!(output, "hello");
        });
    }

    #[test]
    fn extracts_param_input_and_state() {
        block_on(async {
            #[derive(Clone)]
            struct Payload {
                suffix: String,
            }

            async fn user(
                Param(id): Param<u64>,
                Input(payload): Input<Payload>,
                State(prefix): State<String>,
            ) -> String {
                format!("{prefix}:{id}:{}", payload.suffix)
            }

            let app = Router::<String>::new()
                .with_state("user".to_string())
                .route("users/:id", user);

            let output = app
                .call_with(RouteContext::new("users/9").input(Payload {
                    suffix: "saved".to_string(),
                }))
                .await
                .unwrap();

            assert_eq!(output, "user:9:saved");
        });
    }

    #[test]
    fn supports_many_parameters() {
        block_on(async {
            async fn many(
                Param(id): Param<u64>,
                Params(params): Params,
                State(prefix): State<String>,
                context: RouteContext,
            ) -> String {
                format!("{prefix}:{id}:{}:{}", params["id"], context.route)
            }

            let app = Router::<String>::new()
                .with_state("many".to_string())
                .route("items/:id", many);

            let output = app.call("items/7").await.unwrap();

            assert_eq!(output, "many:7:7:items/7");
        });
    }

    #[test]
    fn supports_more_than_eight_parameters() {
        block_on(async {
            async fn many(
                RouteContext { route, .. }: RouteContext,
                State(prefix): State<String>,
                Params(params): Params,
                Param(id): Param<u64>,
                RouteContext { params: p2, .. }: RouteContext,
                State(prefix2): State<String>,
                Params(params2): Params,
                Param(id2): Param<u64>,
                RouteContext { route: route2, .. }: RouteContext,
            ) -> String {
                format!(
                    "{prefix}:{prefix2}:{id}:{id2}:{}:{}:{}:{}:{}",
                    params["id"], params2["id"], p2["id"], route, route2
                )
            }

            let app = Router::<String>::new()
                .with_state("wide".to_string())
                .route("wide/:id", many);

            let output = app.call("wide/11").await.unwrap();

            assert_eq!(output, "wide:wide:11:11:11:11:11:wide/11:wide/11");
        });
    }

    #[test]
    fn routes_sync_handlers_on_a_separate_branch() {
        block_on(async {
            fn show(Param(id): Param<u64>) -> String {
                format!("sync:{id}")
            }

            let app = Router::<String>::new().route_sync("sync/:id", show);
            let output = app.call("sync/12").await.unwrap();

            assert_eq!(output, "sync:12");
        });
    }

    #[test]
    fn handler_can_consume_a_cloned_capture() {
        block_on(async {
            let prefix = String::from("owned");
            let app = Router::<String>::new().route_sync("consume", move || prefix);

            let output = app.call("consume").await.unwrap();

            assert_eq!(output, "owned");
        });
    }

    fn block_on<F>(future: F) -> F::Output
    where
        F: Future,
    {
        let waker = noop_waker();
        let mut context = Context::from_waker(&waker);
        let mut future = Box::pin(future);

        loop {
            match Pin::as_mut(&mut future).poll(&mut context) {
                Poll::Ready(output) => return output,
                Poll::Pending => std::thread::yield_now(),
            }
        }
    }

    fn noop_waker() -> Waker {
        unsafe fn clone(_: *const ()) -> RawWaker {
            noop_raw_waker()
        }

        unsafe fn wake(_: *const ()) {}
        unsafe fn wake_by_ref(_: *const ()) {}
        unsafe fn drop(_: *const ()) {}

        fn noop_raw_waker() -> RawWaker {
            RawWaker::new(
                std::ptr::null(),
                &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
            )
        }

        unsafe { Waker::from_raw(noop_raw_waker()) }
    }
}