use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use crate::{
configuration::Store,
endpoint::{BoxedEndpoint, Endpoint},
Middleware,
};
use path_table::{PathTable, RouteMatch};
pub struct Router<Data> {
table: PathTable<ResourceData<Data>>,
middleware_base: Vec<Arc<dyn Middleware<Data> + Send + Sync>>,
pub(crate) store_base: Store,
}
pub(crate) struct RouteResult<'a, Data> {
pub(crate) endpoint: &'a EndpointData<Data>,
pub(crate) params: Option<RouteMatch<'a>>,
pub(crate) middleware: &'a [Arc<dyn Middleware<Data> + Send + Sync>],
}
fn route_match_success<'a, Data>(
route: &'a ResourceData<Data>,
route_match: RouteMatch<'a>,
method: &http::Method,
) -> Option<RouteResult<'a, Data>> {
let endpoint =
if method == http::Method::HEAD && !route.endpoints.contains_key(&http::Method::HEAD) {
route.endpoints.get(&http::Method::GET)?
} else {
route.endpoints.get(method)?
};
let middleware = &*route.middleware;
Some(RouteResult {
endpoint,
params: Some(route_match),
middleware,
})
}
fn route_match_failure<'a, Data>(
endpoint: &'a EndpointData<Data>,
middleware: &'a [Arc<dyn Middleware<Data> + Send + Sync>],
) -> RouteResult<'a, Data> {
RouteResult {
endpoint,
params: None,
middleware: &*middleware,
}
}
impl<Data: Clone + Send + Sync + 'static> Router<Data> {
pub fn at<'a>(&'a mut self, path: &'a str) -> Resource<'a, Data> {
let table = self.table.setup_table(path);
Resource {
table,
middleware_base: &self.middleware_base,
}
}
pub(crate) fn new() -> Router<Data> {
Router {
table: PathTable::new(),
middleware_base: Vec::new(),
store_base: Store::new(),
}
}
pub fn middleware(&mut self, middleware: impl Middleware<Data> + 'static) -> &mut Self {
let middleware = Arc::new(middleware);
for resource in self.table.iter_mut() {
resource.middleware.push(middleware.clone());
}
self.middleware_base.push(middleware);
self
}
pub fn config<T: Any + Debug + Clone + Send + Sync>(&mut self, item: T) -> &mut Self {
self.store_base.write(item);
self
}
pub(crate) fn route<'a>(
&'a self,
path: &'a str,
method: &http::Method,
default_handler: &'a Arc<EndpointData<Data>>,
) -> RouteResult<'a, Data> {
match self.table.route(path) {
Some((route, route_match)) => route_match_success(route, route_match, method)
.unwrap_or_else(|| route_match_failure(default_handler, &self.middleware_base)),
None => route_match_failure(default_handler, &self.middleware_base),
}
}
}
impl<Data> Router<Data> {
pub(crate) fn apply_default_config(&mut self) {
for resource in self.table.iter_mut() {
for endpoint in resource.endpoints.values_mut() {
endpoint.store.merge(&self.store_base);
}
}
}
pub(crate) fn get_item<T: Any + Debug + Clone + Send + Sync>(&self) -> Option<&T> {
self.store_base.read()
}
}
pub struct EndpointData<Data> {
pub(crate) endpoint: BoxedEndpoint<Data>,
pub(crate) store: Store,
}
impl<Data> EndpointData<Data> {
pub fn config<T: Any + Debug + Clone + Send + Sync>(&mut self, item: T) -> &mut Self {
self.store.write(item);
self
}
}
pub struct Resource<'a, Data> {
table: &'a mut PathTable<ResourceData<Data>>,
middleware_base: &'a Vec<Arc<dyn Middleware<Data> + Send + Sync>>,
}
struct ResourceData<Data> {
endpoints: HashMap<http::Method, EndpointData<Data>>,
middleware: Vec<Arc<dyn Middleware<Data> + Send + Sync>>,
}
impl<'a, Data> Resource<'a, Data> {
pub fn nest(self, builder: impl FnOnce(&mut Router<Data>)) {
let mut subrouter = Router {
table: PathTable::new(),
middleware_base: self.middleware_base.clone(),
store_base: Store::new(),
};
builder(&mut subrouter);
subrouter.apply_default_config();
*self.table = subrouter.table;
}
pub fn method<T: Endpoint<Data, U>, U>(
&mut self,
method: http::Method,
ep: T,
) -> &mut EndpointData<Data> {
let resource = self.table.resource_mut();
if resource.is_none() {
let new_resource = ResourceData {
endpoints: HashMap::new(),
middleware: self.middleware_base.clone(),
};
*resource = Some(new_resource);
}
let resource = resource.as_mut().unwrap();
let entry = resource.endpoints.entry(method);
if let std::collections::hash_map::Entry::Occupied(ep) = entry {
panic!("A {} endpoint already exists for this path", ep.key())
}
let endpoint = EndpointData {
endpoint: BoxedEndpoint::new(ep),
store: Store::new(),
};
entry.or_insert(endpoint)
}
pub fn get<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::GET, ep)
}
pub fn head<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::HEAD, ep)
}
pub fn put<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::PUT, ep)
}
pub fn post<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::POST, ep)
}
pub fn delete<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::DELETE, ep)
}
pub fn options<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::OPTIONS, ep)
}
pub fn connect<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::CONNECT, ep)
}
pub fn patch<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::PATCH, ep)
}
pub fn trace<T: Endpoint<Data, U>, U>(&mut self, ep: T) -> &mut EndpointData<Data> {
self.method(http::Method::TRACE, ep)
}
}
#[cfg(test)]
mod tests {
use futures::{executor::block_on, future::FutureObj};
use super::*;
use crate::{body::Body, middleware::RequestContext, AppData, Response};
fn passthrough_middleware<Data: Clone + Send>(
ctx: RequestContext<Data>,
) -> FutureObj<Response> {
ctx.next()
}
async fn simulate_request<'a, Data: Default + Clone + Send + Sync + 'static>(
router: &'a Router<Data>,
path: &'a str,
method: &'a http::Method,
) -> Option<Response> {
let default_handler = Arc::new(EndpointData {
endpoint: BoxedEndpoint::new(async || http::status::StatusCode::NOT_FOUND),
store: Store::new(),
});
let RouteResult {
endpoint,
params,
middleware,
} = router.route(path, method, &default_handler);
let data = Data::default();
let req = http::Request::builder()
.method(method)
.body(Body::empty())
.unwrap();
let ctx = RequestContext {
app_data: data,
req,
params,
endpoint,
next_middleware: middleware,
};
let res = await!(ctx.next());
Some(res.map(Into::into))
}
fn route_middleware_count<Data: Clone + Send + Sync + 'static>(
router: &Router<Data>,
path: &str,
method: &http::Method,
) -> Option<usize> {
let default_handler = Arc::new(EndpointData {
endpoint: BoxedEndpoint::new(async || http::status::StatusCode::NOT_FOUND),
store: Store::new(),
});
let route_result = router.route(path, method, &default_handler);
Some(route_result.middleware.len())
}
#[test]
fn simple_static() {
let mut router: Router<()> = Router::new();
router.at("/").get(async || "/");
router.at("/foo").get(async || "/foo");
router.at("/foo/bar").get(async || "/foo/bar");
for path in &["/", "/foo", "/foo/bar"] {
let res =
if let Some(res) = block_on(simulate_request(&router, path, &http::Method::GET)) {
res
} else {
panic!("Routing of path `{}` failed", path);
};
let body =
block_on(res.into_body().read_to_vec()).expect("Reading body should succeed");
assert_eq!(&*body, path.as_bytes());
}
}
#[test]
fn nested_static() {
let mut router: Router<()> = Router::new();
router.at("/a").get(async || "/a");
router.at("/b").nest(|router| {
router.at("/").get(async || "/b");
router.at("/a").get(async || "/b/a");
router.at("/b").get(async || "/b/b");
router.at("/c").nest(|router| {
router.at("/a").get(async || "/b/c/a");
router.at("/b").get(async || "/b/c/b");
});
router.at("/d").get(async || "/b/d");
});
router.at("/a/a").nest(|router| {
router.at("/a").get(async || "/a/a/a");
router.at("/b").get(async || "/a/a/b");
});
router.at("/a/b").nest(|router| {
router.at("/").get(async || "/a/b");
});
for failing_path in &["/", "/a/a", "/a/b/a"] {
if let Some(res) = block_on(simulate_request(&router, failing_path, &http::Method::GET))
{
if !res.status().is_client_error() {
panic!(
"Should have returned a client error when router cannot match with path {}",
failing_path
);
}
} else {
panic!("Should have received a response from {}", failing_path);
};
}
for path in &[
"/a", "/a/a/a", "/a/a/b", "/a/b", "/b", "/b/a", "/b/b", "/b/c/a", "/b/c/b", "/b/d",
] {
let res =
if let Some(res) = block_on(simulate_request(&router, path, &http::Method::GET)) {
res
} else {
panic!("Routing of path `{}` failed", path);
};
let body =
block_on(res.into_body().read_to_vec()).expect("Reading body should succeed");
assert_eq!(&*body, path.as_bytes());
}
}
#[test]
fn multiple_methods() {
let mut router: Router<()> = Router::new();
router.at("/a").nest(|router| {
router.at("/b").get(async || "/a/b GET");
});
router.at("/a/b").post(async || "/a/b POST");
for (path, method) in &[("/a/b", http::Method::GET), ("/a/b", http::Method::POST)] {
let res = if let Some(res) = block_on(simulate_request(&router, path, &method)) {
res
} else {
panic!("Routing of {} `{}` failed", method, path);
};
let body =
block_on(res.into_body().read_to_vec()).expect("Reading body should succeed");
assert_eq!(&*body, format!("{} {}", path, method).as_bytes());
}
}
#[test]
#[should_panic]
fn duplicate_endpoint_fails() {
let mut router: Router<()> = Router::new();
router.at("/a").nest(|router| {
router.at("/b").get(async || "");
});
router.at("/a/b").get(async || "duplicate");
}
#[test]
fn simple_middleware() {
let mut router: Router<()> = Router::new();
router.middleware(passthrough_middleware);
router.at("/").get(async || "/");
router.at("/b").nest(|router| {
router.at("/").get(async || "/b");
router.middleware(passthrough_middleware);
});
assert_eq!(
route_middleware_count(&router, "/", &http::Method::GET),
Some(1)
);
assert_eq!(
route_middleware_count(&router, "/b", &http::Method::GET),
Some(2)
);
}
#[test]
fn middleware_apply_order() {
#[derive(Default, Clone, Debug)]
struct Data(Vec<usize>);
struct Pusher(usize);
impl Middleware<Data> for Pusher {
fn handle<'a>(&'a self, mut ctx: RequestContext<'a, Data>) -> FutureObj<'a, Response> {
FutureObj::new(Box::new(
async move {
ctx.app_data.0.push(self.0);
await!(ctx.next())
},
))
}
}
let mut router: Router<Data> = Router::new();
router.middleware(Pusher(0));
router.at("/").get(async move |data: AppData<Data>| {
if (data.0).0 == [0, 2] {
http::StatusCode::OK
} else {
http::StatusCode::INTERNAL_SERVER_ERROR
}
});
router.at("/a").nest(|router| {
router.at("/").get(async move |data: AppData<Data>| {
if (data.0).0 == [0, 1, 2] {
http::StatusCode::OK
} else {
http::StatusCode::INTERNAL_SERVER_ERROR
}
});
router.middleware(Pusher(1));
});
router.middleware(Pusher(2));
router.at("/b").nest(|router| {
router.at("/").get(async move |data: AppData<Data>| {
if (data.0).0 == [0, 2, 1] {
http::StatusCode::OK
} else {
http::StatusCode::INTERNAL_SERVER_ERROR
}
});
router.middleware(Pusher(1));
});
for path in &["/", "/a", "/b"] {
let res = block_on(simulate_request(&router, path, &http::Method::GET)).unwrap();
assert_eq!(res.status(), 200);
}
}
#[test]
fn configuration() {
use crate::ExtractConfiguration;
async fn endpoint(
ExtractConfiguration(x): ExtractConfiguration<&'static str>,
) -> &'static str {
x.unwrap()
}
let mut router: Router<()> = Router::new();
router.config("foo");
router.at("/").get(endpoint);
router.at("/bar").get(endpoint).config("bar");
router.apply_default_config();
let res = block_on(simulate_request(&router, "/", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"foo");
let res = block_on(simulate_request(&router, "/bar", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"bar");
}
#[test]
fn configuration_nested() {
use crate::ExtractConfiguration;
async fn endpoint(
ExtractConfiguration(x): ExtractConfiguration<&'static str>,
) -> &'static str {
x.unwrap()
}
let mut router: Router<()> = Router::new();
router.config("foo");
router.at("/").get(endpoint);
router.at("/bar").nest(|router| {
router.config("bar");
router.at("/").get(endpoint);
router.at("/baz").get(endpoint).config("baz");
});
router.apply_default_config();
let res = block_on(simulate_request(&router, "/", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"foo");
let res = block_on(simulate_request(&router, "/bar", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"bar");
let res = block_on(simulate_request(&router, "/bar/baz", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"baz");
}
#[test]
fn configuration_order() {
use crate::ExtractConfiguration;
async fn endpoint(
ExtractConfiguration(x): ExtractConfiguration<&'static str>,
) -> &'static str {
x.unwrap()
}
let mut router: Router<()> = Router::new();
router.at("/").get(endpoint);
router.config("foo");
router.at("/bar").get(endpoint).config("bar");
router.apply_default_config();
let res = block_on(simulate_request(&router, "/", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"foo");
let res = block_on(simulate_request(&router, "/bar", &http::Method::GET)).unwrap();
let body = block_on(res.into_body().read_to_vec()).unwrap();
assert_eq!(&*body, &*b"bar");
}
}