1#![warn(missing_docs)]
2#![deny(unconditional_recursion)]
3use http::HeaderMap;
7use http::HeaderValue;
8use http::Method;
9use http::Request;
10use http::Response;
11use http::StatusCode;
12use http::Uri;
13use std::collections::hash_map::Entry;
14use std::collections::HashMap;
15use std::time::Duration;
16use std::time::SystemTime;
17use time::format_description::well_known::Rfc2822;
18use time::OffsetDateTime;
19
20const STATUS_CODE_CACHEABLE_BY_DEFAULT: &[u16] =
22 &[200, 203, 204, 206, 300, 301, 308, 404, 405, 410, 414, 501];
23
24const UNDERSTOOD_STATUSES: &[u16] = &[
26 200, 203, 204, 300, 301, 302, 303, 307, 308, 404, 405, 410, 414, 501,
27];
28
29const HOP_BY_HOP_HEADERS: &[&str] = &[
30 "date", "connection",
32 "keep-alive",
33 "proxy-authenticate",
34 "proxy-authorization",
35 "te",
36 "trailer",
37 "transfer-encoding",
38 "upgrade",
39];
40
41const EXCLUDED_FROM_REVALIDATION_UPDATE: &[&str] = &[
42 "content-length",
44 "content-encoding",
45 "transfer-encoding",
46 "content-range",
47];
48
49type CacheControl = HashMap<Box<str>, Option<Box<str>>>;
50
51fn parse_cache_control<'a>(headers: impl IntoIterator<Item = &'a HeaderValue>) -> CacheControl {
52 let mut cc = CacheControl::new();
53 let mut is_valid = true;
54
55 for h in headers.into_iter().filter_map(|v| v.to_str().ok()) {
56 for part in h.split(',') {
57 if part.trim().is_empty() {
59 continue;
60 }
61 let mut kv = part.splitn(2, '=');
62 let k = kv.next().unwrap().trim();
63 if k.is_empty() {
64 continue;
65 }
66 let v = kv.next().map(str::trim);
67 match cc.entry(k.into()) {
68 Entry::Occupied(e) => {
69 if e.get().as_deref() != v {
72 is_valid = false;
73 }
74 }
75 Entry::Vacant(e) => {
76 e.insert(v.map(|v| v.trim_matches('"')).map(From::from)); }
78 }
79 }
80 }
81 if !is_valid {
82 cc.insert("must-revalidate".into(), None);
83 }
84 cc
85}
86
87fn format_cache_control(cc: &CacheControl) -> String {
88 let mut out = String::new();
89 for (k, v) in cc {
90 if !out.is_empty() {
91 out.push_str(", ");
92 }
93 out.push_str(k);
94 if let Some(v) = v {
95 out.push('=');
96 let needs_quote =
97 v.is_empty() || v.as_bytes().iter().any(|b| !b.is_ascii_alphanumeric());
98 if needs_quote {
99 out.push('"');
100 }
101 out.push_str(v);
102 if needs_quote {
103 out.push('"');
104 }
105 }
106 }
107 out
108}
109
110#[derive(Debug, Copy, Clone)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113pub struct CacheOptions {
114 pub shared: bool,
121 pub cache_heuristic: f32,
126 pub immutable_min_time_to_live: Duration,
131 pub ignore_cargo_cult: bool,
137}
138
139impl Default for CacheOptions {
140 fn default() -> Self {
141 Self {
142 shared: true,
143 cache_heuristic: 0.1, immutable_min_time_to_live: Duration::from_secs(24 * 3600),
145 ignore_cargo_cult: false,
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
155#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
156pub struct CachePolicy {
157 #[cfg_attr(feature = "serde", serde(with = "http_serde::header_map"))]
158 req: HeaderMap,
159 #[cfg_attr(feature = "serde", serde(with = "http_serde::header_map"))]
160 res: HeaderMap,
161 #[cfg_attr(feature = "serde", serde(with = "http_serde::uri"))]
162 uri: Uri,
163 #[cfg_attr(feature = "serde", serde(with = "http_serde::status_code"))]
164 status: StatusCode,
165 #[cfg_attr(feature = "serde", serde(with = "http_serde::method"))]
166 method: Method,
167 opts: CacheOptions,
168 res_cc: CacheControl,
169 req_cc: CacheControl,
170 response_time: SystemTime,
171}
172
173impl CachePolicy {
174 #[inline]
177 pub fn new<Req: RequestLike, Res: ResponseLike>(req: &Req, res: &Res) -> Self {
178 let uri = req.uri();
179 let status = res.status();
180 let method = req.method().clone();
181 let res = res.headers().clone();
182 let req = req.headers().clone();
183 Self::from_details(
184 uri,
185 method,
186 status,
187 req,
188 res,
189 SystemTime::now(),
190 Default::default(),
191 )
192 }
193
194 #[inline]
198 pub fn new_options<Req: RequestLike, Res: ResponseLike>(
199 req: &Req,
200 res: &Res,
201 response_time: SystemTime,
202 opts: CacheOptions,
203 ) -> Self {
204 let uri = req.uri();
205 let status = res.status();
206 let method = req.method().clone();
207 let res = res.headers().clone();
208 let req = req.headers().clone();
209 Self::from_details(uri, method, status, req, res, response_time, opts)
210 }
211
212 fn from_details(
213 uri: Uri,
214 method: Method,
215 status: StatusCode,
216 req: HeaderMap,
217 mut res: HeaderMap,
218 response_time: SystemTime,
219 opts: CacheOptions,
220 ) -> Self {
221 let mut res_cc = parse_cache_control(res.get_all("cache-control"));
222 let req_cc = parse_cache_control(req.get_all("cache-control"));
223
224 if opts.ignore_cargo_cult
227 && res_cc.get("pre-check").is_some()
228 && res_cc.get("post-check").is_some()
229 {
230 res_cc.remove("pre-check");
231 res_cc.remove("post-check");
232 res_cc.remove("no-cache");
233 res_cc.remove("no-store");
234 res_cc.remove("must-revalidate");
235 res.insert(
236 "cache-control",
237 HeaderValue::from_str(&format_cache_control(&res_cc)).unwrap(),
238 );
239 res.remove("expires");
240 res.remove("pragma");
241 }
242
243 if !res.contains_key("cache-control")
246 && res
247 .get_str("pragma")
248 .map_or(false, |p| p.contains("no-cache"))
249 {
250 res_cc.insert("no-cache".into(), None);
251 }
252
253 Self { req, res, uri, status, method, opts, res_cc, req_cc, response_time }
254 }
255
256 pub fn is_storable(&self) -> bool {
259 !self.req_cc.contains_key("no-store") &&
261 (Method::GET == self.method ||
264 Method::HEAD == self.method ||
265 (Method::POST == self.method && self.has_explicit_expiration())) &&
266 UNDERSTOOD_STATUSES.contains(&self.status.as_u16()) &&
268 !self.res_cc.contains_key("no-store") &&
270 (!self.opts.shared || !self.res_cc.contains_key("private")) &&
272 (!self.opts.shared ||
274 !self.req.contains_key("authorization") ||
275 self.allows_storing_authenticated()) &&
276 (self.res.contains_key("expires") ||
279 self.res_cc.contains_key("max-age") ||
283 (self.opts.shared && self.res_cc.contains_key("s-maxage")) ||
284 self.res_cc.contains_key("public") ||
285 STATUS_CODE_CACHEABLE_BY_DEFAULT.contains(&self.status.as_u16()))
287 }
288
289 fn has_explicit_expiration(&self) -> bool {
290 (self.opts.shared && self.res_cc.contains_key("s-maxage"))
292 || self.res_cc.contains_key("max-age")
293 || self.res.contains_key("expires")
294 }
295
296 pub fn before_request<Req: RequestLike>(&self, req: &Req, now: SystemTime) -> BeforeRequest {
308 let req_headers = req.headers();
309
310 let (matches, may_revalidate) = self.request_matches(req);
312
313 if matches && self.satisfies_without_revalidation(req_headers, now) {
314 BeforeRequest::Fresh(self.cached_response(now))
315 } else if may_revalidate {
316 BeforeRequest::Stale {
317 request: self.revalidation_request(req),
318 matches,
319 }
320 } else {
321 BeforeRequest::Stale {
322 request: self.request_from_headers(req_headers.clone()),
323 matches,
324 }
325 }
326 }
327
328 fn satisfies_without_revalidation(&self, req_headers: &HeaderMap, now: SystemTime) -> bool {
329 let req_cc = parse_cache_control(req_headers.get_all("cache-control"));
333 if req_cc.contains_key("no-cache")
334 || req_headers
335 .get_str("pragma")
336 .map_or(false, |v| v.contains("no-cache"))
337 {
338 return false;
339 }
340
341 if let Some(max_age) = req_cc
342 .get("max-age")
343 .and_then(|v| v.as_ref())
344 .and_then(|p| p.parse().ok())
345 {
346 if self.age(now) > Duration::from_secs(max_age) {
347 return false;
348 }
349 }
350
351 if let Some(min_fresh) = req_cc
352 .get("min-fresh")
353 .and_then(|v| v.as_ref())
354 .and_then(|p| p.parse().ok())
355 {
356 if self.time_to_live(now) < Duration::from_secs(min_fresh) {
357 return false;
358 }
359 }
360
361 if self.is_stale(now) {
364 let max_stale = req_cc.get("max-stale");
366 let has_max_stale = max_stale.is_some();
367 let max_stale = max_stale
368 .and_then(|m| m.as_ref())
369 .and_then(|s| s.parse().ok());
370 let allows_stale = !self.res_cc.contains_key("must-revalidate")
371 && has_max_stale
372 && max_stale.map_or(true, |val| {
373 Duration::from_secs(val) > self.age(now) - self.max_age()
374 });
375 if !allows_stale {
376 return false;
377 }
378 }
379
380 true
381 }
382
383 fn request_matches<Req: RequestLike>(&self, req: &Req) -> (bool, bool) {
385 let matches = req.is_same_uri(&self.uri) &&
387 (self.req.get("host") == req.headers().get("host")) &&
388 self.vary_matches(req);
390 let exact_match = matches && self.method == req.method();
391
392 (exact_match, exact_match || Method::HEAD == req.method())
394 }
395
396 fn allows_storing_authenticated(&self) -> bool {
397 self.res_cc.contains_key("must-revalidate")
399 || self.res_cc.contains_key("public")
400 || self.res_cc.contains_key("s-maxage")
401 }
402
403 fn vary_matches<Req: RequestLike>(&self, req: &Req) -> bool {
404 for name in get_all_comma(self.res.get_all("vary")) {
405 if name == "*" {
407 return false;
408 }
409 let name = name.trim().to_ascii_lowercase();
410 if req.headers().get(&name) != self.req.get(&name) {
411 return false;
412 }
413 }
414 true
415 }
416
417 fn copy_without_hop_by_hop_headers(in_headers: &HeaderMap) -> HeaderMap {
418 let mut headers = HeaderMap::with_capacity(in_headers.len());
419
420 for (h, v) in in_headers
421 .iter()
422 .filter(|(h, _)| !HOP_BY_HOP_HEADERS.contains(&h.as_str()))
423 {
424 headers.insert(h.clone(), v.clone());
425 }
426
427 for name in get_all_comma(in_headers.get_all("connection")) {
429 headers.remove(name);
430 }
431
432 let new_warnings = join(
433 get_all_comma(in_headers.get_all("warning")).filter(|warning| {
434 !warning.trim_start().starts_with('1') }),
436 );
437 if new_warnings.is_empty() {
438 headers.remove("warning");
439 } else {
440 headers.insert("warning", HeaderValue::from_str(&new_warnings).unwrap());
441 }
442 headers
443 }
444
445 fn cached_response(&self, now: SystemTime) -> http::response::Parts {
453 let mut headers = Self::copy_without_hop_by_hop_headers(&self.res);
454 let age = self.age(now);
455 let day = Duration::from_secs(3600 * 24);
456
457 if age > day && !self.has_explicit_expiration() && self.max_age() > day {
460 headers.append(
461 "warning",
462 HeaderValue::from_static(r#"113 - "rfc7234 5.5.4""#),
463 );
464 }
465 let date = OffsetDateTime::from(now);
466 headers.insert(
467 "age",
468 HeaderValue::from_str(&age.as_secs().to_string()).unwrap(),
469 );
470 headers.insert(
471 "date",
472 HeaderValue::from_str(&date.format(&Rfc2822).unwrap()).unwrap(),
473 );
474
475 let mut parts = Response::builder()
476 .status(self.status)
477 .body(())
478 .unwrap()
479 .into_parts().0;
480 parts.headers = headers;
481 parts
482 }
483
484 fn raw_server_date(&self) -> SystemTime {
485 let date = self
486 .res
487 .get_str("date")
488 .and_then(|d| OffsetDateTime::parse(d, &Rfc2822).ok())
489 .and_then(|d| {
490 SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(d.unix_timestamp() as u64))
491 });
492 date.unwrap_or(self.response_time)
493 }
494
495 pub fn age(&self, now: SystemTime) -> Duration {
499 let mut age = self.age_header_value();
500
501 if let Ok(resident_time) = now.duration_since(self.response_time) {
502 age += resident_time;
503 }
504 age
505 }
506
507 fn age_header_value(&self) -> Duration {
508 Duration::from_secs(
509 self.res
510 .get_str("age")
511 .and_then(|v| v.parse().ok())
512 .unwrap_or(0),
513 )
514 }
515
516 fn max_age(&self) -> Duration {
522 if !self.is_storable() || self.res_cc.contains_key("no-cache") {
523 return Duration::from_secs(0);
524 }
525
526 if self.opts.shared
529 && (self.res.contains_key("set-cookie")
530 && !self.res_cc.contains_key("public")
531 && !self.res_cc.contains_key("immutable"))
532 {
533 return Duration::from_secs(0);
534 }
535
536 if self.res.get_str("vary").map(str::trim) == Some("*") {
537 return Duration::from_secs(0);
538 }
539
540 if self.opts.shared {
541 if self.res_cc.contains_key("proxy-revalidate") {
542 return Duration::from_secs(0);
543 }
544 if let Some(s_max) = self.res_cc.get("s-maxage").and_then(|v| v.as_ref()) {
546 return Duration::from_secs(s_max.parse().unwrap_or(0));
547 }
548 }
549
550 if let Some(max_age) = self.res_cc.get("max-age").and_then(|v| v.as_ref()) {
552 return Duration::from_secs(max_age.parse().unwrap_or(0));
553 }
554
555 let default_min_ttl = if self.res_cc.contains_key("immutable") {
556 self.opts.immutable_min_time_to_live
557 } else {
558 Duration::from_secs(0)
559 };
560
561 let server_date = self.raw_server_date();
562 if let Some(expires) = self.res.get_str("expires") {
563 return match OffsetDateTime::parse(expires, &Rfc2822) {
564 Err(_) => Duration::from_secs(0),
566 Ok(expires) => {
567 let expires = SystemTime::UNIX_EPOCH
568 + Duration::from_secs(expires.unix_timestamp().max(0) as _);
569 return default_min_ttl
570 .max(expires.duration_since(server_date).unwrap_or_default());
571 }
572 };
573 }
574
575 if let Some(last_modified) = self.res.get_str("last-modified") {
576 if let Ok(last_modified) = OffsetDateTime::parse(last_modified, &Rfc2822) {
577 let last_modified = SystemTime::UNIX_EPOCH
578 + Duration::from_secs(last_modified.unix_timestamp().max(0) as _);
579 if let Ok(diff) = server_date.duration_since(last_modified) {
580 let secs_left = diff.as_secs() as f64 * f64::from(self.opts.cache_heuristic);
581 return default_min_ttl.max(Duration::from_secs(secs_left as _));
582 }
583 }
584 }
585
586 default_min_ttl
587 }
588
589 pub fn time_to_live(&self, now: SystemTime) -> Duration {
601 self.max_age()
602 .checked_sub(self.age(now))
603 .unwrap_or_default()
604 }
605
606 pub fn is_stale(&self, now: SystemTime) -> bool {
608 self.max_age() <= self.age(now)
609 }
610
611 fn revalidation_request<Req: RequestLike>(&self, incoming_req: &Req) -> http::request::Parts {
622 let mut headers = Self::copy_without_hop_by_hop_headers(incoming_req.headers());
623
624 headers.remove("if-range");
626
627 if !self.is_storable() {
628 headers.remove("if-none-match");
630 headers.remove("if-modified-since");
631 return self.request_from_headers(headers);
632 }
633
634 if let Some(etag) = self.res.get_str("etag") {
636 let if_none = join(get_all_comma(headers.get_all("if-none-match")).chain(Some(etag)));
637 headers.insert("if-none-match", HeaderValue::from_str(&if_none).unwrap());
638 }
639
640 let forbids_weak_validators = self.method != Method::GET
642 || headers.contains_key("accept-ranges")
643 || headers.contains_key("if-match")
644 || headers.contains_key("if-unmodified-since");
645
646 if forbids_weak_validators {
649 headers.remove("if-modified-since");
650
651 let etags = join(
652 get_all_comma(headers.get_all("if-none-match"))
653 .filter(|etag| !etag.trim_start().starts_with("W/")),
654 );
655 if etags.is_empty() {
656 headers.remove("if-none-match");
657 } else {
658 headers.insert("if-none-match", HeaderValue::from_str(&etags).unwrap());
659 }
660 } else if !headers.contains_key("if-modified-since") {
661 if let Some(last_modified) = self.res.get_str("last-modified") {
662 headers.insert(
663 "if-modified-since",
664 HeaderValue::from_str(last_modified).unwrap(),
665 );
666 }
667 }
668 self.request_from_headers(headers)
669 }
670
671 fn request_from_headers(&self, headers: HeaderMap) -> http::request::Parts {
672 let mut parts = Request::builder()
673 .method(self.method.clone())
674 .uri(self.uri.clone())
675 .body(())
676 .unwrap()
677 .into_parts().0;
678 parts.headers = headers;
679 parts
680 }
681
682 pub fn after_response<Req: RequestLike, Res: ResponseLike>(
688 &self,
689 request: &Req,
690 response: &Res,
691 response_time: SystemTime,
692 ) -> AfterResponse {
693 let response_headers = response.headers();
694 let mut response_status = response.status();
695
696 let old_etag = &self.res.get_str("etag").map(str::trim);
697 let old_last_modified = response_headers.get_str("last-modified").map(str::trim);
698 let new_etag = response_headers.get_str("etag").map(str::trim);
699 let new_last_modified = response_headers.get_str("last-modified").map(str::trim);
700
701 let mut matches = false;
704 if response.status() != StatusCode::NOT_MODIFIED {
705 matches = false;
706 } else if new_etag.map_or(false, |etag| !etag.starts_with("W/")) {
707 matches = old_etag.map(|e| e.trim_start_matches("W/")) == new_etag;
711 } else if let (Some(old), Some(new)) = (old_etag, new_etag) {
712 matches = old.trim_start_matches("W/") == new.trim_start_matches("W/");
716 } else if old_last_modified.is_some() {
717 matches = old_last_modified == new_last_modified;
718 } else {
719 if old_etag.is_none()
724 && new_etag.is_none()
725 && old_last_modified.is_none()
726 && new_last_modified.is_none()
727 {
728 matches = true;
729 }
730 }
731
732 let new_response_headers = if matches {
733 let mut new_response_headers = HeaderMap::with_capacity(self.res.keys_len());
734 for (header, old_value) in &self.res {
737 let header = header.clone();
738 if let Some(new_value) = response_headers.get(&header) {
739 if !EXCLUDED_FROM_REVALIDATION_UPDATE.contains(&header.as_str()) {
740 new_response_headers.insert(header, new_value.clone());
741 continue;
742 }
743 }
744 new_response_headers.insert(header, old_value.clone());
745 }
746 response_status = self.status;
747 new_response_headers
748 } else {
749 response_headers.clone()
750 };
751
752 let new_policy = CachePolicy::from_details(
753 request.uri(),
754 request.method().clone(),
755 response_status,
756 request.headers().clone(),
757 new_response_headers,
758 response_time,
759 self.opts,
760 );
761 let new_response = new_policy.cached_response(response_time);
762
763 if matches && response.status() == StatusCode::NOT_MODIFIED {
764 AfterResponse::NotModified(new_policy, new_response)
765 } else {
766 AfterResponse::Modified(new_policy, new_response)
767 }
768 }
769}
770
771pub enum AfterResponse {
773 NotModified(CachePolicy, http::response::Parts),
775 Modified(CachePolicy, http::response::Parts),
777}
778
779fn get_all_comma<'a>(
780 all: impl IntoIterator<Item = &'a HeaderValue>,
781) -> impl Iterator<Item = &'a str> {
782 all.into_iter()
783 .filter_map(|v| v.to_str().ok())
784 .flat_map(|s| s.split(',').map(str::trim))
785}
786
787trait GetHeaderStr {
788 fn get_str(&self, k: &str) -> Option<&str>;
789}
790
791impl GetHeaderStr for HeaderMap {
792 #[inline]
793 fn get_str(&self, k: &str) -> Option<&str> {
794 self.get(k).and_then(|v| v.to_str().ok())
795 }
796}
797
798fn join<'a>(parts: impl Iterator<Item = &'a str>) -> String {
799 let mut out = String::new();
800 for part in parts {
801 out.reserve(2 + part.len());
802 if !out.is_empty() {
803 out.push_str(", ");
804 }
805 out.push_str(part);
806 }
807 out
808}
809
810pub enum BeforeRequest {
812 Fresh(http::response::Parts),
814 Stale {
816 request: http::request::Parts,
818 matches: bool,
821 },
822}
823
824impl BeforeRequest {
825 pub fn satisfies_without_revalidation(&self) -> bool {
828 matches!(self, Self::Fresh(_))
829 }
830}
831
832pub trait RequestLike {
834 fn uri(&self) -> Uri;
836 fn is_same_uri(&self, other: &Uri) -> bool;
840 fn method(&self) -> &Method;
842 fn headers(&self) -> &HeaderMap;
844}
845
846pub trait ResponseLike {
848 fn status(&self) -> StatusCode;
850 fn headers(&self) -> &HeaderMap;
852}
853
854impl<Body> RequestLike for Request<Body> {
855 fn uri(&self) -> Uri {
856 self.uri().clone()
857 }
858 fn is_same_uri(&self, other: &Uri) -> bool {
859 self.uri() == other
860 }
861 fn method(&self) -> &Method {
862 self.method()
863 }
864 fn headers(&self) -> &HeaderMap {
865 self.headers()
866 }
867}
868
869impl RequestLike for http::request::Parts {
870 fn uri(&self) -> Uri {
871 self.uri.clone()
872 }
873 fn is_same_uri(&self, other: &Uri) -> bool {
874 &self.uri == other
875 }
876 fn method(&self) -> &Method {
877 &self.method
878 }
879 fn headers(&self) -> &HeaderMap {
880 &self.headers
881 }
882}
883
884impl<Body> ResponseLike for Response<Body> {
885 fn status(&self) -> StatusCode {
886 self.status()
887 }
888 fn headers(&self) -> &HeaderMap {
889 self.headers()
890 }
891}
892
893impl ResponseLike for http::response::Parts {
894 fn status(&self) -> StatusCode {
895 self.status
896 }
897 fn headers(&self) -> &HeaderMap {
898 &self.headers
899 }
900}
901
902#[cfg(feature = "reqwest")]
903impl RequestLike for reqwest::Request {
904 fn uri(&self) -> Uri {
905 self.url().as_str().parse().expect("Uri and Url are incompatible!?")
906 }
907 fn is_same_uri(&self, other: &Uri) -> bool {
908 self.url().as_str() == other
909 }
910 fn method(&self) -> &Method {
911 self.method()
912 }
913 fn headers(&self) -> &HeaderMap {
914 self.headers()
915 }
916}
917
918#[cfg(feature = "reqwest")]
919impl ResponseLike for reqwest::Response {
920 fn status(&self) -> StatusCode {
921 self.status()
922 }
923 fn headers(&self) -> &HeaderMap {
924 self.headers()
925 }
926}