1use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::{ElifMethod, ElifRequest};
8use crate::response::{ElifHeaderValue, ElifResponse};
9use axum::http::{HeaderMap, HeaderName, HeaderValue};
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum ETagType {
16 Strong(String),
18 Weak(String),
20}
21
22impl ETagType {
23 pub fn from_header_value(value: &str) -> Option<Self> {
25 let value = value.trim();
26 if value.starts_with("W/") {
27 if value.len() > 3 && value.starts_with("W/\"") && value.ends_with('"') {
29 let etag_value = &value[3..value.len() - 1];
30 Some(Self::Weak(etag_value.to_string()))
31 } else {
32 None
33 }
34 } else if value.starts_with('"') && value.ends_with('"') {
35 let etag_value = &value[1..value.len() - 1];
37 Some(Self::Strong(etag_value.to_string()))
38 } else {
39 None
40 }
41 }
42
43 pub fn to_header_value(&self) -> String {
45 match self {
46 Self::Strong(value) => format!("\"{}\"", value),
47 Self::Weak(value) => format!("W/\"{}\"", value),
48 }
49 }
50
51 pub fn value(&self) -> &str {
53 match self {
54 Self::Strong(value) | Self::Weak(value) => value,
55 }
56 }
57
58 pub fn matches_for_if_none_match(&self, other: &Self) -> bool {
61 self.value() == other.value()
62 }
63
64 pub fn matches_for_if_match(&self, other: &Self) -> bool {
66 match (self, other) {
67 (Self::Strong(a), Self::Strong(b)) => a == b,
68 _ => false, }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub enum ETagStrategy {
76 BodyHash,
78 WeakBodyHash,
80 Custom(fn(&[u8], &HeaderMap) -> Option<ETagType>),
82}
83
84impl Default for ETagStrategy {
85 fn default() -> Self {
86 Self::BodyHash
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct ETagConfig {
93 pub strategy: ETagStrategy,
95 pub min_size: usize,
97 pub max_size: usize,
99 pub content_types: Vec<String>,
101}
102
103impl Default for ETagConfig {
104 fn default() -> Self {
105 Self {
106 strategy: ETagStrategy::default(),
107 min_size: 0,
108 max_size: 10 * 1024 * 1024, content_types: vec![
110 "text/html".to_string(),
111 "text/css".to_string(),
112 "text/javascript".to_string(),
113 "text/plain".to_string(),
114 "application/json".to_string(),
115 "application/javascript".to_string(),
116 "application/xml".to_string(),
117 "text/xml".to_string(),
118 "image/svg+xml".to_string(),
119 ],
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct ETagMiddleware {
127 config: ETagConfig,
128}
129
130impl ETagMiddleware {
131 pub fn new() -> Self {
133 Self {
134 config: ETagConfig::default(),
135 }
136 }
137
138 pub fn with_config(config: ETagConfig) -> Self {
140 Self { config }
141 }
142
143 pub fn strategy(mut self, strategy: ETagStrategy) -> Self {
145 self.config.strategy = strategy;
146 self
147 }
148
149 pub fn min_size(mut self, min_size: usize) -> Self {
151 self.config.min_size = min_size;
152 self
153 }
154
155 pub fn max_size(mut self, max_size: usize) -> Self {
157 self.config.max_size = max_size;
158 self
159 }
160
161 pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
163 self.config.content_types.push(content_type.into());
164 self
165 }
166
167 pub fn weak(mut self) -> Self {
169 self.config.strategy = ETagStrategy::WeakBodyHash;
170 self
171 }
172
173 fn should_generate_etag(&self, headers: &HeaderMap, body_size: usize) -> bool {
175 if body_size < self.config.min_size || body_size > self.config.max_size {
177 return false;
178 }
179
180 if headers.contains_key("etag") {
182 return false;
183 }
184
185 if let Some(content_type) = headers.get("content-type") {
187 if let Ok(content_type_str) = content_type.to_str() {
188 let content_type_lower = content_type_str.to_lowercase();
189 return self
190 .config
191 .content_types
192 .iter()
193 .any(|ct| content_type_lower.starts_with(&ct.to_lowercase()));
194 }
195 }
196
197 true
199 }
200
201 fn generate_etag(&self, body: &[u8], headers: &HeaderMap) -> Option<ETagType> {
203 match &self.config.strategy {
204 ETagStrategy::BodyHash => {
205 let mut hasher = DefaultHasher::new();
206 body.hash(&mut hasher);
207 for (name, value) in headers.iter() {
209 name.as_str().hash(&mut hasher);
210 if let Ok(value_str) = value.to_str() {
211 value_str.hash(&mut hasher);
212 }
213 }
214 let hash = hasher.finish();
215 Some(ETagType::Strong(format!("{:x}", hash)))
216 }
217 ETagStrategy::WeakBodyHash => {
218 let mut hasher = DefaultHasher::new();
219 body.hash(&mut hasher);
220 let hash = hasher.finish();
221 Some(ETagType::Weak(format!("{:x}", hash)))
222 }
223 ETagStrategy::Custom(func) => func(body, headers),
224 }
225 }
226
227 fn parse_if_none_match(&self, header_value: &str) -> Vec<ETagType> {
229 let mut etags = Vec::new();
230
231 if header_value.trim() == "*" {
233 return etags; }
235
236 for etag_str in header_value.split(',') {
238 if let Some(etag) = ETagType::from_header_value(etag_str) {
239 etags.push(etag);
240 }
241 }
242
243 etags
244 }
245
246 fn parse_if_match(&self, header_value: &str) -> Vec<ETagType> {
248 let mut etags = Vec::new();
249
250 if header_value.trim() == "*" {
252 return etags; }
254
255 for etag_str in header_value.split(',') {
257 if let Some(etag) = ETagType::from_header_value(etag_str) {
258 etags.push(etag);
259 }
260 }
261
262 etags
263 }
264
265 fn check_if_none_match(&self, request_etags: &[ETagType], response_etag: &ETagType) -> bool {
267 if request_etags.is_empty() {
268 return true; }
270
271 !request_etags
273 .iter()
274 .any(|req_etag| response_etag.matches_for_if_none_match(req_etag))
275 }
276
277 fn check_if_match(&self, request_etags: &[ETagType], response_etag: &ETagType) -> bool {
279 if request_etags.is_empty() {
280 return true; }
282
283 request_etags
285 .iter()
286 .any(|req_etag| response_etag.matches_for_if_match(req_etag))
287 }
288
289 async fn process_response_with_headers(
291 &self,
292 response: ElifResponse,
293 if_none_match: Option<ElifHeaderValue>,
294 if_match: Option<ElifHeaderValue>,
295 request_method: ElifMethod,
296 ) -> ElifResponse {
297 let axum_if_none_match = if_none_match.as_ref().map(|v| v.to_axum());
299 let axum_if_match = if_match.as_ref().map(|v| v.to_axum());
300 let axum_method = request_method.to_axum();
301
302 let axum_response = response.into_axum_response();
303 let (parts, body) = axum_response.into_parts();
304
305 let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
307 Ok(bytes) => bytes,
308 Err(_) => {
309 let response =
311 axum::response::Response::from_parts(parts, axum::body::Body::empty());
312 return ElifResponse::from_axum_response(response).await;
313 }
314 };
315
316 if !self.should_generate_etag(&parts.headers, body_bytes.len()) {
318 let response =
319 axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes));
320 return ElifResponse::from_axum_response(response).await;
321 }
322
323 let etag = match self.generate_etag(&body_bytes, &parts.headers) {
325 Some(etag) => etag,
326 None => {
327 let response =
329 axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes));
330 return ElifResponse::from_axum_response(response).await;
331 }
332 };
333
334 if let Some(if_none_match) = axum_if_none_match {
338 if let Ok(if_none_match_str) = if_none_match.to_str() {
339 let request_etags = self.parse_if_none_match(if_none_match_str);
340
341 if if_none_match_str.trim() == "*" {
344 return if axum_method == axum::http::Method::GET
345 || axum_method == axum::http::Method::HEAD
346 {
347 ElifResponse::from_axum_response(
349 axum::response::Response::builder()
350 .status(axum::http::StatusCode::NOT_MODIFIED)
351 .header("etag", etag.to_header_value())
352 .body(axum::body::Body::empty())
353 .unwrap(),
354 )
355 .await
356 } else {
357 ElifResponse::from_axum_response(
359 axum::response::Response::builder()
360 .status(axum::http::StatusCode::PRECONDITION_FAILED)
361 .header("etag", etag.to_header_value())
362 .body(axum::body::Body::from(
363 serde_json::to_vec(&serde_json::json!({
364 "error": {
365 "code": "precondition_failed",
366 "message": "If-None-Match: * failed - resource exists"
367 }
368 }))
369 .unwrap_or_default(),
370 ))
371 .unwrap(),
372 )
373 .await
374 };
375 }
376
377 if !self.check_if_none_match(&request_etags, &etag) {
378 return if axum_method == axum::http::Method::GET
380 || axum_method == axum::http::Method::HEAD
381 {
382 ElifResponse::from_axum_response(
384 axum::response::Response::builder()
385 .status(axum::http::StatusCode::NOT_MODIFIED)
386 .header("etag", etag.to_header_value())
387 .body(axum::body::Body::empty())
388 .unwrap(),
389 )
390 .await
391 } else {
392 ElifResponse::from_axum_response(
394 axum::response::Response::builder()
395 .status(axum::http::StatusCode::PRECONDITION_FAILED)
396 .header("etag", etag.to_header_value())
397 .body(axum::body::Body::from(
398 serde_json::to_vec(&serde_json::json!({
399 "error": {
400 "code": "precondition_failed",
401 "message": "If-None-Match precondition failed - resource unchanged"
402 }
403 })).unwrap_or_default()
404 ))
405 .unwrap()
406 ).await
407 };
408 }
409 }
410 }
411
412 if let Some(if_match) = axum_if_match {
414 if let Ok(if_match_str) = if_match.to_str() {
415 let request_etags = self.parse_if_match(if_match_str);
416
417 if if_match_str.trim() == "*" {
419 } else if !self.check_if_match(&request_etags, &etag) {
421 return ElifResponse::from_axum_response(
423 axum::response::Response::builder()
424 .status(axum::http::StatusCode::PRECONDITION_FAILED)
425 .header("etag", etag.to_header_value())
426 .body(axum::body::Body::from(
427 serde_json::to_vec(&serde_json::json!({
428 "error": {
429 "code": "precondition_failed",
430 "message": "Request ETag does not match current resource ETag"
431 }
432 })).unwrap_or_default()
433 ))
434 .unwrap()
435 ).await;
436 }
437 }
438 }
439
440 let mut new_parts = parts;
442 new_parts.headers.insert(
443 HeaderName::from_static("etag"),
444 HeaderValue::from_str(&etag.to_header_value()).unwrap(),
445 );
446
447 if !new_parts.headers.contains_key("cache-control") {
449 new_parts.headers.insert(
450 HeaderName::from_static("cache-control"),
451 HeaderValue::from_static("private, max-age=0"),
452 );
453 }
454
455 let response =
456 axum::response::Response::from_parts(new_parts, axum::body::Body::from(body_bytes));
457
458 ElifResponse::from_axum_response(response).await
459 }
460}
461
462impl Default for ETagMiddleware {
463 fn default() -> Self {
464 Self::new()
465 }
466}
467
468impl Middleware for ETagMiddleware {
469 fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
470 let config = self.config.clone();
471
472 Box::pin(async move {
473 let if_none_match = request.header("if-none-match").cloned();
475 let if_match = request.header("if-match").cloned();
476 let method = request.method.clone();
477
478 let response = next.run(request).await;
479
480 let middleware = ETagMiddleware { config };
482 middleware
483 .process_response_with_headers(response, if_none_match, if_match, method)
484 .await
485 })
486 }
487
488 fn name(&self) -> &'static str {
489 "ETagMiddleware"
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::request::ElifRequest;
497 use crate::response::ElifResponse;
498
499 #[test]
500 fn test_etag_parsing() {
501 let etag = ETagType::from_header_value("\"abc123\"").unwrap();
503 assert_eq!(etag, ETagType::Strong("abc123".to_string()));
504 assert_eq!(etag.to_header_value(), "\"abc123\"");
505
506 let etag = ETagType::from_header_value("W/\"abc123\"").unwrap();
508 assert_eq!(etag, ETagType::Weak("abc123".to_string()));
509 assert_eq!(etag.to_header_value(), "W/\"abc123\"");
510
511 assert!(ETagType::from_header_value("invalid").is_none());
513 assert!(ETagType::from_header_value("\"unclosed").is_none());
514 }
515
516 #[test]
517 fn test_etag_matching() {
518 let strong1 = ETagType::Strong("abc123".to_string());
519 let strong2 = ETagType::Strong("abc123".to_string());
520 let strong3 = ETagType::Strong("def456".to_string());
521 let weak1 = ETagType::Weak("abc123".to_string());
522
523 assert!(strong1.matches_for_if_none_match(&strong2));
525 assert!(strong1.matches_for_if_none_match(&weak1));
526 assert!(!strong1.matches_for_if_none_match(&strong3));
527
528 assert!(strong1.matches_for_if_match(&strong2));
530 assert!(!strong1.matches_for_if_match(&weak1));
531 assert!(!strong1.matches_for_if_match(&strong3));
532 }
533
534 #[test]
535 fn test_etag_config() {
536 let config = ETagConfig::default();
537 assert_eq!(config.min_size, 0);
538 assert_eq!(config.max_size, 10 * 1024 * 1024);
539 assert!(config
540 .content_types
541 .contains(&"application/json".to_string()));
542 }
543
544 #[test]
545 fn test_should_generate_etag() {
546 let middleware = ETagMiddleware::new();
547
548 let mut headers = HeaderMap::new();
550 headers.insert("content-type", "application/json".parse().unwrap());
551 assert!(middleware.should_generate_etag(&headers, 1024));
552
553 headers.insert("etag", "\"existing\"".parse().unwrap());
555 assert!(!middleware.should_generate_etag(&headers, 1024));
556
557 let mut headers = HeaderMap::new();
559 headers.insert("content-type", "image/jpeg".parse().unwrap());
560 assert!(!middleware.should_generate_etag(&headers, 1024));
561 }
562
563 #[tokio::test]
564 async fn test_etag_generation() {
565 let middleware = ETagMiddleware::new();
566
567 let request = ElifRequest::new(
568 crate::request::ElifMethod::GET,
569 "/api/data".parse().unwrap(),
570 crate::response::headers::ElifHeaderMap::new(),
571 );
572
573 let next = Next::new(|_req| {
574 Box::pin(async move {
575 ElifResponse::ok().json_value(serde_json::json!({
576 "message": "Hello, World!"
577 }))
578 })
579 });
580
581 let response = middleware.handle(request, next).await;
582
583 assert_eq!(
584 response.status_code(),
585 crate::response::status::ElifStatusCode::OK
586 );
587
588 let axum_response = response.into_axum_response();
590 let (parts, _) = axum_response.into_parts();
591 assert!(parts.headers.contains_key("etag"));
592 }
593
594 #[tokio::test]
595 async fn test_if_none_match_304() {
596 let middleware = ETagMiddleware::new();
597
598 let request = ElifRequest::new(
600 crate::request::ElifMethod::GET,
601 "/api/data".parse().unwrap(),
602 crate::response::headers::ElifHeaderMap::new(),
603 );
604
605 let next = Next::new(|_req| {
606 Box::pin(async move {
607 ElifResponse::ok().json_value(serde_json::json!({
608 "message": "Hello, World!"
609 }))
610 })
611 });
612
613 let response = middleware.handle(request, next).await;
614 let axum_response = response.into_axum_response();
615 let (parts, _) = axum_response.into_parts();
616
617 let etag_header = parts.headers.get("etag").unwrap();
618 let etag_value = etag_header.to_str().unwrap();
619
620 let mut headers = crate::response::headers::ElifHeaderMap::new();
622 let header_name =
623 crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
624 let header_value = crate::response::headers::ElifHeaderValue::from_str(etag_value).unwrap();
625 headers.insert(header_name, header_value);
626 let request = ElifRequest::new(
627 crate::request::ElifMethod::GET,
628 "/api/data".parse().unwrap(),
629 headers,
630 );
631
632 let next = Next::new(|_req| {
633 Box::pin(async move {
634 ElifResponse::ok().json_value(serde_json::json!({
635 "message": "Hello, World!"
636 }))
637 })
638 });
639
640 let response = middleware.handle(request, next).await;
641 assert_eq!(
642 response.status_code(),
643 crate::response::status::ElifStatusCode::NOT_MODIFIED
644 );
645 }
646
647 #[tokio::test]
648 async fn test_if_match_412() {
649 let middleware = ETagMiddleware::new();
650
651 let mut headers = crate::response::headers::ElifHeaderMap::new();
652 let header_name = crate::response::headers::ElifHeaderName::from_str("if-match").unwrap();
653 let header_value =
654 crate::response::headers::ElifHeaderValue::from_str("\"non-matching-etag\"").unwrap();
655 headers.insert(header_name, header_value);
656 let request = ElifRequest::new(
657 crate::request::ElifMethod::PUT,
658 "/api/data".parse().unwrap(),
659 headers,
660 );
661
662 let next = Next::new(|_req| {
663 Box::pin(async move {
664 ElifResponse::ok().json_value(serde_json::json!({
665 "message": "Updated!"
666 }))
667 })
668 });
669
670 let response = middleware.handle(request, next).await;
671 assert_eq!(
672 response.status_code(),
673 crate::response::status::ElifStatusCode::PRECONDITION_FAILED
674 );
675 }
676
677 #[tokio::test]
678 async fn test_if_none_match_star_put_request() {
679 let middleware = ETagMiddleware::new();
680
681 let mut headers = crate::response::headers::ElifHeaderMap::new();
682 let header_name =
683 crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
684 let header_value = crate::response::headers::ElifHeaderValue::from_str("*").unwrap();
685 headers.insert(header_name, header_value);
686 let request = ElifRequest::new(
687 crate::request::ElifMethod::PUT, "/api/data".parse().unwrap(),
689 headers,
690 );
691
692 let next = Next::new(|_req| {
693 Box::pin(async move {
694 ElifResponse::ok().json_value(serde_json::json!({
695 "message": "Created!"
696 }))
697 })
698 });
699
700 let response = middleware.handle(request, next).await;
701 assert_eq!(
703 response.status_code(),
704 crate::response::status::ElifStatusCode::PRECONDITION_FAILED
705 );
706 }
707
708 #[tokio::test]
709 async fn test_if_none_match_star_get_request() {
710 let middleware = ETagMiddleware::new();
711
712 let mut headers = crate::response::headers::ElifHeaderMap::new();
713 let header_name =
714 crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
715 let header_value = crate::response::headers::ElifHeaderValue::from_str("*").unwrap();
716 headers.insert(header_name, header_value);
717 let request = ElifRequest::new(
718 crate::request::ElifMethod::GET, "/api/data".parse().unwrap(),
720 headers,
721 );
722
723 let next = Next::new(|_req| {
724 Box::pin(async move {
725 ElifResponse::ok().json_value(serde_json::json!({
726 "message": "Data"
727 }))
728 })
729 });
730
731 let response = middleware.handle(request, next).await;
732 assert_eq!(
734 response.status_code(),
735 crate::response::status::ElifStatusCode::NOT_MODIFIED
736 );
737 }
738
739 #[tokio::test]
740 async fn test_if_none_match_etag_put_request() {
741 let middleware = ETagMiddleware::new();
742
743 let request = ElifRequest::new(
745 crate::request::ElifMethod::GET,
746 "/api/data".parse().unwrap(),
747 crate::response::headers::ElifHeaderMap::new(),
748 );
749
750 let next = Next::new(|_req| {
751 Box::pin(async move {
752 ElifResponse::ok().json_value(serde_json::json!({
753 "message": "Data"
754 }))
755 })
756 });
757
758 let response = middleware.handle(request, next).await;
759 let axum_response = response.into_axum_response();
760 let (parts, _) = axum_response.into_parts();
761
762 let etag_header = parts.headers.get("etag").unwrap();
763 let etag_value = etag_header.to_str().unwrap();
764
765 let mut headers = crate::response::headers::ElifHeaderMap::new();
767 let header_name =
768 crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
769 let header_value = crate::response::headers::ElifHeaderValue::from_str(etag_value).unwrap();
770 headers.insert(header_name, header_value);
771 let request = ElifRequest::new(
772 crate::request::ElifMethod::PUT,
773 "/api/data".parse().unwrap(),
774 headers,
775 );
776
777 let next = Next::new(|_req| {
778 Box::pin(async move {
779 ElifResponse::ok().json_value(serde_json::json!({
780 "message": "Data"
781 }))
782 })
783 });
784
785 let response = middleware.handle(request, next).await;
786 assert_eq!(
788 response.status_code(),
789 crate::response::status::ElifStatusCode::PRECONDITION_FAILED
790 );
791 }
792
793 #[tokio::test]
794 async fn test_weak_etag_strategy() {
795 let middleware = ETagMiddleware::new().weak();
796
797 let request = ElifRequest::new(
798 crate::request::ElifMethod::GET,
799 "/api/data".parse().unwrap(),
800 crate::response::headers::ElifHeaderMap::new(),
801 );
802
803 let next = Next::new(|_req| {
804 Box::pin(async move {
805 ElifResponse::ok().json_value(serde_json::json!({
806 "message": "Hello, World!"
807 }))
808 })
809 });
810
811 let response = middleware.handle(request, next).await;
812 let axum_response = response.into_axum_response();
813 let (parts, _) = axum_response.into_parts();
814
815 let etag_header = parts.headers.get("etag").unwrap();
816 let etag_value = etag_header.to_str().unwrap();
817 assert!(etag_value.starts_with("W/"));
818 }
819
820 #[test]
821 fn test_etag_middleware_builder() {
822 let middleware = ETagMiddleware::new()
823 .min_size(1024)
824 .max_size(5 * 1024 * 1024)
825 .content_type("application/xml")
826 .weak();
827
828 assert_eq!(middleware.config.min_size, 1024);
829 assert_eq!(middleware.config.max_size, 5 * 1024 * 1024);
830 assert!(middleware
831 .config
832 .content_types
833 .contains(&"application/xml".to_string()));
834 assert!(matches!(
835 middleware.config.strategy,
836 ETagStrategy::WeakBodyHash
837 ));
838 }
839}