reinhardt_openapi/
router_wrapper.rs1use async_trait::async_trait;
28use reinhardt_http::Handler;
29use reinhardt_http::{Request, Response, Result};
30use reinhardt_rest::openapi::endpoints::generate_openapi_schema;
31use reinhardt_rest::openapi::{RedocUI, SwaggerUI};
32use reinhardt_urls::prelude::Route;
33use reinhardt_urls::routers::Router;
34use std::sync::Arc;
35
36pub type AuthGuard = Arc<dyn Fn(&Request) -> bool + Send + Sync>;
43
44pub struct OpenApiRouter<H> {
58 inner: H,
60 openapi_json: Arc<String>,
62 swagger_html: Arc<String>,
64 redoc_html: Arc<String>,
66 enabled: bool,
69 auth_guard: Option<AuthGuard>,
72}
73
74impl<H> OpenApiRouter<H> {
75 pub fn wrap(handler: H) -> std::result::Result<Self, reinhardt_rest::openapi::SchemaError> {
91 let schema = generate_openapi_schema();
93 let openapi_json = serde_json::to_string_pretty(&schema)?;
94
95 let swagger_ui = SwaggerUI::new(schema.clone());
97 let swagger_html = swagger_ui.render_html()?;
98
99 let redoc_ui = RedocUI::new(schema);
101 let redoc_html = redoc_ui.render_html()?;
102
103 Ok(Self {
104 inner: handler,
105 openapi_json: Arc::new(openapi_json),
106 swagger_html: Arc::new(swagger_html),
107 redoc_html: Arc::new(redoc_html),
108 enabled: true,
109 auth_guard: None,
110 })
111 }
112
113 pub fn enabled(mut self, enabled: bool) -> Self {
131 self.enabled = enabled;
132 self
133 }
134
135 pub fn auth_guard(mut self, guard: impl Fn(&Request) -> bool + Send + Sync + 'static) -> Self {
160 self.auth_guard = Some(Arc::new(guard));
161 self
162 }
163
164 pub fn inner(&self) -> &H {
166 &self.inner
167 }
168
169 fn check_access(&self, request: &Request) -> Option<Response> {
175 if !self.enabled {
176 return Some(Response::not_found());
177 }
178 if let Some(ref guard) = self.auth_guard
179 && !guard(request)
180 {
181 return Some(Response::forbidden());
182 }
183 None
184 }
185
186 fn try_serve_openapi(&self, request: &Request) -> Option<Result<Response>> {
195 match request.uri.path() {
196 "/api/openapi.json" | "/api/docs" | "/api/redoc" => {
197 if let Some(denied) = self.check_access(request) {
198 return Some(Ok(denied));
199 }
200 let response = match request.uri.path() {
201 "/api/openapi.json" => {
202 let json = (*self.openapi_json).clone();
203 Response::ok()
204 .with_header("Content-Type", "application/json; charset=utf-8")
205 .with_body(json)
206 }
207 "/api/docs" => {
208 let html = (*self.swagger_html).clone();
209 Response::ok()
210 .with_header("Content-Type", "text/html; charset=utf-8")
211 .with_body(html)
212 }
213 "/api/redoc" => {
214 let html = (*self.redoc_html).clone();
215 Response::ok()
216 .with_header("Content-Type", "text/html; charset=utf-8")
217 .with_body(html)
218 }
219 _ => unreachable!(),
220 };
221 Some(Ok(Self::apply_security_headers(response)))
222 }
223 _ => None,
224 }
225 }
226
227 fn apply_security_headers(response: Response) -> Response {
234 response
235 .with_header(
236 "Content-Security-Policy",
237 "default-src 'none'; \
238 script-src 'unsafe-inline' https://unpkg.com https://cdn.redoc.ly; \
239 style-src 'unsafe-inline' https://unpkg.com; \
240 img-src 'self' data:; \
241 connect-src 'self'; \
242 font-src https://fonts.gstatic.com; \
243 frame-ancestors 'none'",
244 )
245 .with_header("X-Frame-Options", "DENY")
246 .with_header("X-Content-Type-Options", "nosniff")
247 .with_header("Cache-Control", "no-store")
248 }
249}
250
251#[async_trait]
252impl<H: Handler> Handler for OpenApiRouter<H> {
253 async fn handle(&self, request: Request) -> Result<Response> {
263 if let Some(response) = self.try_serve_openapi(&request) {
265 return response;
266 }
267 self.inner.handle(request).await
268 }
269}
270
271impl<H> Router for OpenApiRouter<H>
277where
278 H: Handler + Router,
279{
280 fn add_route(&mut self, _route: Route) {
287 panic!(
288 "Cannot add routes to OpenApiRouter after wrapping. \
289 Add routes to the base router before calling OpenApiRouter::wrap()."
290 );
291 }
292
293 fn mount(&mut self, _prefix: &str, _routes: Vec<Route>, _namespace: Option<String>) {
300 panic!(
301 "Cannot mount routes in OpenApiRouter after wrapping. \
302 Mount routes in the base router before calling OpenApiRouter::wrap()."
303 );
304 }
305
306 async fn route(&self, request: Request) -> Result<Response> {
317 if let Some(response) = self.try_serve_openapi(&request) {
319 return response;
320 }
321 self.inner.route(request).await
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use hyper::StatusCode;
329 use rstest::rstest;
330
331 struct DummyHandler;
332
333 #[async_trait]
334 impl Handler for DummyHandler {
335 async fn handle(&self, _request: Request) -> Result<Response> {
336 Ok(Response::new(StatusCode::OK).with_body("Hello from inner handler"))
337 }
338 }
339
340 #[rstest]
341 #[tokio::test]
342 async fn test_openapi_json_endpoint() {
343 let handler = DummyHandler;
345 let wrapped = OpenApiRouter::wrap(handler).unwrap();
346
347 let request = Request::builder().uri("/api/openapi.json").build().unwrap();
349 let response = wrapped.handle(request).await.unwrap();
350
351 assert_eq!(response.status, StatusCode::OK);
353 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
354 assert!(body_str.contains("openapi"));
355 assert!(body_str.contains("3.")); }
357
358 #[rstest]
359 #[tokio::test]
360 async fn test_swagger_docs_endpoint() {
361 let handler = DummyHandler;
363 let wrapped = OpenApiRouter::wrap(handler).unwrap();
364
365 let request = Request::builder().uri("/api/docs").build().unwrap();
367 let response = wrapped.handle(request).await.unwrap();
368
369 assert_eq!(response.status, StatusCode::OK);
371 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
372 assert!(body_str.contains("swagger-ui"));
373 }
374
375 #[rstest]
376 #[tokio::test]
377 async fn test_redoc_docs_endpoint() {
378 let handler = DummyHandler;
380 let wrapped = OpenApiRouter::wrap(handler).unwrap();
381
382 let request = Request::builder().uri("/api/redoc").build().unwrap();
384 let response = wrapped.handle(request).await.unwrap();
385
386 assert_eq!(response.status, StatusCode::OK);
388 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
389 assert!(body_str.contains("redoc"));
390 }
391
392 #[rstest]
393 #[tokio::test]
394 async fn test_delegation_to_inner_handler() {
395 let handler = DummyHandler;
397 let wrapped = OpenApiRouter::wrap(handler).unwrap();
398
399 let request = Request::builder().uri("/some/other/path").build().unwrap();
401 let response = wrapped.handle(request).await.unwrap();
402
403 assert_eq!(response.status, StatusCode::OK);
405 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
406 assert_eq!(body_str, "Hello from inner handler");
407 }
408
409 #[rstest]
412 #[case("/api/openapi.json")]
413 #[case("/api/docs")]
414 #[case("/api/redoc")]
415 #[tokio::test]
416 async fn test_disabled_endpoints_return_404(#[case] path: &str) {
417 let handler = DummyHandler;
419 let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
420
421 let request = Request::builder().uri(path).build().unwrap();
423 let response = wrapped.handle(request).await.unwrap();
424
425 assert_eq!(response.status, StatusCode::NOT_FOUND);
427 }
428
429 #[rstest]
430 #[tokio::test]
431 async fn test_disabled_does_not_affect_other_routes() {
432 let handler = DummyHandler;
434 let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
435
436 let request = Request::builder().uri("/some/other/path").build().unwrap();
438 let response = wrapped.handle(request).await.unwrap();
439
440 assert_eq!(response.status, StatusCode::OK);
442 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
443 assert_eq!(body_str, "Hello from inner handler");
444 }
445
446 #[rstest]
447 #[case("/api/openapi.json")]
448 #[case("/api/docs")]
449 #[case("/api/redoc")]
450 #[tokio::test]
451 async fn test_auth_guard_rejects_unauthorized(#[case] path: &str) {
452 let handler = DummyHandler;
454 let wrapped = OpenApiRouter::wrap(handler)
455 .unwrap()
456 .auth_guard(|_request| false);
457
458 let request = Request::builder().uri(path).build().unwrap();
460 let response = wrapped.handle(request).await.unwrap();
461
462 assert_eq!(response.status, StatusCode::FORBIDDEN);
464 }
465
466 #[rstest]
467 #[case("/api/openapi.json")]
468 #[case("/api/docs")]
469 #[case("/api/redoc")]
470 #[tokio::test]
471 async fn test_auth_guard_allows_authorized(#[case] path: &str) {
472 let handler = DummyHandler;
474 let wrapped = OpenApiRouter::wrap(handler)
475 .unwrap()
476 .auth_guard(|_request| true);
477
478 let request = Request::builder().uri(path).build().unwrap();
480 let response = wrapped.handle(request).await.unwrap();
481
482 assert_eq!(response.status, StatusCode::OK);
484 }
485
486 #[rstest]
487 #[tokio::test]
488 async fn test_auth_guard_does_not_affect_other_routes() {
489 let handler = DummyHandler;
491 let wrapped = OpenApiRouter::wrap(handler)
492 .unwrap()
493 .auth_guard(|_request| false);
494
495 let request = Request::builder().uri("/some/other/path").build().unwrap();
497 let response = wrapped.handle(request).await.unwrap();
498
499 assert_eq!(response.status, StatusCode::OK);
501 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
502 assert_eq!(body_str, "Hello from inner handler");
503 }
504
505 #[rstest]
506 #[case("/api/openapi.json")]
507 #[case("/api/docs")]
508 #[case("/api/redoc")]
509 #[tokio::test]
510 async fn test_disabled_takes_precedence_over_auth_guard(#[case] path: &str) {
511 let handler = DummyHandler;
513 let wrapped = OpenApiRouter::wrap(handler)
514 .unwrap()
515 .enabled(false)
516 .auth_guard(|_request| true);
517
518 let request = Request::builder().uri(path).build().unwrap();
520 let response = wrapped.handle(request).await.unwrap();
521
522 assert_eq!(response.status, StatusCode::NOT_FOUND);
524 }
525
526 #[rstest]
527 #[tokio::test]
528 async fn test_openapi_json_response_body_is_valid_openapi_json() {
529 let handler = DummyHandler;
531 let wrapped = OpenApiRouter::wrap(handler).unwrap();
532
533 let request = Request::builder().uri("/api/openapi.json").build().unwrap();
535 let response = wrapped.handle(request).await.unwrap();
536
537 assert_eq!(response.status, StatusCode::OK);
539 let body_bytes = response.body.to_vec();
540 let json: serde_json::Value =
541 serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON");
542 let openapi_version = json["openapi"]
543 .as_str()
544 .expect("JSON should have an 'openapi' string field");
545 assert!(
546 openapi_version.starts_with("3."),
547 "openapi field should start with '3.', got: {}",
548 openapi_version
549 );
550 }
551
552 #[rstest]
553 #[tokio::test]
554 async fn test_openapi_json_response_content_type_header() {
555 let handler = DummyHandler;
557 let wrapped = OpenApiRouter::wrap(handler).unwrap();
558
559 let request = Request::builder().uri("/api/openapi.json").build().unwrap();
561 let response = wrapped.handle(request).await.unwrap();
562
563 assert_eq!(response.status, StatusCode::OK);
565 let content_type = response
566 .headers
567 .get("Content-Type")
568 .and_then(|v| v.to_str().ok())
569 .unwrap_or("");
570 assert!(
571 content_type.contains("application/json"),
572 "Content-Type should contain 'application/json', got: {}",
573 content_type
574 );
575 }
576
577 #[rstest]
578 #[tokio::test]
579 async fn test_swagger_docs_response_body_contains_swagger_ui_marker() {
580 let handler = DummyHandler;
582 let wrapped = OpenApiRouter::wrap(handler).unwrap();
583
584 let request = Request::builder().uri("/api/docs").build().unwrap();
586 let response = wrapped.handle(request).await.unwrap();
587
588 assert_eq!(response.status, StatusCode::OK);
590 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
591 assert!(
592 body_str.contains("swagger-ui"),
593 "Swagger docs HTML should contain 'swagger-ui'"
594 );
595 }
596
597 #[rstest]
598 #[tokio::test]
599 async fn test_redoc_docs_response_body_contains_redoc_marker() {
600 let handler = DummyHandler;
602 let wrapped = OpenApiRouter::wrap(handler).unwrap();
603
604 let request = Request::builder().uri("/api/redoc").build().unwrap();
606 let response = wrapped.handle(request).await.unwrap();
607
608 assert_eq!(response.status, StatusCode::OK);
610 let body_str = String::from_utf8(response.body.to_vec())
611 .unwrap()
612 .to_lowercase();
613 assert!(
614 body_str.contains("redoc"),
615 "Redoc docs HTML should contain 'redoc' (case-insensitive)"
616 );
617 }
618
619 #[rstest]
620 #[tokio::test]
621 async fn test_auth_guard_inspects_request_headers() {
622 let handler = DummyHandler;
624 let wrapped = OpenApiRouter::wrap(handler).unwrap().auth_guard(|request| {
625 request
626 .headers
627 .get("X-Docs-Token")
628 .and_then(|v| v.to_str().ok())
629 .map(|v| v == "valid-token")
630 .unwrap_or(false)
631 });
632
633 let request_no_token = Request::builder().uri("/api/docs").build().unwrap();
635 let response_no_token = wrapped.handle(request_no_token).await.unwrap();
636
637 assert_eq!(response_no_token.status, StatusCode::FORBIDDEN);
639
640 let request_valid = Request::builder()
642 .uri("/api/docs")
643 .header("X-Docs-Token", "valid-token")
644 .build()
645 .unwrap();
646 let response_valid = wrapped.handle(request_valid).await.unwrap();
647
648 assert_eq!(response_valid.status, StatusCode::OK);
650
651 let request_invalid = Request::builder()
653 .uri("/api/docs")
654 .header("X-Docs-Token", "wrong-token")
655 .build()
656 .unwrap();
657 let response_invalid = wrapped.handle(request_invalid).await.unwrap();
658
659 assert_eq!(response_invalid.status, StatusCode::FORBIDDEN);
661 }
662}