reqwest_drive/cache_middleware.rs
1use async_trait::async_trait;
2 // Binary serialization
3use bytes::Bytes;
4use chrono::{DateTime, Utc};
5use http::{Extensions, HeaderMap, HeaderValue, StatusCode};
6use reqwest::{Request, Response};
7use reqwest_middleware::{Middleware, Next, Result};
8use serde::{Deserialize, Serialize};
9use simd_r_drive::DataStore;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime, UNIX_EPOCH}; // For parsing `Expires` headers
13
14/// Defines the caching policy for storing and retrieving responses.
15#[derive(Clone, Debug)]
16pub struct CachePolicy {
17 /// Defines the caching policy for storing and retrieving responses.
18 pub default_ttl: Duration,
19 /// Determines whether cache expiration should respect HTTP headers.
20 pub respect_headers: bool,
21 /// Optional override for caching specific HTTP status codes.
22 /// - If `None`, only success responses (`2xx`) are cached.
23 /// - If `Some(Vec<u16>)`, only the specified status codes are cached.
24 pub cache_status_override: Option<Vec<u16>>,
25}
26
27impl Default for CachePolicy {
28 fn default() -> Self {
29 Self {
30 default_ttl: Duration::from_secs(60 * 60 * 24), // Default 1 day TTL
31 respect_headers: true, // Use headers if available
32 cache_status_override: None, // Default behavior: Cache only 2xx responses
33 }
34 }
35}
36
37/// Represents a cached HTTP response.
38#[derive(Serialize, Deserialize)]
39struct CachedResponse {
40 /// HTTP status code of the cached response.
41 status: u16,
42 /// HTTP headers stored as key-value pairs, where values are raw bytes.
43 headers: Vec<(String, Vec<u8>)>,
44 /// Response body stored as raw bytes.
45 body: Vec<u8>,
46 /// Unix timestamp (in milliseconds) indicating when the cache entry expires.
47 expiration_timestamp: u64,
48}
49
50/// Provides an HTTP cache layer backed by a `SIMD R Drive` data store.
51#[derive(Clone)]
52pub struct DriveCache {
53 store: Arc<DataStore>,
54 policy: CachePolicy, // Configurable policy
55}
56
57impl DriveCache {
58 /// Creates a new cache backed by a file-based data store.
59 ///
60 /// # Arguments
61 ///
62 /// * `cache_storage_file` - Path to the file where cached responses are stored.
63 /// * `policy` - Configuration specifying cache expiration behavior.
64 ///
65 /// # Panics
66 ///
67 /// This function will panic if the `DataStore` fails to initialize.
68 pub fn new(cache_storage_file: &Path, policy: CachePolicy) -> Self {
69 Self {
70 store: Arc::new(DataStore::open(cache_storage_file).unwrap()),
71 policy,
72 }
73 }
74
75 /// Creates a new cache using an existing `Arc<DataStore>`.
76 ///
77 /// This allows sharing the cache store across multiple components.
78 ///
79 /// # Arguments
80 ///
81 /// * `store` - A shared `Arc<DataStore>` instance.
82 /// * `policy` - Cache expiration configuration.
83 pub fn with_drive_arc(store: Arc<DataStore>, policy: CachePolicy) -> Self {
84 Self {
85 store,
86 policy,
87 }
88 }
89
90 /// Checks whether a request is cached and still valid.
91 ///
92 /// This method retrieves the cache entry associated with the request
93 /// and determines if it is still within its valid TTL.
94 ///
95 /// # Arguments
96 ///
97 /// * `req` - The HTTP request to check for a cached response.
98 ///
99 /// # Returns
100 ///
101 /// Returns `true` if the request has a valid cached response; otherwise, `false`.
102 pub async fn is_cached(&self, req: &Request) -> bool {
103 let store = self.store.as_ref();
104
105 let cache_key = self.generate_cache_key(req);
106 let cache_key_bytes = cache_key.as_bytes();
107
108 // let store = self.store.read().await;
109 if let Some(entry_handle) = store.read(cache_key_bytes) {
110 eprintln!("Entry handle: {:?}", entry_handle);
111
112 if let Ok(cached) = bincode::deserialize::<CachedResponse>(entry_handle.as_slice()) {
113 let now = SystemTime::now()
114 .duration_since(UNIX_EPOCH)
115 .expect("Time went backwards")
116 .as_millis() as u64;
117
118 // Extract TTL based on the policy (either from headers or default)
119 let ttl = if self.policy.respect_headers {
120 // Convert headers back to HeaderMap to extract TTL
121 let mut headers = HeaderMap::new();
122 for (k, v) in cached.headers.iter() {
123 if let Ok(header_name) = k.parse::<http::HeaderName>() {
124 if let Ok(header_value) = HeaderValue::from_bytes(v) {
125 headers.insert(header_name, header_value);
126 }
127 }
128 }
129 Self::extract_ttl(&headers, &self.policy)
130 } else {
131 self.policy.default_ttl
132 };
133
134 let expected_expiration = cached.expiration_timestamp + ttl.as_millis() as u64;
135
136 // If expired, remove from cache
137 if now >= expected_expiration {
138 // eprintln!("Determined cache is expired. now - expected_expiration: {:?}", now - expected_expiration);
139 eprintln!(
140 "Cache expires at: {}",
141 chrono::DateTime::from_timestamp_millis(expected_expiration as i64)
142 .unwrap()
143 );
144 eprintln!(
145 "Expiration timestamp: {}",
146 chrono::DateTime::from_timestamp_millis(cached.expiration_timestamp as i64)
147 .unwrap()
148 );
149 eprintln!(
150 "Now: {}",
151 chrono::DateTime::from_timestamp_millis(now as i64).unwrap()
152 );
153
154 // TODO: Rename API method to `delete`
155 store.delete_entry(cache_key_bytes).ok();
156 return false;
157 }
158
159 return true;
160 }
161 }
162 false
163 }
164
165 /// Generates a cache key based on the request method, URL, and relevant headers.
166 ///
167 /// The generated key is used to uniquely identify cached responses.
168 ///
169 /// # Arguments
170 ///
171 /// * `req` - The HTTP request for which to generate a cache key.
172 ///
173 /// # Returns
174 ///
175 /// A string representing the cache key.
176 fn generate_cache_key(&self, req: &Request) -> String {
177 let method = req.method();
178 let url = req.url().as_str();
179 let headers = req.headers();
180
181 let relevant_headers = ["accept", "authorization"];
182 let header_string = relevant_headers
183 .iter()
184 .filter_map(|h| headers.get(*h))
185 .map(|v| v.to_str().unwrap_or_default())
186 .collect::<Vec<_>>()
187 .join(",");
188
189 format!("{} {} {}", method, url, header_string)
190 }
191
192 /// Extracts the TTL from HTTP headers or falls back to the default TTL.
193 ///
194 /// # Arguments
195 ///
196 /// * `headers` - The HTTP headers to inspect.
197 /// * `policy` - The cache policy specifying TTL behavior.
198 ///
199 /// # Returns
200 ///
201 /// A `Duration` indicating the cache expiration time.
202 fn extract_ttl(headers: &HeaderMap, policy: &CachePolicy) -> Duration {
203 if !policy.respect_headers {
204 return policy.default_ttl;
205 }
206
207 // Check `Cache-Control: max-age=N`
208 if let Some(cache_control) = headers.get("cache-control") {
209 if let Ok(cache_control) = cache_control.to_str() {
210 for directive in cache_control.split(',') {
211 if let Some(max_age) = directive.trim().strip_prefix("max-age=") {
212 if let Ok(seconds) = max_age.parse::<u64>() {
213 return Duration::from_secs(seconds);
214 }
215 }
216 }
217 }
218 }
219
220 // Check `Expires`
221 if let Some(expires) = headers.get("expires") {
222 if let Ok(expires) = expires.to_str() {
223 if let Ok(expiry_time) = DateTime::parse_from_rfc2822(expires) {
224 if let Some(duration) =
225 expiry_time.timestamp().checked_sub(Utc::now().timestamp())
226 {
227 if duration > 0 {
228 return Duration::from_secs(duration as u64);
229 }
230 }
231 }
232 }
233 }
234
235 // Fallback to default TTL
236 policy.default_ttl
237 }
238}
239
240#[async_trait]
241impl Middleware for DriveCache {
242 /// Intercepts HTTP requests to apply caching behavior.
243 ///
244 /// This method first checks if a valid cached response exists for the incoming request.
245 /// - If a cached response is found and still valid, it is returned immediately.
246 /// - If no cache entry exists, the request is forwarded to the next middleware or backend.
247 /// - If a response is received, it is cached according to the defined `CachePolicy`.
248 ///
249 /// This middleware **only caches GET and HEAD requests**. Other HTTP methods are passed through without caching.
250 ///
251 /// # Arguments
252 ///
253 /// * `req` - The incoming HTTP request.
254 /// * `extensions` - A mutable reference to request extensions, which may store metadata.
255 /// * `next` - The next middleware in the processing chain.
256 ///
257 /// # Returns
258 ///
259 /// A `Result<Response, reqwest_middleware::Error>` that contains either:
260 /// - A cached response (if available).
261 /// - A fresh response from the backend, which is then cached (if applicable).
262 ///
263 /// # Behavior
264 ///
265 /// - If the request is **already cached and valid**, returns the cached response.
266 /// - If **no cache is found**, the request is sent to the backend, and the response is cached.
267 /// - If **the cache has expired**, the old entry is deleted, and a fresh request is made.
268 async fn handle(
269 &self,
270 req: Request,
271 extensions: &mut Extensions,
272 next: Next<'_>,
273 ) -> Result<Response> {
274 let cache_key = self.generate_cache_key(&req);
275
276 eprintln!("Handle cache key: {}", cache_key);
277
278 let store = self.store.as_ref();
279 let cache_key_bytes = cache_key.as_bytes();
280
281 if req.method() == "GET" || req.method() == "HEAD" {
282 // Use is_cached() to determine if the cache should be used
283 if self.is_cached(&req).await {
284 // let store = self.store.read().await;
285 if let Some(entry_handle) = store.read(cache_key_bytes) {
286 if let Ok(cached) =
287 bincode::deserialize::<CachedResponse>(entry_handle.as_slice())
288 {
289 let mut headers = HeaderMap::new();
290 for (k, v) in cached.headers {
291 if let Ok(header_name) = k.parse::<http::HeaderName>() {
292 if let Ok(header_value) = HeaderValue::from_bytes(&v) {
293 headers.insert(header_name, header_value);
294 }
295 }
296 }
297 let status = StatusCode::from_u16(cached.status).unwrap_or(StatusCode::OK);
298 return Ok(build_response(status, headers, Bytes::from(cached.body)));
299 }
300 }
301 }
302
303 let response = next.run(req, extensions).await?;
304 let status = response.status();
305 let headers = response.headers().clone();
306 let body = response.bytes().await?.to_vec();
307
308 let ttl = Self::extract_ttl(&headers, &self.policy);
309 let expiration_timestamp = SystemTime::now()
310 .duration_since(UNIX_EPOCH)
311 .expect("Time went backwards")
312 .as_millis() as u64
313 + ttl.as_millis() as u64;
314
315 let body_clone = body.clone(); // Fix: Clone before moving
316
317 // Determine whether to cache the response
318 let should_cache = match &self.policy.cache_status_override {
319 Some(status_codes) => status_codes.contains(&status.as_u16()), // Use the override
320 None => status.is_success(), // Default: Cache only success responses (2xx)
321 };
322
323 if should_cache {
324 let serialized = bincode::serialize(&CachedResponse {
325 status: status.as_u16(),
326 headers: headers
327 .iter()
328 .map(|(k, v)| (k.to_string(), v.as_bytes().to_vec()))
329 .collect(),
330 body, // Move the original body here
331 expiration_timestamp,
332 })
333 .expect("Serialization failed");
334
335 {
336 let store = self.store.as_ref();
337
338 eprintln!("Writing cache with key: {}", cache_key);
339 store.write(cache_key_bytes, serialized.as_slice()).ok();
340 }
341 }
342
343 return Ok(build_response(status, headers, Bytes::from(body_clone)));
344 }
345
346 next.run(req, extensions).await
347 }
348}
349
350/// Constructs a `reqwest::Response` from a given status code, headers, and body.
351///
352/// This function is used to rebuild an HTTP response from cached data,
353/// ensuring that it correctly retains headers and status information.
354///
355/// # Arguments
356///
357/// * `status` - The HTTP status code of the response.
358/// * `headers` - A `HeaderMap` containing response headers.
359/// * `body` - A `Bytes` object containing the response body.
360///
361/// # Returns
362///
363/// A `reqwest::Response` representing the reconstructed HTTP response.
364///
365/// # Panics
366///
367/// This function will panic if the response body fails to be constructed.
368fn build_response(status: StatusCode, headers: HeaderMap, body: Bytes) -> Response {
369 let mut response_builder = http::Response::builder().status(status);
370
371 for (key, value) in headers.iter() {
372 response_builder = response_builder.header(key, value);
373 }
374
375 let http_response = response_builder
376 .body(body)
377 .expect("Failed to create HTTP response");
378
379 Response::from(http_response)
380}