Skip to main content

reinhardt_middleware/
xframe.rs

1//! X-Frame-Options Middleware
2//!
3//! Provides clickjacking protection by setting the X-Frame-Options header.
4
5use async_trait::async_trait;
6use hyper::header::HeaderName;
7use reinhardt_http::{Handler, Middleware, Request, Response, Result};
8use std::sync::Arc;
9
10/// X-Frame-Options values
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum XFrameOptions {
13	/// DENY - The page cannot be displayed in a frame
14	Deny,
15	/// SAMEORIGIN - The page can only be displayed in a frame on the same origin
16	SameOrigin,
17}
18
19impl XFrameOptions {
20	/// Convert to header value string
21	///
22	/// # Examples
23	///
24	/// ```
25	/// use reinhardt_middleware::XFrameOptions;
26	///
27	/// let deny = XFrameOptions::Deny;
28	/// assert_eq!(deny.as_str(), "DENY");
29	///
30	/// let same_origin = XFrameOptions::SameOrigin;
31	/// assert_eq!(same_origin.as_str(), "SAMEORIGIN");
32	/// ```
33	pub fn as_str(&self) -> &'static str {
34		match self {
35			XFrameOptions::Deny => "DENY",
36			XFrameOptions::SameOrigin => "SAMEORIGIN",
37		}
38	}
39}
40
41/// X-Frame-Options middleware for clickjacking protection
42pub struct XFrameOptionsMiddleware {
43	option: XFrameOptions,
44}
45
46impl XFrameOptionsMiddleware {
47	/// Create middleware with DENY option
48	///
49	/// Prevents the page from being displayed in any frame, providing maximum clickjacking protection.
50	///
51	/// # Examples
52	///
53	/// ```
54	/// use std::sync::Arc;
55	/// use reinhardt_middleware::XFrameOptionsMiddleware;
56	/// use reinhardt_http::{Handler, Middleware, Request, Response};
57	/// use hyper::{StatusCode, Method, Version, HeaderMap};
58	/// use bytes::Bytes;
59	///
60	/// struct TestHandler;
61	///
62	/// #[async_trait::async_trait]
63	/// impl Handler for TestHandler {
64	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
65	///         Ok(Response::new(StatusCode::OK))
66	///     }
67	/// }
68	///
69	/// # tokio_test::block_on(async {
70	/// let middleware = XFrameOptionsMiddleware::deny();
71	/// let handler = Arc::new(TestHandler);
72	///
73	/// let request = Request::builder()
74	///     .method(Method::GET)
75	///     .uri("/secure-page")
76	///     .version(Version::HTTP_11)
77	///     .headers(HeaderMap::new())
78	///     .body(Bytes::new())
79	///     .build()
80	///     .unwrap();
81	///
82	/// let response = middleware.process(request, handler).await.unwrap();
83	/// assert_eq!(response.headers.get("X-Frame-Options").unwrap(), "DENY");
84	/// # });
85	/// ```
86	pub fn deny() -> Self {
87		Self {
88			option: XFrameOptions::Deny,
89		}
90	}
91	/// Create middleware with SAMEORIGIN option
92	///
93	/// Allows the page to be framed only by pages from the same origin.
94	///
95	/// # Examples
96	///
97	/// ```
98	/// use std::sync::Arc;
99	/// use reinhardt_middleware::XFrameOptionsMiddleware;
100	/// use reinhardt_http::{Handler, Middleware, Request, Response};
101	/// use hyper::{StatusCode, Method, Version, HeaderMap};
102	/// use bytes::Bytes;
103	///
104	/// struct TestHandler;
105	///
106	/// #[async_trait::async_trait]
107	/// impl Handler for TestHandler {
108	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
109	///         Ok(Response::new(StatusCode::OK))
110	///     }
111	/// }
112	///
113	/// # tokio_test::block_on(async {
114	/// let middleware = XFrameOptionsMiddleware::same_origin();
115	/// let handler = Arc::new(TestHandler);
116	///
117	/// let request = Request::builder()
118	///     .method(Method::GET)
119	///     .uri("/dashboard")
120	///     .version(Version::HTTP_11)
121	///     .headers(HeaderMap::new())
122	///     .body(Bytes::new())
123	///     .build()
124	///     .unwrap();
125	///
126	/// let response = middleware.process(request, handler).await.unwrap();
127	/// assert_eq!(response.headers.get("X-Frame-Options").unwrap(), "SAMEORIGIN");
128	/// # });
129	/// ```
130	pub fn same_origin() -> Self {
131		Self {
132			option: XFrameOptions::SameOrigin,
133		}
134	}
135	/// Create middleware with custom option
136	///
137	/// # Arguments
138	///
139	/// * `option` - The X-Frame-Options value to use
140	///
141	/// # Examples
142	///
143	/// ```
144	/// use std::sync::Arc;
145	/// use reinhardt_middleware::{XFrameOptionsMiddleware, XFrameOptions};
146	/// use reinhardt_http::{Handler, Middleware, Request, Response};
147	/// use hyper::{StatusCode, Method, Version, HeaderMap};
148	/// use bytes::Bytes;
149	///
150	/// struct TestHandler;
151	///
152	/// #[async_trait::async_trait]
153	/// impl Handler for TestHandler {
154	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
155	///         Ok(Response::new(StatusCode::OK))
156	///     }
157	/// }
158	///
159	/// # tokio_test::block_on(async {
160	/// let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
161	/// let handler = Arc::new(TestHandler);
162	///
163	/// let request = Request::builder()
164	///     .method(Method::GET)
165	///     .uri("/admin")
166	///     .version(Version::HTTP_11)
167	///     .headers(HeaderMap::new())
168	///     .body(Bytes::new())
169	///     .build()
170	///     .unwrap();
171	///
172	/// let response = middleware.process(request, handler).await.unwrap();
173	/// assert_eq!(response.headers.get("X-Frame-Options").unwrap(), "DENY");
174	/// # });
175	/// ```
176	pub fn new(option: XFrameOptions) -> Self {
177		Self { option }
178	}
179}
180
181impl Default for XFrameOptionsMiddleware {
182	fn default() -> Self {
183		Self::same_origin()
184	}
185}
186
187const X_FRAME_OPTIONS: HeaderName = HeaderName::from_static("x-frame-options");
188
189#[async_trait]
190impl Middleware for XFrameOptionsMiddleware {
191	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
192		// Convert errors to responses so post-processing (e.g., security headers)
193		// always runs, even when invoked outside MiddlewareChain. (#3244)
194		let mut response = match handler.handle(request).await {
195			Ok(resp) => resp,
196			Err(e) => Response::from(e),
197		};
198
199		// Only add header if not already present
200		if !response.headers.contains_key(&X_FRAME_OPTIONS) {
201			let header_value = match self.option {
202				XFrameOptions::Deny => hyper::header::HeaderValue::from_static("DENY"),
203				XFrameOptions::SameOrigin => hyper::header::HeaderValue::from_static("SAMEORIGIN"),
204			};
205			response.headers.insert(X_FRAME_OPTIONS, header_value);
206		}
207
208		Ok(response)
209	}
210}
211
212#[cfg(test)]
213mod tests {
214	use super::*;
215	use bytes::Bytes;
216	use hyper::{HeaderMap, Method, StatusCode, Version};
217	use reinhardt_http::Error;
218	use rstest::rstest;
219
220	struct TestHandler;
221
222	#[async_trait]
223	impl Handler for TestHandler {
224		async fn handle(&self, _request: Request) -> Result<Response> {
225			Ok(Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..])))
226		}
227	}
228
229	#[tokio::test]
230	async fn test_deny_option() {
231		let middleware = XFrameOptionsMiddleware::deny();
232		let handler = Arc::new(TestHandler);
233		let request = Request::builder()
234			.method(Method::GET)
235			.uri("/test")
236			.version(Version::HTTP_11)
237			.headers(HeaderMap::new())
238			.body(Bytes::new())
239			.build()
240			.unwrap();
241
242		let response = middleware.process(request, handler).await.unwrap();
243
244		assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
245	}
246
247	#[tokio::test]
248	async fn test_same_origin_option() {
249		let middleware = XFrameOptionsMiddleware::same_origin();
250		let handler = Arc::new(TestHandler);
251		let request = Request::builder()
252			.method(Method::GET)
253			.uri("/test")
254			.version(Version::HTTP_11)
255			.headers(HeaderMap::new())
256			.body(Bytes::new())
257			.build()
258			.unwrap();
259
260		let response = middleware.process(request, handler).await.unwrap();
261
262		assert_eq!(
263			response.headers.get(&X_FRAME_OPTIONS).unwrap(),
264			"SAMEORIGIN"
265		);
266	}
267
268	#[tokio::test]
269	async fn test_default_is_same_origin() {
270		let middleware = XFrameOptionsMiddleware::default();
271		let handler = Arc::new(TestHandler);
272		let request = Request::builder()
273			.method(Method::GET)
274			.uri("/test")
275			.version(Version::HTTP_11)
276			.headers(HeaderMap::new())
277			.body(Bytes::new())
278			.build()
279			.unwrap();
280
281		let response = middleware.process(request, handler).await.unwrap();
282
283		assert_eq!(
284			response.headers.get(&X_FRAME_OPTIONS).unwrap(),
285			"SAMEORIGIN"
286		);
287	}
288
289	#[tokio::test]
290	async fn test_does_not_override_existing_header() {
291		struct TestHandlerWithHeader;
292
293		#[async_trait]
294		impl Handler for TestHandlerWithHeader {
295			async fn handle(&self, _request: Request) -> Result<Response> {
296				let mut response =
297					Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..]));
298				response
299					.headers
300					.insert(X_FRAME_OPTIONS, "DENY".parse().unwrap());
301				Ok(response)
302			}
303		}
304
305		let middleware = XFrameOptionsMiddleware::same_origin();
306		let handler = Arc::new(TestHandlerWithHeader);
307		let request = Request::builder()
308			.method(Method::GET)
309			.uri("/test")
310			.version(Version::HTTP_11)
311			.headers(HeaderMap::new())
312			.body(Bytes::new())
313			.build()
314			.unwrap();
315
316		let response = middleware.process(request, handler).await.unwrap();
317
318		// Should keep the original DENY value
319		assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
320	}
321
322	#[tokio::test]
323	async fn test_new_constructor_with_deny() {
324		let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
325		let handler = Arc::new(TestHandler);
326		let request = Request::builder()
327			.method(Method::GET)
328			.uri("/secure")
329			.version(Version::HTTP_11)
330			.headers(HeaderMap::new())
331			.body(Bytes::new())
332			.build()
333			.unwrap();
334
335		let response = middleware.process(request, handler).await.unwrap();
336		assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
337	}
338
339	#[tokio::test]
340	async fn test_new_constructor_with_same_origin() {
341		let middleware = XFrameOptionsMiddleware::new(XFrameOptions::SameOrigin);
342		let handler = Arc::new(TestHandler);
343		let request = Request::builder()
344			.method(Method::GET)
345			.uri("/dashboard")
346			.version(Version::HTTP_11)
347			.headers(HeaderMap::new())
348			.body(Bytes::new())
349			.build()
350			.unwrap();
351
352		let response = middleware.process(request, handler).await.unwrap();
353		assert_eq!(
354			response.headers.get(&X_FRAME_OPTIONS).unwrap(),
355			"SAMEORIGIN"
356		);
357	}
358
359	#[tokio::test]
360	async fn test_response_body_preserved() {
361		struct TestHandlerWithBody;
362
363		#[async_trait]
364		impl Handler for TestHandlerWithBody {
365			async fn handle(&self, _request: Request) -> Result<Response> {
366				Ok(Response::new(StatusCode::OK)
367					.with_body(Bytes::from(&b"custom response body"[..])))
368			}
369		}
370
371		let middleware = XFrameOptionsMiddleware::deny();
372		let handler = Arc::new(TestHandlerWithBody);
373		let request = Request::builder()
374			.method(Method::GET)
375			.uri("/content")
376			.version(Version::HTTP_11)
377			.headers(HeaderMap::new())
378			.body(Bytes::new())
379			.build()
380			.unwrap();
381
382		let response = middleware.process(request, handler).await.unwrap();
383
384		// Header should be added
385		assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
386		// Body should be preserved
387		assert_eq!(response.body, Bytes::from(&b"custom response body"[..]));
388	}
389
390	#[tokio::test]
391	async fn test_middleware_reusable_across_requests() {
392		let middleware = XFrameOptionsMiddleware::deny();
393		let handler = Arc::new(TestHandler);
394
395		// First request
396		let request1 = Request::builder()
397			.method(Method::GET)
398			.uri("/page1")
399			.version(Version::HTTP_11)
400			.headers(HeaderMap::new())
401			.body(Bytes::new())
402			.build()
403			.unwrap();
404		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
405		assert_eq!(response1.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
406
407		// Second request
408		let request2 = Request::builder()
409			.method(Method::POST)
410			.uri("/page2")
411			.version(Version::HTTP_11)
412			.headers(HeaderMap::new())
413			.body(Bytes::new())
414			.build()
415			.unwrap();
416		let response2 = middleware.process(request2, handler.clone()).await.unwrap();
417		assert_eq!(response2.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
418
419		// Third request
420		let request3 = Request::builder()
421			.method(Method::PUT)
422			.uri("/page3")
423			.version(Version::HTTP_11)
424			.headers(HeaderMap::new())
425			.body(Bytes::new())
426			.build()
427			.unwrap();
428		let response3 = middleware.process(request3, handler).await.unwrap();
429		assert_eq!(response3.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
430	}
431
432	/// Handler that always returns an error to simulate inner handler failure.
433	struct ErrorHandler;
434
435	#[async_trait]
436	impl Handler for ErrorHandler {
437		async fn handle(&self, _request: Request) -> Result<Response> {
438			Err(Error::Http("handler error".to_string()))
439		}
440	}
441
442	#[rstest]
443	#[tokio::test]
444	async fn test_xframe_header_applied_on_handler_error() {
445		// Arrange
446		let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
447		let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
448
449		let request = Request::builder()
450			.method(Method::GET)
451			.uri("/test")
452			.version(Version::HTTP_11)
453			.headers(HeaderMap::new())
454			.body(Bytes::new())
455			.build()
456			.unwrap();
457
458		// Act
459		let response = middleware.process(request, handler).await.unwrap();
460
461		// Assert — error is converted to response with X-Frame-Options applied
462		assert!(response.status.is_client_error() || response.status.is_server_error());
463		assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
464	}
465}