#![allow(non_snake_case)]
use std::convert::TryFrom;
use actix_http::http::{self, header, uri::Uri};
use actix_http::RequestHead;
pub trait Guard {
fn check(&self, request: &RequestHead) -> bool;
}
pub fn fn_guard<F>(f: F) -> impl Guard
where
F: Fn(&RequestHead) -> bool,
{
FnGuard(f)
}
struct FnGuard<F: Fn(&RequestHead) -> bool>(F);
impl<F> Guard for FnGuard<F>
where
F: Fn(&RequestHead) -> bool,
{
fn check(&self, head: &RequestHead) -> bool {
(self.0)(head)
}
}
impl<F> Guard for F
where
F: Fn(&RequestHead) -> bool,
{
fn check(&self, head: &RequestHead) -> bool {
(self)(head)
}
}
pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
AnyGuard(vec![Box::new(guard)])
}
pub struct AnyGuard(Vec<Box<dyn Guard>>);
impl AnyGuard {
pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard));
self
}
}
impl Guard for AnyGuard {
fn check(&self, req: &RequestHead) -> bool {
for p in &self.0 {
if p.check(req) {
return true;
}
}
false
}
}
pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
AllGuard(vec![Box::new(guard)])
}
pub struct AllGuard(Vec<Box<dyn Guard>>);
impl AllGuard {
pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard));
self
}
}
impl Guard for AllGuard {
fn check(&self, request: &RequestHead) -> bool {
for p in &self.0 {
if !p.check(request) {
return false;
}
}
true
}
}
pub fn Not<F: Guard + 'static>(guard: F) -> NotGuard {
NotGuard(Box::new(guard))
}
#[doc(hidden)]
pub struct NotGuard(Box<dyn Guard>);
impl Guard for NotGuard {
fn check(&self, request: &RequestHead) -> bool {
!self.0.check(request)
}
}
#[doc(hidden)]
pub struct MethodGuard(http::Method);
impl Guard for MethodGuard {
fn check(&self, request: &RequestHead) -> bool {
request.method == self.0
}
}
pub fn Get() -> MethodGuard {
MethodGuard(http::Method::GET)
}
pub fn Post() -> MethodGuard {
MethodGuard(http::Method::POST)
}
pub fn Put() -> MethodGuard {
MethodGuard(http::Method::PUT)
}
pub fn Delete() -> MethodGuard {
MethodGuard(http::Method::DELETE)
}
pub fn Head() -> MethodGuard {
MethodGuard(http::Method::HEAD)
}
pub fn Options() -> MethodGuard {
MethodGuard(http::Method::OPTIONS)
}
pub fn Connect() -> MethodGuard {
MethodGuard(http::Method::CONNECT)
}
pub fn Patch() -> MethodGuard {
MethodGuard(http::Method::PATCH)
}
pub fn Trace() -> MethodGuard {
MethodGuard(http::Method::TRACE)
}
pub fn Method(method: http::Method) -> MethodGuard {
MethodGuard(method)
}
pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard {
HeaderGuard(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
)
}
#[doc(hidden)]
pub struct HeaderGuard(header::HeaderName, header::HeaderValue);
impl Guard for HeaderGuard {
fn check(&self, req: &RequestHead) -> bool {
if let Some(val) = req.headers.get(&self.0) {
return val == self.1;
}
false
}
}
pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
HostGuard(host.as_ref().to_string(), None)
}
fn get_host_uri(req: &RequestHead) -> Option<Uri> {
use core::str::FromStr;
req.headers
.get(header::HOST)
.and_then(|host_value| host_value.to_str().ok())
.or_else(|| req.uri.host())
.map(|host: &str| Uri::from_str(host).ok())
.and_then(|host_success| host_success)
}
#[doc(hidden)]
pub struct HostGuard(String, Option<String>);
impl HostGuard {
pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
self.1 = Some(scheme.as_ref().to_string());
self
}
}
impl Guard for HostGuard {
fn check(&self, req: &RequestHead) -> bool {
let req_host_uri = if let Some(uri) = get_host_uri(req) {
uri
} else {
return false;
};
if let Some(uri_host) = req_host_uri.host() {
if self.0 != uri_host {
return false;
}
} else {
return false;
}
if let Some(ref scheme) = self.1 {
if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
return scheme == req_host_uri_scheme;
}
}
true
}
}
#[cfg(test)]
mod tests {
use actix_http::http::{header, Method};
use super::*;
use crate::test::TestRequest;
#[test]
fn test_header() {
let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked")
.to_http_request();
let pred = Header("transfer-encoding", "chunked");
assert!(pred.check(req.head()));
let pred = Header("transfer-encoding", "other");
assert!(!pred.check(req.head()));
let pred = Header("content-type", "other");
assert!(!pred.check(req.head()));
}
#[test]
fn test_host() {
let req = TestRequest::default()
.header(
header::HOST,
header::HeaderValue::from_static("www.rust-lang.org"),
)
.to_http_request();
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let pred = Host("crates.io");
assert!(!pred.check(req.head()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
}
#[test]
fn test_host_scheme() {
let req = TestRequest::default()
.header(
header::HOST,
header::HeaderValue::from_static("https://www.rust-lang.org"),
)
.to_http_request();
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let pred = Host("www.rust-lang.org").scheme("http");
assert!(!pred.check(req.head()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let pred = Host("crates.io").scheme("https");
assert!(!pred.check(req.head()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
}
#[test]
fn test_host_without_header() {
let req = TestRequest::default()
.uri("www.rust-lang.org")
.to_http_request();
let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head()));
let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head()));
let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head()));
let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head()));
let pred = Host("crates.io");
assert!(!pred.check(req.head()));
let pred = Host("localhost");
assert!(!pred.check(req.head()));
}
#[test]
fn test_methods() {
let req = TestRequest::default().to_http_request();
let req2 = TestRequest::default()
.method(Method::POST)
.to_http_request();
assert!(Get().check(req.head()));
assert!(!Get().check(req2.head()));
assert!(Post().check(req2.head()));
assert!(!Post().check(req.head()));
let r = TestRequest::default().method(Method::PUT).to_http_request();
assert!(Put().check(r.head()));
assert!(!Put().check(req.head()));
let r = TestRequest::default()
.method(Method::DELETE)
.to_http_request();
assert!(Delete().check(r.head()));
assert!(!Delete().check(req.head()));
let r = TestRequest::default()
.method(Method::HEAD)
.to_http_request();
assert!(Head().check(r.head()));
assert!(!Head().check(req.head()));
let r = TestRequest::default()
.method(Method::OPTIONS)
.to_http_request();
assert!(Options().check(r.head()));
assert!(!Options().check(req.head()));
let r = TestRequest::default()
.method(Method::CONNECT)
.to_http_request();
assert!(Connect().check(r.head()));
assert!(!Connect().check(req.head()));
let r = TestRequest::default()
.method(Method::PATCH)
.to_http_request();
assert!(Patch().check(r.head()));
assert!(!Patch().check(req.head()));
let r = TestRequest::default()
.method(Method::TRACE)
.to_http_request();
assert!(Trace().check(r.head()));
assert!(!Trace().check(req.head()));
}
#[test]
fn test_preds() {
let r = TestRequest::default()
.method(Method::TRACE)
.to_http_request();
assert!(Not(Get()).check(r.head()));
assert!(!Not(Trace()).check(r.head()));
assert!(All(Trace()).and(Trace()).check(r.head()));
assert!(!All(Get()).and(Trace()).check(r.head()));
assert!(Any(Get()).or(Trace()).check(r.head()));
assert!(!Any(Get()).or(Get()).check(r.head()));
}
}