use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use rama_core::Layer;
use super::origin::{Origins, parse_trusted_origin};
use super::service::Csrf;
use super::{BypassFn, ConfigError, DebugFn, DefaultResponseForProtectionError};
use crate::Method;
use rama_net::uri::Uri;
#[derive(Clone)]
#[must_use]
pub struct CsrfLayer<T = DefaultResponseForProtectionError> {
insecure_bypass: Option<Arc<BypassFn>>,
rejection_response: T,
trusted_origins: Origins,
}
impl Default for CsrfLayer {
fn default() -> Self {
Self {
insecure_bypass: None,
rejection_response: DefaultResponseForProtectionError,
trusted_origins: Origins::default(),
}
}
}
impl<T> Debug for CsrfLayer<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("CsrfLayer")
.field(
"insecure_bypass",
&self.insecure_bypass.as_ref().map(|_| DebugFn),
)
.field("trusted_origins", &self.trusted_origins)
.field("rejection_response", &DebugFn)
.finish()
}
}
impl CsrfLayer {
pub fn new() -> Self {
Self::default()
}
}
impl<T> CsrfLayer<T> {
pub fn add_trusted_origin<S: AsRef<str>>(mut self, origin: S) -> Result<Self, ConfigError> {
let origin = parse_trusted_origin(origin.as_ref())?;
self.trusted_origins.insert(origin);
Ok(self)
}
pub fn with_insecure_bypass<F>(mut self, predicate: F) -> Self
where
F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static,
{
self.insecure_bypass = Some(Arc::new(predicate));
self
}
pub fn with_rejection_response<R>(self, rejection_response: R) -> CsrfLayer<R>
where
R: Clone,
{
CsrfLayer {
insecure_bypass: self.insecure_bypass,
trusted_origins: self.trusted_origins,
rejection_response,
}
}
}
impl<S, T> Layer<S> for CsrfLayer<T>
where
T: Clone,
{
type Service = Csrf<S, T>;
fn layer(&self, inner: S) -> Self::Service {
Csrf::new(
inner,
self.insecure_bypass.clone(),
self.rejection_response.clone(),
self.trusted_origins.clone(),
)
}
fn into_layer(self, inner: S) -> Self::Service {
Csrf::new(
inner,
self.insecure_bypass,
self.rejection_response,
self.trusted_origins,
)
}
}