reinhardt_middleware/
xframe.rs1use async_trait::async_trait;
6use hyper::header::HeaderName;
7use reinhardt_http::{Handler, Middleware, Request, Response, Result};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum XFrameOptions {
13 Deny,
15 SameOrigin,
17}
18
19impl XFrameOptions {
20 pub fn as_str(&self) -> &'static str {
34 match self {
35 XFrameOptions::Deny => "DENY",
36 XFrameOptions::SameOrigin => "SAMEORIGIN",
37 }
38 }
39}
40
41pub struct XFrameOptionsMiddleware {
43 option: XFrameOptions,
44}
45
46impl XFrameOptionsMiddleware {
47 pub fn deny() -> Self {
87 Self {
88 option: XFrameOptions::Deny,
89 }
90 }
91 pub fn same_origin() -> Self {
131 Self {
132 option: XFrameOptions::SameOrigin,
133 }
134 }
135 pub fn new(option: XFrameOptions) -> Self {
177 Self { option }
178 }
179}
180
181impl Default for XFrameOptionsMiddleware {
182 fn default() -> Self {
183 Self::same_origin()
184 }
185}
186
187const X_FRAME_OPTIONS: HeaderName = HeaderName::from_static("x-frame-options");
188
189#[async_trait]
190impl Middleware for XFrameOptionsMiddleware {
191 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
192 let mut response = match handler.handle(request).await {
195 Ok(resp) => resp,
196 Err(e) => Response::from(e),
197 };
198
199 if !response.headers.contains_key(&X_FRAME_OPTIONS) {
201 let header_value = match self.option {
202 XFrameOptions::Deny => hyper::header::HeaderValue::from_static("DENY"),
203 XFrameOptions::SameOrigin => hyper::header::HeaderValue::from_static("SAMEORIGIN"),
204 };
205 response.headers.insert(X_FRAME_OPTIONS, header_value);
206 }
207
208 Ok(response)
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use bytes::Bytes;
216 use hyper::{HeaderMap, Method, StatusCode, Version};
217 use reinhardt_http::Error;
218 use rstest::rstest;
219
220 struct TestHandler;
221
222 #[async_trait]
223 impl Handler for TestHandler {
224 async fn handle(&self, _request: Request) -> Result<Response> {
225 Ok(Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..])))
226 }
227 }
228
229 #[tokio::test]
230 async fn test_deny_option() {
231 let middleware = XFrameOptionsMiddleware::deny();
232 let handler = Arc::new(TestHandler);
233 let request = Request::builder()
234 .method(Method::GET)
235 .uri("/test")
236 .version(Version::HTTP_11)
237 .headers(HeaderMap::new())
238 .body(Bytes::new())
239 .build()
240 .unwrap();
241
242 let response = middleware.process(request, handler).await.unwrap();
243
244 assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
245 }
246
247 #[tokio::test]
248 async fn test_same_origin_option() {
249 let middleware = XFrameOptionsMiddleware::same_origin();
250 let handler = Arc::new(TestHandler);
251 let request = Request::builder()
252 .method(Method::GET)
253 .uri("/test")
254 .version(Version::HTTP_11)
255 .headers(HeaderMap::new())
256 .body(Bytes::new())
257 .build()
258 .unwrap();
259
260 let response = middleware.process(request, handler).await.unwrap();
261
262 assert_eq!(
263 response.headers.get(&X_FRAME_OPTIONS).unwrap(),
264 "SAMEORIGIN"
265 );
266 }
267
268 #[tokio::test]
269 async fn test_default_is_same_origin() {
270 let middleware = XFrameOptionsMiddleware::default();
271 let handler = Arc::new(TestHandler);
272 let request = Request::builder()
273 .method(Method::GET)
274 .uri("/test")
275 .version(Version::HTTP_11)
276 .headers(HeaderMap::new())
277 .body(Bytes::new())
278 .build()
279 .unwrap();
280
281 let response = middleware.process(request, handler).await.unwrap();
282
283 assert_eq!(
284 response.headers.get(&X_FRAME_OPTIONS).unwrap(),
285 "SAMEORIGIN"
286 );
287 }
288
289 #[tokio::test]
290 async fn test_does_not_override_existing_header() {
291 struct TestHandlerWithHeader;
292
293 #[async_trait]
294 impl Handler for TestHandlerWithHeader {
295 async fn handle(&self, _request: Request) -> Result<Response> {
296 let mut response =
297 Response::new(StatusCode::OK).with_body(Bytes::from(&b"test"[..]));
298 response
299 .headers
300 .insert(X_FRAME_OPTIONS, "DENY".parse().unwrap());
301 Ok(response)
302 }
303 }
304
305 let middleware = XFrameOptionsMiddleware::same_origin();
306 let handler = Arc::new(TestHandlerWithHeader);
307 let request = Request::builder()
308 .method(Method::GET)
309 .uri("/test")
310 .version(Version::HTTP_11)
311 .headers(HeaderMap::new())
312 .body(Bytes::new())
313 .build()
314 .unwrap();
315
316 let response = middleware.process(request, handler).await.unwrap();
317
318 assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
320 }
321
322 #[tokio::test]
323 async fn test_new_constructor_with_deny() {
324 let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
325 let handler = Arc::new(TestHandler);
326 let request = Request::builder()
327 .method(Method::GET)
328 .uri("/secure")
329 .version(Version::HTTP_11)
330 .headers(HeaderMap::new())
331 .body(Bytes::new())
332 .build()
333 .unwrap();
334
335 let response = middleware.process(request, handler).await.unwrap();
336 assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
337 }
338
339 #[tokio::test]
340 async fn test_new_constructor_with_same_origin() {
341 let middleware = XFrameOptionsMiddleware::new(XFrameOptions::SameOrigin);
342 let handler = Arc::new(TestHandler);
343 let request = Request::builder()
344 .method(Method::GET)
345 .uri("/dashboard")
346 .version(Version::HTTP_11)
347 .headers(HeaderMap::new())
348 .body(Bytes::new())
349 .build()
350 .unwrap();
351
352 let response = middleware.process(request, handler).await.unwrap();
353 assert_eq!(
354 response.headers.get(&X_FRAME_OPTIONS).unwrap(),
355 "SAMEORIGIN"
356 );
357 }
358
359 #[tokio::test]
360 async fn test_response_body_preserved() {
361 struct TestHandlerWithBody;
362
363 #[async_trait]
364 impl Handler for TestHandlerWithBody {
365 async fn handle(&self, _request: Request) -> Result<Response> {
366 Ok(Response::new(StatusCode::OK)
367 .with_body(Bytes::from(&b"custom response body"[..])))
368 }
369 }
370
371 let middleware = XFrameOptionsMiddleware::deny();
372 let handler = Arc::new(TestHandlerWithBody);
373 let request = Request::builder()
374 .method(Method::GET)
375 .uri("/content")
376 .version(Version::HTTP_11)
377 .headers(HeaderMap::new())
378 .body(Bytes::new())
379 .build()
380 .unwrap();
381
382 let response = middleware.process(request, handler).await.unwrap();
383
384 assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
386 assert_eq!(response.body, Bytes::from(&b"custom response body"[..]));
388 }
389
390 #[tokio::test]
391 async fn test_middleware_reusable_across_requests() {
392 let middleware = XFrameOptionsMiddleware::deny();
393 let handler = Arc::new(TestHandler);
394
395 let request1 = Request::builder()
397 .method(Method::GET)
398 .uri("/page1")
399 .version(Version::HTTP_11)
400 .headers(HeaderMap::new())
401 .body(Bytes::new())
402 .build()
403 .unwrap();
404 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
405 assert_eq!(response1.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
406
407 let request2 = Request::builder()
409 .method(Method::POST)
410 .uri("/page2")
411 .version(Version::HTTP_11)
412 .headers(HeaderMap::new())
413 .body(Bytes::new())
414 .build()
415 .unwrap();
416 let response2 = middleware.process(request2, handler.clone()).await.unwrap();
417 assert_eq!(response2.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
418
419 let request3 = Request::builder()
421 .method(Method::PUT)
422 .uri("/page3")
423 .version(Version::HTTP_11)
424 .headers(HeaderMap::new())
425 .body(Bytes::new())
426 .build()
427 .unwrap();
428 let response3 = middleware.process(request3, handler).await.unwrap();
429 assert_eq!(response3.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
430 }
431
432 struct ErrorHandler;
434
435 #[async_trait]
436 impl Handler for ErrorHandler {
437 async fn handle(&self, _request: Request) -> Result<Response> {
438 Err(Error::Http("handler error".to_string()))
439 }
440 }
441
442 #[rstest]
443 #[tokio::test]
444 async fn test_xframe_header_applied_on_handler_error() {
445 let middleware = XFrameOptionsMiddleware::new(XFrameOptions::Deny);
447 let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
448
449 let request = Request::builder()
450 .method(Method::GET)
451 .uri("/test")
452 .version(Version::HTTP_11)
453 .headers(HeaderMap::new())
454 .body(Bytes::new())
455 .build()
456 .unwrap();
457
458 let response = middleware.process(request, handler).await.unwrap();
460
461 assert!(response.status.is_client_error() || response.status.is_server_error());
463 assert_eq!(response.headers.get(&X_FRAME_OPTIONS).unwrap(), "DENY");
464 }
465}