use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::{Method, Request, Response, StatusCode};
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
pub type Handler<S> = Arc<dyn Fn(Request, S) -> BoxFuture<Response> + Send + Sync>;
pub type Middleware<S> = Arc<dyn Fn(Request, S, Next<S>) -> BoxFuture<Response> + Send + Sync>;
#[derive(Clone)]
enum RouteMethod {
Exact(Method),
Any,
}
impl RouteMethod {
fn matches(&self, method: &Method) -> bool {
match self {
Self::Exact(expected) => expected == method,
Self::Any => true,
}
}
}
#[derive(Clone)]
struct Route<S> {
method: RouteMethod,
path: String,
handler: Handler<S>,
}
#[derive(Clone)]
pub struct Next<S> {
middleware: Arc<Vec<Middleware<S>>>,
handler: Handler<S>,
index: usize,
}
impl<S> Next<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn run(&self, request: Request, state: S) -> BoxFuture<Response> {
if let Some(middleware) = self.middleware.get(self.index).cloned() {
let next = Self {
middleware: self.middleware.clone(),
handler: self.handler.clone(),
index: self.index + 1,
};
middleware(request, state, next)
} else {
(self.handler)(request, state)
}
}
}
#[derive(Clone)]
pub struct Router<S> {
state: S,
routes: Vec<Route<S>>,
middleware: Vec<Middleware<S>>,
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(state: S) -> Self {
Self {
state,
routes: Vec::new(),
middleware: Vec::new(),
}
}
pub fn get<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(Request, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Get), path, handler)
}
pub fn post<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(Request, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(RouteMethod::Exact(Method::Post), path, handler)
}
pub fn any<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
where
F: Fn(Request, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(RouteMethod::Any, path, handler)
}
fn route<F, Fut>(mut self, method: RouteMethod, path: impl Into<String>, handler: F) -> Self
where
F: Fn(Request, S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
let handler: Handler<S> = Arc::new(move |request, state| Box::pin(handler(request, state)));
self.routes.push(Route {
method,
path: path.into(),
handler,
});
self
}
pub fn middleware<F, Fut>(mut self, middleware: F) -> Self
where
F: Fn(Request, S, Next<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
let middleware: Middleware<S> =
Arc::new(move |request, state, next| Box::pin(middleware(request, state, next)));
self.middleware.push(middleware);
self
}
pub fn merge(mut self, other: Self) -> Self {
self.routes.extend(other.routes);
self.middleware.extend(other.middleware);
self
}
pub async fn handle(&self, request: Request) -> Response {
let path_exists = self.routes.iter().any(|route| route.path == request.path);
let route = match self
.routes
.iter()
.find(|route| route.path == request.path && route.method.matches(&request.method))
{
Some(route) => route,
None if path_exists => return Response::new(StatusCode::METHOD_NOT_ALLOWED),
None => return Response::new(StatusCode::NOT_FOUND),
};
let next = Next {
middleware: Arc::new(self.middleware.clone()),
handler: route.handler.clone(),
index: 0,
};
next.run(request, self.state.clone()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
#[derive(Clone)]
struct TestState {
prefix: &'static str,
}
fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
fn raw_waker() -> RawWaker {
fn clone(_: *const ()) -> RawWaker {
raw_waker()
}
fn wake(_: *const ()) {}
fn wake_by_ref(_: *const ()) {}
fn drop(_: *const ()) {}
RawWaker::new(
std::ptr::null(),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop),
)
}
let waker = unsafe { Waker::from_raw(raw_waker()) };
let mut future = Box::pin(future);
let mut context = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(value) => return value,
Poll::Pending => std::thread::yield_now(),
}
}
}
#[test]
fn method_as_str_returns_expected_tokens() {
assert_eq!(Method::Get.as_str(), "GET");
assert_eq!(Method::Post.as_str(), "POST");
assert_eq!(Method::Other("PATCH".into()).as_str(), "PATCH");
}
#[test]
fn status_code_round_trip() {
assert_eq!(StatusCode::from_u16(204).as_u16(), 204);
}
#[test]
fn header_map_is_case_insensitive() {
let mut headers = crate::HeaderMap::new();
headers.insert("Content-Type", "application/json");
assert_eq!(headers.get("content-type"), Some("application/json"));
}
#[test]
fn extensions_store_typed_values() {
let mut extensions = crate::Extensions::new();
extensions.insert::<u64>(7);
assert_eq!(extensions.get::<u64>(), Some(&7));
}
#[test]
fn request_builder_sets_optional_fields() {
let request = crate::Request::new(Method::Get, "/x")
.with_query("a=1")
.with_body(b"hello".to_vec());
assert_eq!(request.query.as_deref(), Some("a=1"));
assert_eq!(request.body, b"hello");
}
#[test]
fn response_builder_sets_body() {
let response = crate::Response::new(StatusCode::OK).with_body(b"ok".to_vec());
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"ok");
}
#[test]
fn router_dispatches_exact_method_and_path() {
let router =
Router::new(TestState { prefix: "echo:" }).get("/echo", |request, state| async move {
let mut body = state.prefix.as_bytes().to_vec();
body.extend_from_slice(&request.body);
Response::new(StatusCode::OK).with_body(body)
});
let response = block_on(
router.handle(Request::new(Method::Get, "/echo").with_body(b"hello".to_vec())),
);
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, b"echo:hello");
}
#[test]
fn router_returns_method_not_allowed_when_path_exists() {
let router = Router::new(TestState { prefix: "x" })
.get("/echo", |_request, _state| async move {
Response::new(StatusCode::OK)
});
let response = block_on(router.handle(Request::new(Method::Post, "/echo")));
assert_eq!(response.status, StatusCode::METHOD_NOT_ALLOWED);
}
#[test]
fn middleware_wraps_handler() {
let router = Router::new(TestState { prefix: "core:" })
.get("/x", |_request, _state| async move {
Response::new(StatusCode::OK).with_body(b"body".to_vec())
})
.middleware(|request, state, next| async move {
let mut response = next.run(request, state).await;
response.headers.insert("x-middleware", "yes");
response
});
let response = block_on(router.handle(Request::new(Method::Get, "/x")));
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.headers.get("X-Middleware"), Some("yes"));
}
}