1use async_trait::async_trait;
6use bytes::Bytes;
7use chrono::{DateTime, Utc};
8use hyper::header::{
9 ETAG, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, LAST_MODIFIED,
10};
11use hyper::{Method, StatusCode};
12use reinhardt_http::{Handler, Middleware, Request, Response, Result};
13use sha2::{Digest, Sha256};
14use std::sync::Arc;
15
16pub struct ConditionalGetMiddleware {
23 generate_etag: bool,
25}
26
27impl ConditionalGetMiddleware {
28 pub fn new() -> Self {
69 Self {
70 generate_etag: true,
71 }
72 }
73 pub fn without_etag() -> Self {
120 Self {
121 generate_etag: false,
122 }
123 }
124
125 fn generate_etag_from_body(&self, body: &[u8]) -> String {
127 let mut hasher = Sha256::new();
128 hasher.update(body);
129 let result = hasher.finalize();
130 format!("\"{}\"", hex::encode(&result[..16]))
131 }
132
133 fn parse_if_none_match(&self, value: &str) -> Vec<String> {
135 value.split(',').map(|s| s.trim().to_string()).collect()
136 }
137
138 fn etag_matches(&self, etag: &str, if_none_match: &[String]) -> bool {
140 if_none_match
141 .iter()
142 .any(|inm| inm == "*" || inm == etag || inm.trim_matches('"') == etag.trim_matches('"'))
143 }
144
145 fn parse_http_date(&self, value: &str) -> Option<DateTime<Utc>> {
147 httpdate::parse_http_date(value).ok().map(DateTime::from)
148 }
149}
150
151impl Default for ConditionalGetMiddleware {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[async_trait]
158impl Middleware for ConditionalGetMiddleware {
159 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
160 let if_none_match = request.headers.get(IF_NONE_MATCH).cloned();
162 let if_modified_since = request.headers.get(IF_MODIFIED_SINCE).cloned();
163 let if_match = request.headers.get(IF_MATCH).cloned();
164 let if_unmodified_since = request.headers.get(IF_UNMODIFIED_SINCE).cloned();
165 let method = request.method.clone();
166
167 let mut response = handler.handle(request).await?;
169
170 if method != Method::GET && method != Method::HEAD {
172 return Ok(response);
173 }
174
175 if !response.status.is_success() {
177 return Ok(response);
178 }
179
180 let etag = if self.generate_etag && !response.headers.contains_key(ETAG) {
182 let generated = self.generate_etag_from_body(&response.body);
183 if let Ok(etag_value) = generated.parse() {
184 response.headers.insert(ETAG, etag_value);
185 Some(generated)
186 } else {
187 None
190 }
191 } else {
192 response
193 .headers
194 .get(ETAG)
195 .and_then(|v| v.to_str().ok())
196 .map(|s| s.to_string())
197 };
198
199 let last_modified = response
201 .headers
202 .get(LAST_MODIFIED)
203 .and_then(|v| v.to_str().ok())
204 .and_then(|s| self.parse_http_date(s));
205
206 if let Some(if_none_match) = if_none_match
208 && let (Ok(inm_str), Some(etag_value)) = (if_none_match.to_str(), etag.as_ref())
209 {
210 let inm_list = self.parse_if_none_match(inm_str);
211 if self.etag_matches(etag_value, &inm_list) {
212 let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
214
215 if let Some(etag_header) = response.headers.get(ETAG) {
217 not_modified.headers.insert(ETAG, etag_header.clone());
218 }
219 if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
220 not_modified
221 .headers
222 .insert(LAST_MODIFIED, lm_header.clone());
223 }
224
225 return Ok(not_modified);
226 }
227 }
228
229 if let Some(if_modified_since) = if_modified_since
231 && let (Ok(ims_str), Some(lm)) = (if_modified_since.to_str(), last_modified)
232 && let Some(ims) = self.parse_http_date(ims_str)
233 {
234 if lm <= ims {
236 let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
238
239 if let Some(etag_header) = response.headers.get(ETAG) {
241 not_modified.headers.insert(ETAG, etag_header.clone());
242 }
243 if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
244 not_modified
245 .headers
246 .insert(LAST_MODIFIED, lm_header.clone());
247 }
248
249 return Ok(not_modified);
250 }
251 }
252
253 if let Some(if_match) = if_match
255 && let (Ok(im_str), Some(etag_value)) = (if_match.to_str(), etag.as_ref())
256 {
257 let im_list = self.parse_if_none_match(im_str);
258 if !self.etag_matches(etag_value, &im_list) && !im_list.contains(&"*".to_string()) {
259 return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
261 .with_body(Bytes::from(&b"Precondition Failed"[..])));
262 }
263 }
264
265 if let Some(if_unmodified_since) = if_unmodified_since
267 && let (Ok(ius_str), Some(lm)) = (if_unmodified_since.to_str(), last_modified)
268 && let Some(ius) = self.parse_http_date(ius_str)
269 {
270 if lm > ius {
272 return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
274 .with_body(Bytes::from(&b"Precondition Failed"[..])));
275 }
276 }
277
278 Ok(response)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use hyper::{HeaderMap, Version};
286
287 struct TestHandler {
288 body: &'static str,
289 with_etag: Option<String>,
290 with_last_modified: Option<DateTime<Utc>>,
291 }
292
293 #[async_trait]
294 impl Handler for TestHandler {
295 async fn handle(&self, _request: Request) -> Result<Response> {
296 let mut response = Response::new(StatusCode::OK).with_body(self.body.as_bytes());
297
298 if let Some(ref etag) = self.with_etag {
299 response.headers.insert(ETAG, etag.parse().unwrap());
300 }
301
302 if let Some(lm) = self.with_last_modified {
303 let lm_str = httpdate::fmt_http_date(lm.into());
304 response
305 .headers
306 .insert(LAST_MODIFIED, lm_str.parse().unwrap());
307 }
308
309 Ok(response)
310 }
311 }
312
313 #[tokio::test]
314 async fn test_generates_etag() {
315 let middleware = ConditionalGetMiddleware::new();
316 let handler = Arc::new(TestHandler {
317 body: "test response",
318 with_etag: None,
319 with_last_modified: None,
320 });
321
322 let request = Request::builder()
323 .method(Method::GET)
324 .uri("/test")
325 .version(Version::HTTP_11)
326 .headers(HeaderMap::new())
327 .body(Bytes::new())
328 .build()
329 .unwrap();
330
331 let response = middleware.process(request, handler).await.unwrap();
332
333 assert!(response.headers.contains_key(ETAG));
334 }
335
336 #[tokio::test]
337 async fn test_if_none_match_returns_304() {
338 let middleware = ConditionalGetMiddleware::new();
339 let etag = "\"abc123\"";
340 let handler = Arc::new(TestHandler {
341 body: "test response",
342 with_etag: Some(etag.to_string()),
343 with_last_modified: None,
344 });
345
346 let mut headers = HeaderMap::new();
347 headers.insert(IF_NONE_MATCH, etag.parse().unwrap());
348
349 let request = Request::builder()
350 .method(Method::GET)
351 .uri("/test")
352 .version(Version::HTTP_11)
353 .headers(headers)
354 .body(Bytes::new())
355 .build()
356 .unwrap();
357
358 let response = middleware.process(request, handler).await.unwrap();
359
360 assert_eq!(response.status, StatusCode::NOT_MODIFIED);
361 assert_eq!(response.body.len(), 0);
362 }
363
364 #[tokio::test]
365 async fn test_if_modified_since_returns_304() {
366 let middleware = ConditionalGetMiddleware::new();
367 let last_modified = Utc::now() - chrono::Duration::days(1);
368 let handler = Arc::new(TestHandler {
369 body: "test response",
370 with_etag: None,
371 with_last_modified: Some(last_modified),
372 });
373
374 let mut headers = HeaderMap::new();
375 let ims_str = httpdate::fmt_http_date((last_modified + chrono::Duration::hours(1)).into());
376 headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
377
378 let request = Request::builder()
379 .method(Method::GET)
380 .uri("/test")
381 .version(Version::HTTP_11)
382 .headers(headers)
383 .body(Bytes::new())
384 .build()
385 .unwrap();
386
387 let response = middleware.process(request, handler).await.unwrap();
388
389 assert_eq!(response.status, StatusCode::NOT_MODIFIED);
390 }
391
392 #[tokio::test]
393 async fn test_if_match_fails_returns_412() {
394 let middleware = ConditionalGetMiddleware::new();
395 let etag = "\"abc123\"";
396 let handler = Arc::new(TestHandler {
397 body: "test response",
398 with_etag: Some(etag.to_string()),
399 with_last_modified: None,
400 });
401
402 let mut headers = HeaderMap::new();
403 headers.insert(IF_MATCH, "\"xyz789\"".parse().unwrap());
404
405 let request = Request::builder()
406 .method(Method::GET)
407 .uri("/test")
408 .version(Version::HTTP_11)
409 .headers(headers)
410 .body(Bytes::new())
411 .build()
412 .unwrap();
413
414 let response = middleware.process(request, handler).await.unwrap();
415
416 assert_eq!(response.status, StatusCode::PRECONDITION_FAILED);
417 }
418
419 #[tokio::test]
420 async fn test_middleware_wont_overwrite_etag() {
421 let middleware = ConditionalGetMiddleware::new();
422 let custom_etag = "\"custom-etag\"";
423 let handler = Arc::new(TestHandler {
424 body: "test response",
425 with_etag: Some(custom_etag.to_string()),
426 with_last_modified: None,
427 });
428
429 let request = Request::builder()
430 .method(Method::GET)
431 .uri("/test")
432 .version(Version::HTTP_11)
433 .headers(HeaderMap::new())
434 .body(Bytes::new())
435 .build()
436 .unwrap();
437
438 let response = middleware.process(request, handler).await.unwrap();
439
440 assert_eq!(response.status, StatusCode::OK);
441 assert_eq!(
442 response.headers.get(ETAG).unwrap().to_str().unwrap(),
443 custom_etag
444 );
445 }
446
447 #[tokio::test]
448 async fn test_if_none_match_and_different_etag() {
449 let middleware = ConditionalGetMiddleware::new();
450 let etag = "\"abc123\"";
451 let handler = Arc::new(TestHandler {
452 body: "test response",
453 with_etag: Some(etag.to_string()),
454 with_last_modified: None,
455 });
456
457 let mut headers = HeaderMap::new();
458 headers.insert(IF_NONE_MATCH, "\"different-etag\"".parse().unwrap());
459
460 let request = Request::builder()
461 .method(Method::GET)
462 .uri("/test")
463 .version(Version::HTTP_11)
464 .headers(headers)
465 .body(Bytes::new())
466 .build()
467 .unwrap();
468
469 let response = middleware.process(request, handler).await.unwrap();
470
471 assert_eq!(response.status, StatusCode::OK);
472 }
473
474 #[tokio::test]
475 async fn test_if_modified_since_and_last_modified_in_the_future() {
476 let middleware = ConditionalGetMiddleware::new();
477 let last_modified = Utc::now();
478 let handler = Arc::new(TestHandler {
479 body: "test response",
480 with_etag: None,
481 with_last_modified: Some(last_modified),
482 });
483
484 let mut headers = HeaderMap::new();
485 let ims_str = httpdate::fmt_http_date((last_modified - chrono::Duration::hours(1)).into());
486 headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
487
488 let request = Request::builder()
489 .method(Method::GET)
490 .uri("/test")
491 .version(Version::HTTP_11)
492 .headers(headers)
493 .body(Bytes::new())
494 .build()
495 .unwrap();
496
497 let response = middleware.process(request, handler).await.unwrap();
498
499 assert_eq!(response.status, StatusCode::OK);
500 }
501
502 #[tokio::test]
503 async fn test_no_etag_on_post_request() {
504 let middleware = ConditionalGetMiddleware::new();
505 let handler = Arc::new(TestHandler {
506 body: "test response",
507 with_etag: None,
508 with_last_modified: None,
509 });
510
511 let request = Request::builder()
512 .method(Method::POST)
513 .uri("/test")
514 .version(Version::HTTP_11)
515 .headers(HeaderMap::new())
516 .body(Bytes::new())
517 .build()
518 .unwrap();
519
520 let response = middleware.process(request, handler).await.unwrap();
521
522 assert!(!response.headers.contains_key(ETAG));
524 }
525
526 #[tokio::test]
527 async fn test_without_etag_generation() {
528 let middleware = ConditionalGetMiddleware::without_etag();
529 let handler = Arc::new(TestHandler {
530 body: "test response",
531 with_etag: None,
532 with_last_modified: None,
533 });
534
535 let request = Request::builder()
536 .method(Method::GET)
537 .uri("/test")
538 .version(Version::HTTP_11)
539 .headers(HeaderMap::new())
540 .body(Bytes::new())
541 .build()
542 .unwrap();
543
544 let response = middleware.process(request, handler).await.unwrap();
545
546 assert!(!response.headers.contains_key(ETAG));
548 }
549}