use tide::{Middleware, Next, Request};
use helmet_core::Helmet as HelmetCore;
pub use helmet_core::*;
#[derive(Clone, Debug)]
pub struct Helmet {
headers: Vec<(&'static str, String)>,
}
impl Default for Helmet {
fn default() -> Self {
Self::from(HelmetCore::default())
}
}
impl Helmet {
pub fn new() -> Self {
Self::from(HelmetCore::new())
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, header: impl Into<helmet_core::Header>) -> Self {
self.headers.push(header.into());
self
}
}
impl From<HelmetCore> for Helmet {
fn from(core: HelmetCore) -> Self {
Self {
headers: core
.headers
.iter()
.map(|header| (header.0, header.1.clone()))
.collect(),
}
}
}
#[async_trait::async_trait]
impl<State: Clone + Send + Sync + 'static> Middleware<State> for Helmet {
async fn handle(&self, req: Request<State>, next: Next<'_, State>) -> tide::Result {
let mut res = next.run(req).await;
for (name, value) in &self.headers {
res.insert_header(*name, value.as_str());
}
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[async_std::test]
async fn test_helmet() {
let mut app = tide::new();
app.with(
Helmet::new()
.add(helmet_core::XContentTypeOptions::nosniff())
.add(helmet_core::XFrameOptions::same_origin())
.add(helmet_core::XXSSProtection::on().mode_block()),
);
app.at("/").get(|_| async { Ok("Hello, world!") });
let req = tide::http::Request::new(
tide::http::Method::Get,
tide::http::Url::parse("http://localhost/").unwrap(),
);
let res: tide::http::Response = app.respond(req).await.unwrap();
assert_eq!(res.status(), tide::StatusCode::Ok);
assert_eq!(
res.header("X-Content-Type-Options").map(|v| v.as_str()),
Some("nosniff")
);
assert_eq!(
res.header("X-Frame-Options").map(|v| v.as_str()),
Some("SAMEORIGIN")
);
assert_eq!(
res.header("X-XSS-Protection").map(|v| v.as_str()),
Some("1; mode=block")
);
}
#[async_std::test]
async fn test_helmet_default() {
let mut app = tide::new();
app.with(Helmet::default());
app.at("/").get(|_| async { Ok("Hello, world!") });
let req = tide::http::Request::new(
tide::http::Method::Get,
tide::http::Url::parse("http://localhost/").unwrap(),
);
let res: tide::http::Response = app.respond(req).await.unwrap();
assert_eq!(res.status(), tide::StatusCode::Ok);
assert_eq!(
res.header("X-Frame-Options").map(|v| v.as_str()),
Some("SAMEORIGIN")
);
assert_eq!(
res.header("X-XSS-Protection").map(|v| v.as_str()),
Some("0")
);
assert_eq!(
res.header("Referrer-Policy").map(|v| v.as_str()),
Some("no-referrer")
);
}
}