Skip to main content

reinhardt_openapi/
router_wrapper.rs

1//! Router wrapper that adds OpenAPI documentation endpoints
2//!
3//! This module provides a wrapper around any `Handler` implementation that
4//! automatically serves OpenAPI documentation endpoints without modifying
5//! user code.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use reinhardt_openapi::OpenApiRouter;
11//! use reinhardt_urls::routers::BasicRouter;
12//!
13//! fn main() {
14//!     // Create your existing router
15//!     let router = BasicRouter::new();
16//!
17//!     // Wrap with OpenAPI endpoints
18//!     let wrapped = OpenApiRouter::wrap(router)?;
19//!
20//!     // The wrapped router now serves:
21//!     // - /api/openapi.json (OpenAPI spec)
22//!     // - /api/docs (Swagger UI)
23//!     // - /api/redoc (Redoc UI)
24//! }
25//! ```
26
27use async_trait::async_trait;
28use reinhardt_http::Handler;
29use reinhardt_http::{Request, Response, Result};
30use reinhardt_rest::openapi::endpoints::generate_openapi_schema;
31use reinhardt_rest::openapi::{RedocUI, SwaggerUI};
32use reinhardt_urls::prelude::Route;
33use reinhardt_urls::routers::Router;
34use std::sync::Arc;
35
36/// Type alias for the authentication guard callback.
37///
38/// The guard receives a reference to the incoming request and returns
39/// `true` if the request is authorized to access documentation endpoints,
40/// or `false` to deny access with HTTP 403 Forbidden.
41// Fixes #828
42pub type AuthGuard = Arc<dyn Fn(&Request) -> bool + Send + Sync>;
43
44/// Router wrapper that adds OpenAPI documentation endpoints
45///
46/// This wrapper intercepts requests to OpenAPI documentation paths and
47/// serves them from memory, delegating all other requests to the wrapped
48/// handler.
49///
50/// The OpenAPI schema is generated once at wrap time from the global
51/// schema registry, ensuring minimal runtime overhead.
52///
53/// Access control is supported via the `enabled` flag and an optional
54/// authentication guard callback. When `enabled` is `false`, all
55/// documentation endpoints return HTTP 404. When an auth guard is set
56/// and returns `false`, endpoints return HTTP 403.
57pub struct OpenApiRouter<H> {
58	/// Base handler to delegate to
59	inner: H,
60	/// Pre-generated OpenAPI JSON schema
61	openapi_json: Arc<String>,
62	/// Swagger UI HTML
63	swagger_html: Arc<String>,
64	/// Redoc UI HTML
65	redoc_html: Arc<String>,
66	/// Whether documentation endpoints are enabled (default: true)
67	// Fixes #828
68	enabled: bool,
69	/// Optional authentication guard for documentation endpoints
70	// Fixes #828
71	auth_guard: Option<AuthGuard>,
72}
73
74impl<H> OpenApiRouter<H> {
75	/// Wrap an existing handler with OpenAPI endpoints
76	///
77	/// This generates the OpenAPI schema from the global registry and
78	/// pre-renders the Swagger and Redoc UIs.
79	///
80	/// # Example
81	///
82	/// ```rust,ignore
83	/// use reinhardt_openapi::OpenApiRouter;
84	/// use reinhardt_urls::routers::BasicRouter;
85	///
86	/// let router = BasicRouter::new();
87	/// let wrapped = OpenApiRouter::wrap(router)?;
88	/// # Ok::<(), reinhardt_rest::openapi::SchemaError>(())
89	/// ```
90	pub fn wrap(handler: H) -> std::result::Result<Self, reinhardt_rest::openapi::SchemaError> {
91		// Generate OpenAPI schema from global registry
92		let schema = generate_openapi_schema();
93		let openapi_json = serde_json::to_string_pretty(&schema)?;
94
95		// Generate Swagger UI HTML
96		let swagger_ui = SwaggerUI::new(schema.clone());
97		let swagger_html = swagger_ui.render_html()?;
98
99		// Generate Redoc UI HTML
100		let redoc_ui = RedocUI::new(schema);
101		let redoc_html = redoc_ui.render_html()?;
102
103		Ok(Self {
104			inner: handler,
105			openapi_json: Arc::new(openapi_json),
106			swagger_html: Arc::new(swagger_html),
107			redoc_html: Arc::new(redoc_html),
108			enabled: true,
109			auth_guard: None,
110		})
111	}
112
113	/// Set whether documentation endpoints are enabled
114	///
115	/// When set to `false`, all documentation endpoints (`/api/openapi.json`,
116	/// `/api/docs`, `/api/redoc`) will return HTTP 404 Not Found.
117	///
118	/// Default is `true`.
119	///
120	/// # Example
121	///
122	/// ```rust,ignore
123	/// use reinhardt_openapi::OpenApiRouter;
124	/// use reinhardt_urls::routers::BasicRouter;
125	///
126	/// let router = BasicRouter::new();
127	/// let wrapped = OpenApiRouter::wrap(router)?.enabled(false);
128	/// ```
129	// Fixes #828
130	pub fn enabled(mut self, enabled: bool) -> Self {
131		self.enabled = enabled;
132		self
133	}
134
135	/// Set an authentication guard for documentation endpoints
136	///
137	/// The guard function receives a reference to the incoming request and
138	/// should return `true` to allow access or `false` to deny with HTTP 403
139	/// Forbidden.
140	///
141	/// The guard is only checked when `enabled` is `true`. When `enabled` is
142	/// `false`, endpoints return 404 regardless of the guard.
143	///
144	/// # Example
145	///
146	/// ```rust,ignore
147	/// use reinhardt_openapi::OpenApiRouter;
148	/// use reinhardt_urls::routers::BasicRouter;
149	///
150	/// let router = BasicRouter::new();
151	/// let wrapped = OpenApiRouter::wrap(router)?.auth_guard(|request| {
152	///     // Check for API key in header
153	///     request.headers().get("X-Api-Key")
154	///         .map(|v| v == "secret")
155	///         .unwrap_or(false)
156	/// });
157	/// ```
158	// Fixes #828
159	pub fn auth_guard(mut self, guard: impl Fn(&Request) -> bool + Send + Sync + 'static) -> Self {
160		self.auth_guard = Some(Arc::new(guard));
161		self
162	}
163
164	/// Get a reference to the wrapped handler
165	pub fn inner(&self) -> &H {
166		&self.inner
167	}
168
169	/// Check access control for documentation endpoints.
170	///
171	/// Returns `None` if access is allowed, or `Some(Response)` with the
172	/// appropriate error status if access is denied.
173	// Fixes #828
174	fn check_access(&self, request: &Request) -> Option<Response> {
175		if !self.enabled {
176			return Some(Response::not_found());
177		}
178		if let Some(ref guard) = self.auth_guard
179			&& !guard(request)
180		{
181			return Some(Response::forbidden());
182		}
183		None
184	}
185
186	/// Try to serve an OpenAPI documentation endpoint.
187	///
188	/// Returns `Some(Ok(Response))` if the request path matches an OpenAPI
189	/// endpoint and access control checks pass, `Some(Ok(denied))` if access
190	/// is denied, or `None` if the path does not match any documentation
191	/// endpoint.
192	///
193	/// Fixes #831: Deduplicate route handling between Handler and Router.
194	fn try_serve_openapi(&self, request: &Request) -> Option<Result<Response>> {
195		match request.uri.path() {
196			"/api/openapi.json" | "/api/docs" | "/api/redoc" => {
197				if let Some(denied) = self.check_access(request) {
198					return Some(Ok(denied));
199				}
200				let response = match request.uri.path() {
201					"/api/openapi.json" => {
202						let json = (*self.openapi_json).clone();
203						Response::ok()
204							.with_header("Content-Type", "application/json; charset=utf-8")
205							.with_body(json)
206					}
207					"/api/docs" => {
208						let html = (*self.swagger_html).clone();
209						Response::ok()
210							.with_header("Content-Type", "text/html; charset=utf-8")
211							.with_body(html)
212					}
213					"/api/redoc" => {
214						let html = (*self.redoc_html).clone();
215						Response::ok()
216							.with_header("Content-Type", "text/html; charset=utf-8")
217							.with_body(html)
218					}
219					_ => unreachable!(),
220				};
221				Some(Ok(Self::apply_security_headers(response)))
222			}
223			_ => None,
224		}
225	}
226
227	/// Apply security headers to documentation endpoint responses.
228	///
229	/// Adds Content-Security-Policy, X-Frame-Options, X-Content-Type-Options,
230	/// and Cache-Control headers to prevent clickjacking, MIME sniffing,
231	/// and stale cache attacks on documentation pages.
232	// Fixes #830
233	fn apply_security_headers(response: Response) -> Response {
234		response
235			.with_header(
236				"Content-Security-Policy",
237				"default-src 'none'; \
238				 script-src 'unsafe-inline' https://unpkg.com https://cdn.redoc.ly; \
239				 style-src 'unsafe-inline' https://unpkg.com; \
240				 img-src 'self' data:; \
241				 connect-src 'self'; \
242				 font-src https://fonts.gstatic.com; \
243				 frame-ancestors 'none'",
244			)
245			.with_header("X-Frame-Options", "DENY")
246			.with_header("X-Content-Type-Options", "nosniff")
247			.with_header("Cache-Control", "no-store")
248	}
249}
250
251#[async_trait]
252impl<H: Handler> Handler for OpenApiRouter<H> {
253	/// Handle requests, intercepting OpenAPI documentation paths
254	///
255	/// Requests to `/api/openapi.json`, `/api/docs`, or `/api/redoc`
256	/// are served from memory if access control checks pass. All other
257	/// requests are delegated to the wrapped handler.
258	///
259	/// Access control is enforced via the `enabled` flag and optional
260	/// auth guard. Disabled endpoints return 404, unauthorized requests
261	/// return 403.
262	async fn handle(&self, request: Request) -> Result<Response> {
263		// Fixes #831: Use shared OpenAPI serving logic
264		if let Some(response) = self.try_serve_openapi(&request) {
265			return response;
266		}
267		self.inner.handle(request).await
268	}
269}
270
271/// Router trait implementation for OpenApiRouter
272///
273/// This implementation allows OpenApiRouter to be used where Router trait
274/// is required. However, routes cannot be modified after wrapping - use
275/// `add_route()` and `include()` on the base router before wrapping.
276impl<H> Router for OpenApiRouter<H>
277where
278	H: Handler + Router,
279{
280	/// Add a route to the router
281	///
282	/// # Panics
283	///
284	/// This method always panics. Routes must be added to the base router
285	/// before wrapping with `OpenApiRouter::wrap()`.
286	fn add_route(&mut self, _route: Route) {
287		panic!(
288			"Cannot add routes to OpenApiRouter after wrapping. \
289             Add routes to the base router before calling OpenApiRouter::wrap()."
290		);
291	}
292
293	/// Include routes with a prefix
294	///
295	/// # Panics
296	///
297	/// This method always panics. Routes must be mounted in the base router
298	/// before wrapping with `OpenApiRouter::wrap()`.
299	fn mount(&mut self, _prefix: &str, _routes: Vec<Route>, _namespace: Option<String>) {
300		panic!(
301			"Cannot mount routes in OpenApiRouter after wrapping. \
302             Mount routes in the base router before calling OpenApiRouter::wrap()."
303		);
304	}
305
306	/// Route a request through the OpenAPI wrapper
307	///
308	/// OpenAPI documentation endpoints (`/api/openapi.json`, `/api/docs`,
309	/// `/api/redoc`) are handled directly if access control checks pass.
310	/// All other requests are delegated to the wrapped router's `route()`
311	/// method.
312	///
313	/// Access control is enforced via the `enabled` flag and optional
314	/// auth guard. Disabled endpoints return 404, unauthorized requests
315	/// return 403.
316	async fn route(&self, request: Request) -> Result<Response> {
317		// Fixes #831: Use shared OpenAPI serving logic
318		if let Some(response) = self.try_serve_openapi(&request) {
319			return response;
320		}
321		self.inner.route(request).await
322	}
323}
324
325#[cfg(test)]
326mod tests {
327	use super::*;
328	use hyper::StatusCode;
329	use rstest::rstest;
330
331	struct DummyHandler;
332
333	#[async_trait]
334	impl Handler for DummyHandler {
335		async fn handle(&self, _request: Request) -> Result<Response> {
336			Ok(Response::new(StatusCode::OK).with_body("Hello from inner handler"))
337		}
338	}
339
340	#[rstest]
341	#[tokio::test]
342	async fn test_openapi_json_endpoint() {
343		// Arrange
344		let handler = DummyHandler;
345		let wrapped = OpenApiRouter::wrap(handler).unwrap();
346
347		// Act
348		let request = Request::builder().uri("/api/openapi.json").build().unwrap();
349		let response = wrapped.handle(request).await.unwrap();
350
351		// Assert
352		assert_eq!(response.status, StatusCode::OK);
353		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
354		assert!(body_str.contains("openapi"));
355		assert!(body_str.contains("3.")); // OpenAPI version (3.0 or 3.1)
356	}
357
358	#[rstest]
359	#[tokio::test]
360	async fn test_swagger_docs_endpoint() {
361		// Arrange
362		let handler = DummyHandler;
363		let wrapped = OpenApiRouter::wrap(handler).unwrap();
364
365		// Act
366		let request = Request::builder().uri("/api/docs").build().unwrap();
367		let response = wrapped.handle(request).await.unwrap();
368
369		// Assert
370		assert_eq!(response.status, StatusCode::OK);
371		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
372		assert!(body_str.contains("swagger-ui"));
373	}
374
375	#[rstest]
376	#[tokio::test]
377	async fn test_redoc_docs_endpoint() {
378		// Arrange
379		let handler = DummyHandler;
380		let wrapped = OpenApiRouter::wrap(handler).unwrap();
381
382		// Act
383		let request = Request::builder().uri("/api/redoc").build().unwrap();
384		let response = wrapped.handle(request).await.unwrap();
385
386		// Assert
387		assert_eq!(response.status, StatusCode::OK);
388		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
389		assert!(body_str.contains("redoc"));
390	}
391
392	#[rstest]
393	#[tokio::test]
394	async fn test_delegation_to_inner_handler() {
395		// Arrange
396		let handler = DummyHandler;
397		let wrapped = OpenApiRouter::wrap(handler).unwrap();
398
399		// Act
400		let request = Request::builder().uri("/some/other/path").build().unwrap();
401		let response = wrapped.handle(request).await.unwrap();
402
403		// Assert
404		assert_eq!(response.status, StatusCode::OK);
405		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
406		assert_eq!(body_str, "Hello from inner handler");
407	}
408
409	// Fixes #828: Access control tests
410
411	#[rstest]
412	#[case("/api/openapi.json")]
413	#[case("/api/docs")]
414	#[case("/api/redoc")]
415	#[tokio::test]
416	async fn test_disabled_endpoints_return_404(#[case] path: &str) {
417		// Arrange
418		let handler = DummyHandler;
419		let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
420
421		// Act
422		let request = Request::builder().uri(path).build().unwrap();
423		let response = wrapped.handle(request).await.unwrap();
424
425		// Assert
426		assert_eq!(response.status, StatusCode::NOT_FOUND);
427	}
428
429	#[rstest]
430	#[tokio::test]
431	async fn test_disabled_does_not_affect_other_routes() {
432		// Arrange
433		let handler = DummyHandler;
434		let wrapped = OpenApiRouter::wrap(handler).unwrap().enabled(false);
435
436		// Act
437		let request = Request::builder().uri("/some/other/path").build().unwrap();
438		let response = wrapped.handle(request).await.unwrap();
439
440		// Assert
441		assert_eq!(response.status, StatusCode::OK);
442		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
443		assert_eq!(body_str, "Hello from inner handler");
444	}
445
446	#[rstest]
447	#[case("/api/openapi.json")]
448	#[case("/api/docs")]
449	#[case("/api/redoc")]
450	#[tokio::test]
451	async fn test_auth_guard_rejects_unauthorized(#[case] path: &str) {
452		// Arrange
453		let handler = DummyHandler;
454		let wrapped = OpenApiRouter::wrap(handler)
455			.unwrap()
456			.auth_guard(|_request| false);
457
458		// Act
459		let request = Request::builder().uri(path).build().unwrap();
460		let response = wrapped.handle(request).await.unwrap();
461
462		// Assert
463		assert_eq!(response.status, StatusCode::FORBIDDEN);
464	}
465
466	#[rstest]
467	#[case("/api/openapi.json")]
468	#[case("/api/docs")]
469	#[case("/api/redoc")]
470	#[tokio::test]
471	async fn test_auth_guard_allows_authorized(#[case] path: &str) {
472		// Arrange
473		let handler = DummyHandler;
474		let wrapped = OpenApiRouter::wrap(handler)
475			.unwrap()
476			.auth_guard(|_request| true);
477
478		// Act
479		let request = Request::builder().uri(path).build().unwrap();
480		let response = wrapped.handle(request).await.unwrap();
481
482		// Assert
483		assert_eq!(response.status, StatusCode::OK);
484	}
485
486	#[rstest]
487	#[tokio::test]
488	async fn test_auth_guard_does_not_affect_other_routes() {
489		// Arrange
490		let handler = DummyHandler;
491		let wrapped = OpenApiRouter::wrap(handler)
492			.unwrap()
493			.auth_guard(|_request| false);
494
495		// Act
496		let request = Request::builder().uri("/some/other/path").build().unwrap();
497		let response = wrapped.handle(request).await.unwrap();
498
499		// Assert
500		assert_eq!(response.status, StatusCode::OK);
501		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
502		assert_eq!(body_str, "Hello from inner handler");
503	}
504
505	#[rstest]
506	#[case("/api/openapi.json")]
507	#[case("/api/docs")]
508	#[case("/api/redoc")]
509	#[tokio::test]
510	async fn test_disabled_takes_precedence_over_auth_guard(#[case] path: &str) {
511		// Arrange: enabled=false should return 404 even with a passing auth guard
512		let handler = DummyHandler;
513		let wrapped = OpenApiRouter::wrap(handler)
514			.unwrap()
515			.enabled(false)
516			.auth_guard(|_request| true);
517
518		// Act
519		let request = Request::builder().uri(path).build().unwrap();
520		let response = wrapped.handle(request).await.unwrap();
521
522		// Assert: Should be 404 (disabled), not 200 (auth passed)
523		assert_eq!(response.status, StatusCode::NOT_FOUND);
524	}
525
526	#[rstest]
527	#[tokio::test]
528	async fn test_openapi_json_response_body_is_valid_openapi_json() {
529		// Arrange
530		let handler = DummyHandler;
531		let wrapped = OpenApiRouter::wrap(handler).unwrap();
532
533		// Act
534		let request = Request::builder().uri("/api/openapi.json").build().unwrap();
535		let response = wrapped.handle(request).await.unwrap();
536
537		// Assert: body is valid JSON with an openapi version field starting with "3."
538		assert_eq!(response.status, StatusCode::OK);
539		let body_bytes = response.body.to_vec();
540		let json: serde_json::Value =
541			serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON");
542		let openapi_version = json["openapi"]
543			.as_str()
544			.expect("JSON should have an 'openapi' string field");
545		assert!(
546			openapi_version.starts_with("3."),
547			"openapi field should start with '3.', got: {}",
548			openapi_version
549		);
550	}
551
552	#[rstest]
553	#[tokio::test]
554	async fn test_openapi_json_response_content_type_header() {
555		// Arrange
556		let handler = DummyHandler;
557		let wrapped = OpenApiRouter::wrap(handler).unwrap();
558
559		// Act
560		let request = Request::builder().uri("/api/openapi.json").build().unwrap();
561		let response = wrapped.handle(request).await.unwrap();
562
563		// Assert: Content-Type header contains application/json
564		assert_eq!(response.status, StatusCode::OK);
565		let content_type = response
566			.headers
567			.get("Content-Type")
568			.and_then(|v| v.to_str().ok())
569			.unwrap_or("");
570		assert!(
571			content_type.contains("application/json"),
572			"Content-Type should contain 'application/json', got: {}",
573			content_type
574		);
575	}
576
577	#[rstest]
578	#[tokio::test]
579	async fn test_swagger_docs_response_body_contains_swagger_ui_marker() {
580		// Arrange
581		let handler = DummyHandler;
582		let wrapped = OpenApiRouter::wrap(handler).unwrap();
583
584		// Act
585		let request = Request::builder().uri("/api/docs").build().unwrap();
586		let response = wrapped.handle(request).await.unwrap();
587
588		// Assert: HTML body contains the swagger-ui marker
589		assert_eq!(response.status, StatusCode::OK);
590		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
591		assert!(
592			body_str.contains("swagger-ui"),
593			"Swagger docs HTML should contain 'swagger-ui'"
594		);
595	}
596
597	#[rstest]
598	#[tokio::test]
599	async fn test_redoc_docs_response_body_contains_redoc_marker() {
600		// Arrange
601		let handler = DummyHandler;
602		let wrapped = OpenApiRouter::wrap(handler).unwrap();
603
604		// Act
605		let request = Request::builder().uri("/api/redoc").build().unwrap();
606		let response = wrapped.handle(request).await.unwrap();
607
608		// Assert: HTML body contains the redoc marker (case-insensitive)
609		assert_eq!(response.status, StatusCode::OK);
610		let body_str = String::from_utf8(response.body.to_vec())
611			.unwrap()
612			.to_lowercase();
613		assert!(
614			body_str.contains("redoc"),
615			"Redoc docs HTML should contain 'redoc' (case-insensitive)"
616		);
617	}
618
619	#[rstest]
620	#[tokio::test]
621	async fn test_auth_guard_inspects_request_headers() {
622		// Arrange: Guard checks for a specific header value
623		let handler = DummyHandler;
624		let wrapped = OpenApiRouter::wrap(handler).unwrap().auth_guard(|request| {
625			request
626				.headers
627				.get("X-Docs-Token")
628				.and_then(|v| v.to_str().ok())
629				.map(|v| v == "valid-token")
630				.unwrap_or(false)
631		});
632
633		// Act: Request without token
634		let request_no_token = Request::builder().uri("/api/docs").build().unwrap();
635		let response_no_token = wrapped.handle(request_no_token).await.unwrap();
636
637		// Assert: Should be forbidden
638		assert_eq!(response_no_token.status, StatusCode::FORBIDDEN);
639
640		// Act: Request with valid token
641		let request_valid = Request::builder()
642			.uri("/api/docs")
643			.header("X-Docs-Token", "valid-token")
644			.build()
645			.unwrap();
646		let response_valid = wrapped.handle(request_valid).await.unwrap();
647
648		// Assert: Should be OK
649		assert_eq!(response_valid.status, StatusCode::OK);
650
651		// Act: Request with invalid token
652		let request_invalid = Request::builder()
653			.uri("/api/docs")
654			.header("X-Docs-Token", "wrong-token")
655			.build()
656			.unwrap();
657		let response_invalid = wrapped.handle(request_invalid).await.unwrap();
658
659		// Assert: Should be forbidden
660		assert_eq!(response_invalid.status, StatusCode::FORBIDDEN);
661	}
662}