use async_trait::async_trait;
use hyper::header::HeaderName;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum XFrameOptions {
Deny,
SameOrigin,
}
impl XFrameOptions {
pub fn as_str(&self) -> &'static str {
match self {
XFrameOptions::Deny => "DENY",
XFrameOptions::SameOrigin => "SAMEORIGIN",
}
}
}
pub struct XFrameOptionsMiddleware {
option: XFrameOptions,
}
impl XFrameOptionsMiddleware {
pub fn deny() -> Self {
Self {
option: XFrameOptions::Deny,
}
}
pub fn same_origin() -> Self {
Self {
option: XFrameOptions::SameOrigin,
}
}
pub fn new(option: XFrameOptions) -> Self {
Self { option }
}
}
impl Default for XFrameOptionsMiddleware {
fn default() -> Self {
Self::same_origin()
}
}
const X_FRAME_OPTIONS: HeaderName = HeaderName::from_static("x-frame-options");
#[async_trait]
impl Middleware for XFrameOptionsMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if !response.headers.contains_key(&X_FRAME_OPTIONS) {
let header_value = match self.option {
XFrameOptions::Deny => hyper::header::HeaderValue::from_static("DENY"),
XFrameOptions::SameOrigin => hyper::header::HeaderValue::from_static("SAMEORIGIN"),
};
response.headers.insert(X_FRAME_OPTIONS, header_value);
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
use reinhardt_http::Error;
use rstest::rstest;
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..])))
}
}
#[tokio::test]
async fn test_deny_option() {
let middleware = XFrameOptionsMiddleware::deny();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
}
#[tokio::test]
async fn test_same_origin_option() {
let middleware = XFrameOptionsMiddleware::same_origin();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get(&X_FRAME_OPTIONS).unwrap(),
"SAMEORIGIN"
);
}
#[tokio::test]
async fn test_default_is_same_origin() {
let middleware = XFrameOptionsMiddleware::default();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get(&X_FRAME_OPTIONS).unwrap(),
"SAMEORIGIN"
);
}
#[tokio::test]
async fn test_does_not_override_existing_header() {
struct TestHandlerWithHeader;
#[async_trait]
impl Handler for TestHandlerWithHeader {
async fn handle(&self, _request: Request) -> Result<Response> {
let mut response =
Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..]));
response
.headers
.insert(X_FRAME_OPTIONS, "DENY".parse().unwrap());
Ok(response)
}
}
let middleware = XFrameOptionsMiddleware::same_origin();
let handler = Arc::new(TestHandlerWithHeader);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
}
#[tokio::test]
async fn test_new_constructor_with_deny() {
let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/secure")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
}
#[tokio::test]
async fn test_new_constructor_with_same_origin() {
let middleware = XFrameOptionsMiddleware::new(XFrameOptions::SameOrigin);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/dashboard")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get(&X_FRAME_OPTIONS).unwrap(),
"SAMEORIGIN"
);
}
#[tokio::test]
async fn test_response_body_preserved() {
struct TestHandlerWithBody;
#[async_trait]
impl Handler for TestHandlerWithBody {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK)
.with_body(Bytes::from(&b"custom response body"[..])))
}
}
let middleware = XFrameOptionsMiddleware::deny();
let handler = Arc::new(TestHandlerWithBody);
let request = Request::builder()
.method(Method::GET)
.uri("/content")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
assert_eq!(response.body, Bytes::from(&b"custom response body"[..]));
}
#[tokio::test]
async fn test_middleware_reusable_across_requests() {
let middleware = XFrameOptionsMiddleware::deny();
let handler = Arc::new(TestHandler);
let request1 = Request::builder()
.method(Method::GET)
.uri("/page1")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response1 = middleware.process(request1, handler.clone()).await.unwrap();
assert_eq!(response1.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
let request2 = Request::builder()
.method(Method::POST)
.uri("/page2")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response2 = middleware.process(request2, handler.clone()).await.unwrap();
assert_eq!(response2.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
let request3 = Request::builder()
.method(Method::PUT)
.uri("/page3")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response3 = middleware.process(request3, handler).await.unwrap();
assert_eq!(response3.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
}
struct ErrorHandler;
#[async_trait]
impl Handler for ErrorHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Err(Error::Http("handler error".to_string()))
}
}
#[rstest]
#[tokio::test]
async fn test_xframe_header_applied_on_handler_error() {
let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.status.is_client_error() || response.status.is_server_error());
assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
}
}