Skip to main content

reqwest_drive/
cache_middleware.rs

1use async_trait::async_trait;
2// Binary serialization
3use bitcode::{Decode, Encode};
4use bytes::Bytes;
5use cache_manager::{CacheRoot, ProcessScopedCacheGroup};
6use chrono::{DateTime, Utc};
7use http::{Extensions, HeaderMap, HeaderValue, StatusCode};
8use reqwest::{Request, Response};
9use reqwest_middleware::{Middleware, Next, Result};
10use simd_r_drive::traits::{DataStoreReader, DataStoreWriter};
11use simd_r_drive::{DataStore, compute_hash};
12use std::io;
13use std::path::Path;
14use std::sync::Arc;
15use std::time::{Duration, SystemTime, UNIX_EPOCH}; // For parsing `Expires` headers
16
17/// Per-request control for bypassing cache behavior.
18///
19/// When set to `CacheBypass(true)` in request extensions, the cache middleware
20/// will skip both cache reads and cache writes for that request.
21///
22/// This is useful when you want a one-off fresh fetch while still reusing the
23/// same client, cache store, and throttle middleware stack.
24///
25/// # Example
26///
27/// ```rust
28/// use reqwest_drive::{CacheBypass, CachePolicy, ThrottlePolicy, init_cache_with_throttle};
29/// use reqwest_middleware::ClientBuilder;
30/// use tempfile::tempdir;
31///
32/// # #[tokio::main]
33/// # async fn main() {
34/// let temp_dir = tempdir().unwrap();
35/// let cache_path = temp_dir.path().join("cache_storage.bin");
36///
37/// let (cache, throttle) = init_cache_with_throttle(
38///     &cache_path,
39///     CachePolicy::default(),
40///     ThrottlePolicy::default(),
41/// );
42///
43/// let client = ClientBuilder::new(reqwest::Client::new())
44///     .with_arc(cache)
45///     .with_arc(throttle)
46///     .build();
47///
48/// let mut request = client.get("https://example.com");
49/// request.extensions().insert(CacheBypass(true));
50/// let _ = request.send().await;
51/// # }
52/// ```
53#[derive(Clone, Copy, Debug, Default)]
54pub struct CacheBypass(pub bool);
55
56/// Per-request control for busting and refreshing cache behavior.
57///
58/// When set to `CacheBust(true)` in request extensions, the cache middleware
59/// skips cache reads for that request, forces a fresh network fetch, and then
60/// writes the new response back to cache (subject to `CachePolicy`).
61///
62/// This is useful when you want to refresh a stale entry and make future
63/// non-busted requests use the updated cached response.
64///
65/// # Example
66///
67/// ```rust
68/// use reqwest_drive::{CacheBust, CachePolicy, ThrottlePolicy, init_cache_with_throttle};
69/// use reqwest_middleware::ClientBuilder;
70/// use tempfile::tempdir;
71///
72/// # #[tokio::main]
73/// # async fn main() {
74/// let temp_dir = tempdir().unwrap();
75/// let cache_path = temp_dir.path().join("cache_storage.bin");
76///
77/// let (cache, throttle) = init_cache_with_throttle(
78///     &cache_path,
79///     CachePolicy::default(),
80///     ThrottlePolicy::default(),
81/// );
82///
83/// let client = ClientBuilder::new(reqwest::Client::new())
84///     .with_arc(cache)
85///     .with_arc(throttle)
86///     .build();
87///
88/// let mut request = client.get("https://example.com");
89/// request.extensions().insert(CacheBust(true));
90/// let _ = request.send().await;
91/// # }
92/// ```
93#[derive(Clone, Copy, Debug, Default)]
94pub struct CacheBust(pub bool);
95
96/// Defines the caching policy for storing and retrieving responses.
97#[derive(Clone, Debug)]
98pub struct CachePolicy {
99    /// Defines the caching policy for storing and retrieving responses.
100    pub default_ttl: Duration,
101    /// Determines whether cache expiration should respect HTTP headers.
102    pub respect_headers: bool,
103    /// Optional override for caching specific HTTP status codes.
104    /// - If `None`, only success responses (`2xx`) are cached.
105    /// - If `Some(Vec<u16>)`, only the specified status codes are cached.
106    pub cache_status_override: Option<Vec<u16>>,
107}
108
109impl Default for CachePolicy {
110    fn default() -> Self {
111        Self {
112            default_ttl: Duration::from_secs(60 * 60 * 24), // Default 1 day TTL
113            respect_headers: true,                          // Use headers if available
114            cache_status_override: None, // Default behavior: Cache only 2xx responses
115        }
116    }
117}
118
119/// Represents a cached HTTP response.
120#[derive(Encode, Decode)]
121struct CachedResponse {
122    /// HTTP status code of the cached response.
123    status: u16,
124    /// HTTP headers stored as key-value pairs, where values are raw bytes.
125    headers: Vec<(String, Vec<u8>)>,
126    /// Response body stored as raw bytes.
127    body: Vec<u8>,
128    /// Unix timestamp (in milliseconds) indicating when the cache entry expires.
129    expiration_timestamp: u64,
130}
131
132/// Provides an HTTP cache layer backed by a `SIMD R Drive` data store.
133///
134/// ## Concurrency model
135///
136/// - Thread-safe for concurrent access within a single process.
137/// - Not multi-process safe for concurrent access to the same backing file.
138///
139/// If multiple processes need caching, use process-level coordination
140/// (e.g., external locking/ownership) or separate cache files per process.
141#[derive(Clone)]
142pub struct DriveCache {
143    store: Arc<DataStore>,
144    policy: CachePolicy, // Configurable policy
145    _process_scoped_group: Option<Arc<ProcessScopedCacheGroup>>,
146}
147
148impl DriveCache {
149    /// Creates a new cache backed by a file-based data store.
150    ///
151    /// # Arguments
152    ///
153    /// * `cache_storage_file` - Path to the file where cached responses are stored.
154    /// * `policy` - Configuration specifying cache expiration behavior.
155    ///
156    /// # Concurrency
157    ///
158    /// The cache is thread-safe within a process, but the backing file should
159    /// not be shared for concurrent reads/writes across multiple processes.
160    ///
161    /// # Panics
162    ///
163    /// This function will panic if the `DataStore` fails to initialize.
164    pub fn new(cache_storage_file: &Path, policy: CachePolicy) -> Self {
165        Self {
166            store: Arc::new(DataStore::open(cache_storage_file).unwrap()),
167            policy,
168            _process_scoped_group: None,
169        }
170    }
171
172    /// Creates a new cache using discovered `.cache` root and a process-scoped storage bin.
173    ///
174    /// The cache group is derived from this crate name (`reqwest-drive`), and the entry
175    /// file is created under a process/thread scoped subdirectory so callers do not need
176    /// to manually provide a cache path.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if discovery or process-scoped directory/file initialization fails.
181    pub fn new_process_scoped(policy: CachePolicy) -> io::Result<Self> {
182        let cache_root = CacheRoot::from_discovery()?;
183        let scoped_group = Arc::new(ProcessScopedCacheGroup::new(
184            &cache_root,
185            env!("CARGO_PKG_NAME"),
186        )?);
187        let cache_storage_file = scoped_group.touch_thread_entry("cache_storage.bin")?;
188        let store = DataStore::open(&cache_storage_file).map_err(|err| {
189            io::Error::other(format!(
190                "failed to open DataStore at {}: {err}",
191                cache_storage_file.display()
192            ))
193        })?;
194
195        Ok(Self {
196            store: Arc::new(store),
197            policy,
198            _process_scoped_group: Some(scoped_group),
199        })
200    }
201
202    /// Creates a new cache using an existing `Arc<DataStore>`.
203    ///
204    /// This allows sharing the cache store across multiple components.
205    ///
206    /// # Arguments
207    ///
208    /// * `store` - A shared `Arc<DataStore>` instance.
209    /// * `policy` - Cache expiration configuration.
210    ///
211    /// # Concurrency
212    ///
213    /// This is thread-safe within a process. Avoid concurrent multi-process
214    /// access to the same underlying store/file.
215    pub fn with_drive_arc(store: Arc<DataStore>, policy: CachePolicy) -> Self {
216        Self {
217            store,
218            policy,
219            _process_scoped_group: None,
220        }
221    }
222
223    /// Checks whether a request is cached and still valid.
224    ///
225    /// This method retrieves the cache entry associated with the request
226    /// and determines if it is still within its valid TTL.
227    ///
228    /// # Arguments
229    ///
230    /// * `req` - The HTTP request to check for a cached response.
231    ///
232    /// # Returns
233    ///
234    /// Returns `true` if the request has a valid cached response; otherwise, `false`.
235    pub async fn is_cached(&self, req: &Request) -> bool {
236        let store = self.store.as_ref();
237
238        let cache_key = self.generate_cache_key(req);
239        let cache_key_bytes = cache_key.as_bytes();
240
241        // let store = self.store.read().await;
242        if let Ok(Some(entry_handle)) = store.read(cache_key_bytes) {
243            tracing::debug!("Entry handle: {:?}", entry_handle);
244
245            if let Ok(cached) = bitcode::decode::<CachedResponse>(entry_handle.as_slice()) {
246                let now = SystemTime::now()
247                    .duration_since(UNIX_EPOCH)
248                    .expect("Time went backwards")
249                    .as_millis() as u64;
250
251                // Extract TTL based on the policy (either from headers or default)
252                let ttl = if self.policy.respect_headers {
253                    // Convert headers back to HeaderMap to extract TTL
254                    let mut headers = HeaderMap::new();
255                    for (k, v) in cached.headers.iter() {
256                        if let Ok(header_name) = k.parse::<http::HeaderName>()
257                            && let Ok(header_value) = HeaderValue::from_bytes(v)
258                        {
259                            headers.insert(header_name, header_value);
260                        }
261                    }
262                    Self::extract_ttl(&headers, &self.policy)
263                } else {
264                    self.policy.default_ttl
265                };
266
267                let expected_expiration = cached.expiration_timestamp + ttl.as_millis() as u64;
268
269                // If expired, remove from cache
270                if now >= expected_expiration {
271                    // tracing::debug!("Determined cache is expired. now - expected_expiration: {:?}", now - expected_expiration);
272                    tracing::debug!(
273                        "Cache expires at: {}",
274                        chrono::DateTime::from_timestamp_millis(expected_expiration as i64)
275                            .unwrap()
276                    );
277                    tracing::debug!(
278                        "Expiration timestamp: {}",
279                        chrono::DateTime::from_timestamp_millis(cached.expiration_timestamp as i64)
280                            .unwrap()
281                    );
282                    tracing::debug!(
283                        "Now: {}",
284                        chrono::DateTime::from_timestamp_millis(now as i64).unwrap()
285                    );
286
287                    store.delete(cache_key_bytes).ok();
288                    return false;
289                }
290
291                return true;
292            }
293        }
294        false
295    }
296
297    /// Generates a cache key based on request method, canonicalized URL, and relevant headers.
298    ///
299    /// The generated key is used to uniquely identify cached responses.
300    ///
301    /// Key strategy:
302    /// - Includes request method.
303    /// - Canonicalizes URL query parameters by sorting them by key/value.
304    /// - Includes selected representation-affecting headers.
305    /// - Hashes sensitive header values (e.g. Authorization) before adding them to key material.
306    ///
307    /// # Arguments
308    ///
309    /// * `req` - The HTTP request for which to generate a cache key.
310    ///
311    /// # Returns
312    ///
313    /// A string representing the cache key.
314    fn generate_cache_key(&self, req: &Request) -> String {
315        let method = req.method();
316        let url = Self::canonicalize_url(req.url());
317        let headers = req.headers();
318
319        let relevant_headers = [
320            "accept",
321            "accept-language",
322            "content-type",
323            "authorization",
324            "x-api-key",
325        ];
326
327        let header_string = relevant_headers
328            .iter()
329            .filter_map(|name| {
330                headers.get(*name).map(|value| {
331                    let value_str = if Self::is_sensitive_header(name) {
332                        format!("h:{:016x}", compute_hash(value.as_bytes()))
333                    } else {
334                        value.to_str().unwrap_or_default().to_string()
335                    };
336
337                    format!("{}={}", name, value_str)
338                })
339            })
340            .collect::<Vec<_>>()
341            .join("&");
342
343        format!("{} {} {}", method, url, header_string)
344    }
345
346    fn canonicalize_url(url: &reqwest::Url) -> String {
347        let mut normalized = url.clone();
348
349        let mut query_pairs = url
350            .query_pairs()
351            .map(|(k, v)| (k.into_owned(), v.into_owned()))
352            .collect::<Vec<_>>();
353
354        if !query_pairs.is_empty() {
355            query_pairs.sort_by(|(k1, v1), (k2, v2)| k1.cmp(k2).then_with(|| v1.cmp(v2)));
356
357            {
358                let mut serializer = normalized.query_pairs_mut();
359                serializer.clear();
360                for (key, value) in query_pairs.iter() {
361                    serializer.append_pair(key, value);
362                }
363            }
364        }
365
366        normalized.to_string()
367    }
368
369    fn is_sensitive_header(name: &str) -> bool {
370        matches!(
371            name,
372            "authorization" | "proxy-authorization" | "cookie" | "x-api-key"
373        )
374    }
375
376    /// Extracts the TTL from HTTP headers or falls back to the default TTL.
377    ///
378    /// # Arguments
379    ///
380    /// * `headers` - The HTTP headers to inspect.
381    /// * `policy` - The cache policy specifying TTL behavior.
382    ///
383    /// # Returns
384    ///
385    /// A `Duration` indicating the cache expiration time.
386    fn extract_ttl(headers: &HeaderMap, policy: &CachePolicy) -> Duration {
387        if !policy.respect_headers {
388            return policy.default_ttl;
389        }
390
391        if let Some(cache_control) = headers.get("cache-control")
392            && let Ok(cache_control) = cache_control.to_str()
393        {
394            for directive in cache_control.split(',') {
395                if let Some(max_age) = directive.trim().strip_prefix("max-age=")
396                    && let Ok(seconds) = max_age.parse::<u64>()
397                {
398                    return Duration::from_secs(seconds);
399                }
400            }
401        }
402
403        if let Some(expires) = headers.get("expires")
404            && let Ok(expires) = expires.to_str()
405            && let Ok(expiry_time) = DateTime::parse_from_rfc2822(expires)
406            && let Some(duration) = expiry_time.timestamp().checked_sub(Utc::now().timestamp())
407            && duration > 0
408        {
409            return Duration::from_secs(duration as u64);
410        }
411
412        policy.default_ttl
413    }
414}
415
416#[async_trait]
417impl Middleware for DriveCache {
418    /// Intercepts HTTP requests to apply caching behavior.
419    ///
420    /// This method first checks if a valid cached response exists for the incoming request.
421    /// - If a cached response is found and still valid, it is returned immediately.
422    /// - If no cache entry exists, the request is forwarded to the next middleware or backend.
423    /// - If a response is received, it is cached according to the defined `CachePolicy`.
424    ///
425    /// This middleware **only caches GET and HEAD requests**. Other HTTP methods are passed through without caching.
426    ///
427    /// # Arguments
428    ///
429    /// * `req` - The incoming HTTP request.
430    /// * `extensions` - A mutable reference to request extensions, which may store metadata.
431    /// * `next` - The next middleware in the processing chain.
432    ///
433    /// # Returns
434    ///
435    /// A `Result<Response, reqwest_middleware::Error>` that contains either:
436    /// - A cached response (if available).
437    /// - A fresh response from the backend, which is then cached (if applicable).
438    ///
439    /// # Behavior
440    ///
441    /// - If the request is **already cached and valid**, returns the cached response.
442    /// - If **no cache is found**, the request is sent to the backend, and the response is cached.
443    /// - If **the cache has expired**, the old entry is deleted, and a fresh request is made.
444    async fn handle(
445        &self,
446        req: Request,
447        extensions: &mut Extensions,
448        next: Next<'_>,
449    ) -> Result<Response> {
450        let bypass_cache = extensions
451            .get::<CacheBypass>()
452            .map(|flag| flag.0)
453            .unwrap_or(false);
454        let bust_cache = extensions
455            .get::<CacheBust>()
456            .map(|flag| flag.0)
457            .unwrap_or(false);
458
459        let cache_key = self.generate_cache_key(&req);
460
461        tracing::debug!("Handle cache key: {}", cache_key);
462
463        let store = self.store.as_ref();
464        let cache_key_bytes = cache_key.as_bytes();
465
466        if req.method() == "GET" || req.method() == "HEAD" {
467            if !bypass_cache
468                && !bust_cache
469                && self.is_cached(&req).await
470                && let Ok(Some(entry_handle)) = store.read(cache_key_bytes)
471                && let Ok(cached) = bitcode::decode::<CachedResponse>(entry_handle.as_slice())
472            {
473                let mut headers = HeaderMap::new();
474                for (k, v) in cached.headers {
475                    if let Ok(header_name) = k.parse::<http::HeaderName>()
476                        && let Ok(header_value) = HeaderValue::from_bytes(&v)
477                    {
478                        headers.insert(header_name, header_value);
479                    }
480                }
481                let status = StatusCode::from_u16(cached.status).unwrap_or(StatusCode::OK);
482                return Ok(build_response(status, headers, Bytes::from(cached.body)));
483            }
484
485            let response = next.run(req, extensions).await?;
486            let status = response.status();
487            let headers = response.headers().clone();
488            let body = response.bytes().await?.to_vec();
489
490            let ttl = Self::extract_ttl(&headers, &self.policy);
491            let expiration_timestamp = SystemTime::now()
492                .duration_since(UNIX_EPOCH)
493                .expect("Time went backwards")
494                .as_millis() as u64
495                + ttl.as_millis() as u64;
496
497            let body_clone = body.clone();
498
499            let should_cache = match &self.policy.cache_status_override {
500                Some(status_codes) => status_codes.contains(&status.as_u16()),
501                None => status.is_success(),
502            };
503
504            if should_cache && !bypass_cache {
505                let serialized = bitcode::encode(&CachedResponse {
506                    status: status.as_u16(),
507                    headers: headers
508                        .iter()
509                        .map(|(k, v)| (k.to_string(), v.as_bytes().to_vec()))
510                        .collect(),
511                    body,
512                    expiration_timestamp,
513                });
514
515                tracing::debug!("Writing cache with key: {}", cache_key);
516                store.write(cache_key_bytes, serialized.as_slice()).ok();
517            }
518
519            return Ok(build_response(status, headers, Bytes::from(body_clone)));
520        }
521
522        next.run(req, extensions).await
523    }
524}
525
526/// Constructs a `reqwest::Response` from a given status code, headers, and body.
527///
528/// This function is used to rebuild an HTTP response from cached data,
529/// ensuring that it correctly retains headers and status information.
530///
531/// # Arguments
532///
533/// * `status` - The HTTP status code of the response.
534/// * `headers` - A `HeaderMap` containing response headers.
535/// * `body` - A `Bytes` object containing the response body.
536///
537/// # Returns
538///
539/// A `reqwest::Response` representing the reconstructed HTTP response.
540///
541/// # Panics
542///
543/// This function will panic if the response body fails to be constructed.
544fn build_response(status: StatusCode, headers: HeaderMap, body: Bytes) -> Response {
545    let mut response_builder = http::Response::builder().status(status);
546
547    for (key, value) in headers.iter() {
548        response_builder = response_builder.header(key, value);
549    }
550
551    let http_response = response_builder
552        .body(body)
553        .expect("Failed to create HTTP response");
554
555    Response::from(http_response)
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use rand::rngs::StdRng;
562    use rand::{RngExt, SeedableRng};
563    use reqwest::Method;
564    use std::collections::{HashMap, HashSet};
565    use std::time::{SystemTime, UNIX_EPOCH};
566    use tempfile::TempDir;
567
568    fn build_request(method: Method, url: &str, headers: &[(&str, Option<&str>)]) -> Request {
569        // Construct `reqwest::Request` directly rather than building a
570        // `reqwest::Client` per-iteration. Creating a `Client` repeatedly in
571        // tight loops was the dominant cost on some CI runners; building the
572        // `Request` directly avoids that overhead while remaining functionally
573        // equivalent for these key-generation tests.
574        let mut request = Request::new(
575            method,
576            reqwest::Url::parse(url).expect("failed to parse request URL"),
577        );
578
579        for (name, value) in headers {
580            if let Some(value) = value {
581                let header_name = http::header::HeaderName::from_bytes(name.as_bytes())
582                    .expect("invalid header name");
583                let header_value =
584                    http::header::HeaderValue::from_str(value).expect("invalid header value");
585                request.headers_mut().insert(header_name, header_value);
586            }
587        }
588
589        request
590    }
591
592    fn build_cache_for_tests() -> DriveCache {
593        let temp_dir = TempDir::new().expect("failed to create temp dir");
594        let cache_path = temp_dir.path().join("cache_key_matrix.bin");
595        DriveCache::new(&cache_path, CachePolicy::default())
596    }
597
598    fn random_token(rng: &mut StdRng, min_len: usize, max_len: usize) -> String {
599        let alphabet = b"abcdefghijklmnopqrstuvwxyz0123456789";
600        let token_len = rng.random_range(min_len..=max_len);
601
602        (0..token_len)
603            .map(|_| {
604                let index = rng.random_range(0..alphabet.len());
605                alphabet[index] as char
606            })
607            .collect()
608    }
609
610    fn build_random_request(rng: &mut StdRng) -> Request {
611        let methods = [
612            Method::GET,
613            Method::HEAD,
614            Method::POST,
615            Method::PUT,
616            Method::PATCH,
617            Method::DELETE,
618        ];
619
620        let method = methods[rng.random_range(0..methods.len())].clone();
621        let mut url = format!(
622            "https://example.test/{}/{}",
623            random_token(rng, 3, 10),
624            random_token(rng, 3, 10)
625        );
626
627        let query_pair_count = rng.random_range(0..=6);
628        if query_pair_count > 0 {
629            url.push('?');
630            for query_index in 0..query_pair_count {
631                if query_index > 0 {
632                    url.push('&');
633                }
634
635                let query_key = random_token(rng, 1, 8);
636                let query_value = random_token(rng, 0, 12);
637                url.push_str(&query_key);
638                url.push('=');
639                url.push_str(&query_value);
640            }
641        }
642
643        let mut request = Request::new(
644            method,
645            reqwest::Url::parse(&url).expect("failed to parse randomized URL"),
646        );
647
648        if rng.random::<bool>() {
649            let accept_values = ["application/json", "text/plain", "*/*"];
650            request.headers_mut().insert(
651                http::header::ACCEPT,
652                http::header::HeaderValue::from_str(accept_values[rng.random_range(0..3)])
653                    .expect("invalid accept header value"),
654            );
655        }
656
657        if rng.random::<bool>() {
658            let language_values = ["en-US", "fr-FR", "es-ES", "de-DE"];
659            request.headers_mut().insert(
660                http::header::ACCEPT_LANGUAGE,
661                http::header::HeaderValue::from_str(language_values[rng.random_range(0..4)])
662                    .expect("invalid accept-language header value"),
663            );
664        }
665
666        if rng.random::<bool>() {
667            let content_type_values = ["application/json", "application/xml", "text/plain"];
668            request.headers_mut().insert(
669                http::header::CONTENT_TYPE,
670                http::header::HeaderValue::from_str(content_type_values[rng.random_range(0..3)])
671                    .expect("invalid content-type header value"),
672            );
673        }
674
675        if rng.random::<bool>() {
676            let authorization_value = format!("Bearer {}", random_token(rng, 16, 48));
677            request.headers_mut().insert(
678                http::header::AUTHORIZATION,
679                http::header::HeaderValue::from_str(&authorization_value)
680                    .expect("invalid authorization header value"),
681            );
682        }
683
684        if rng.random::<bool>() {
685            let api_key_value = random_token(rng, 12, 32);
686            request.headers_mut().insert(
687                http::header::HeaderName::from_static("x-api-key"),
688                http::header::HeaderValue::from_str(&api_key_value)
689                    .expect("invalid x-api-key header value"),
690            );
691        }
692
693        request
694    }
695
696    #[test]
697    fn fuzz_cache_key_hash_collisions_uses_library_key_generator() {
698        let temp_dir = TempDir::new().expect("failed to create temp dir");
699        let cache_path = temp_dir.path().join("cache_key_fuzz.bin");
700        let cache = DriveCache::new(&cache_path, CachePolicy::default());
701
702        let mut observed_hash_to_key: HashMap<u64, String> = HashMap::new();
703        let mut random_generator = StdRng::seed_from_u64(0xD15EA5E5);
704
705        let sample_count = 50_000;
706
707        let mut distinct_key_count = 0usize;
708
709        // Seed one known key first so the equality-assert branch is exercised
710        // deterministically without relying on random hash collisions.
711        let duplicate_request = build_request(
712            Method::GET,
713            "https://example.test/duplicate?a=1&b=2",
714            &[("accept", Some("application/json"))],
715        );
716        let duplicate_key = cache.generate_cache_key(&duplicate_request);
717        let duplicate_hash = compute_hash(duplicate_key.as_bytes());
718        observed_hash_to_key.insert(duplicate_hash, duplicate_key.clone());
719        if let Some(existing_key) = observed_hash_to_key.get(&duplicate_hash) {
720            assert_eq!(existing_key, &duplicate_key);
721        }
722
723        for sample_index in 0..sample_count {
724            let request = if sample_index == 0 {
725                // Reuse the exact same request as the seeded entry so this
726                // loop deterministically hits the "hash already seen" branch.
727                // We are NOT expecting collisions between distinct keys.
728                // A distinct-key collision still fails the test via `assert_eq!`.
729                build_request(
730                    Method::GET,
731                    "https://example.test/duplicate?a=1&b=2",
732                    &[("accept", Some("application/json"))],
733                )
734            } else {
735                build_random_request(&mut random_generator)
736            };
737
738            let cache_key = cache.generate_cache_key(&request);
739            let hash = compute_hash(cache_key.as_bytes());
740
741            if let Some(existing_key) = observed_hash_to_key.get(&hash) {
742                assert_eq!(
743                    existing_key, &cache_key,
744                    "hash collision detected for distinct cache keys"
745                );
746            } else {
747                observed_hash_to_key.insert(hash, cache_key);
748                distinct_key_count += 1;
749            }
750        }
751
752        assert!(
753            distinct_key_count > sample_count / 2,
754            "random generation produced too few distinct keys"
755        );
756    }
757
758    #[tokio::test]
759    async fn is_cached_uses_default_ttl_when_respect_headers_is_disabled() {
760        let temp_dir = TempDir::new().expect("failed to create temp dir");
761        let cache_path = temp_dir.path().join("cache_default_ttl.bin");
762        let cache = DriveCache::new(
763            &cache_path,
764            CachePolicy {
765                default_ttl: Duration::from_secs(60),
766                respect_headers: false,
767                cache_status_override: None,
768            },
769        );
770
771        let request = build_request(
772            Method::GET,
773            "https://example.test/default-ttl",
774            &[("accept", Some("application/json"))],
775        );
776        let cache_key = cache.generate_cache_key(&request);
777        let cache_key_bytes = cache_key.as_bytes();
778
779        let now = SystemTime::now()
780            .duration_since(UNIX_EPOCH)
781            .expect("time went backwards")
782            .as_millis() as u64;
783
784        let cached = CachedResponse {
785            status: 200,
786            headers: vec![("cache-control".to_string(), b"max-age=0".to_vec())],
787            body: b"ok".to_vec(),
788            expiration_timestamp: now,
789        };
790
791        let serialized = bitcode::encode(&cached);
792        cache
793            .store
794            .as_ref()
795            .write(cache_key_bytes, serialized.as_slice())
796            .expect("write cached entry");
797
798        assert!(cache.is_cached(&request).await);
799    }
800
801    #[tokio::test]
802    async fn is_cached_evicts_entry_when_expired() {
803        let temp_dir = TempDir::new().expect("failed to create temp dir");
804        let cache_path = temp_dir.path().join("cache_expired_evict.bin");
805        let cache = DriveCache::new(
806            &cache_path,
807            CachePolicy {
808                default_ttl: Duration::from_millis(0),
809                respect_headers: false,
810                cache_status_override: None,
811            },
812        );
813
814        let request = build_request(
815            Method::GET,
816            "https://example.test/expired-entry",
817            &[("accept", Some("application/json"))],
818        );
819        let cache_key = cache.generate_cache_key(&request);
820        let cache_key_bytes = cache_key.as_bytes();
821
822        let cached = CachedResponse {
823            status: 200,
824            headers: Vec::new(),
825            body: b"stale".to_vec(),
826            expiration_timestamp: 0,
827        };
828
829        let serialized = bitcode::encode(&cached);
830        cache
831            .store
832            .as_ref()
833            .write(cache_key_bytes, serialized.as_slice())
834            .expect("write cached entry");
835
836        assert!(!cache.is_cached(&request).await);
837        let stored = cache
838            .store
839            .as_ref()
840            .read(cache_key_bytes)
841            .expect("read cache key after eviction");
842        assert!(stored.is_none(), "expired key should be evicted");
843    }
844
845    #[test]
846    fn extract_ttl_returns_default_when_header_respect_is_disabled() {
847        let policy = CachePolicy {
848            default_ttl: Duration::from_secs(321),
849            respect_headers: false,
850            cache_status_override: None,
851        };
852
853        let mut headers = HeaderMap::new();
854        headers.insert("cache-control", HeaderValue::from_static("max-age=1"));
855
856        assert_eq!(
857            DriveCache::extract_ttl(&headers, &policy),
858            policy.default_ttl
859        );
860    }
861
862    #[test]
863    fn extract_ttl_uses_cache_control_max_age_when_present() {
864        let policy = CachePolicy {
865            default_ttl: Duration::from_secs(321),
866            respect_headers: true,
867            cache_status_override: None,
868        };
869
870        let mut headers = HeaderMap::new();
871        headers.insert(
872            "cache-control",
873            HeaderValue::from_static("public, max-age=42"),
874        );
875
876        assert_eq!(
877            DriveCache::extract_ttl(&headers, &policy),
878            Duration::from_secs(42)
879        );
880    }
881
882    #[test]
883    fn extract_ttl_uses_expires_header_when_cache_control_missing() {
884        let policy = CachePolicy {
885            default_ttl: Duration::from_secs(600),
886            respect_headers: true,
887            cache_status_override: None,
888        };
889
890        let mut headers = HeaderMap::new();
891        let future = (Utc::now() + chrono::Duration::seconds(120)).to_rfc2822();
892        headers.insert(
893            "expires",
894            HeaderValue::from_str(&future).expect("expires header should be valid"),
895        );
896
897        let ttl = DriveCache::extract_ttl(&headers, &policy);
898        assert!(ttl > Duration::from_secs(0));
899        assert!(ttl < policy.default_ttl);
900    }
901
902    #[test]
903    fn exhaustive_cache_key_matrix_no_hash_collisions_for_distinct_keys() {
904        let cache = build_cache_for_tests();
905
906        let (methods, paths, queries) = (
907            vec![
908                Method::GET,
909                Method::HEAD,
910                Method::POST,
911                Method::PUT,
912                Method::PATCH,
913                Method::DELETE,
914            ],
915            vec!["/resource", "/resource/v2", "/resource/deep/path"],
916            vec![
917                "", "?a=1", "?a=2", "?a=1&b=2", "?b=2&a=1", "?a=1&a=2", "?a=1&a=3", "?z=9",
918            ],
919        );
920
921        let accept_values = [None, Some("application/json"), Some("text/plain")];
922        let language_values = [None, Some("en-US"), Some("fr-FR")];
923        let content_type_values = [None, Some("application/json"), Some("application/xml")];
924        let authorization_values = [None, Some("Bearer alpha-token"), Some("Bearer beta-token")];
925        let api_key_values = [None, Some("alpha-api-key"), Some("beta-api-key")];
926
927        let mut hash_to_key: HashMap<u64, String> = HashMap::new();
928        let mut distinct_keys: HashSet<String> = HashSet::new();
929        let mut sample_count = 0usize;
930
931        for method in &methods {
932            for path in &paths {
933                for query in &queries {
934                    for accept in accept_values {
935                        for accept_language in language_values {
936                            for content_type in content_type_values {
937                                for authorization in authorization_values {
938                                    for api_key in api_key_values {
939                                        sample_count += 1;
940
941                                        let url = format!("https://example.test{}{}", path, query);
942                                        let request = build_request(
943                                            method.clone(),
944                                            &url,
945                                            &[
946                                                ("accept", accept),
947                                                ("accept-language", accept_language),
948                                                ("content-type", content_type),
949                                                ("authorization", authorization),
950                                                ("x-api-key", api_key),
951                                            ],
952                                        );
953
954                                        let cache_key = cache.generate_cache_key(&request);
955                                        let hash = compute_hash(cache_key.as_bytes());
956
957                                        if let Some(existing_key) = hash_to_key.get(&hash) {
958                                            assert_eq!(
959                                                existing_key, &cache_key,
960                                                "hash collision detected for distinct cache keys"
961                                            );
962                                        } else {
963                                            hash_to_key.insert(hash, cache_key.clone());
964                                        }
965
966                                        distinct_keys.insert(cache_key);
967                                    }
968                                }
969                            }
970                        }
971                    }
972                }
973            }
974        }
975
976        let expected_sample_count = methods.len()
977            * paths.len()
978            * queries.len()
979            * accept_values.len()
980            * language_values.len()
981            * content_type_values.len()
982            * authorization_values.len()
983            * api_key_values.len();
984        assert_eq!(sample_count, expected_sample_count);
985        assert!(
986            distinct_keys.len() > sample_count / 2,
987            "matrix generation produced too few distinct keys"
988        );
989    }
990
991    #[test]
992    fn cache_key_query_reordering_is_canonical_and_hash_stable() {
993        let cache = build_cache_for_tests();
994
995        let request_a = build_request(
996            Method::GET,
997            "https://example.test/resource?a=1&b=2",
998            &[("accept", Some("application/json"))],
999        );
1000        let request_b = build_request(
1001            Method::GET,
1002            "https://example.test/resource?b=2&a=1",
1003            &[("accept", Some("application/json"))],
1004        );
1005
1006        let key_a = cache.generate_cache_key(&request_a);
1007        let key_b = cache.generate_cache_key(&request_b);
1008
1009        assert_eq!(key_a, key_b);
1010        assert_eq!(
1011            compute_hash(key_a.as_bytes()),
1012            compute_hash(key_b.as_bytes())
1013        );
1014    }
1015
1016    #[test]
1017    fn cache_key_changes_for_each_response_affecting_dimension() {
1018        let cache = build_cache_for_tests();
1019
1020        let base_request = build_request(
1021            Method::GET,
1022            "https://example.test/resource?a=1&b=2",
1023            &[
1024                ("accept", Some("application/json")),
1025                ("accept-language", Some("en-US")),
1026                ("content-type", Some("application/json")),
1027                ("authorization", Some("Bearer alpha-token")),
1028                ("x-api-key", Some("alpha-api-key")),
1029            ],
1030        );
1031        let base_key = cache.generate_cache_key(&base_request);
1032        let base_hash = compute_hash(base_key.as_bytes());
1033
1034        let variants = vec![
1035            build_request(
1036                Method::POST,
1037                "https://example.test/resource?a=1&b=2",
1038                &[
1039                    ("accept", Some("application/json")),
1040                    ("accept-language", Some("en-US")),
1041                    ("content-type", Some("application/json")),
1042                    ("authorization", Some("Bearer alpha-token")),
1043                    ("x-api-key", Some("alpha-api-key")),
1044                ],
1045            ),
1046            build_request(
1047                Method::GET,
1048                "https://example.test/resource/v2?a=1&b=2",
1049                &[
1050                    ("accept", Some("application/json")),
1051                    ("accept-language", Some("en-US")),
1052                    ("content-type", Some("application/json")),
1053                    ("authorization", Some("Bearer alpha-token")),
1054                    ("x-api-key", Some("alpha-api-key")),
1055                ],
1056            ),
1057            build_request(
1058                Method::GET,
1059                "https://example.test/resource?a=99&b=2",
1060                &[
1061                    ("accept", Some("application/json")),
1062                    ("accept-language", Some("en-US")),
1063                    ("content-type", Some("application/json")),
1064                    ("authorization", Some("Bearer alpha-token")),
1065                    ("x-api-key", Some("alpha-api-key")),
1066                ],
1067            ),
1068            build_request(
1069                Method::GET,
1070                "https://example.test/resource?a=1&b=2",
1071                &[
1072                    ("accept", Some("text/plain")),
1073                    ("accept-language", Some("en-US")),
1074                    ("content-type", Some("application/json")),
1075                    ("authorization", Some("Bearer alpha-token")),
1076                    ("x-api-key", Some("alpha-api-key")),
1077                ],
1078            ),
1079            build_request(
1080                Method::GET,
1081                "https://example.test/resource?a=1&b=2",
1082                &[
1083                    ("accept", Some("application/json")),
1084                    ("accept-language", Some("fr-FR")),
1085                    ("content-type", Some("application/json")),
1086                    ("authorization", Some("Bearer alpha-token")),
1087                    ("x-api-key", Some("alpha-api-key")),
1088                ],
1089            ),
1090            build_request(
1091                Method::GET,
1092                "https://example.test/resource?a=1&b=2",
1093                &[
1094                    ("accept", Some("application/json")),
1095                    ("accept-language", Some("en-US")),
1096                    ("content-type", Some("application/xml")),
1097                    ("authorization", Some("Bearer alpha-token")),
1098                    ("x-api-key", Some("alpha-api-key")),
1099                ],
1100            ),
1101            build_request(
1102                Method::GET,
1103                "https://example.test/resource?a=1&b=2",
1104                &[
1105                    ("accept", Some("application/json")),
1106                    ("accept-language", Some("en-US")),
1107                    ("content-type", Some("application/json")),
1108                    ("authorization", Some("Bearer beta-token")),
1109                    ("x-api-key", Some("alpha-api-key")),
1110                ],
1111            ),
1112            build_request(
1113                Method::GET,
1114                "https://example.test/resource?a=1&b=2",
1115                &[
1116                    ("accept", Some("application/json")),
1117                    ("accept-language", Some("en-US")),
1118                    ("content-type", Some("application/json")),
1119                    ("authorization", Some("Bearer alpha-token")),
1120                    ("x-api-key", Some("beta-api-key")),
1121                ],
1122            ),
1123        ];
1124
1125        for variant in variants {
1126            let variant_key = cache.generate_cache_key(&variant);
1127            let variant_hash = compute_hash(variant_key.as_bytes());
1128
1129            assert_ne!(
1130                variant_key, base_key,
1131                "variant unexpectedly produced same key"
1132            );
1133            assert_ne!(
1134                variant_hash, base_hash,
1135                "variant unexpectedly produced same hash"
1136            );
1137        }
1138    }
1139
1140    /*
1141    Stress experiment (disabled):
1142    - This 100,000,000-sample collision test worked (no collisions observed).
1143    - End-to-end runtime was several hours, even in `--release` mode.
1144    - A Rayon parallelization attempt did not produce a meaningful speedup for
1145      this workload, so the test is commented out to keep normal test cycles fast.
1146
1147        // Kept as commented code because it is only used by the disabled stress
1148        // test below. If we re-enable that test in the future, this helper can be
1149        // uncommented together with it.
1150        // fn build_unique_request_from_index(index: u64) -> Request {
1151        //     let methods = [
1152        //         Method::GET,
1153        //         Method::HEAD,
1154        //         Method::POST,
1155        //         Method::PUT,
1156        //         Method::PATCH,
1157        //         Method::DELETE,
1158        //     ];
1159        //
1160        //     let method = methods[(index % methods.len() as u64) as usize].clone();
1161        //     let path = format!("/resource/{}/{}/{}", index % 97, index % 503, index % 9973);
1162        //
1163        //     let query = format!(
1164        //         "a={}&b={}&c={}&d={}",
1165        //         index,
1166        //         index.wrapping_mul(31),
1167        //         index.rotate_left(7),
1168        //         index ^ 0xA5A5_A5A5_A5A5_A5A5
1169        //     );
1170        //
1171        //     let url = format!("https://example.test{}?{}", path, query);
1172        //
1173        //     let accept_values = ["application/json", "text/plain", "STAR_SLASH_STAR"];
1174        //     let language_values = ["en-US", "fr-FR", "es-ES", "de-DE"];
1175        //     let content_type_values = ["application/json", "application/xml", "text/plain"];
1176        //
1177        //     let mut request = Request::new(
1178        //         method,
1179        //         reqwest::Url::parse(&url).expect("failed to parse stress URL"),
1180        //     );
1181        //
1182        //     request.headers_mut().insert(
1183        //         http::header::ACCEPT,
1184        //         http::header::HeaderValue::from_str(
1185        //             accept_values[(index % accept_values.len() as u64) as usize],
1186        //         )
1187        //         .expect("invalid accept header value"),
1188        //     );
1189        //     request.headers_mut().insert(
1190        //         http::header::ACCEPT_LANGUAGE,
1191        //         http::header::HeaderValue::from_str(
1192        //             language_values[(index % language_values.len() as u64) as usize],
1193        //         )
1194        //         .expect("invalid accept-language header value"),
1195        //     );
1196        //     request.headers_mut().insert(
1197        //         http::header::CONTENT_TYPE,
1198        //         http::header::HeaderValue::from_str(
1199        //             content_type_values[(index % content_type_values.len() as u64) as usize],
1200        //         )
1201        //         .expect("invalid content-type header value"),
1202        //     );
1203        //
1204        //     let authorization_value = format!("Bearer token-{:016x}", index);
1205        //     request.headers_mut().insert(
1206        //         http::header::AUTHORIZATION,
1207        //         http::header::HeaderValue::from_str(&authorization_value)
1208        //             .expect("invalid authorization header value"),
1209        //     );
1210        //
1211        //     let api_key_value = format!("api-key-{:016x}", index.rotate_right(11));
1212        //     request.headers_mut().insert(
1213        //         http::header::HeaderName::from_static("x-api-key"),
1214        //         http::header::HeaderValue::from_str(&api_key_value)
1215        //             .expect("invalid x-api-key header value"),
1216        //     );
1217        //
1218        //     request
1219        // }
1220
1221    #[test]
1222    #[ignore = "expensive: runs 100,000,000 samples"]
1223    fn cache_key_hash_collision_stress_100_million() {
1224        init_test_tracing();
1225
1226        let cache = build_cache_for_tests();
1227
1228        let samples = 100_000_000_u64;
1229        let mut seen_hashes: HashSet<u64> = HashSet::new();
1230        let started_at = Instant::now();
1231
1232        for index in 0..samples {
1233            let request = build_unique_request_from_index(index);
1234            let cache_key = cache.generate_cache_key(&request);
1235            let hash = compute_hash(cache_key.as_bytes());
1236
1237            assert!(
1238                seen_hashes.insert(hash),
1239                "hash collision detected in stress test at sample index {} (hash={})",
1240                index,
1241                hash
1242            );
1243
1244            let completed = index + 1;
1245            if completed % 10_000 == 0 {
1246                let elapsed = started_at.elapsed();
1247                let pct = (completed as f64 / samples as f64) * 100.0;
1248                tracing::info!(
1249                    "stress progress: {}/{} ({:.4}%) elapsed={:?}",
1250                    completed,
1251                    samples,
1252                    pct,
1253                    elapsed
1254                );
1255            }
1256        }
1257
1258        tracing::info!(
1259            "stress complete: {} samples in {:?}",
1260            samples,
1261            started_at.elapsed()
1262        );
1263    }
1264    */
1265}