neco-server-core 0.1.0

core http primitives for neco-server
Documentation
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use crate::{Method, Request, Response, StatusCode};

type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;

/// Route handler function type.
pub type Handler<S> = Arc<dyn Fn(Request, S) -> BoxFuture<Response> + Send + Sync>;

/// Middleware function type.
pub type Middleware<S> = Arc<dyn Fn(Request, S, Next<S>) -> BoxFuture<Response> + Send + Sync>;

#[derive(Clone)]
enum RouteMethod {
    Exact(Method),
    Any,
}

impl RouteMethod {
    fn matches(&self, method: &Method) -> bool {
        match self {
            Self::Exact(expected) => expected == method,
            Self::Any => true,
        }
    }
}

#[derive(Clone)]
struct Route<S> {
    method: RouteMethod,
    path: String,
    handler: Handler<S>,
}

/// Middleware continuation.
#[derive(Clone)]
pub struct Next<S> {
    middleware: Arc<Vec<Middleware<S>>>,
    handler: Handler<S>,
    index: usize,
}

impl<S> Next<S>
where
    S: Clone + Send + Sync + 'static,
{
    /// Runs the next middleware or the final handler.
    pub fn run(&self, request: Request, state: S) -> BoxFuture<Response> {
        if let Some(middleware) = self.middleware.get(self.index).cloned() {
            let next = Self {
                middleware: self.middleware.clone(),
                handler: self.handler.clone(),
                index: self.index + 1,
            };
            middleware(request, state, next)
        } else {
            (self.handler)(request, state)
        }
    }
}

/// Fixed-path router with middleware chain.
#[derive(Clone)]
pub struct Router<S> {
    state: S,
    routes: Vec<Route<S>>,
    middleware: Vec<Middleware<S>>,
}

