use self::{future::RouterFuture, request_spec::RequestSpec};
use crate::body::{boxed, Body, BoxBody, HttpBody};
use crate::BoxError;
use http::{Request, Response, StatusCode};
use std::{
convert::Infallible,
task::{Context, Poll},
};
use tower::layer::Layer;
use tower::util::ServiceExt;
use tower::{Service, ServiceBuilder};
use tower_http::map_response_body::MapResponseBodyLayer;
mod future;
mod into_make_service;
#[doc(hidden)]
pub mod request_spec;
mod route;
pub use self::{into_make_service::IntoMakeService, route::Route};
#[derive(Debug)]
pub struct Router<B = Body> {
routes: Vec<(Route<B>, RequestSpec)>,
}
impl<B> Clone for Router<B> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
}
}
}
impl<B> Default for Router<B>
where
B: Send + 'static,
{
fn default() -> Self {
Self {
routes: Default::default(),
}
}
}
impl<B> Router<B>
where
B: Send + 'static,
{
#[doc(hidden)]
pub fn from_box_clone_service_iter<T>(routes: T) -> Self
where
T: IntoIterator<
Item = (
tower::util::BoxCloneService<Request<B>, Response<BoxBody>, Infallible>,
RequestSpec,
),
>,
{
let mut routes: Vec<(Route<B>, RequestSpec)> = routes
.into_iter()
.map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
.collect();
routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));
Self { routes }
}
pub fn into_make_service(self) -> IntoMakeService<Self> {
IntoMakeService::new(self)
}
pub fn layer<L, NewReqBody, NewResBody>(self, layer: L) -> Router<NewReqBody>
where
L: Layer<Route<B>>,
L::Service:
Service<Request<NewReqBody>, Response = Response<NewResBody>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = bytes::Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed))
.layer(layer);
let routes = self
.routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router { routes }
}
}
impl<B> Service<Request<B>> for Router<B>
where
B: Send + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = RouterFuture<B>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let mut method_not_allowed = false;
for (route, request_spec) in &self.routes {
match request_spec.matches(&req) {
request_spec::Match::Yes => {
return RouterFuture::from_oneshot(route.clone().oneshot(req));
}
request_spec::Match::MethodNotAllowed => method_not_allowed = true,
request_spec::Match::No => continue,
}
}
let status_code = if method_not_allowed {
StatusCode::METHOD_NOT_ALLOWED
} else {
StatusCode::NOT_FOUND
};
RouterFuture::from_response(
Response::builder()
.status(status_code)
.body(crate::body::empty())
.unwrap(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{body::boxed, routing::request_spec::*};
use futures_util::Future;
use http::Method;
use std::pin::Pin;
pub fn req(method: &Method, uri: &str) -> Request<()> {
Request::builder().method(method).uri(uri).body(()).unwrap()
}
#[derive(Clone)]
struct NamedEchoUriService(String);
impl<B> Service<Request<B>> for NamedEchoUriService {
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
let body = boxed(Body::from(format!("{} :: {}", self.0, req.uri().to_string())));
let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) };
Box::pin(fut)
}
}
async fn get_body_as_str<B>(res: &mut Response<B>) -> String
where
B: http_body::Body + std::marker::Unpin,
B::Error: std::fmt::Debug,
{
let body_mut = res.body_mut();
let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap();
String::from(std::str::from_utf8(&body_bytes).unwrap())
}
#[tokio::test]
async fn simple_routing() {
let request_specs: Vec<(RequestSpec, &str)> = vec![
(
RequestSpec::from_parts(
Method::GET,
vec![
PathSegment::Literal(String::from("a")),
PathSegment::Label,
PathSegment::Label,
],
Vec::new(),
),
"A",
),
(
RequestSpec::from_parts(
Method::GET,
vec![
PathSegment::Literal(String::from("mg")),
PathSegment::Greedy,
PathSegment::Literal(String::from("z")),
],
Vec::new(),
),
"MiddleGreedy",
),
(
RequestSpec::from_parts(
Method::DELETE,
Vec::new(),
vec![
QuerySegment::KeyValue(String::from("foo"), String::from("bar")),
QuerySegment::Key(String::from("baz")),
],
),
"Delete",
),
(
RequestSpec::from_parts(
Method::POST,
vec![PathSegment::Literal(String::from("query_key_only"))],
vec![QuerySegment::Key(String::from("foo"))],
),
"QueryKeyOnly",
),
];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
)
}));
let hits = vec![
("A", Method::GET, "/a/b/c"),
("MiddleGreedy", Method::GET, "/mg/a/z"),
("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux"),
("Delete", Method::DELETE, "/?foo=bar&baz"),
("Delete", Method::DELETE, "/?foo=bar&baz=&"),
("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo"),
("QueryKeyOnly", Method::POST, "/query_key_only?foo="),
("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"),
];
for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri)).await.unwrap();
let actual_body = get_body_as_str(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
}
for (_, _, uri) in hits {
let res = router.call(req(&Method::PATCH, uri)).await.unwrap();
assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status());
}
let misses = vec![
(Method::GET, "/a"),
(Method::GET, "/a/b"),
(Method::GET, "/mg"),
(Method::GET, "/mg/q"),
(Method::GET, "/mg/z"),
(Method::GET, "/mg/a/b/z/c"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?foo=bar"),
(Method::DELETE, "/?baz=quux"),
(Method::POST, "/query_key_only?baz=quux"),
(Method::GET, "/"),
(Method::POST, "/"),
];
for (method, miss) in misses {
let res = router.call(req(&method, miss)).await.unwrap();
assert_eq!(StatusCode::NOT_FOUND, res.status());
}
}
#[tokio::test]
async fn basic_pattern_conflict_avoidance() {
let request_specs: Vec<(RequestSpec, &str)> = vec![
(
RequestSpec::from_parts(
Method::GET,
vec![PathSegment::Literal(String::from("a")), PathSegment::Label],
Vec::new(),
),
"A1",
),
(
RequestSpec::from_parts(
Method::GET,
vec![
PathSegment::Literal(String::from("a")),
PathSegment::Label,
PathSegment::Literal(String::from("a")),
],
Vec::new(),
),
"A2",
),
(
RequestSpec::from_parts(
Method::GET,
vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy],
Vec::new(),
),
"B1",
),
(
RequestSpec::from_parts(
Method::GET,
vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy],
vec![QuerySegment::Key(String::from("q"))],
),
"B2",
),
];
let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| {
(
tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))),
spec,
)
}));
let hits = vec![
("A1", Method::GET, "/a/foo"),
("A2", Method::GET, "/a/foo/a"),
("B1", Method::GET, "/b/foo/bar/baz"),
("B2", Method::GET, "/b/foo?q=baz"),
];
for (svc_name, method, uri) in &hits {
let mut res = router.call(req(method, uri)).await.unwrap();
let actual_body = get_body_as_str(&mut res).await;
assert_eq!(format!("{} :: {}", svc_name, uri), actual_body);
}
}
}