use std::collections::HashSet;
use hyper::{Method, StatusCode};
#[derive(Clone)]
pub struct RouteNonMatch {
status: StatusCode,
allow: MethodSet,
}
impl RouteNonMatch {
pub fn new(status: StatusCode) -> RouteNonMatch {
RouteNonMatch {
status,
allow: MethodSet::default(),
}
}
pub fn with_allow_list(self, allow: &[Method]) -> RouteNonMatch {
RouteNonMatch {
allow: allow.into(),
..self
}
}
pub fn intersection(self, other: RouteNonMatch) -> RouteNonMatch {
let status = higher_precedence_status(self.status, other.status);
let allow = self.allow.intersection(other.allow);
RouteNonMatch { status, allow }
}
pub fn union(self, other: RouteNonMatch) -> RouteNonMatch {
let status = higher_precedence_status(self.status, other.status);
let allow = self.allow.union(other.allow);
RouteNonMatch { status, allow }
}
pub(super) fn deconstruct(self) -> (StatusCode, Vec<Method>) {
(self.status, self.allow.into())
}
}
impl From<RouteNonMatch> for StatusCode {
fn from(val: RouteNonMatch) -> StatusCode {
val.status
}
}
fn higher_precedence_status(lhs: StatusCode, rhs: StatusCode) -> StatusCode {
match (lhs, rhs) {
(StatusCode::NOT_FOUND, _) => rhs,
(_, StatusCode::NOT_FOUND) => lhs,
(StatusCode::METHOD_NOT_ALLOWED, _) => rhs,
(_, StatusCode::METHOD_NOT_ALLOWED) => lhs,
(StatusCode::NOT_ACCEPTABLE, _) => rhs,
(_, StatusCode::NOT_ACCEPTABLE) => lhs,
(_, _) if lhs.is_client_error() => lhs,
(_, _) if rhs.is_client_error() => rhs,
(_, _) => lhs,
}
}
#[derive(Clone)]
struct MethodSet {
connect: bool,
delete: bool,
get: bool,
head: bool,
options: bool,
patch: bool,
post: bool,
put: bool,
trace: bool,
other: HashSet<Method>,
}
impl MethodSet {
fn intersection(self, other: MethodSet) -> MethodSet {
MethodSet {
connect: self.connect && other.connect,
delete: self.delete && other.delete,
get: self.get && other.get,
head: self.head && other.head,
options: self.options && other.options,
patch: self.patch && other.patch,
post: self.post && other.post,
put: self.put && other.put,
trace: self.trace && other.trace,
other: self.other.intersection(&other.other).cloned().collect(),
}
}
fn union(self, other: MethodSet) -> MethodSet {
MethodSet {
connect: self.connect || other.connect,
delete: self.delete || other.delete,
get: self.get || other.get,
head: self.head || other.head,
options: self.options || other.options,
patch: self.patch || other.patch,
post: self.post || other.post,
put: self.put || other.put,
trace: self.trace || other.trace,
other: self.other.union(&other.other).cloned().collect(),
}
}
}
impl Default for MethodSet {
fn default() -> MethodSet {
MethodSet {
connect: false,
delete: true,
get: true,
head: true,
options: true,
patch: true,
post: true,
put: true,
trace: false,
other: HashSet::default(),
}
}
}
impl<'a> From<&'a [Method]> for MethodSet {
fn from(methods: &[Method]) -> MethodSet {
let (
mut connect,
mut delete,
mut get,
mut head,
mut options,
mut patch,
mut post,
mut put,
mut trace,
) = (
false, false, false, false, false, false, false, false, false,
);
let mut other = HashSet::new();
for method in methods {
match *method {
Method::CONNECT => {
connect = true;
}
Method::DELETE => {
delete = true;
}
Method::GET => {
get = true;
}
Method::HEAD => {
head = true;
}
Method::OPTIONS => {
options = true;
}
Method::PATCH => {
patch = true;
}
Method::POST => {
post = true;
}
Method::PUT => {
put = true;
}
Method::TRACE => {
trace = true;
}
_ => {
other.insert(method.clone());
}
}
}
MethodSet {
connect,
delete,
get,
head,
options,
patch,
post,
put,
trace,
other,
}
}
}
impl From<MethodSet> for Vec<Method> {
fn from(method_set: MethodSet) -> Vec<Method> {
let methods_with_flags: [(Method, bool); 9] = [
(Method::CONNECT, method_set.connect),
(Method::DELETE, method_set.delete),
(Method::GET, method_set.get),
(Method::HEAD, method_set.head),
(Method::OPTIONS, method_set.options),
(Method::PATCH, method_set.patch),
(Method::POST, method_set.post),
(Method::PUT, method_set.put),
(Method::TRACE, method_set.trace),
];
let mut result = methods_with_flags
.iter()
.filter_map(|&(ref method, flag)| if flag { Some(method.clone()) } else { None })
.chain(method_set.other.into_iter())
.collect::<Vec<Method>>();
result.sort_unstable_by(|a, b| a.as_ref().cmp(b.as_ref()));
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::{Method, StatusCode};
#[test]
fn intersection_tests() {
let all = [
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
];
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
.intersection(RouteNonMatch::new(StatusCode::NOT_FOUND))
.deconstruct();
assert_eq!(status, StatusCode::NOT_FOUND);
assert_eq!(&allow_list[..], &all);
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
.intersection(
RouteNonMatch::new(StatusCode::METHOD_NOT_ALLOWED).with_allow_list(&[Method::GET]),
)
.deconstruct();
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(&allow_list[..], &[Method::GET]);
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_ACCEPTABLE)
.with_allow_list(&[Method::GET, Method::PATCH, Method::POST])
.intersection(
RouteNonMatch::new(StatusCode::METHOD_NOT_ALLOWED).with_allow_list(&[
Method::GET,
Method::POST,
Method::OPTIONS,
]),
)
.deconstruct();
assert_eq!(status, StatusCode::NOT_ACCEPTABLE);
assert_eq!(&allow_list[..], &[Method::GET, Method::POST]);
}
#[test]
fn union_tests() {
let all = [
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
];
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
.union(RouteNonMatch::new(StatusCode::NOT_FOUND))
.deconstruct();
assert_eq!(status, StatusCode::NOT_FOUND);
assert_eq!(&allow_list[..], &all);
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
.union(
RouteNonMatch::new(StatusCode::METHOD_NOT_ALLOWED).with_allow_list(&[Method::GET]),
)
.deconstruct();
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(&allow_list[..], &all);
let (status, allow_list) = RouteNonMatch::new(StatusCode::NOT_ACCEPTABLE)
.with_allow_list(&[Method::GET, Method::PATCH, Method::POST])
.union(
RouteNonMatch::new(StatusCode::METHOD_NOT_ALLOWED).with_allow_list(&[
Method::GET,
Method::POST,
Method::OPTIONS,
]),
)
.deconstruct();
assert_eq!(status, StatusCode::NOT_ACCEPTABLE);
assert_eq!(
&allow_list[..],
&[Method::GET, Method::OPTIONS, Method::PATCH, Method::POST]
);
}
#[test]
fn deconstruct_tests() {
let (_, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
.with_allow_list(&[
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
Method::CONNECT,
Method::TRACE,
Method::from_bytes(b"PROPFIND").unwrap(),
Method::from_bytes(b"PROPSET").unwrap(),
])
.deconstruct();
assert_eq!(
&allow_list[..],
&[
Method::CONNECT,
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::from_bytes(b"PROPFIND").unwrap(),
Method::from_bytes(b"PROPSET").unwrap(),
Method::PUT,
Method::TRACE,
]
);
}
}