use async_trait::async_trait;
use reinhardt_http::Handler;
use reinhardt_http::{Request, Response, Result};
use reinhardt_rest::openapi::endpoints::generate_openapi_schema;
use reinhardt_rest::openapi::{RedocUI, SwaggerUI};
use reinhardt_urls::prelude::Route;
use reinhardt_urls::routers::Router;
use std::sync::Arc;
pub type AuthGuard = Arc<dyn Fn(&Request) -> bool + Send + Sync>;
pub struct OpenApiRouter<H> {
inner: H,
openapi_json: Arc<String>,
swagger_html: Arc<String>,
redoc_html: Arc<String>,
enabled: bool,
auth_guard: Option<AuthGuard>,
}
impl<H> OpenApiRouter<H> {
pub fn wrap(handler: H) -> std::result::Result<Self, reinhardt_rest::openapi::SchemaError> {
let schema = generate_openapi_schema();
let openapi_json = serde_json::to_string_pretty(&schema)?;
let swagger_ui = SwaggerUI::new(schema.clone());
let swagger_html = swagger_ui.render_html()?;
let redoc_ui = RedocUI::new(schema);
let redoc_html = redoc_ui.render_html()?;
Ok(Self {
inner: handler,
openapi_json: Arc::new(openapi_json),
swagger_html: Arc::new(swagger_html),
redoc_html: Arc::new(redoc_html),
enabled: true,
auth_guard: None,
})
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn auth_guard(mut self, guard: impl Fn(&Request) -> bool + Send + Sync + 'static) -> Self {
self.auth_guard = Some(Arc::new(guard));
self
}
pub fn inner(&self) -> &H {
&self.inner
}
fn check_access(&self, request: &Request) -> Option<Response> {
if !self.enabled {
return Some(Response::not_found());
}
if let Some(ref guard) = self.auth_guard
&& !guard(request)
{
return Some(Response::forbidden());
}
None
}
fn try_serve_openapi(&self, request: &Request) -> Option<Result<Response>> {
match request.uri.path() {
"/api/openapi.json" | "/api/docs" | "/api/redoc" => {
if let Some(denied) = self.check_access(request) {
return Some(Ok(denied));
}
let response = match request.uri.path() {
"/api/openapi.json" => {
let json = (*self.openapi_json).clone();
Response::ok()
.with_header("Content-Type", "application/json; charset=utf-8")
.with_body(json)
}
"/api/docs" => {
let html = (*self.swagger_html).clone();
Response::ok()
.with_header("Content-Type", "text/html; charset=utf-8")
.with_body(html)
}
"/api/redoc" => {
let html = (*self.redoc_html).clone();
Response::ok()
.with_header("Content-Type", "text/html; charset=utf-8")
.with_body(html)
}
_ => unreachable!(),
};
Some(Ok(Self::apply_security_headers(response)))
}
_ => None,
}
}
fn apply_security_headers(response: Response) -> Response {
response
.with_header(
"Content-Security-Policy",
"default-src 'none'; \
script-src 'unsafe-inline' https://unpkg.com https://cdn.redoc.ly; \
style-src 'unsafe-inline' https://unpkg.com; \
img-src 'self' data:; \
connect-src 'self'; \
font-src https://fonts.gstatic.com; \
frame-ancestors 'none'",
)
.with_header("X-Frame-Options", "DENY")
.with_header("X-Content-Type-Options", "nosniff")
.with_header("Cache-Control", "no-store")
}
}
#[async_trait]
impl<H: Handler> Handler for OpenApiRouter<H> {
async fn handle(&self, request: Request) -> Result<Response> {
if let Some(response) = self.try_serve_openapi(&request) {
return response;
}
self.inner.handle(request).await
}
}
impl<H> Router for OpenApiRouter<H>
where
H: Handler + Router,
{
fn add_route(&mut self, _route: Route) {
panic!(
"Cannot add routes to OpenApiRouter after wrapping. \
Add routes to the base router before calling OpenApiRouter::wrap()."
);
}
fn mount(&mut self, _prefix: &str, _routes: Vec<Route>, _namespace: Option<String>) {
panic!(
"Cannot mount routes in OpenApiRouter after wrapping. \
Mount routes in the base router before calling OpenApiRouter::wrap()."
);
}
async fn route(&self, request: Request) -> Result<Response> {
if let Some(response) = self.try_serve_openapi(&request) {
return response;
}
self.inner.route(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::StatusCode;
use rstest::rstest;
struct DummyHandler;
#[async_trait]
impl Handler for DummyHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body("Hello from inner handler"))
}
}
#[rstest]
#[tokio::test]
async fn test_openapi_json_endpoint() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/openapi.json").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert!(body_str.contains("openapi"));
assert!(body_str.contains("3.")); }
#[rstest]
#[tokio::test]
async fn test_swagger_docs_endpoint() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/docs").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert!(body_str.contains("swagger-ui"));
}
#[rstest]
#[tokio::test]
async fn test_redoc_docs_endpoint() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/redoc").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert!(body_str.contains("redoc"));
}
#[rstest]
#[tokio::test]
async fn test_delegation_to_inner_handler() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/some/other/path").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body_str, "Hello from inner handler");
}
#[rstest]
#[case("/api/openapi.json")]
#[case("/api/docs")]
#[case("/api/redoc")]
#[tokio::test]
async fn test_disabled_endpoints_return_404(#[case] path: &str) {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
let request = Request::builder().uri(path).build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
}
#[rstest]
#[tokio::test]
async fn test_disabled_does_not_affect_other_routes() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
let request = Request::builder().uri("/some/other/path").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body_str, "Hello from inner handler");
}
#[rstest]
#[case("/api/openapi.json")]
#[case("/api/docs")]
#[case("/api/redoc")]
#[tokio::test]
async fn test_auth_guard_rejects_unauthorized(#[case] path: &str) {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler)
.unwrap()
.auth_guard(|_request| false);
let request = Request::builder().uri(path).build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::FORBIDDEN);
}
#[rstest]
#[case("/api/openapi.json")]
#[case("/api/docs")]
#[case("/api/redoc")]
#[tokio::test]
async fn test_auth_guard_allows_authorized(#[case] path: &str) {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler)
.unwrap()
.auth_guard(|_request| true);
let request = Request::builder().uri(path).build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_auth_guard_does_not_affect_other_routes() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler)
.unwrap()
.auth_guard(|_request| false);
let request = Request::builder().uri("/some/other/path").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body_str, "Hello from inner handler");
}
#[rstest]
#[case("/api/openapi.json")]
#[case("/api/docs")]
#[case("/api/redoc")]
#[tokio::test]
async fn test_disabled_takes_precedence_over_auth_guard(#[case] path: &str) {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler)
.unwrap()
.enabled(false)
.auth_guard(|_request| true);
let request = Request::builder().uri(path).build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
}
#[rstest]
#[tokio::test]
async fn test_openapi_json_response_body_is_valid_openapi_json() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/openapi.json").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_bytes = response.body.to_vec();
let json: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON");
let openapi_version = json["openapi"]
.as_str()
.expect("JSON should have an 'openapi' string field");
assert!(
openapi_version.starts_with("3."),
"openapi field should start with '3.', got: {}",
openapi_version
);
}
#[rstest]
#[tokio::test]
async fn test_openapi_json_response_content_type_header() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/openapi.json").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let content_type = response
.headers
.get("Content-Type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("application/json"),
"Content-Type should contain 'application/json', got: {}",
content_type
);
}
#[rstest]
#[tokio::test]
async fn test_swagger_docs_response_body_contains_swagger_ui_marker() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/docs").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert!(
body_str.contains("swagger-ui"),
"Swagger docs HTML should contain 'swagger-ui'"
);
}
#[rstest]
#[tokio::test]
async fn test_redoc_docs_response_body_contains_redoc_marker() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap();
let request = Request::builder().uri("/api/redoc").build().unwrap();
let response = wrapped.handle(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let body_str = String::from_utf8(response.body.to_vec())
.unwrap()
.to_lowercase();
assert!(
body_str.contains("redoc"),
"Redoc docs HTML should contain 'redoc' (case-insensitive)"
);
}
#[rstest]
#[tokio::test]
async fn test_auth_guard_inspects_request_headers() {
let handler = DummyHandler;
let wrapped = OpenApiRouter::wrap(handler).unwrap().auth_guard(|request| {
request
.headers
.get("X-Docs-Token")
.and_then(|v| v.to_str().ok())
.map(|v| v == "valid-token")
.unwrap_or(false)
});
let request_no_token = Request::builder().uri("/api/docs").build().unwrap();
let response_no_token = wrapped.handle(request_no_token).await.unwrap();
assert_eq!(response_no_token.status, StatusCode::FORBIDDEN);
let request_valid = Request::builder()
.uri("/api/docs")
.header("X-Docs-Token", "valid-token")
.build()
.unwrap();
let response_valid = wrapped.handle(request_valid).await.unwrap();
assert_eq!(response_valid.status, StatusCode::OK);
let request_invalid = Request::builder()
.uri("/api/docs")
.header("X-Docs-Token", "wrong-token")
.build()
.unwrap();
let response_invalid = wrapped.handle(request_invalid).await.unwrap();
assert_eq!(response_invalid.status, StatusCode::FORBIDDEN);
}
}