use std::collections::BTreeSet;
use http::Method;
use crate::blueprint::router::method_guard::inner::method_to_bitset;
use crate::router::{AllowedMethods, MethodAllowList};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MethodGuard {
inner: inner::MethodGuard<'static>,
}
impl MethodGuard {
pub fn from_iter(allowed_methods: impl IntoIterator<Item = Method>) -> Self {
let mut bitset = 0;
let mut extensions = BTreeSet::new();
for method in allowed_methods {
let method = inner::Method::from(method);
if let Some(bit) = method_to_bitset(&method) {
bitset |= bit;
} else {
extensions.insert(method);
}
}
MethodGuard {
inner: inner::MethodGuard::Some(inner::SomeMethodGuard { bitset, extensions }),
}
}
pub fn or(self, other: MethodGuard) -> Self {
MethodGuard {
inner: self.inner.or(other.inner),
}
}
pub fn allows(&self, method: &Method) -> bool {
self.allows_(&inner::Method::from(method))
}
fn allows_(&self, method: &inner::Method) -> bool {
match &self.inner {
inner::MethodGuard::Any => true,
inner::MethodGuard::Some(inner::SomeMethodGuard { bitset, extensions }) => {
if let Some(bit) = method_to_bitset(method) {
*bitset & bit != 0
} else {
extensions.contains(method)
}
}
}
}
pub fn allowed_methods(&self) -> AllowedMethods {
match &self.inner {
inner::MethodGuard::Any => AllowedMethods::All,
inner::MethodGuard::Some(inner::SomeMethodGuard {
bitset: _,
extensions,
}) => {
let methods = extensions
.iter()
.cloned()
.chain(
[
inner::Method::GET,
inner::Method::POST,
inner::Method::PATCH,
inner::Method::OPTIONS,
inner::Method::PUT,
inner::Method::DELETE,
inner::Method::TRACE,
inner::Method::HEAD,
inner::Method::CONNECT,
]
.into_iter()
.filter(|method| self.allows_(method)),
)
.map(Method::from);
AllowedMethods::Some(MethodAllowList::from_iter(methods))
}
}
}
}
impl From<Method> for MethodGuard {
fn from(val: Method) -> Self {
let method = inner::Method::from(val);
let inner = if let Some(bit) = method_to_bitset(&method) {
inner::MethodGuard::Some(inner::SomeMethodGuard {
bitset: bit,
extensions: BTreeSet::new(),
})
} else {
let mut extensions = BTreeSet::new();
extensions.insert(method);
inner::MethodGuard::Some(inner::SomeMethodGuard {
bitset: 0,
extensions,
})
};
MethodGuard { inner }
}
}
pub const ANY: MethodGuard = MethodGuard { inner: inner::ANY };
pub const ANY_WITH_EXTENSIONS: MethodGuard = MethodGuard {
inner: inner::ANY_WITH_EXTENSIONS,
};
pub const GET: MethodGuard = MethodGuard { inner: inner::GET };
pub const POST: MethodGuard = MethodGuard { inner: inner::POST };
pub const PATCH: MethodGuard = MethodGuard {
inner: inner::PATCH,
};
pub const OPTIONS: MethodGuard = MethodGuard {
inner: inner::OPTIONS,
};
pub const PUT: MethodGuard = MethodGuard { inner: inner::PUT };
pub const DELETE: MethodGuard = MethodGuard {
inner: inner::DELETE,
};
pub const TRACE: MethodGuard = MethodGuard {
inner: inner::TRACE,
};
pub const HEAD: MethodGuard = MethodGuard { inner: inner::HEAD };
pub const CONNECT: MethodGuard = MethodGuard {
inner: inner::CONNECT,
};
mod inner {
#![allow(clippy::upper_case_acronyms)]
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::str::FromStr;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub(super) enum MethodGuard<'a> {
Any,
Some(SomeMethodGuard<'a>),
}
#[derive(
Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, PartialOrd, Ord,
)]
pub(super) enum Method<'a> {
GET,
POST,
PATCH,
OPTIONS,
PUT,
DELETE,
TRACE,
HEAD,
CONNECT,
Custom(Cow<'a, str>),
}
impl<'a> Method<'a> {
pub(super) fn into_owned(self) -> Method<'static> {
match self {
Method::GET => Method::GET,
Method::POST => Method::POST,
Method::PATCH => Method::PATCH,
Method::OPTIONS => Method::OPTIONS,
Method::PUT => Method::PUT,
Method::DELETE => Method::DELETE,
Method::TRACE => Method::TRACE,
Method::HEAD => Method::HEAD,
Method::CONNECT => Method::CONNECT,
Method::Custom(c) => Method::Custom(Cow::Owned(c.into_owned())),
}
}
}
impl<'a> From<&'a http::Method> for Method<'a> {
fn from(method: &'a http::Method) -> Method<'a> {
match method {
&http::Method::GET => Method::GET,
&http::Method::POST => Method::POST,
&http::Method::PATCH => Method::PATCH,
&http::Method::OPTIONS => Method::OPTIONS,
&http::Method::PUT => Method::PUT,
&http::Method::DELETE => Method::DELETE,
&http::Method::TRACE => Method::TRACE,
&http::Method::HEAD => Method::HEAD,
&http::Method::CONNECT => Method::CONNECT,
m => Method::Custom(Cow::Borrowed(m.as_str())),
}
}
}
impl From<http::Method> for Method<'static> {
fn from(value: http::Method) -> Self {
<&http::Method as Into<Method<'_>>>::into(&value).into_owned()
}
}
impl<'a> From<Method<'a>> for http::Method {
fn from(value: Method) -> Self {
match value {
Method::GET => http::Method::GET,
Method::POST => http::Method::POST,
Method::PATCH => http::Method::PATCH,
Method::OPTIONS => http::Method::OPTIONS,
Method::PUT => http::Method::PUT,
Method::DELETE => http::Method::DELETE,
Method::TRACE => http::Method::TRACE,
Method::HEAD => http::Method::HEAD,
Method::CONNECT => http::Method::CONNECT,
Method::Custom(c) => http::Method::from_str(c.as_ref()).unwrap(),
}
}
}
impl<'a> MethodGuard<'a> {
pub(super) fn or(self, other: MethodGuard<'a>) -> Self {
match (self, other) {
(MethodGuard::Any, _) | (_, MethodGuard::Any) => MethodGuard::Any,
(MethodGuard::Some(this), MethodGuard::Some(other)) => {
MethodGuard::Some(this.or(other))
}
}
}
const fn from_bits(bitset: u16) -> Self {
MethodGuard::Some(SomeMethodGuard {
bitset,
extensions: BTreeSet::new(),
})
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub(super) struct SomeMethodGuard<'a> {
pub(super) bitset: u16,
pub(super) extensions: BTreeSet<Method<'a>>,
}
impl<'a> SomeMethodGuard<'a> {
pub(super) fn or(mut self, other: SomeMethodGuard<'a>) -> Self {
self.bitset |= other.bitset;
self.extensions.extend(other.extensions);
self
}
}
pub(super) const fn method_to_bitset(method: &Method) -> Option<u16> {
match method {
&Method::GET
| &Method::POST
| &Method::PATCH
| &Method::OPTIONS
| &Method::PUT
| &Method::DELETE
| &Method::TRACE
| &Method::HEAD
| &Method::CONNECT => Some(_method_to_bitset(method)),
_ => None,
}
}
const fn _method_to_bitset(method: &Method) -> u16 {
match *method {
Method::GET => 0b0000_0001_0000_0000,
Method::POST => 0b0000_0000_1000_0000,
Method::PATCH => 0b0000_0000_0100_0000,
Method::OPTIONS => 0b0000_0000_0010_0000,
Method::PUT => 0b0000_0000_0001_0000,
Method::DELETE => 0b0000_0000_0000_1000,
Method::TRACE => 0b0000_0000_0000_0100,
Method::HEAD => 0b0000_0000_0000_0010,
Method::CONNECT => 0b0000_0000_0000_0001,
Method::Custom(_) => panic!(),
}
}
pub(super) const GET: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::GET));
pub(super) const POST: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::POST));
pub(super) const PATCH: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::PATCH));
pub(super) const OPTIONS: MethodGuard =
MethodGuard::from_bits(_method_to_bitset(&Method::OPTIONS));
pub(super) const PUT: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::PUT));
pub(super) const DELETE: MethodGuard =
MethodGuard::from_bits(_method_to_bitset(&Method::DELETE));
pub(super) const TRACE: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::TRACE));
pub(super) const HEAD: MethodGuard = MethodGuard::from_bits(_method_to_bitset(&Method::HEAD));
pub(super) const CONNECT: MethodGuard =
MethodGuard::from_bits(_method_to_bitset(&Method::CONNECT));
pub(super) const ANY: MethodGuard = MethodGuard::from_bits(0b0000_0001_1111_1111);
pub(super) const ANY_WITH_EXTENSIONS: MethodGuard = MethodGuard::Any;
}