Skip to main content

reinhardt_middleware/
conditional.rs

1//! Conditional GET Middleware
2//!
3//! Handles ETags and Last-Modified headers for conditional GET requests.
4
5use async_trait::async_trait;
6use bytes::Bytes;
7use chrono::{DateTime, Utc};
8use hyper::header::{
9	ETAG, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, LAST_MODIFIED,
10};
11use hyper::{Method, StatusCode};
12use reinhardt_http::{Handler, Middleware, Request, Response, Result};
13use sha2::{Digest, Sha256};
14use std::sync::Arc;
15
16/// Conditional GET middleware
17///
18/// Implements HTTP conditional requests using ETags and Last-Modified headers.
19/// - Supports If-None-Match (ETag-based)
20/// - Supports If-Modified-Since (Last-Modified-based)
21/// - Supports If-Match and If-Unmodified-Since for safe methods
22pub struct ConditionalGetMiddleware {
23	/// Whether to generate ETags automatically
24	generate_etag: bool,
25}
26
27impl ConditionalGetMiddleware {
28	/// Create a new ConditionalGetMiddleware
29	///
30	/// By default, automatic ETag generation is enabled for responses.
31	///
32	/// # Examples
33	///
34	/// ```
35	/// use std::sync::Arc;
36	/// use reinhardt_middleware::ConditionalGetMiddleware;
37	/// use reinhardt_http::{Handler, Middleware, Request, Response};
38	/// use hyper::{StatusCode, Method, Version, HeaderMap};
39	/// use bytes::Bytes;
40	///
41	/// struct TestHandler;
42	///
43	/// #[async_trait::async_trait]
44	/// impl Handler for TestHandler {
45	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
46	///         Ok(Response::new(StatusCode::OK).with_body(Bytes::from("content")))
47	///     }
48	/// }
49	///
50	/// # tokio_test::block_on(async {
51	/// let middleware = ConditionalGetMiddleware::new();
52	/// let handler = Arc::new(TestHandler);
53	///
54	/// let request = Request::builder()
55	///     .method(Method::GET)
56	///     .uri("/api/resource")
57	///     .version(Version::HTTP_11)
58	///     .headers(HeaderMap::new())
59	///     .body(Bytes::new())
60	///     .build()
61	///     .unwrap();
62	///
63	/// let response = middleware.process(request, handler).await.unwrap();
64	/// assert_eq!(response.status, StatusCode::OK);
65	/// assert!(response.headers.contains_key(hyper::header::ETAG));
66	/// # });
67	/// ```
68	pub fn new() -> Self {
69		Self {
70			generate_etag: true,
71		}
72	}
73	/// Create middleware without automatic ETag generation
74	///
75	/// Use this when you want to handle ETags manually or rely only on Last-Modified headers.
76	///
77	/// # Examples
78	///
79	/// ```
80	/// use std::sync::Arc;
81	/// use reinhardt_middleware::ConditionalGetMiddleware;
82	/// use reinhardt_http::{Handler, Middleware, Request, Response};
83	/// use hyper::{StatusCode, Method, Version, HeaderMap};
84	/// use bytes::Bytes;
85	///
86	/// struct TestHandler;
87	///
88	/// #[async_trait::async_trait]
89	/// impl Handler for TestHandler {
90	///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
91	///         let mut response = Response::new(StatusCode::OK).with_body(Bytes::from("content"));
92	///         response.headers.insert(
93	///             hyper::header::LAST_MODIFIED,
94	///             "Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap()
95	///         );
96	///         Ok(response)
97	///     }
98	/// }
99	///
100	/// # tokio_test::block_on(async {
101	/// let middleware = ConditionalGetMiddleware::without_etag();
102	/// let handler = Arc::new(TestHandler);
103	///
104	/// let request = Request::builder()
105	///     .method(Method::GET)
106	///     .uri("/api/resource")
107	///     .version(Version::HTTP_11)
108	///     .headers(HeaderMap::new())
109	///     .body(Bytes::new())
110	///     .build()
111	///     .unwrap();
112	///
113	/// let response = middleware.process(request, handler).await.unwrap();
114	/// assert_eq!(response.status, StatusCode::OK);
115	/// assert!(!response.headers.contains_key(hyper::header::ETAG));
116	/// assert!(response.headers.contains_key(hyper::header::LAST_MODIFIED));
117	/// # });
118	/// ```
119	pub fn without_etag() -> Self {
120		Self {
121			generate_etag: false,
122		}
123	}
124
125	/// Generate an ETag from response body
126	fn generate_etag_from_body(&self, body: &[u8]) -> String {
127		let mut hasher = Sha256::new();
128		hasher.update(body);
129		let result = hasher.finalize();
130		format!("\"{}\"", hex::encode(&result[..16]))
131	}
132
133	/// Parse If-None-Match header
134	fn parse_if_none_match(&self, value: &str) -> Vec<String> {
135		value.split(',').map(|s| s.trim().to_string()).collect()
136	}
137
138	/// Check if ETag matches
139	fn etag_matches(&self, etag: &str, if_none_match: &[String]) -> bool {
140		if_none_match
141			.iter()
142			.any(|inm| inm == "*" || inm == etag || inm.trim_matches('"') == etag.trim_matches('"'))
143	}
144
145	/// Parse HTTP date
146	fn parse_http_date(&self, value: &str) -> Option<DateTime<Utc>> {
147		httpdate::parse_http_date(value).ok().map(DateTime::from)
148	}
149}
150
151impl Default for ConditionalGetMiddleware {
152	fn default() -> Self {
153		Self::new()
154	}
155}
156
157#[async_trait]
158impl Middleware for ConditionalGetMiddleware {
159	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
160		// Store request headers for later use
161		let if_none_match = request.headers.get(IF_NONE_MATCH).cloned();
162		let if_modified_since = request.headers.get(IF_MODIFIED_SINCE).cloned();
163		let if_match = request.headers.get(IF_MATCH).cloned();
164		let if_unmodified_since = request.headers.get(IF_UNMODIFIED_SINCE).cloned();
165		let method = request.method.clone();
166
167		// Call the handler first
168		let mut response = handler.handle(request).await?;
169
170		// Only process GET and HEAD requests
171		if method != Method::GET && method != Method::HEAD {
172			return Ok(response);
173		}
174
175		// Only process successful responses
176		if !response.status.is_success() {
177			return Ok(response);
178		}
179
180		// Generate ETag if not present and configured to do so
181		let etag = if self.generate_etag && !response.headers.contains_key(ETAG) {
182			let generated = self.generate_etag_from_body(&response.body);
183			if let Ok(etag_value) = generated.parse() {
184				response.headers.insert(ETAG, etag_value);
185				Some(generated)
186			} else {
187				// ETag value could not be parsed as a valid header value;
188				// treat as if no ETag was generated
189				None
190			}
191		} else {
192			response
193				.headers
194				.get(ETAG)
195				.and_then(|v| v.to_str().ok())
196				.map(|s| s.to_string())
197		};
198
199		// Get Last-Modified if present
200		let last_modified = response
201			.headers
202			.get(LAST_MODIFIED)
203			.and_then(|v| v.to_str().ok())
204			.and_then(|s| self.parse_http_date(s));
205
206		// Check If-None-Match (ETag)
207		if let Some(if_none_match) = if_none_match
208			&& let (Ok(inm_str), Some(etag_value)) = (if_none_match.to_str(), etag.as_ref())
209		{
210			let inm_list = self.parse_if_none_match(inm_str);
211			if self.etag_matches(etag_value, &inm_list) {
212				// Return 304 Not Modified
213				let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
214
215				// Copy relevant headers
216				if let Some(etag_header) = response.headers.get(ETAG) {
217					not_modified.headers.insert(ETAG, etag_header.clone());
218				}
219				if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
220					not_modified
221						.headers
222						.insert(LAST_MODIFIED, lm_header.clone());
223				}
224
225				return Ok(not_modified);
226			}
227		}
228
229		// Check If-Modified-Since (Last-Modified)
230		if let Some(if_modified_since) = if_modified_since
231			&& let (Ok(ims_str), Some(lm)) = (if_modified_since.to_str(), last_modified)
232			&& let Some(ims) = self.parse_http_date(ims_str)
233		{
234			// If resource hasn't been modified since the given date
235			if lm <= ims {
236				// Return 304 Not Modified
237				let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
238
239				// Copy relevant headers
240				if let Some(etag_header) = response.headers.get(ETAG) {
241					not_modified.headers.insert(ETAG, etag_header.clone());
242				}
243				if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
244					not_modified
245						.headers
246						.insert(LAST_MODIFIED, lm_header.clone());
247				}
248
249				return Ok(not_modified);
250			}
251		}
252
253		// Check If-Match (for safe methods, should match)
254		if let Some(if_match) = if_match
255			&& let (Ok(im_str), Some(etag_value)) = (if_match.to_str(), etag.as_ref())
256		{
257			let im_list = self.parse_if_none_match(im_str);
258			if !self.etag_matches(etag_value, &im_list) && !im_list.contains(&"*".to_string()) {
259				// Return 412 Precondition Failed
260				return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
261					.with_body(Bytes::from(&b"Precondition Failed"[..])));
262			}
263		}
264
265		// Check If-Unmodified-Since
266		if let Some(if_unmodified_since) = if_unmodified_since
267			&& let (Ok(ius_str), Some(lm)) = (if_unmodified_since.to_str(), last_modified)
268			&& let Some(ius) = self.parse_http_date(ius_str)
269		{
270			// If resource has been modified since the given date
271			if lm > ius {
272				// Return 412 Precondition Failed
273				return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
274					.with_body(Bytes::from(&b"Precondition Failed"[..])));
275			}
276		}
277
278		Ok(response)
279	}
280}
281
282#[cfg(test)]
283mod tests {
284	use super::*;
285	use hyper::{HeaderMap, Version};
286
287	struct TestHandler {
288		body: &'static str,
289		with_etag: Option<String>,
290		with_last_modified: Option<DateTime<Utc>>,
291	}
292
293	#[async_trait]
294	impl Handler for TestHandler {
295		async fn handle(&self, _request: Request) -> Result<Response> {
296			let mut response = Response::new(StatusCode::OK).with_body(self.body.as_bytes());
297
298			if let Some(ref etag) = self.with_etag {
299				response.headers.insert(ETAG, etag.parse().unwrap());
300			}
301
302			if let Some(lm) = self.with_last_modified {
303				let lm_str = httpdate::fmt_http_date(lm.into());
304				response
305					.headers
306					.insert(LAST_MODIFIED, lm_str.parse().unwrap());
307			}
308
309			Ok(response)
310		}
311	}
312
313	#[tokio::test]
314	async fn test_generates_etag() {
315		let middleware = ConditionalGetMiddleware::new();
316		let handler = Arc::new(TestHandler {
317			body: "test response",
318			with_etag: None,
319			with_last_modified: None,
320		});
321
322		let request = Request::builder()
323			.method(Method::GET)
324			.uri("/test")
325			.version(Version::HTTP_11)
326			.headers(HeaderMap::new())
327			.body(Bytes::new())
328			.build()
329			.unwrap();
330
331		let response = middleware.process(request, handler).await.unwrap();
332
333		assert!(response.headers.contains_key(ETAG));
334	}
335
336	#[tokio::test]
337	async fn test_if_none_match_returns_304() {
338		let middleware = ConditionalGetMiddleware::new();
339		let etag = "\"abc123\"";
340		let handler = Arc::new(TestHandler {
341			body: "test response",
342			with_etag: Some(etag.to_string()),
343			with_last_modified: None,
344		});
345
346		let mut headers = HeaderMap::new();
347		headers.insert(IF_NONE_MATCH, etag.parse().unwrap());
348
349		let request = Request::builder()
350			.method(Method::GET)
351			.uri("/test")
352			.version(Version::HTTP_11)
353			.headers(headers)
354			.body(Bytes::new())
355			.build()
356			.unwrap();
357
358		let response = middleware.process(request, handler).await.unwrap();
359
360		assert_eq!(response.status, StatusCode::NOT_MODIFIED);
361		assert_eq!(response.body.len(), 0);
362	}
363
364	#[tokio::test]
365	async fn test_if_modified_since_returns_304() {
366		let middleware = ConditionalGetMiddleware::new();
367		let last_modified = Utc::now() - chrono::Duration::days(1);
368		let handler = Arc::new(TestHandler {
369			body: "test response",
370			with_etag: None,
371			with_last_modified: Some(last_modified),
372		});
373
374		let mut headers = HeaderMap::new();
375		let ims_str = httpdate::fmt_http_date((last_modified + chrono::Duration::hours(1)).into());
376		headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
377
378		let request = Request::builder()
379			.method(Method::GET)
380			.uri("/test")
381			.version(Version::HTTP_11)
382			.headers(headers)
383			.body(Bytes::new())
384			.build()
385			.unwrap();
386
387		let response = middleware.process(request, handler).await.unwrap();
388
389		assert_eq!(response.status, StatusCode::NOT_MODIFIED);
390	}
391
392	#[tokio::test]
393	async fn test_if_match_fails_returns_412() {
394		let middleware = ConditionalGetMiddleware::new();
395		let etag = "\"abc123\"";
396		let handler = Arc::new(TestHandler {
397			body: "test response",
398			with_etag: Some(etag.to_string()),
399			with_last_modified: None,
400		});
401
402		let mut headers = HeaderMap::new();
403		headers.insert(IF_MATCH, "\"xyz789\"".parse().unwrap());
404
405		let request = Request::builder()
406			.method(Method::GET)
407			.uri("/test")
408			.version(Version::HTTP_11)
409			.headers(headers)
410			.body(Bytes::new())
411			.build()
412			.unwrap();
413
414		let response = middleware.process(request, handler).await.unwrap();
415
416		assert_eq!(response.status, StatusCode::PRECONDITION_FAILED);
417	}
418
419	#[tokio::test]
420	async fn test_middleware_wont_overwrite_etag() {
421		let middleware = ConditionalGetMiddleware::new();
422		let custom_etag = "\"custom-etag\"";
423		let handler = Arc::new(TestHandler {
424			body: "test response",
425			with_etag: Some(custom_etag.to_string()),
426			with_last_modified: None,
427		});
428
429		let request = Request::builder()
430			.method(Method::GET)
431			.uri("/test")
432			.version(Version::HTTP_11)
433			.headers(HeaderMap::new())
434			.body(Bytes::new())
435			.build()
436			.unwrap();
437
438		let response = middleware.process(request, handler).await.unwrap();
439
440		assert_eq!(response.status, StatusCode::OK);
441		assert_eq!(
442			response.headers.get(ETAG).unwrap().to_str().unwrap(),
443			custom_etag
444		);
445	}
446
447	#[tokio::test]
448	async fn test_if_none_match_and_different_etag() {
449		let middleware = ConditionalGetMiddleware::new();
450		let etag = "\"abc123\"";
451		let handler = Arc::new(TestHandler {
452			body: "test response",
453			with_etag: Some(etag.to_string()),
454			with_last_modified: None,
455		});
456
457		let mut headers = HeaderMap::new();
458		headers.insert(IF_NONE_MATCH, "\"different-etag\"".parse().unwrap());
459
460		let request = Request::builder()
461			.method(Method::GET)
462			.uri("/test")
463			.version(Version::HTTP_11)
464			.headers(headers)
465			.body(Bytes::new())
466			.build()
467			.unwrap();
468
469		let response = middleware.process(request, handler).await.unwrap();
470
471		assert_eq!(response.status, StatusCode::OK);
472	}
473
474	#[tokio::test]
475	async fn test_if_modified_since_and_last_modified_in_the_future() {
476		let middleware = ConditionalGetMiddleware::new();
477		let last_modified = Utc::now();
478		let handler = Arc::new(TestHandler {
479			body: "test response",
480			with_etag: None,
481			with_last_modified: Some(last_modified),
482		});
483
484		let mut headers = HeaderMap::new();
485		let ims_str = httpdate::fmt_http_date((last_modified - chrono::Duration::hours(1)).into());
486		headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
487
488		let request = Request::builder()
489			.method(Method::GET)
490			.uri("/test")
491			.version(Version::HTTP_11)
492			.headers(headers)
493			.body(Bytes::new())
494			.build()
495			.unwrap();
496
497		let response = middleware.process(request, handler).await.unwrap();
498
499		assert_eq!(response.status, StatusCode::OK);
500	}
501
502	#[tokio::test]
503	async fn test_no_etag_on_post_request() {
504		let middleware = ConditionalGetMiddleware::new();
505		let handler = Arc::new(TestHandler {
506			body: "test response",
507			with_etag: None,
508			with_last_modified: None,
509		});
510
511		let request = Request::builder()
512			.method(Method::POST)
513			.uri("/test")
514			.version(Version::HTTP_11)
515			.headers(HeaderMap::new())
516			.body(Bytes::new())
517			.build()
518			.unwrap();
519
520		let response = middleware.process(request, handler).await.unwrap();
521
522		// ETag should not be generated for POST requests
523		assert!(!response.headers.contains_key(ETAG));
524	}
525
526	#[tokio::test]
527	async fn test_without_etag_generation() {
528		let middleware = ConditionalGetMiddleware::without_etag();
529		let handler = Arc::new(TestHandler {
530			body: "test response",
531			with_etag: None,
532			with_last_modified: None,
533		});
534
535		let request = Request::builder()
536			.method(Method::GET)
537			.uri("/test")
538			.version(Version::HTTP_11)
539			.headers(HeaderMap::new())
540			.body(Bytes::new())
541			.build()
542			.unwrap();
543
544		let response = middleware.process(request, handler).await.unwrap();
545
546		// ETag should not be generated when disabled
547		assert!(!response.headers.contains_key(ETAG));
548	}
549}