impl<S> Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    /// Creates an empty router bound to a clonable state value.
    pub fn new(state: S) -> Self {
        Self {
            state,
            routes: Vec::new(),
            middleware: Vec::new(),
        }
    }

    /// Registers a GET route.
    pub fn get<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
    where
        F: Fn(Request, S) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Response> + Send + 'static,
    {
        self.route(RouteMethod::Exact(Method::Get), path, handler)
    }

    /// Registers a POST route.
    pub fn post<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
    where
        F: Fn(Request, S) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Response> + Send + 'static,
    {
        self.route(RouteMethod::Exact(Method::Post), path, handler)
    }

    /// Registers a route for any method.
    pub fn any<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
    where
        F: Fn(Request, S) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Response> + Send + 'static,
    {
        self.route(RouteMethod::Any, path, handler)
    }

    fn route<F, Fut>(mut self, method: RouteMethod, path: impl Into<String>, handler: F) -> Self
    where
        F: Fn(Request, S) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Response> + Send + 'static,
    {
        let handler: Handler<S> = Arc::new(move |request, state| Box::pin(handler(request, state)));
        self.routes.push(Route {
            method,
            path: path.into(),
            handler,
        });
        self
    }

    /// Adds a middleware to the end of the chain.
    pub fn middleware<F, Fut>(mut self, middleware: F) -> Self
    where
        F: Fn(Request, S, Next<S>) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Response> + Send + 'static,
    {
        let middleware: Middleware<S> =
            Arc::new(move |request, state, next| Box::pin(middleware(request, state, next)));
        self.middleware.push(middleware);
        self
    }

    /// Merges routes and middleware from another router with the same state.
    pub fn merge(mut self, other: Self) -> Self {
        self.routes.extend(other.routes);
        self.middleware.extend(other.middleware);
        self
    }

    /// Dispatches a request entirely in-process.
    pub async fn handle(&self, request: Request) -> Response {
        let path_exists = self.routes.iter().any(|route| route.path == request.path);
        let route = match self
            .routes
            .iter()
            .find(|route| route.path == request.path && route.method.matches(&request.method))
        {
            Some(route) => route,
            None if path_exists => return Response::new(StatusCode::METHOD_NOT_ALLOWED),
            None => return Response::new(StatusCode::NOT_FOUND),
        };

        let next = Next {
            middleware: Arc::new(self.middleware.clone()),
            handler: route.handler.clone(),
            index: 0,
        };
        next.run(request, self.state.clone()).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::future::Future;
    use std::pin::Pin;
    use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

    #[derive(Clone)]
    struct TestState {
        prefix: &'static str,
    }

    fn block_on<F>(future: F) -> F::Output
    where
        F: Future,
    {
        fn raw_waker() -> RawWaker {
            fn clone(_: *const ()) -> RawWaker {
                raw_waker()
            }
            fn wake(_: *const ()) {}
            fn wake_by_ref(_: *const ()) {}
            fn drop(_: *const ()) {}

            RawWaker::new(
                std::ptr::null(),
                &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
            )
        }

        let waker = unsafe { Waker::from_raw(raw_waker()) };
        let mut future = Box::pin(future);
        let mut context = Context::from_waker(&waker);

        loop {
            match Pin::as_mut(&mut future).poll(&mut context) {
                Poll::Ready(value) => return value,
                Poll::Pending => std::thread::yield_now(),
            }
        }
    }

    #[test]
    fn method_as_str_returns_expected_tokens() {
        assert_eq!(Method::Get.as_str(), "GET");
        assert_eq!(Method::Post.as_str(), "POST");
        assert_eq!(Method::Other("PATCH".into()).as_str(), "PATCH");
    }

    #[test]
    fn status_code_round_trip() {
        assert_eq!(StatusCode::from_u16(204).as_u16(), 204);
    }

    #[test]
    fn header_map_is_case_insensitive() {
        let mut headers = crate::HeaderMap::new();
        headers.insert("Content-Type", "application/json");
        assert_eq!(headers.get("content-type"), Some("application/json"));
    }

    #[test]
    fn extensions_store_typed_values() {
        let mut extensions = crate::Extensions::new();
        extensions.insert::<u64>(7);
        assert_eq!(extensions.get::<u64>(), Some(&7));
    }

    #[test]
    fn request_builder_sets_optional_fields() {
        let request = crate::Request::new(Method::Get, "/x")
            .with_query("a=1")
            .with_body(b"hello".to_vec());
        assert_eq!(request.query.as_deref(), Some("a=1"));
        assert_eq!(request.body, b"hello");
    }

    #[test]
    fn response_builder_sets_body() {
        let response = crate::Response::new(StatusCode::OK).with_body(b"ok".to_vec());
        assert_eq!(response.status, StatusCode::OK);
        assert_eq!(response.body, b"ok");
    }

    #[test]
    fn router_dispatches_exact_method_and_path() {
        let router =
            Router::new(TestState { prefix: "echo:" }).get("/echo", |request, state| async move {
                let mut body = state.prefix.as_bytes().to_vec();
                body.extend_from_slice(&request.body);
                Response::new(StatusCode::OK).with_body(body)
            });

        let response = block_on(
            router.handle(Request::new(Method::Get, "/echo").with_body(b"hello".to_vec())),
        );

        assert_eq!(response.status, StatusCode::OK);
        assert_eq!(response.body, b"echo:hello");
    }

    #[test]
    fn router_returns_method_not_allowed_when_path_exists() {
        let router = Router::new(TestState { prefix: "x" })
            .get("/echo", |_request, _state| async move {
                Response::new(StatusCode::OK)
            });

        let response = block_on(router.handle(Request::new(Method::Post, "/echo")));
        assert_eq!(response.status, StatusCode::METHOD_NOT_ALLOWED);
    }

    #[test]
    fn middleware_wraps_handler() {
        let router = Router::new(TestState { prefix: "core:" })
            .get("/x", |_request, _state| async move {
                Response::new(StatusCode::OK).with_body(b"body".to_vec())
            })
            .middleware(|request, state, next| async move {
                let mut response = next.run(request, state).await;
                response.headers.insert("x-middleware", "yes");
                response
            });

        let response = block_on(router.handle(Request::new(Method::Get, "/x")));
        assert_eq!(response.status, StatusCode::OK);
        assert_eq!(response.headers.get("X-Middleware"), Some("yes"));
    }
}