use crate::endpoint::WitchcraftEndpoint;
use crate::service::{Layer, Service};
use conjure_http::server::PathSegment;
use conjure_http::PathParams;
use http::{Method, Request};
use itertools::Itertools;
use regex::{Regex, RegexSet};
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;
const DEFAULT_REGEX: &str = "[^/]+?";
pub struct Endpoint {
endpoint: Arc<dyn WitchcraftEndpoint + Sync + Send>,
regex: Regex,
literal_chars: usize,
path_params: Vec<Cow<'static, str>>,
custom_path_params: usize,
}
impl Endpoint {
pub fn new(endpoint: Box<dyn WitchcraftEndpoint + Sync + Send>) -> Self {
let mut regex = "^".to_string();
for segment in endpoint.path() {
regex.push('/');
match segment {
PathSegment::Literal(s) => regex.push_str(®ex::escape(s)),
PathSegment::Parameter {
name,
regex: segment_regex,
} => {
let segment_regex = segment_regex.as_deref().unwrap_or(DEFAULT_REGEX);
write!(regex, "(?P<{}>{})", name, segment_regex).unwrap();
}
}
}
regex.push('$');
let mut literal_chars = 0;
let mut path_params = vec![];
let mut custom_path_params = 0;
for segment in endpoint.path() {
match segment {
PathSegment::Literal(s) => literal_chars += s.len() + 1,
PathSegment::Parameter { name, regex } => {
path_params.push(name.clone());
if regex.as_deref().unwrap_or(DEFAULT_REGEX) != DEFAULT_REGEX {
custom_path_params += 1;
}
}
}
}
Endpoint {
endpoint: Arc::from(endpoint),
regex: Regex::new(®ex).unwrap(),
literal_chars,
path_params,
custom_path_params,
}
}
fn cmp_priority(&self, other: &Self) -> Ordering {
self.literal_chars
.cmp(&other.literal_chars)
.then_with(|| self.path_params.len().cmp(&other.path_params.len()))
.then_with(|| self.custom_path_params.cmp(&other.custom_path_params))
.reverse()
}
}
pub enum Route {
Resolved(Arc<dyn WitchcraftEndpoint + Sync + Send>),
MethodNotAllowed(Vec<Method>),
StarOptions,
Options(Vec<Method>),
Unresolved,
}
pub struct RoutingLayer {
endpoints: HashMap<Method, Routes>,
}
impl RoutingLayer {
pub fn new(endpoints: Vec<Box<dyn WitchcraftEndpoint + Sync + Send>>) -> Self {
let endpoints_by_method = endpoints
.into_iter()
.map(Endpoint::new)
.into_group_map_by(|e| e.endpoint.method());
RoutingLayer {
endpoints: endpoints_by_method
.into_iter()
.map(|(method, endpoints)| (method, Routes::new(endpoints)))
.collect(),
}
}
}
impl<S> Layer<S> for RoutingLayer {
type Service = RoutingService<S>;
fn layer(self, inner: S) -> Self::Service {
RoutingService {
inner,
endpoints: self.endpoints,
}
}
}
struct Routes {
set: RegexSet,
endpoints: Vec<Endpoint>,
}
impl Routes {
fn new(mut endpoints: Vec<Endpoint>) -> Self {
endpoints.sort_by(Endpoint::cmp_priority);
Routes {
set: RegexSet::new(endpoints.iter().map(|e| e.regex.as_str())).unwrap(),
endpoints,
}
}
fn is_match(&self, path: &str) -> bool {
self.set.is_match(path)
}
fn route(&self, path: &str) -> Option<&Endpoint> {
self.set
.matches(path)
.iter()
.next()
.map(|idx| &self.endpoints[idx])
}
}
pub struct RoutingService<S> {
inner: S,
endpoints: HashMap<Method, Routes>,
}
impl<S> RoutingService<S> {
fn supported_methods(&self, path: &str) -> Vec<Method> {
self.endpoints
.iter()
.filter(|(_, routes)| routes.is_match(path))
.map(|(method, _)| method)
.sorted_by_key(|m| m.as_str())
.cloned()
.collect()
}
}
impl<S, B> Service<Request<B>> for RoutingService<S>
where
S: Service<Request<B>> + Sync,
B: Send,
{
type Response = S::Response;
async fn call(&self, mut req: Request<B>) -> Self::Response {
let (route, endpoint) = if req.method() == Method::OPTIONS && req.uri() == "*" {
(Route::StarOptions, None)
} else {
match self
.endpoints
.get(req.method())
.and_then(|r| r.route(req.uri().path()))
{
Some(endpoint) => (Route::Resolved(endpoint.endpoint.clone()), Some(endpoint)),
None if req.method() == Method::OPTIONS => (
Route::Options(self.supported_methods(req.uri().path())),
None,
),
None => {
let methods = self.supported_methods(req.uri().path());
if methods.is_empty() {
(Route::Unresolved, None)
} else {
(Route::MethodNotAllowed(methods), None)
}
}
}
};
if let Some(endpoint) = endpoint {
if !endpoint.path_params.is_empty() {
let captures = endpoint.regex.captures(req.uri().path()).unwrap();
let mut path_params = PathParams::new();
for name in &endpoint.path_params {
path_params.insert(&**name, captures.name(name).unwrap().as_str());
}
req.extensions_mut().insert(path_params);
}
}
req.extensions_mut().insert(route);
self.inner.call(req).await
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::health::endpoint_500s::EndpointHealth;
use crate::server::RawBody;
use crate::service::endpoint_metrics::EndpointMetrics;
use crate::service::handler::BodyWriteAborted;
use crate::service::test_util::service_fn;
use async_trait::async_trait;
use bytes::Bytes;
use conjure_http::server::EndpointMetadata;
use http::Response;
use http_body::combinators::BoxBody;
struct TestEndpoint {
method: Method,
path: Vec<PathSegment>,
name: &'static str,
}
impl EndpointMetadata for TestEndpoint {
fn method(&self) -> Method {
self.method.clone()
}
fn path(&self) -> &[PathSegment] {
&self.path
}
fn template(&self) -> &str {
""
}
fn service_name(&self) -> &str {
""
}
fn name(&self) -> &str {
self.name
}
fn deprecated(&self) -> Option<&str> {
None
}
}
#[async_trait]
impl WitchcraftEndpoint for TestEndpoint {
fn metrics(&self) -> Option<&EndpointMetrics> {
None
}
fn health(&self) -> Option<&Arc<EndpointHealth>> {
None
}
async fn handle(&self, _: Request<RawBody>) -> Response<BoxBody<Bytes, BodyWriteAborted>> {
unimplemented!()
}
}
fn endpoint(
method: Method,
path: Vec<PathSegment>,
name: &'static str,
) -> Box<dyn WitchcraftEndpoint + Sync + Send> {
Box::new(TestEndpoint { method, path, name })
}
#[tokio::test]
async fn empty() {
let service = RoutingLayer::new(vec![]).layer(service_fn(|req| async { req }));
let req = service
.call(
Request::builder()
.method(Method::OPTIONS)
.uri("*")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::StarOptions) => {}
_ => panic!("bad route"),
}
let req = service
.call(
Request::builder()
.method(Method::OPTIONS)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Options(methods)) => assert!(methods.is_empty()),
_ => panic!("bad route"),
}
let req = service
.call(
Request::builder()
.method(Method::GET)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Unresolved) => {}
_ => panic!("bad route"),
}
}
#[tokio::test]
async fn nonempty() {
let service = RoutingLayer::new(vec![
endpoint(
Method::GET,
vec![
PathSegment::Literal(Cow::Borrowed("foo")),
PathSegment::Literal(Cow::Borrowed("bar")),
],
"a",
),
endpoint(
Method::POST,
vec![
PathSegment::Literal(Cow::Borrowed("foo")),
PathSegment::Parameter {
name: Cow::Borrowed("arg"),
regex: None,
},
],
"b",
),
])
.layer(service_fn(|req| async { req }));
let req = service
.call(
Request::builder()
.method(Method::OPTIONS)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Options(methods)) => assert_eq!(*methods, [Method::GET, Method::POST]),
_ => panic!("bad route"),
}
let req = service
.call(
Request::builder()
.method(Method::GET)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Resolved(endpoint)) => assert_eq!(endpoint.name(), "a"),
_ => panic!("bad route"),
}
let req = service
.call(
Request::builder()
.method(Method::POST)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Resolved(endpoint)) => assert_eq!(endpoint.name(), "b"),
_ => panic!("bad route"),
}
assert_eq!(&req.extensions().get::<PathParams>().unwrap()["arg"], "bar");
let req = service
.call(
Request::builder()
.method(Method::PUT)
.uri("/foo/bar?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::MethodNotAllowed(methods)) => {
assert_eq!(*methods, [Method::GET, Method::POST])
}
_ => panic!("bad route"),
}
let req = service
.call(
Request::builder()
.method(Method::POST)
.uri("/foo/bar/baz?a=b")
.body(())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Unresolved) => {}
_ => panic!("bad route"),
}
}
#[tokio::test]
async fn custom_regex() {
let service = RoutingLayer::new(vec![endpoint(
Method::GET,
vec![
PathSegment::Literal(Cow::Borrowed("foo")),
PathSegment::Parameter {
name: Cow::Borrowed("arg"),
regex: Some(Cow::Borrowed(".*")),
},
],
"a",
)])
.layer(service_fn(|req: Request<hyper::Body>| async { req }));
let req = service
.call(
Request::builder()
.method(Method::GET)
.uri("/foo/bar/baz?a=b")
.body(hyper::Body::empty())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Resolved(endpoint)) => assert_eq!(endpoint.name(), "a"),
_ => panic!("bad route"),
}
assert_eq!(
&req.extensions().get::<PathParams>().unwrap()["arg"],
"bar/baz"
);
}
#[tokio::test]
async fn ambiguity() {
let service = RoutingLayer::new(vec![
endpoint(
Method::GET,
vec![
PathSegment::Literal(Cow::Borrowed("foo")),
PathSegment::Literal(Cow::Borrowed("bar")),
],
"a",
),
endpoint(
Method::GET,
vec![
PathSegment::Literal(Cow::Borrowed("foo")),
PathSegment::Parameter {
name: Cow::Borrowed("arg"),
regex: None,
},
],
"b",
),
])
.layer(service_fn(|req: Request<hyper::Body>| async { req }));
let req = service
.call(
Request::builder()
.method(Method::GET)
.uri("/foo/bar?a=b")
.body(hyper::Body::empty())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Resolved(endpoint)) => assert_eq!(endpoint.name(), "a"),
_ => panic!("bad route"),
}
}
#[tokio::test]
async fn absolute_form() {
let service = RoutingLayer::new(vec![endpoint(
Method::GET,
vec![PathSegment::Literal(Cow::Borrowed("foo"))],
"a",
)])
.layer(service_fn(|req: Request<hyper::Body>| async { req }));
let req = service
.call(
Request::builder()
.method(Method::GET)
.uri("https://foobar.com/foo?a=b")
.body(hyper::Body::empty())
.unwrap(),
)
.await;
match req.extensions().get() {
Some(Route::Resolved(endpoint)) => assert_eq!(endpoint.name(), "a"),
_ => panic!("bad route"),
}
}
}