use http::header::{HeaderMap, HeaderName, HeaderValue};
use salvo::handler::Handler;
use salvo::{async_trait, Depot, FlowCtrl, Request, Response};
use helmet_core::Helmet as HelmetCore;
pub use helmet_core::*;
#[derive(Default)]
pub struct Helmet(HelmetCore);
impl Helmet {
pub fn new() -> Self {
Self(HelmetCore::new())
}
#[allow(clippy::should_implement_trait)]
pub fn add(self, header: impl Into<helmet_core::Header>) -> Self {
Self(self.0.add(header))
}
pub fn into_handler(self) -> Result<HelmetHandler, HelmetError> {
self.try_into()
}
}
pub struct HelmetHandler {
headers: HeaderMap,
}
impl TryFrom<Helmet> for HelmetHandler {
type Error = HelmetError;
fn try_from(helmet: Helmet) -> Result<Self, Self::Error> {
let mut headers = HeaderMap::new();
for header in helmet.0.headers.iter() {
let name = HeaderName::try_from(header.0)
.map_err(|_| HelmetError::InvalidHeaderName(header.0.to_string()))?;
let value = HeaderValue::from_str(&header.1)
.map_err(|_| HelmetError::InvalidHeaderValue(header.1.clone()))?;
headers.insert(name, value);
}
Ok(Self { headers })
}
}
#[async_trait]
impl Handler for HelmetHandler {
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
ctrl.call_next(req, depot, res).await;
for (name, value) in self.headers.iter() {
res.headers_mut().insert(name.clone(), value.clone());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use salvo::prelude::*;
use salvo::test::TestClient;
#[handler]
async fn index() -> &'static str {
"Hello, world!"
}
#[tokio::test]
async fn test_helmet() {
let router = Router::with_hoop(
Helmet::new()
.add(helmet_core::XContentTypeOptions::nosniff())
.add(helmet_core::XFrameOptions::same_origin())
.add(helmet_core::XXSSProtection::on().mode_block())
.into_handler()
.unwrap(),
)
.get(index);
let service = Service::new(router);
let res = TestClient::get("http://localhost/").send(&service).await;
assert_eq!(res.status_code, Some(StatusCode::OK));
assert_eq!(
res.headers()
.get("X-Content-Type-Options")
.map(|v| v.to_str().unwrap()),
Some("nosniff")
);
assert_eq!(
res.headers()
.get("X-Frame-Options")
.map(|v| v.to_str().unwrap()),
Some("SAMEORIGIN")
);
assert_eq!(
res.headers()
.get("X-XSS-Protection")
.map(|v| v.to_str().unwrap()),
Some("1; mode=block")
);
}
#[tokio::test]
async fn test_helmet_default() {
let handler: HelmetHandler = Helmet::default().try_into().unwrap();
let router = Router::with_hoop(handler).get(index);
let service = Service::new(router);
let res = TestClient::get("http://localhost/").send(&service).await;
assert_eq!(res.status_code, Some(StatusCode::OK));
assert_eq!(
res.headers()
.get("X-Frame-Options")
.map(|v| v.to_str().unwrap()),
Some("SAMEORIGIN")
);
assert_eq!(
res.headers()
.get("X-XSS-Protection")
.map(|v| v.to_str().unwrap()),
Some("0")
);
assert_eq!(
res.headers()
.get("Referrer-Policy")
.map(|v| v.to_str().unwrap()),
Some("no-referrer")
);
}
}