reqwest_drive/
cache_middleware.rs

1use async_trait::async_trait;
2 // Binary serialization
3use bytes::Bytes;
4use chrono::{DateTime, Utc};
5use http::{Extensions, HeaderMap, HeaderValue, StatusCode};
6use reqwest::{Request, Response};
7use reqwest_middleware::{Middleware, Next, Result};
8use serde::{Deserialize, Serialize};
9use simd_r_drive::DataStore;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime, UNIX_EPOCH}; // For parsing `Expires` headers
13
14/// Defines the caching policy for storing and retrieving responses.
15#[derive(Clone, Debug)]
16pub struct CachePolicy {
17    /// Defines the caching policy for storing and retrieving responses.
18    pub default_ttl: Duration,
19    /// Determines whether cache expiration should respect HTTP headers.
20    pub respect_headers: bool,
21    /// Optional override for caching specific HTTP status codes.
22    /// - If `None`, only success responses (`2xx`) are cached.
23    /// - If `Some(Vec<u16>)`, only the specified status codes are cached.
24    pub cache_status_override: Option<Vec<u16>>,
25}
26
27impl Default for CachePolicy {
28    fn default() -> Self {
29        Self {
30            default_ttl: Duration::from_secs(60 * 60 * 24), // Default 1 day TTL
31            respect_headers: true,                          // Use headers if available
32            cache_status_override: None, // Default behavior: Cache only 2xx responses
33        }
34    }
35}
36
37/// Represents a cached HTTP response.
38#[derive(Serialize, Deserialize)]
39struct CachedResponse {
40    /// HTTP status code of the cached response.
41    status: u16,
42    /// HTTP headers stored as key-value pairs, where values are raw bytes.
43    headers: Vec<(String, Vec<u8>)>,
44    /// Response body stored as raw bytes.
45    body: Vec<u8>,
46    /// Unix timestamp (in milliseconds) indicating when the cache entry expires.
47    expiration_timestamp: u64,
48}
49
50/// Provides an HTTP cache layer backed by a `SIMD R Drive` data store.
51#[derive(Clone)]
52pub struct DriveCache {
53    store: Arc<DataStore>,
54    policy: CachePolicy, // Configurable policy
55}
56
57impl DriveCache {
58    /// Creates a new cache backed by a file-based data store.
59    ///
60    /// # Arguments
61    ///
62    /// * `cache_storage_file` - Path to the file where cached responses are stored.
63    /// * `policy` - Configuration specifying cache expiration behavior.
64    ///
65    /// # Panics
66    ///
67    /// This function will panic if the `DataStore` fails to initialize.
68    pub fn new(cache_storage_file: &Path, policy: CachePolicy) -> Self {
69        Self {
70            store: Arc::new(DataStore::open(cache_storage_file).unwrap()),
71            policy,
72        }
73    }
74
75    /// Creates a new cache using an existing `Arc<DataStore>`.
76    ///
77    /// This allows sharing the cache store across multiple components.
78    ///
79    /// # Arguments
80    ///
81    /// * `store` - A shared `Arc<DataStore>` instance.
82    /// * `policy` - Cache expiration configuration.
83    pub fn with_drive_arc(store: Arc<DataStore>, policy: CachePolicy) -> Self {
84        Self {
85            store,
86            policy,
87        }
88    }
89
90    /// Checks whether a request is cached and still valid.
91    ///
92    /// This method retrieves the cache entry associated with the request
93    /// and determines if it is still within its valid TTL.
94    ///
95    /// # Arguments
96    ///
97    /// * `req` - The HTTP request to check for a cached response.
98    ///
99    /// # Returns
100    ///
101    /// Returns `true` if the request has a valid cached response; otherwise, `false`.
102    pub async fn is_cached(&self, req: &Request) -> bool {
103        let store = self.store.as_ref();
104
105        let cache_key = self.generate_cache_key(req);
106        let cache_key_bytes = cache_key.as_bytes();
107
108        // let store = self.store.read().await;
109        if let Some(entry_handle) = store.read(cache_key_bytes) {
110            eprintln!("Entry handle: {:?}", entry_handle);
111
112            if let Ok(cached) = bincode::deserialize::<CachedResponse>(entry_handle.as_slice()) {
113                let now = SystemTime::now()
114                    .duration_since(UNIX_EPOCH)
115                    .expect("Time went backwards")
116                    .as_millis() as u64;
117
118                // Extract TTL based on the policy (either from headers or default)
119                let ttl = if self.policy.respect_headers {
120                    // Convert headers back to HeaderMap to extract TTL
121                    let mut headers = HeaderMap::new();
122                    for (k, v) in cached.headers.iter() {
123                        if let Ok(header_name) = k.parse::<http::HeaderName>() {
124                            if let Ok(header_value) = HeaderValue::from_bytes(v) {
125                                headers.insert(header_name, header_value);
126                            }
127                        }
128                    }
129                    Self::extract_ttl(&headers, &self.policy)
130                } else {
131                    self.policy.default_ttl
132                };
133
134                let expected_expiration = cached.expiration_timestamp + ttl.as_millis() as u64;
135
136                // If expired, remove from cache
137                if now >= expected_expiration {
138                    // eprintln!("Determined cache is expired. now - expected_expiration: {:?}", now - expected_expiration);
139                    eprintln!(
140                        "Cache expires at: {}",
141                        chrono::DateTime::from_timestamp_millis(expected_expiration as i64)
142                            .unwrap()
143                    );
144                    eprintln!(
145                        "Expiration timestamp: {}",
146                        chrono::DateTime::from_timestamp_millis(cached.expiration_timestamp as i64)
147                            .unwrap()
148                    );
149                    eprintln!(
150                        "Now: {}",
151                        chrono::DateTime::from_timestamp_millis(now as i64).unwrap()
152                    );
153
154                    // TODO: Rename API method to `delete`
155                    store.delete_entry(cache_key_bytes).ok();
156                    return false;
157                }
158
159                return true;
160            }
161        }
162        false
163    }
164
165    /// Generates a cache key based on the request method, URL, and relevant headers.
166    ///
167    /// The generated key is used to uniquely identify cached responses.
168    ///
169    /// # Arguments
170    ///
171    /// * `req` - The HTTP request for which to generate a cache key.
172    ///
173    /// # Returns
174    ///
175    /// A string representing the cache key.
176    fn generate_cache_key(&self, req: &Request) -> String {
177        let method = req.method();
178        let url = req.url().as_str();
179        let headers = req.headers();
180
181        let relevant_headers = ["accept", "authorization"];
182        let header_string = relevant_headers
183            .iter()
184            .filter_map(|h| headers.get(*h))
185            .map(|v| v.to_str().unwrap_or_default())
186            .collect::<Vec<_>>()
187            .join(",");
188
189        format!("{} {} {}", method, url, header_string)
190    }
191
192    /// Extracts the TTL from HTTP headers or falls back to the default TTL.
193    ///
194    /// # Arguments
195    ///
196    /// * `headers` - The HTTP headers to inspect.
197    /// * `policy` - The cache policy specifying TTL behavior.
198    ///
199    /// # Returns
200    ///
201    /// A `Duration` indicating the cache expiration time.
202    fn extract_ttl(headers: &HeaderMap, policy: &CachePolicy) -> Duration {
203        if !policy.respect_headers {
204            return policy.default_ttl;
205        }
206
207        // Check `Cache-Control: max-age=N`
208        if let Some(cache_control) = headers.get("cache-control") {
209            if let Ok(cache_control) = cache_control.to_str() {
210                for directive in cache_control.split(',') {
211                    if let Some(max_age) = directive.trim().strip_prefix("max-age=") {
212                        if let Ok(seconds) = max_age.parse::<u64>() {
213                            return Duration::from_secs(seconds);
214                        }
215                    }
216                }
217            }
218        }
219
220        // Check `Expires`
221        if let Some(expires) = headers.get("expires") {
222            if let Ok(expires) = expires.to_str() {
223                if let Ok(expiry_time) = DateTime::parse_from_rfc2822(expires) {
224                    if let Some(duration) =
225                        expiry_time.timestamp().checked_sub(Utc::now().timestamp())
226                    {
227                        if duration > 0 {
228                            return Duration::from_secs(duration as u64);
229                        }
230                    }
231                }
232            }
233        }
234
235        // Fallback to default TTL
236        policy.default_ttl
237    }
238}
239
240#[async_trait]
241impl Middleware for DriveCache {
242     /// Intercepts HTTP requests to apply caching behavior.
243    ///
244    /// This method first checks if a valid cached response exists for the incoming request.
245    /// - If a cached response is found and still valid, it is returned immediately.
246    /// - If no cache entry exists, the request is forwarded to the next middleware or backend.
247    /// - If a response is received, it is cached according to the defined `CachePolicy`.
248    ///
249    /// This middleware **only caches GET and HEAD requests**. Other HTTP methods are passed through without caching.
250    ///
251    /// # Arguments
252    ///
253    /// * `req` - The incoming HTTP request.
254    /// * `extensions` - A mutable reference to request extensions, which may store metadata.
255    /// * `next` - The next middleware in the processing chain.
256    ///
257    /// # Returns
258    ///
259    /// A `Result<Response, reqwest_middleware::Error>` that contains either:
260    /// - A cached response (if available).
261    /// - A fresh response from the backend, which is then cached (if applicable).
262    ///
263    /// # Behavior
264    ///
265    /// - If the request is **already cached and valid**, returns the cached response.
266    /// - If **no cache is found**, the request is sent to the backend, and the response is cached.
267    /// - If **the cache has expired**, the old entry is deleted, and a fresh request is made.
268    async fn handle(
269        &self,
270        req: Request,
271        extensions: &mut Extensions,
272        next: Next<'_>,
273    ) -> Result<Response> {
274        let cache_key = self.generate_cache_key(&req);
275
276        eprintln!("Handle cache key: {}", cache_key);
277
278        let store = self.store.as_ref();
279        let cache_key_bytes = cache_key.as_bytes();
280
281        if req.method() == "GET" || req.method() == "HEAD" {
282            // Use is_cached() to determine if the cache should be used
283            if self.is_cached(&req).await {
284                // let store = self.store.read().await;
285                if let Some(entry_handle) = store.read(cache_key_bytes) {
286                    if let Ok(cached) =
287                        bincode::deserialize::<CachedResponse>(entry_handle.as_slice())
288                    {
289                        let mut headers = HeaderMap::new();
290                        for (k, v) in cached.headers {
291                            if let Ok(header_name) = k.parse::<http::HeaderName>() {
292                                if let Ok(header_value) = HeaderValue::from_bytes(&v) {
293                                    headers.insert(header_name, header_value);
294                                }
295                            }
296                        }
297                        let status = StatusCode::from_u16(cached.status).unwrap_or(StatusCode::OK);
298                        return Ok(build_response(status, headers, Bytes::from(cached.body)));
299                    }
300                }
301            }
302
303            let response = next.run(req, extensions).await?;
304            let status = response.status();
305            let headers = response.headers().clone();
306            let body = response.bytes().await?.to_vec();
307
308            let ttl = Self::extract_ttl(&headers, &self.policy);
309            let expiration_timestamp = SystemTime::now()
310                .duration_since(UNIX_EPOCH)
311                .expect("Time went backwards")
312                .as_millis() as u64
313                + ttl.as_millis() as u64;
314
315            let body_clone = body.clone(); // Fix: Clone before moving
316
317            // Determine whether to cache the response
318            let should_cache = match &self.policy.cache_status_override {
319                Some(status_codes) => status_codes.contains(&status.as_u16()), // Use the override
320                None => status.is_success(), // Default: Cache only success responses (2xx)
321            };
322
323            if should_cache {
324                let serialized = bincode::serialize(&CachedResponse {
325                    status: status.as_u16(),
326                    headers: headers
327                        .iter()
328                        .map(|(k, v)| (k.to_string(), v.as_bytes().to_vec()))
329                        .collect(),
330                    body, // Move the original body here
331                    expiration_timestamp,
332                })
333                .expect("Serialization failed");
334    
335                {
336                    let store = self.store.as_ref();
337    
338                    eprintln!("Writing cache with key: {}", cache_key);
339                    store.write(cache_key_bytes, serialized.as_slice()).ok();
340                }
341            }
342            
343            return Ok(build_response(status, headers, Bytes::from(body_clone)));
344        }
345
346        next.run(req, extensions).await
347    }
348}
349
350/// Constructs a `reqwest::Response` from a given status code, headers, and body.
351///
352/// This function is used to rebuild an HTTP response from cached data,
353/// ensuring that it correctly retains headers and status information.
354///
355/// # Arguments
356///
357/// * `status` - The HTTP status code of the response.
358/// * `headers` - A `HeaderMap` containing response headers.
359/// * `body` - A `Bytes` object containing the response body.
360///
361/// # Returns
362///
363/// A `reqwest::Response` representing the reconstructed HTTP response.
364///
365/// # Panics
366///
367/// This function will panic if the response body fails to be constructed.
368fn build_response(status: StatusCode, headers: HeaderMap, body: Bytes) -> Response {
369    let mut response_builder = http::Response::builder().status(status);
370
371    for (key, value) in headers.iter() {
372        response_builder = response_builder.header(key, value);
373    }
374
375    let http_response = response_builder
376        .body(body)
377        .expect("Failed to create HTTP response");
378
379    Response::from(http_response)
380}