use crate::xss::XssProtector;
use async_trait::async_trait;
use bytes::Bytes;
use hyper::StatusCode;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct AuthenticationState {
authenticated: bool,
user_id: Option<String>,
}
impl AuthenticationState {
pub fn authenticated(user_id: String) -> Self {
Self {
authenticated: true,
user_id: Some(user_id),
}
}
pub fn anonymous() -> Self {
Self {
authenticated: false,
user_id: None,
}
}
pub fn is_authenticated(&self) -> bool {
self.authenticated
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Flatpage {
pub url: String,
pub title: String,
pub content: String,
pub enable_comments: bool,
pub registration_required: bool,
}
impl Flatpage {
pub fn new(url: String, title: String, content: String) -> Self {
Self {
url,
title,
content,
enable_comments: false,
registration_required: false,
}
}
pub fn with_comments(mut self) -> Self {
self.enable_comments = true;
self
}
pub fn require_registration(mut self) -> Self {
self.registration_required = true;
self
}
}
#[derive(Debug, Default)]
pub struct FlatpageStore {
pages: RwLock<HashMap<String, Flatpage>>,
}
impl FlatpageStore {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, page: Flatpage) {
let url = page.url.clone();
self.pages
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(url, page);
}
pub fn get(&self, url: &str) -> Option<Flatpage> {
self.pages
.read()
.unwrap_or_else(|e| e.into_inner())
.get(url)
.cloned()
}
pub fn remove(&self, url: &str) -> Option<Flatpage> {
self.pages
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(url)
}
pub fn all(&self) -> Vec<Flatpage> {
self.pages
.read()
.unwrap_or_else(|e| e.into_inner())
.values()
.cloned()
.collect()
}
pub fn clear(&self) {
self.pages
.write()
.unwrap_or_else(|e| e.into_inner())
.clear();
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct FlatpagesConfig {
pub enabled: bool,
pub append_slash: bool,
pub try_without_slash: bool,
}
impl FlatpagesConfig {
pub fn new() -> Self {
Self {
enabled: true,
append_slash: true,
try_without_slash: true,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
append_slash: true,
try_without_slash: true,
}
}
}
impl Default for FlatpagesConfig {
fn default() -> Self {
Self::new()
}
}
pub struct FlatpagesMiddleware {
config: FlatpagesConfig,
store: Arc<FlatpageStore>,
}
impl FlatpagesMiddleware {
pub fn new(config: FlatpagesConfig) -> Self {
Self {
config,
store: Arc::new(FlatpageStore::new()),
}
}
pub fn from_arc(config: FlatpagesConfig, store: Arc<FlatpageStore>) -> Self {
Self { config, store }
}
pub fn store(&self) -> &FlatpageStore {
&self.store
}
pub fn store_arc(&self) -> Arc<FlatpageStore> {
Arc::clone(&self.store)
}
fn try_get_page(&self, url: &str) -> Option<Flatpage> {
if let Some(page) = self.store.get(url) {
return Some(page);
}
if self.config.append_slash && !url.ends_with('/') {
let with_slash = format!("{}/", url);
if let Some(page) = self.store.get(&with_slash) {
return Some(page);
}
}
if self.config.try_without_slash && url.ends_with('/') && url.len() > 1 {
let without_slash = &url[..url.len() - 1];
if let Some(page) = self.store.get(without_slash) {
return Some(page);
}
}
None
}
}
#[async_trait]
impl Middleware for FlatpagesMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
if !self.config.enabled {
return handler.handle(request).await;
}
let path = request.uri.path().to_string();
let is_authenticated = request
.extensions
.get::<AuthenticationState>()
.map(|auth| auth.is_authenticated())
.unwrap_or(false);
let response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if response.status != StatusCode::NOT_FOUND {
return Ok(response);
}
let path = path.as_str();
if let Some(page) = self.try_get_page(path) {
if page.registration_required && !is_authenticated {
return Ok(Response::new(StatusCode::UNAUTHORIZED)
.with_header("WWW-Authenticate", "Basic realm=\"Restricted\"")
.with_body(Bytes::from("Authentication required to view this page")));
}
let escaped_title = XssProtector::escape_for_html_body(&page.title);
let escaped_content = XssProtector::escape_for_html_body(&page.content);
let html = format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>{}</title>
</head>
<body>
{}
</body>
</html>"#,
escaped_title, escaped_content
);
return Ok(Response::new(StatusCode::OK).with_body(Bytes::from(html)));
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use rstest::rstest;
struct TestHandler {
status: StatusCode,
}
impl TestHandler {
fn ok() -> Self {
Self {
status: StatusCode::OK,
}
}
fn not_found() -> Self {
Self {
status: StatusCode::NOT_FOUND,
}
}
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(self.status).with_body(Bytes::from("handler response")))
}
}
#[tokio::test]
async fn test_basic_flatpage() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/about/".to_string(),
"About Us".to_string(),
"<h1>About Us</h1>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/about/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("About Us"));
assert!(body.contains("<h1>About Us</h1>"));
}
#[tokio::test]
async fn test_flatpage_with_trailing_slash() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/contact/".to_string(),
"Contact".to_string(),
"<p>Contact us</p>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/contact")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("Contact"));
}
#[tokio::test]
async fn test_flatpage_without_trailing_slash() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/faq".to_string(),
"FAQ".to_string(),
"<p>Frequently Asked Questions</p>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/faq/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("FAQ"));
}
#[tokio::test]
async fn test_no_flatpage_found() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/nonexistent/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
let body = String::from_utf8_lossy(&response.body);
assert_eq!(body, "handler response");
}
#[tokio::test]
async fn test_non_404_response_passthrough() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/about/".to_string(),
"About".to_string(),
"<h1>About</h1>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::ok());
let request = Request::builder()
.method(Method::GET)
.uri("/about/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert_eq!(body, "handler response");
}
#[tokio::test]
async fn test_disabled_middleware() {
let config = FlatpagesConfig::disabled();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/about/".to_string(),
"About".to_string(),
"<h1>About</h1>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/about/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_flatpage_store_operations() {
let store = FlatpageStore::new();
let page1 = Flatpage::new(
"/page1/".to_string(),
"Page 1".to_string(),
"<p>Content 1</p>".to_string(),
);
store.register(page1.clone());
let retrieved = store.get("/page1/").unwrap();
assert_eq!(retrieved, page1);
let page2 = Flatpage::new(
"/page2/".to_string(),
"Page 2".to_string(),
"<p>Content 2</p>".to_string(),
);
store.register(page2);
let all = store.all();
assert_eq!(all.len(), 2);
let removed = store.remove("/page1/").unwrap();
assert_eq!(removed, page1);
assert!(store.get("/page1/").is_none());
store.clear();
assert_eq!(store.all().len(), 0);
}
#[tokio::test]
async fn test_flatpage_with_comments() {
let page = Flatpage::new(
"/test/".to_string(),
"Test".to_string(),
"<p>Test</p>".to_string(),
)
.with_comments();
assert!(page.enable_comments);
}
#[tokio::test]
async fn test_flatpage_require_registration() {
let page = Flatpage::new(
"/test/".to_string(),
"Test".to_string(),
"<p>Test</p>".to_string(),
)
.require_registration();
assert!(page.registration_required);
}
#[tokio::test]
async fn test_exact_match_priority() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page_with_slash = Flatpage::new(
"/test/".to_string(),
"With Slash".to_string(),
"<p>With slash</p>".to_string(),
);
let page_without_slash = Flatpage::new(
"/test".to_string(),
"Without Slash".to_string(),
"<p>Without slash</p>".to_string(),
);
middleware.store.register(page_with_slash);
middleware.store.register(page_without_slash);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/test/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler.clone()).await.unwrap();
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("With Slash"));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("Without Slash"));
}
#[tokio::test]
async fn test_append_slash_disabled() {
let mut config = FlatpagesConfig::new();
config.append_slash = false;
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/test/".to_string(),
"Test".to_string(),
"<p>Test</p>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_registration_required_authenticated_user() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let mut page = Flatpage::new(
"/protected/".to_string(),
"Protected Page".to_string(),
"<p>Protected Content</p>".to_string(),
);
page.registration_required = true;
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/protected/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
request
.extensions
.insert(AuthenticationState::authenticated("user123".to_string()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("Protected Content"));
}
#[tokio::test]
async fn test_registration_required_anonymous_user() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let mut page = Flatpage::new(
"/protected/".to_string(),
"Protected Page".to_string(),
"<p>Protected Content</p>".to_string(),
);
page.registration_required = true;
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/protected/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::UNAUTHORIZED);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("Authentication required"));
assert!(response.headers.contains_key("WWW-Authenticate"));
}
#[tokio::test]
async fn test_no_registration_required_anonymous_user() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/public/".to_string(),
"Public Page".to_string(),
"<p>Public Content</p>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/public/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("Public Content"));
}
#[rstest]
#[case::script_tag_in_title("<script>alert('xss')</script>", "Safe content", "<script>", false)]
#[case::script_tag_in_content("Safe Title", "<script>alert('xss')</script>", "<script>", false)]
#[case::img_onerror_in_content(
"Safe Title",
r#"<img src=x onerror="alert('xss')">"#,
"<img",
false
)]
#[case::event_handler_in_title(
r#"" onmouseover="alert('xss')"#,
"Safe content",
r#"onmouseover=""#,
false
)]
#[case::ampersand_escaped("Tom & Jerry", "A & B", "&", true)]
#[tokio::test]
async fn test_flatpage_xss_prevention(
#[case] title: &str,
#[case] content: &str,
#[case] pattern: &str,
#[case] should_contain: bool,
) {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/xss-test/".to_string(),
title.to_string(),
content.to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/xss-test/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert_eq!(
body.contains(pattern),
should_contain,
"Body should {} contain '{}'. Body: {}",
if should_contain { "" } else { "NOT" },
pattern,
body
);
}
#[rstest]
#[tokio::test]
async fn test_flatpage_xss_full_escape_verification() {
let config = FlatpagesConfig::new();
let middleware = Arc::new(FlatpagesMiddleware::new(config));
let page = Flatpage::new(
"/escape-test/".to_string(),
"<b>Title</b> & 'quotes' \"double\"".to_string(),
"<script>alert(1)</script>".to_string(),
);
middleware.store.register(page);
let handler = Arc::new(TestHandler::not_found());
let request = Request::builder()
.method(Method::GET)
.uri("/escape-test/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body = String::from_utf8_lossy(&response.body);
assert!(body.contains("<b>Title</b>"));
assert!(body.contains("&"));
assert!(body.contains("'quotes'"));
assert!(body.contains(""double""));
assert!(body.contains("<script>alert(1)</script>"));
assert!(!body.contains("<script>"));
assert!(!body.contains("</script>"));
}
}