use std::{
cell::{Ref, RefMut},
convert::TryFrom,
rc::Rc,
};
use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead};
use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
#[derive(Debug)]
pub struct GuardContext<'a> {
pub(crate) req: &'a ServiceRequest,
}
impl<'a> GuardContext<'a> {
#[inline]
pub fn head(&self) -> &RequestHead {
self.req.head()
}
#[inline]
pub fn req_data(&self) -> Ref<'a, Extensions> {
self.req.extensions()
}
#[inline]
pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
self.req.extensions_mut()
}
#[inline]
pub fn header<H: Header>(&self) -> Option<H> {
H::parse(self.req).ok()
}
}
pub trait Guard {
fn check(&self, ctx: &GuardContext<'_>) -> bool;
}
impl Guard for Rc<dyn Guard> {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(**self).check(ctx)
}
}
pub fn fn_guard<F>(f: F) -> impl Guard
where
F: Fn(&GuardContext<'_>) -> bool,
{
FnGuard(f)
}
struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
impl<F> Guard for FnGuard<F>
where
F: Fn(&GuardContext<'_>) -> bool,
{
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self.0)(ctx)
}
}
impl<F> Guard for F
where
F: Fn(&GuardContext<'_>) -> bool,
{
fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self)(ctx)
}
}
#[allow(non_snake_case)]
pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
AnyGuard {
guards: vec![Box::new(guard)],
}
}
pub struct AnyGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AnyGuard {
pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
self.guards.push(Box::new(guard));
self
}
}
impl Guard for AnyGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
for guard in &self.guards {
if guard.check(ctx) {
return true;
}
}
false
}
}
#[allow(non_snake_case)]
pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
AllGuard {
guards: vec![Box::new(guard)],
}
}
pub struct AllGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AllGuard {
pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
self.guards.push(Box::new(guard));
self
}
}
impl Guard for AllGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
for guard in &self.guards {
if !guard.check(ctx) {
return false;
}
}
true
}
}
pub struct Not<G>(pub G);
impl<G: Guard> Guard for Not<G> {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
!self.0.check(ctx)
}
}
#[allow(non_snake_case)]
pub fn Method(method: HttpMethod) -> impl Guard {
MethodGuard(method)
}
struct MethodGuard(HttpMethod);
impl Guard for MethodGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
ctx.head().method == self.0
}
}
macro_rules! method_guard {
($method_fn:ident, $method_const:ident) => {
#[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
#[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
#[doc = concat!(" .guard(guard::", stringify!($method_fn), "())")]
#[allow(non_snake_case)]
pub fn $method_fn() -> impl Guard {
MethodGuard(HttpMethod::$method_const)
}
};
}
method_guard!(Get, GET);
method_guard!(Post, POST);
method_guard!(Put, PUT);
method_guard!(Delete, DELETE);
method_guard!(Head, HEAD);
method_guard!(Options, OPTIONS);
method_guard!(Connect, CONNECT);
method_guard!(Patch, PATCH);
method_guard!(Trace, TRACE);
#[allow(non_snake_case)]
pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
HeaderGuard(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
)
}
struct HeaderGuard(header::HeaderName, header::HeaderValue);
impl Guard for HeaderGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
if let Some(val) = ctx.head().headers.get(&self.0) {
return val == self.1;
}
false
}
}
#[allow(non_snake_case)]
pub fn Host(host: impl AsRef<str>) -> HostGuard {
HostGuard {
host: host.as_ref().to_string(),
scheme: None,
}
}
fn get_host_uri(req: &RequestHead) -> Option<Uri> {
req.headers
.get(header::HOST)
.and_then(|host_value| host_value.to_str().ok())
.or_else(|| req.uri.host())
.and_then(|host| host.parse().ok())
}
#[doc(hidden)]
pub struct HostGuard {
host: String,
scheme: Option<String>,
}
impl HostGuard {
pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
self.scheme = Some(scheme.as_ref().to_string());
self
}
}
impl Guard for HostGuard {
fn check(&self, ctx: &GuardContext<'_>) -> bool {
let req_host_uri = match get_host_uri(ctx.head()) {
Some(uri) => uri,
None => return false,
};
match req_host_uri.host() {
Some(uri_host) if self.host == uri_host => {}
_ => return false,
}
if let Some(ref scheme) = self.scheme {
if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
return scheme == req_host_uri_scheme;
}
}
true
}
}
#[cfg(test)]
mod tests {
use actix_http::{header, Method};
use super::*;
use crate::test::TestRequest;
#[test]
fn header_match() {
let req = TestRequest::default()
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.to_srv_request();
let hdr = Header("transfer-encoding", "chunked");
assert!(hdr.check(&req.guard_ctx()));
let hdr = Header("transfer-encoding", "other");
assert!(!hdr.check(&req.guard_ctx()));
let hdr = Header("content-type", "chunked");
assert!(!hdr.check(&req.guard_ctx()));
let hdr = Header("content-type", "other");
assert!(!hdr.check(&req.guard_ctx()));
}
#[test]
fn host_from_header() {
let req = TestRequest::default()
.insert_header((
header::HOST,
header::HeaderValue::from_static("www.rust-lang.org"),
))
.to_srv_request();
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let host = Host("crates.io");
assert!(!host.check(&req.guard_ctx()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn host_without_header() {
let req = TestRequest::default()
.uri("www.rust-lang.org")
.to_srv_request();
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let host = Host("crates.io");
assert!(!host.check(&req.guard_ctx()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn host_scheme() {
let req = TestRequest::default()
.insert_header((
header::HOST,
header::HeaderValue::from_static("https://www.rust-lang.org"),
))
.to_srv_request();
let host = Host("www.rust-lang.org").scheme("https");
assert!(host.check(&req.guard_ctx()));
let host = Host("www.rust-lang.org");
assert!(host.check(&req.guard_ctx()));
let host = Host("www.rust-lang.org").scheme("http");
assert!(!host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org");
assert!(!host.check(&req.guard_ctx()));
let host = Host("blog.rust-lang.org").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let host = Host("crates.io").scheme("https");
assert!(!host.check(&req.guard_ctx()));
let host = Host("localhost");
assert!(!host.check(&req.guard_ctx()));
}
#[test]
fn method_guards() {
let get_req = TestRequest::get().to_srv_request();
let post_req = TestRequest::post().to_srv_request();
assert!(Get().check(&get_req.guard_ctx()));
assert!(!Get().check(&post_req.guard_ctx()));
assert!(Post().check(&post_req.guard_ctx()));
assert!(!Post().check(&get_req.guard_ctx()));
let req = TestRequest::put().to_srv_request();
assert!(Put().check(&req.guard_ctx()));
assert!(!Put().check(&get_req.guard_ctx()));
let req = TestRequest::patch().to_srv_request();
assert!(Patch().check(&req.guard_ctx()));
assert!(!Patch().check(&get_req.guard_ctx()));
let r = TestRequest::delete().to_srv_request();
assert!(Delete().check(&r.guard_ctx()));
assert!(!Delete().check(&get_req.guard_ctx()));
let req = TestRequest::default().method(Method::HEAD).to_srv_request();
assert!(Head().check(&req.guard_ctx()));
assert!(!Head().check(&get_req.guard_ctx()));
let req = TestRequest::default()
.method(Method::OPTIONS)
.to_srv_request();
assert!(Options().check(&req.guard_ctx()));
assert!(!Options().check(&get_req.guard_ctx()));
let req = TestRequest::default()
.method(Method::CONNECT)
.to_srv_request();
assert!(Connect().check(&req.guard_ctx()));
assert!(!Connect().check(&get_req.guard_ctx()));
let req = TestRequest::default()
.method(Method::TRACE)
.to_srv_request();
assert!(Trace().check(&req.guard_ctx()));
assert!(!Trace().check(&get_req.guard_ctx()));
}
#[test]
fn aggregate_any() {
let req = TestRequest::default()
.method(Method::TRACE)
.to_srv_request();
assert!(Any(Trace()).check(&req.guard_ctx()));
assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
}
#[test]
fn aggregate_all() {
let req = TestRequest::default()
.method(Method::TRACE)
.to_srv_request();
assert!(All(Trace()).check(&req.guard_ctx()));
assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
}
#[test]
fn nested_not() {
let req = TestRequest::default().to_srv_request();
let get = Get();
assert!(get.check(&req.guard_ctx()));
let not_get = Not(get);
assert!(!not_get.check(&req.guard_ctx()));
let not_not_get = Not(not_get);
assert!(not_not_get.check(&req.guard_ctx()));
}
#[test]
fn function_guard() {
let domain = "rust-lang.org".to_owned();
let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
let req = TestRequest::default()
.uri("blog.rust-lang.org")
.to_srv_request();
assert!(guard.check(&req.guard_ctx()));
let req = TestRequest::default().uri("crates.io").to_srv_request();
assert!(!guard.check(&req.guard_ctx()));
}
#[test]
fn mega_nesting() {
let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
let req = TestRequest::default().to_srv_request();
assert!(!guard.check(&req.guard_ctx()));
let req = TestRequest::default()
.method(Method::TRACE)
.to_srv_request();
assert!(guard.check(&req.guard_ctx()));
}
}