use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use neco_server_core::{Method, Request, Response, StatusCode};
use crate::Extensions;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
pub struct RoutedRequest {
pub request: Request,
pub extensions: Extensions,
}
impl RoutedRequest {
pub fn new(request: Request) -> Self {
Self {
request,
extensions: Extensions::new(),
}
}
}
pub type Handler<S, R = Response> = Arc<dyn Fn(RoutedRequest, S) -> BoxFuture<R> + Send + Sync>;
pub type Middleware<S, R = Response> =
Arc<dyn Fn(RoutedRequest, S, Next<S, R>) -> BoxFuture<R> + Send + Sync>;
#[derive(Clone)]
enum RouteMethod {
Exact(Method),
Any,
}
impl RouteMethod {
fn matches(&self, method: &Method) -> bool {
match self {
Self::Exact(expected) => expected == method,
Self::Any => true,
}
}
}
struct Route<S, R> {
method: RouteMethod,
path: String,
handler: Handler<S, R>,
middleware: Vec<Middleware<S, R>>,
}
impl<S, R> Clone for Route<S, R> {
fn clone(&self) -> Self {
Self {
method: self.method.clone(),
path: self.path.clone(),
handler: self.handler.clone(),
middleware: self.middleware.clone(),
}
}
}
pub struct Next<S, R = Response> {
middleware: Arc<Vec<Middleware<S, R>>>,
handler: Handler<S, R>,
index: usize,
}
impl<S, R> Clone for Next<S, R> {
fn clone(&self) -> Self {
Self {
middleware: self.middleware.clone(),
handler: self.handler.clone(),
index: self.index,
}
}
}
impl<S, R> Next<S, R>
where
S: Clone + Send + Sync + 'static,
R: Send + 'static,
{
pub fn run(&self, request: RoutedRequest, state: S) -> BoxFuture<R> {
if let Some(middleware) = self.middleware.get(self.index).cloned() {
let next = Self {
middleware: self.middleware.clone(),
handler: self.handler.clone(),
index: self.index + 1,
};
middleware(request, state, next)
} else {
(self.handler)(request, state)
}
}
}
pub struct Router<S, R = Response> {
state: S,
routes: Vec<Route<S, R>>,
pending_middleware: Vec<Middleware<S, R>>,
}
impl<S, R> Clone for Router<S, R>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
routes: self.routes.clone(),
pending_middleware: self.pending_middleware.clone(),
}
}
}
impl<S, R> Router<S, R>
where
S: Clone + Send + Sync + 'static,
R: From<Response> + Send + 'static,
{
pub fn new(state: S) -> Self {
Self {
state,
routes: Vec::new(),
pending_middleware: Vec::new(),
}
}
pub fn get<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Get), path, handler)
}
pub fn post<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Post), path, handler)
}
pub fn put<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Put), path, handler)
}
pub fn delete<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Delete), path, handler)
}
pub fn patch<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Patch), path, handler)
}
pub fn head<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Head), path, handler)
}
pub fn options<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Options), path, handler)
}
pub fn any<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Any, path, handler)
}
pub fn on<F, Fut>(self, method: Method, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
self.route(RouteMethod::Exact(method), path, handler)
}
fn route<F, Fut>(mut self, method: RouteMethod, path: impl Into<String>, handler: F) -> Self
where
F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let handler: Handler<S, R> =
Arc::new(move |request, state| Box::pin(handler(request, state)));
self.routes.push(Route {
method,
path: path.into(),
handler,
middleware: self.pending_middleware.clone(),
});
self
}
pub fn middleware<F, Fut>(mut self, middleware: F) -> Self
where
F: Fn(RoutedRequest, S, Next<S, R>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let middleware: Middleware<S, R> =
Arc::new(move |request, state, next| Box::pin(middleware(request, state, next)));
for route in &mut self.routes {
route.middleware.push(middleware.clone());
}
self.pending_middleware.push(middleware);
self
}
pub fn merge(mut self, other: Self) -> Self {
self.routes.extend(other.routes);
self
}
pub async fn handle(&self, request: Request) -> R {
self.dispatch_routed(RoutedRequest::new(request)).await
}
pub async fn handle_routed(&self, request: RoutedRequest) -> R {
self.dispatch_routed(request).await
}
async fn dispatch_routed(&self, request: RoutedRequest) -> R {
let path_exists = self
.routes
.iter()
.any(|route| route.path == request.request.path);
let route = match self.routes.iter().find(|route| {
route.path == request.request.path && route.method.matches(&request.request.method)
}) {
Some(route) => route,
None if path_exists => return not_found_or_method::<R>(StatusCode::METHOD_NOT_ALLOWED),
None => return not_found_or_method::<R>(StatusCode::NOT_FOUND),
};
let next = Next {
middleware: Arc::new(route.middleware.clone()),
handler: route.handler.clone(),
index: 0,
};
next.run(request, self.state.clone()).await
}
}
fn not_found_or_method<R>(status: StatusCode) -> R
where
R: From<Response>,
{
Response::new(status).into()
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
#[derive(Clone)]
struct TestState {
prefix: &'static str,
}
fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
fn raw_waker() -> RawWaker {
fn clone(_: *const ()) -> RawWaker {
raw_waker()
}
fn wake(_: *const ()) {}
fn wake_by_ref(_: *const ()) {}
fn drop(_: *const ()) {}
RawWaker::new(
std::ptr::null(),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop),
)
}
let waker = unsafe { Waker::from_raw(raw_waker()) };
let mut future = Box::pin(future);
let mut context = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(value) => return value,
Poll::Pending => std::thread::yield_now(),
}
}
}
#[test]
fn router_dispatches_exact_method_and_path() {
let router =
Router::new(TestState { prefix: "echo:" }).get("/echo", |request, state| async move {
let mut body = state.prefix.as_bytes().to_vec();
body.extend_from_slice(&request.request.body);
Response::new(StatusCode::OK).with_body(body)
});
let response = block_on(
router.handle(Request::new(Method::Get, "/echo").with_body(b"hello".to_vec())),
);
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"echo:hello");
}
#[test]
fn router_dispatches_custom_method_route() {
let router = Router::new(TestState { prefix: "patch:" }).on(
Method::Other("PATCH".into()),
"/echo",
|request, state| async move {
let mut body = state.prefix.as_bytes().to_vec();
body.extend_from_slice(&request.request.body);
Response::new(StatusCode::OK).with_body(body)
},
);
let response = block_on(
router.handle(Request::new(Method::Other("PATCH".into()), "/echo").with_body(b"ok")),
);
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"patch:ok");
}
#[test]
fn router_dispatches_put_route() {
let router =
Router::new(TestState { prefix: "put:" }).put("/item", |request, state| async move {
let mut body = state.prefix.as_bytes().to_vec();
body.extend_from_slice(&request.request.body);
Response::new(StatusCode::OK).with_body(body)
});
let response = block_on(router.handle(Request::new(Method::Put, "/item").with_body(b"ok")));
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"put:ok");
}
#[test]
fn router_returns_method_not_allowed_when_path_exists() {
let router = Router::new(TestState { prefix: "x" })
.get("/echo", |_request, _state| async move {
Response::new(StatusCode::OK)
});
let response = block_on(router.handle(Request::new(Method::Post, "/echo")));
assert_eq!(response.status, StatusCode::METHOD_NOT_ALLOWED);
}
#[test]
fn middleware_wraps_handler() {
let router = Router::new(TestState { prefix: "core:" })
.get("/x", |_request, _state| async move {
Response::new(StatusCode::OK).with_body(b"body".to_vec())
})
.middleware(|mut request, state, next| async move {
request.extensions.insert::<u64>(7);
let mut response = next.run(request, state).await;
response.headers.insert("x-middleware", "yes");
response
});
let response = block_on(router.handle(Request::new(Method::Get, "/x")));
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.headers.get("X-Middleware"), Some("yes"));
}
#[test]
fn middleware_extensions_reach_handler() {
let router = Router::new(TestState { prefix: "ext:" })
.get("/x", |mut request, state| async move {
let marker = request.extensions.remove::<u64>().unwrap_or_default();
let mut body = state.prefix.as_bytes().to_vec();
body.extend_from_slice(marker.to_string().as_bytes());
Response::new(StatusCode::OK).with_body(body)
})
.middleware(|mut request, state, next| async move {
request.extensions.insert::<u64>(7);
next.run(request, state).await
});
let response = block_on(router.handle(Request::new(Method::Get, "/x")));
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"ext:7");
}
#[test]
fn middleware_applies_to_routes_added_after_layer() {
let router = Router::new(TestState { prefix: "late:" })
.middleware(|request, state, next| async move {
let mut response: Response = next.run(request, state).await;
response.headers.insert("x-layered", "yes");
response
})
.get("/x", |_request, state| async move {
Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
});
let response = block_on(router.handle(Request::new(Method::Get, "/x")));
assert_eq!(response.headers.get("x-layered"), Some("yes"));
}
#[test]
fn merged_router_does_not_leak_middleware_to_later_routes() {
let public = Router::new(TestState { prefix: "public:" }).get(
"/public",
|_request, state| async move {
Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
},
);
let protected = Router::new(TestState { prefix: "auth:" })
.get("/protected", |_request, state| async move {
Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
})
.middleware(|request, state, next| async move {
let mut response = next.run(request, state).await;
response.headers.insert("x-auth", "yes");
response
});
let router = public
.merge(protected)
.get("/later", |_request, state| async move {
Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
});
let protected_response = block_on(router.handle(Request::new(Method::Get, "/protected")));
assert_eq!(protected_response.headers.get("x-auth"), Some("yes"));
let later_response = block_on(router.handle(Request::new(Method::Get, "/later")));
assert_eq!(later_response.headers.get("x-auth"), None);
}
}