1use std::{sync::Arc, time::SystemTime};
4
5use bytes::Bytes;
6use chashmap_async::CHashMap;
7pub use http_cache_semantics::CacheOptions;
8use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
9use reqwest::Url;
10use reqwest_middleware::Middleware;
11
12#[derive(Debug)]
14struct CacheEntry {
15 policy: CachePolicy,
17 response: Bytes,
19}
20
21impl CacheEntry {
22 pub fn new(policy: CachePolicy, response: Bytes) -> Self {
24 Self { policy, response }
25 }
26}
27
28#[derive(Default)]
32pub struct CacheMiddleware {
33 cache: Arc<CHashMap<Url, CacheEntry>>,
35 options: CacheOptions,
37}
38
39impl CacheMiddleware {
40 #[must_use]
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 #[must_use]
48 pub fn with_options(options: CacheOptions) -> Self {
49 Self { cache: Arc::new(CHashMap::new()), options }
50 }
51}
52
53impl std::fmt::Debug for CacheMiddleware {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("CacheMiddleware")
56 .field("cache", &format!("<{} entries>", self.cache.len()))
57 .field("options", &self.options)
58 .finish()
59 }
60}
61
62#[async_trait::async_trait]
63impl Middleware for CacheMiddleware {
64 async fn handle(
65 &self,
66 mut req: reqwest::Request,
67 extensions: &mut task_local_extensions::Extensions,
68 next: reqwest_middleware::Next<'_>,
69 ) -> reqwest_middleware::Result<reqwest::Response> {
70 let mut url = req.url().clone();
73 url.set_fragment(None);
74
75 if let Some(mut cache) = self.cache.get_mut(&url).await {
76 let before = cache.policy.before_request(&req, SystemTime::now());
78 match before {
79 BeforeRequest::Fresh(parts) => {
80 let response = http::Response::from_parts(parts, cache.response.clone());
82 return Ok(response.into());
83 }
84 BeforeRequest::Stale { request: parts, matches } => {
85 *req.headers_mut() = parts.headers.clone();
87 let response = next.run(req, extensions).await?;
88 let after = cache.policy.after_response(&parts, &response, SystemTime::now());
89 match after {
90 AfterResponse::NotModified(policy, parts) => {
91 if matches {
93 cache.policy = policy;
94 }
95 let response =
96 http::Response::from_parts(parts, cache.response.clone());
97 return Ok(response.into());
98 }
99 AfterResponse::Modified(policy, parts) => {
100 if matches {
102 cache.policy = policy;
103 }
104 let body = response.bytes().await?;
105 cache.response = body;
106 let response =
107 http::Response::from_parts(parts, cache.response.clone());
108 return Ok(response.into());
109 }
110 }
111 }
112 }
113 }
114 #[allow(clippy::expect_used)]
117 let (mut parts, _) = http::Request::builder()
118 .uri(req.uri())
119 .method(req.method().clone())
120 .version(req.version())
121 .body(())
122 .expect("Builder used correctly")
123 .into_parts();
124 parts.headers = req.headers().clone();
127 let response = next.run(req, extensions).await?;
128 let policy = CachePolicy::new_options(&parts, &response, SystemTime::now(), self.options);
129 if policy.is_storable() {
130 let response = reqwest_to_http(response).await?;
131 let cache = CacheEntry::new(policy, response.body().clone());
132 self.cache
133 .alter(url, |entry| async move {
134 match entry {
135 None => Some(cache),
136 Some(entry) => {
137 let time = SystemTime::now();
139 if entry.policy.age(time) > cache.policy.age(time) {
140 Some(cache)
141 } else {
142 Some(entry)
143 }
144 }
145 }
146 })
147 .await;
148 return Ok(response.into());
149 }
150 Ok(response)
151 }
152}
153
154async fn reqwest_to_http(
156 mut response: reqwest::Response,
157) -> reqwest::Result<http::Response<Bytes>> {
158 let mut http = http::Response::new(Bytes::new());
159 *http.status_mut() = response.status();
160 *http.version_mut() = response.version();
161 std::mem::swap(http.headers_mut(), response.headers_mut());
162 *http.body_mut() = response.bytes().await?;
163 Ok(http)
164}