Skip to main content

ic_bn_lib/http/
cache.rs

1use std::{
2    fmt::Debug,
3    marker::PhantomData,
4    mem::size_of,
5    sync::Arc,
6    time::{Duration, Instant},
7};
8
9use ahash::RandomState;
10use async_trait::async_trait;
11use axum::{body::Body, extract::Request, middleware::Next, response::Response};
12use bytes::Bytes;
13use http::{
14    Method,
15    header::{CACHE_CONTROL, RANGE},
16};
17use http_body::Body as _;
18use ic_bn_lib_common::{
19    traits::{
20        Run,
21        http::{Bypasser, CustomBypassReason, KeyExtractor},
22    },
23    types::http::{CacheBypassReason, CacheError, Error as HttpError},
24};
25use moka::{
26    Expiry,
27    sync::{Cache as MokaCache, CacheBuilder as MokaCacheBuilder},
28};
29use prometheus::{
30    Counter, CounterVec, Histogram, HistogramVec, IntGauge, Registry,
31    register_counter_vec_with_registry, register_counter_with_registry,
32    register_histogram_vec_with_registry, register_histogram_with_registry,
33    register_int_gauge_with_registry,
34};
35use sha1::{Digest, Sha1};
36use strum_macros::{Display, IntoStaticStr};
37use tokio::{select, sync::Mutex, time::sleep};
38use tokio_util::sync::CancellationToken;
39
40use super::{body::buffer_body, calc_headers_size, extract_authority};
41use crate::http::headers::X_CACHE_TTL;
42
43#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
44pub enum CustomBypassReasonDummy {}
45impl CustomBypassReason for CustomBypassReasonDummy {}
46
47/// Status of the cache lookup operation
48#[derive(Debug, Clone, Display, PartialEq, Eq, Default, IntoStaticStr)]
49#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
50pub enum CacheStatus<R: CustomBypassReason = CustomBypassReasonDummy> {
51    #[default]
52    Disabled,
53    Bypass(CacheBypassReason<R>),
54    Hit,
55    Miss,
56}
57
58impl<B: CustomBypassReason> CacheStatus<B> {
59    /// Injects itself into a given response to be accessible by middleware
60    pub fn with_response<T>(self, mut resp: Response<T>) -> Response<T> {
61        resp.extensions_mut().insert(self);
62        resp
63    }
64}
65
66enum ResponseType<R: CustomBypassReason> {
67    Fetched(Response<Bytes>, Duration),
68    Streamed(Response, CacheBypassReason<R>),
69}
70
71/// Cache entry
72#[derive(Clone)]
73struct Entry {
74    response: Response<Bytes>,
75    /// Time it took to generate the response for given entry.
76    /// Used for x-fetch algorithm.
77    delta: f64,
78    expires: Instant,
79}
80
81impl Entry {
82    /// Probabilistically decide if we need to refresh the given cache entry early.
83    /// This is an implementation of x-fetch algorigthm, see:
84    /// https://en.wikipedia.org/wiki/Cache_stampede#Probabilistic_early_expiration
85    fn need_to_refresh(&self, now: Instant, beta: f64) -> bool {
86        // fast path
87        if beta == 0.0 {
88            return false;
89        }
90
91        let rnd = rand::random::<f64>();
92        let xfetch = -(self.delta * beta * rnd.ln());
93        let ttl_left = (self.expires - now).as_secs_f64();
94
95        xfetch > ttl_left
96    }
97}
98
99/// No-op cache bypasser that never bypasses
100#[derive(Debug, Clone)]
101pub struct NoopBypasser;
102
103impl Bypasser for NoopBypasser {
104    type BypassReason = CustomBypassReasonDummy;
105
106    fn bypass<T>(&self, _req: &Request<T>) -> Result<Option<Self::BypassReason>, CacheError> {
107        Ok(None)
108    }
109}
110
111/// Cache metrics
112#[derive(Clone)]
113pub struct Metrics {
114    lock_await: HistogramVec,
115    requests_count: CounterVec,
116    requests_duration: HistogramVec,
117    ttl: Histogram,
118    x_fetch: Counter,
119    memory: IntGauge,
120    entries: IntGauge,
121}
122
123impl Metrics {
124    /// Create new `Metrics`
125    pub fn new(registry: &Registry) -> Self {
126        let lbls = &["cache_status", "cache_bypass_reason"];
127
128        Self {
129            lock_await: register_histogram_vec_with_registry!(
130                "cache_proxy_lock_await",
131                "Time spent waiting for the proxy cache lock",
132                &["lock_obtained"],
133                registry,
134            )
135            .unwrap(),
136
137            requests_count: register_counter_vec_with_registry!(
138                "cache_requests_count",
139                "Cache requests count",
140                lbls,
141                registry,
142            )
143            .unwrap(),
144
145            requests_duration: register_histogram_vec_with_registry!(
146                "cache_requests_duration",
147                "Time it took to execute the request",
148                lbls,
149                registry,
150            )
151            .unwrap(),
152
153            ttl: register_histogram_with_registry!(
154                "cache_ttl",
155                "TTL that was set when storing the response",
156                vec![1.0, 10.0, 100.0, 1000.0, 10000.0, 86400.0],
157                registry,
158            )
159            .unwrap(),
160
161            x_fetch: register_counter_with_registry!(
162                "cache_xfetch_count",
163                "Number of requests that x-fetch refreshed",
164                registry,
165            )
166            .unwrap(),
167
168            memory: register_int_gauge_with_registry!(
169                "cache_memory",
170                "Memory usage by the cache in bytes",
171                registry,
172            )
173            .unwrap(),
174
175            entries: register_int_gauge_with_registry!(
176                "cache_entries",
177                "Count of entries in the cache",
178                registry,
179            )
180            .unwrap(),
181        }
182    }
183}
184
185/// Cache options
186pub struct Opts {
187    pub cache_size: u64,
188    pub max_item_size: usize,
189    pub ttl: Duration,
190    pub max_ttl: Duration,
191    pub obey_cache_control: bool,
192    pub lock_timeout: Duration,
193    pub body_timeout: Duration,
194    pub xfetch_beta: f64,
195    pub methods: Vec<Method>,
196}
197
198impl Default for Opts {
199    fn default() -> Self {
200        Self {
201            cache_size: 128 * 1024 * 1024,
202            max_item_size: 16 * 1024 * 1024,
203            ttl: Duration::from_secs(10),
204            max_ttl: Duration::from_secs(86400),
205            obey_cache_control: false,
206            lock_timeout: Duration::from_secs(5),
207            body_timeout: Duration::from_secs(60),
208            xfetch_beta: 0.0,
209            methods: vec![Method::GET],
210        }
211    }
212}
213
214#[derive(Debug, PartialEq, Eq)]
215enum CacheControl {
216    NoCache,
217    MaxAge(Duration),
218}
219
220/// Tries to infer the caching TTL from the response headers
221fn infer_ttl<T>(req: &Response<T>) -> Option<CacheControl> {
222    // Extract the Cache-Control header & try to parse it as a string
223    let hdr = req
224        .headers()
225        .get(CACHE_CONTROL)
226        .and_then(|x| x.to_str().ok())?;
227
228    // Iterate over the key-value pairs (or just keys)
229    hdr.split(',').find_map(|x| {
230        let (k, v) = {
231            let mut split = x.split('=').map(|s| s.trim());
232            (split.next().unwrap(), split.next())
233        };
234
235        if ["no-cache", "no-store"].contains(&k) {
236            Some(CacheControl::NoCache)
237        } else if k == "max-age" {
238            let v = v.and_then(|x| x.parse::<u64>().ok());
239            if v == Some(0) {
240                Some(CacheControl::NoCache)
241            } else {
242                v.map(|x| CacheControl::MaxAge(Duration::from_secs(x)))
243            }
244        } else {
245            None
246        }
247    })
248}
249
250/// Extracts TTL from the Entry
251struct Expirer<K: KeyExtractor>(PhantomData<K>);
252
253impl<K: KeyExtractor> Expiry<K::Key, Arc<Entry>> for Expirer<K> {
254    fn expire_after_create(
255        &self,
256        _key: &K::Key,
257        value: &Arc<Entry>,
258        created_at: Instant,
259    ) -> Option<Duration> {
260        Some(value.expires - created_at)
261    }
262}
263
264/// Builds a cache using some overridable defaults
265pub struct CacheBuilder<K: KeyExtractor, B: Bypasser> {
266    key_extractor: K,
267    bypasser: Option<B>,
268    opts: Opts,
269    registry: Registry,
270}
271
272impl<K: KeyExtractor> CacheBuilder<K, NoopBypasser> {
273    /// Create new `CacheBuilder`
274    pub fn new(key_extractor: K) -> Self {
275        Self {
276            key_extractor,
277            bypasser: None,
278            opts: Opts::default(),
279            registry: Registry::new(),
280        }
281    }
282}
283
284impl<K: KeyExtractor, B: Bypasser> CacheBuilder<K, B> {
285    /// Create new `CacheBuilder` with a bypasser
286    pub fn new_with_bypasser(key_extractor: K, bypasser: B) -> Self {
287        Self {
288            key_extractor,
289            bypasser: Some(bypasser),
290            opts: Opts::default(),
291            registry: Registry::new(),
292        }
293    }
294
295    /// Sets the cache size. Default 128MB.
296    pub const fn cache_size(mut self, v: u64) -> Self {
297        self.opts.cache_size = v;
298        self
299    }
300
301    /// Sets the maximum entry size. Default 16MB.
302    pub const fn max_item_size(mut self, v: usize) -> Self {
303        self.opts.max_item_size = v;
304        self
305    }
306
307    /// Sets the default cache entry TTL. Default 10 sec.
308    pub const fn ttl(mut self, v: Duration) -> Self {
309        self.opts.ttl = v;
310        self
311    }
312
313    /// Sets the maximum cache entry TTL that can be overriden by `Cache-Control` header. Default 1 day.
314    pub const fn max_ttl(mut self, v: Duration) -> Self {
315        self.opts.max_ttl = v;
316        self
317    }
318
319    /// Sets the cache lock timeout. Default 5 sec.
320    pub const fn lock_timeout(mut self, v: Duration) -> Self {
321        self.opts.lock_timeout = v;
322        self
323    }
324
325    /// Sets the body reading timeout. Default 1 min.
326    pub const fn body_timeout(mut self, v: Duration) -> Self {
327        self.opts.body_timeout = v;
328        self
329    }
330
331    /// Sets the beta term of X-Fetch algorithm. Default 0.0
332    pub const fn xfetch_beta(mut self, v: f64) -> Self {
333        self.opts.xfetch_beta = v;
334        self
335    }
336
337    /// Sets cacheable methods. Defaults to only GET.
338    pub fn methods(mut self, v: &[Method]) -> Self {
339        self.opts.methods = v.into();
340        self
341    }
342
343    /// Whether to obey `Cache-Control` headers in the *response*. Defaults to false.
344    pub const fn obey_cache_control(mut self, v: bool) -> Self {
345        self.opts.obey_cache_control = v;
346        self
347    }
348
349    /// Sets the metrics registry to use.
350    pub fn registry(mut self, v: &Registry) -> Self {
351        self.registry = v.clone();
352        self
353    }
354
355    /// Try to build the cache from this builder
356    pub fn build(self) -> Result<Cache<K, B>, CacheError> {
357        Cache::new(self.opts, self.key_extractor, self.bypasser, &self.registry)
358    }
359}
360
361/// HTTP Cache
362pub struct Cache<K: KeyExtractor, B: Bypasser = NoopBypasser> {
363    store: MokaCache<K::Key, Arc<Entry>, RandomState>,
364    locks: MokaCache<K::Key, Arc<Mutex<()>>, RandomState>,
365    key_extractor: K,
366    bypasser: Option<B>,
367    metrics: Metrics,
368    opts: Opts,
369}
370
371fn weigh_entry<K: KeyExtractor>(_k: &K::Key, v: &Arc<Entry>) -> u32 {
372    let mut size = size_of::<K::Key>() + size_of::<Arc<Entry>>();
373
374    size += calc_headers_size(v.response.headers());
375    size += v.response.body().len();
376
377    size as u32
378}
379
380impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
381    /// Create new `Cache`
382    pub fn new(
383        opts: Opts,
384        key_extractor: K,
385        bypasser: Option<B>,
386        registry: &Registry,
387    ) -> Result<Self, CacheError> {
388        if opts.max_item_size as u64 >= opts.cache_size {
389            return Err(CacheError::Other(
390                "Cache item size should be less than whole cache size".into(),
391            ));
392        }
393
394        if opts.ttl > opts.max_ttl {
395            return Err(CacheError::Other("TTL should be <= max TTL".into()));
396        }
397
398        Ok(Self {
399            store: MokaCacheBuilder::new(opts.cache_size)
400                .expire_after(Expirer::<K>(PhantomData))
401                .weigher(weigh_entry::<K>)
402                .build_with_hasher(RandomState::default()),
403
404            // The params of the lock cache are somewhat arbitrary, maybe needs tuning
405            locks: MokaCacheBuilder::new(32768)
406                .time_to_idle(Duration::from_secs(60))
407                .build_with_hasher(RandomState::default()),
408
409            key_extractor,
410            bypasser,
411            metrics: Metrics::new(registry),
412
413            opts,
414        })
415    }
416
417    /// Looks up the given entry
418    pub fn get(&self, key: &K::Key, now: Instant, beta: f64) -> Option<Response> {
419        let val = self.store.get(key)?;
420
421        // Run x-fetch if configured and simulate the cache miss if we need to refresh the entry
422        if val.need_to_refresh(now, beta) {
423            self.metrics.x_fetch.inc();
424            return None;
425        }
426
427        let (parts, body) = val.response.clone().into_parts();
428        Some(Response::from_parts(parts, Body::from(body)))
429    }
430
431    /// Insert a new entry into the cache
432    pub fn insert(
433        &self,
434        key: K::Key,
435        now: Instant,
436        ttl: Duration,
437        delta: Duration,
438        response: Response<Bytes>,
439    ) {
440        self.metrics.ttl.observe(ttl.as_secs_f64());
441
442        self.store.insert(
443            key,
444            Arc::new(Entry {
445                response,
446                delta: delta.as_secs_f64(),
447                expires: now + ttl,
448            }),
449        );
450    }
451
452    /// Process the HTTP request
453    pub async fn process_request(
454        &self,
455        request: Request,
456        next: Next,
457    ) -> Result<Response, CacheError> {
458        let now = Instant::now();
459        let (cache_status, response) = self.process_inner(now, request, next).await?;
460
461        // Record metrics
462        let cache_status_str: &'static str = (&cache_status).into();
463        let cache_bypass_reason_str: &'static str = match cache_status.clone() {
464            CacheStatus::Bypass(v) => v.into_str(),
465            _ => "none",
466        };
467
468        let labels = &[cache_status_str, cache_bypass_reason_str];
469
470        self.metrics.requests_count.with_label_values(labels).inc();
471        self.metrics
472            .requests_duration
473            .with_label_values(labels)
474            .observe(now.elapsed().as_secs_f64());
475
476        Ok(cache_status.with_response(response))
477    }
478
479    async fn process_inner(
480        &self,
481        now: Instant,
482        request: Request,
483        next: Next,
484    ) -> Result<(CacheStatus<B::BypassReason>, Response), CacheError> {
485        // Check if we have bypasser configured
486        if let Some(b) = &self.bypasser {
487            // Run it
488            if let Ok(v) = b.bypass(&request) {
489                // If it decided to bypass - return the custom reason
490                if let Some(r) = v {
491                    return Ok((
492                        CacheStatus::Bypass(CacheBypassReason::Custom(r)),
493                        next.run(request).await,
494                    ));
495                }
496            } else {
497                return Ok((
498                    CacheStatus::Bypass(CacheBypassReason::UnableToRunBypasser),
499                    next.run(request).await,
500                ));
501            }
502        }
503
504        // Check the method
505        if !self.opts.methods.contains(request.method()) {
506            return Ok((
507                CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable),
508                next.run(request).await,
509            ));
510        }
511
512        // Use cached response if found
513        let Ok(key) = self.key_extractor.extract(&request) else {
514            return Ok((
515                CacheStatus::Bypass(CacheBypassReason::UnableToExtractKey),
516                next.run(request).await,
517            ));
518        };
519
520        if let Some(v) = self.get(&key, now, self.opts.xfetch_beta) {
521            return Ok((CacheStatus::Hit, v));
522        }
523
524        // Get synchronization lock to handle parallel requests.
525        let lock = self
526            .locks
527            .get_with_by_ref(&key, || Arc::new(Mutex::new(())));
528
529        let mut lock_obtained = false;
530        select! {
531            // Only one parallel request should execute the response and populate the cache.
532            // Other requests will wait for the lock to be released and get results from the cache.
533            _ = lock.lock() => {
534                lock_obtained = true;
535            }
536
537            // We proceed with the request as is if takes too long to get the lock
538            _ = sleep(self.opts.lock_timeout) => {}
539        }
540
541        // Record prometheus metrics for the time spent waiting for the lock.
542        self.metrics
543            .lock_await
544            .with_label_values(&[if lock_obtained { "yes" } else { "no" }])
545            .observe(now.elapsed().as_secs_f64());
546
547        // Check again the cache in case some other request filled it
548        // while we were waiting for the lock
549        if let Some(v) = self.get(&key, now, 0.0) {
550            return Ok((CacheStatus::Hit, v));
551        }
552
553        // Otherwise pass the request forward
554        let now = Instant::now();
555        Ok(match self.pass_request(request, next).await? {
556            // If the body was fetched - cache it
557            ResponseType::Fetched(v, ttl) => {
558                let delta = now.elapsed();
559                self.insert(key, now + delta, ttl, delta, v.clone());
560
561                let (mut parts, body) = v.into_parts();
562                parts.headers.insert(X_CACHE_TTL, ttl.as_secs().into());
563                let response = Response::from_parts(parts, Body::from(body));
564                (CacheStatus::Miss, response)
565            }
566
567            // Otherwise just pass it up
568            ResponseType::Streamed(v, reason) => (CacheStatus::Bypass(reason), v),
569        })
570    }
571
572    // Passes the request down the line and conditionally fetches the response body
573    async fn pass_request(
574        &self,
575        request: Request,
576        next: Next,
577    ) -> Result<ResponseType<B::BypassReason>, CacheError> {
578        // Execute the response & get the headers
579        let response = next.run(request).await;
580
581        // Do not cache non-2xx responses
582        if !response.status().is_success() {
583            return Ok(ResponseType::Streamed(
584                response,
585                CacheBypassReason::HTTPError,
586            ));
587        }
588
589        // Extract content length from the response header if there's one
590        let body_size = response.body().size_hint().exact().map(|x| x as usize);
591
592        // Do not cache responses that have no known size (probably streaming etc)
593        let Some(body_size) = body_size else {
594            return Ok(ResponseType::Streamed(
595                response,
596                CacheBypassReason::SizeUnknown,
597            ));
598        };
599
600        // Do not cache items larger than configured
601        if body_size > self.opts.max_item_size {
602            return Ok(ResponseType::Streamed(
603                response,
604                CacheBypassReason::BodyTooBig,
605            ));
606        }
607
608        // Infer the TTL if requested to obey Cache-Control headers
609        let ttl = if self.opts.obey_cache_control {
610            let ttl = infer_ttl(&response);
611
612            match ttl {
613                // Do not cache if we're asked not to
614                Some(CacheControl::NoCache) => {
615                    return Ok(ResponseType::Streamed(
616                        response,
617                        CacheBypassReason::CacheControl,
618                    ));
619                }
620
621                // Use TTL from max-age capping it to max_ttl
622                Some(CacheControl::MaxAge(v)) => v.min(self.opts.max_ttl),
623
624                // Otherwise use default
625                None => self.opts.ttl,
626            }
627        } else {
628            self.opts.ttl
629        };
630
631        // Read the response body into a buffer
632        let (parts, body) = response.into_parts();
633        let body = buffer_body(body, body_size, self.opts.body_timeout)
634            .await
635            .map_err(|e| match e {
636                HttpError::BodyTooBig => CacheError::FetchBodyTooBig,
637                HttpError::BodyTimedOut => CacheError::FetchBodyTimeout,
638                _ => CacheError::FetchBody(e.to_string()),
639            })?;
640
641        Ok(ResponseType::Fetched(
642            Response::from_parts(parts, body),
643            ttl,
644        ))
645    }
646}
647
648#[async_trait]
649impl<K: KeyExtractor, B: Bypasser> Run for Cache<K, B> {
650    async fn run(&self, _: CancellationToken) -> Result<(), anyhow::Error> {
651        self.store.run_pending_tasks();
652        self.metrics.memory.set(self.store.weighted_size() as i64);
653        self.metrics.entries.set(self.store.entry_count() as i64);
654        Ok(())
655    }
656}
657
658#[cfg(test)]
659impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
660    pub fn housekeep(&self) {
661        self.store.run_pending_tasks();
662        self.locks.run_pending_tasks();
663    }
664
665    pub fn size(&self) -> u64 {
666        self.store.weighted_size()
667    }
668
669    #[allow(clippy::len_without_is_empty)]
670    pub fn len(&self) -> u64 {
671        self.store.entry_count()
672    }
673
674    pub fn clear(&self) {
675        self.store.invalidate_all();
676        self.locks.invalidate_all();
677        self.housekeep();
678    }
679}
680
681/// Key extractor that is keyed by URI and a `Range` header
682#[derive(Clone, Debug)]
683pub struct KeyExtractorUriRange;
684
685impl KeyExtractor for KeyExtractorUriRange {
686    type Key = [u8; 20];
687
688    fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
689        let authority = extract_authority(request)
690            .ok_or_else(|| CacheError::ExtractKey("no authority found".into()))?
691            .as_bytes();
692        let paq = request
693            .uri()
694            .path_and_query()
695            .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
696            .as_str()
697            .as_bytes();
698
699        // Compute a composite hash
700        let mut hash = Sha1::new().chain_update(authority).chain_update(paq);
701        if let Some(v) = request.headers().get(RANGE) {
702            hash = hash.chain_update(v.as_bytes());
703        }
704
705        Ok(hash.finalize().into())
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use crate::hval;
712
713    use super::*;
714
715    use axum::{
716        Router,
717        body::to_bytes,
718        extract::State,
719        middleware::from_fn_with_state,
720        response::IntoResponse,
721        routing::{get, post},
722    };
723    use http::{Request, Response, StatusCode, Uri};
724    use sha1::Digest;
725    use tower::{Service, ServiceExt};
726
727    #[derive(Clone, Debug)]
728    pub struct KeyExtractorTest;
729
730    impl KeyExtractor for KeyExtractorTest {
731        type Key = [u8; 20];
732
733        fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
734            let paq = request
735                .uri()
736                .path_and_query()
737                .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
738                .as_str()
739                .as_bytes();
740
741            let hash: [u8; 20] = sha1::Sha1::new().chain_update(paq).finalize().into();
742            Ok(hash)
743        }
744    }
745
746    const MAX_ITEM_SIZE: usize = 1024;
747    const MAX_CACHE_SIZE: u64 = 32768;
748    const PROXY_LOCK_TIMEOUT: Duration = Duration::from_secs(1);
749
750    async fn dispatch_get_request(router: &mut Router, uri: String) -> Option<CacheStatus> {
751        let req = Request::get(uri).body(Body::from("")).unwrap();
752        let result = router.call(req).await.unwrap();
753        assert_eq!(result.status(), StatusCode::OK);
754        result.extensions().get::<CacheStatus>().cloned()
755    }
756
757    async fn handler(_request: Request<Body>) -> impl IntoResponse {
758        "test_body"
759    }
760
761    async fn handler_proxy_cache_lock(request: Request<Body>) -> impl IntoResponse {
762        if request.uri().path().contains("slow_response") {
763            sleep(2 * PROXY_LOCK_TIMEOUT).await;
764        }
765
766        "test_body"
767    }
768
769    async fn handler_too_big(_request: Request<Body>) -> impl IntoResponse {
770        "a".repeat(MAX_ITEM_SIZE + 1)
771    }
772
773    async fn handler_cache_control_max_age_1d(_request: Request<Body>) -> impl IntoResponse {
774        [(CACHE_CONTROL, "max-age=86400")]
775    }
776
777    async fn handler_cache_control_max_age_7d(_request: Request<Body>) -> impl IntoResponse {
778        [(CACHE_CONTROL, "max-age=604800")]
779    }
780
781    async fn handler_cache_control_no_cache(_request: Request<Body>) -> impl IntoResponse {
782        [(CACHE_CONTROL, "no-cache")]
783    }
784
785    async fn handler_cache_control_no_store(_request: Request<Body>) -> impl IntoResponse {
786        [(CACHE_CONTROL, "no-store")]
787    }
788
789    async fn middleware(
790        State(cache): State<Arc<Cache<KeyExtractorTest>>>,
791        request: Request<Body>,
792        next: Next,
793    ) -> impl IntoResponse {
794        cache
795            .process_request(request, next)
796            .await
797            .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
798    }
799
800    #[test]
801    fn test_bypass_reason_serialize() {
802        #[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
803        #[strum(serialize_all = "snake_case")]
804        enum CustomReasonTest {
805            Bar,
806        }
807        impl CustomBypassReason for CustomReasonTest {}
808
809        let a: CacheBypassReason<CustomReasonTest> =
810            CacheBypassReason::Custom(CustomReasonTest::Bar);
811        let txt = a.into_str();
812        assert_eq!(txt, "bar");
813
814        let a: CacheBypassReason<CustomReasonTest> = CacheBypassReason::BodyTooBig;
815        let txt = a.into_str();
816        assert_eq!(txt, "body_too_big");
817    }
818
819    #[test]
820    fn test_key_extractor_uri_range() {
821        let x = KeyExtractorUriRange;
822
823        // Baseline
824        let mut req = Request::new("foo");
825        *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=1");
826        let key1 = x.extract(&req).unwrap();
827
828        // Make sure that changing authority/path/query changes the key
829        let mut req = Request::new("foo");
830        *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=2");
831        let key2 = x.extract(&req).unwrap();
832        assert_ne!(key1, key2);
833
834        let mut req = Request::new("foo");
835        *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/ba?abc=1");
836        let key2 = x.extract(&req).unwrap();
837        assert_ne!(key1, key2);
838
839        let mut req = Request::new("foo");
840        *req.uri_mut() = Uri::from_static("http://foo.bar.ba:80/foo/bar?abc=1");
841        let key2 = x.extract(&req).unwrap();
842        assert_ne!(key1, key2);
843
844        // Make sure that changing schema doesn't affect the key
845        let mut req = Request::new("foo");
846        *req.uri_mut() = Uri::from_static("https://foo.bar.baz:80/foo/bar?abc=1");
847        let key2 = x.extract(&req).unwrap();
848        assert_eq!(key1, key2);
849
850        // Make sure that adding Range header changes the key
851        let mut req = Request::new("foo");
852        *req.uri_mut() = Uri::from_static("http://foo.bar.bar:80/foo/bar?abc=1");
853        (*req.headers_mut()).insert(RANGE, hval!("1000-2000"));
854        let key2 = x.extract(&req).unwrap();
855        assert_ne!(key1, key2);
856    }
857
858    #[test]
859    fn test_infer_ttl() {
860        let mut req = Response::new(());
861
862        assert_eq!(infer_ttl(&req), None);
863
864        // Don't cache
865        req.headers_mut().insert(CACHE_CONTROL, hval!("no-cache"));
866        assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
867
868        req.headers_mut().insert(CACHE_CONTROL, hval!("no-store"));
869        assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
870
871        req.headers_mut()
872            .insert(CACHE_CONTROL, hval!("no-store, no-cache"));
873        assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
874
875        // Order matters
876        req.headers_mut()
877            .insert(CACHE_CONTROL, hval!("no-store, no-cache, max-age=1"));
878        assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
879
880        req.headers_mut()
881            .insert(CACHE_CONTROL, hval!("max-age=1, no-store, no-cache"));
882        assert_eq!(
883            infer_ttl(&req),
884            Some(CacheControl::MaxAge(Duration::from_secs(1)))
885        );
886
887        // Max-age
888        req.headers_mut()
889            .insert(CACHE_CONTROL, hval!("max-age=86400"));
890        assert_eq!(
891            infer_ttl(&req),
892            Some(CacheControl::MaxAge(Duration::from_secs(86400)))
893        );
894        req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=0"));
895        assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
896
897        req.headers_mut()
898            .insert(CACHE_CONTROL, hval!("max-age=foo"));
899        assert_eq!(infer_ttl(&req), None);
900
901        req.headers_mut().insert(CACHE_CONTROL, hval!("max-age="));
902        assert_eq!(infer_ttl(&req), None);
903
904        req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=-1"));
905        assert_eq!(infer_ttl(&req), None);
906
907        // Empty
908        req.headers_mut().insert(CACHE_CONTROL, hval!(""));
909        assert_eq!(infer_ttl(&req), None);
910
911        // Broken
912        req.headers_mut()
913            .insert(CACHE_CONTROL, hval!(", =foobar, "));
914        assert_eq!(infer_ttl(&req), None);
915    }
916
917    #[test]
918    fn test_cache_creation_errors() {
919        let cache = CacheBuilder::new(KeyExtractorTest)
920            .cache_size(1)
921            .max_item_size(2)
922            .build();
923        assert!(cache.is_err());
924
925        let cache = CacheBuilder::new(KeyExtractorTest)
926            .ttl(Duration::from_secs(2))
927            .max_ttl(Duration::from_secs(1))
928            .build();
929        assert!(cache.is_err());
930    }
931
932    #[tokio::test]
933    async fn test_cache_bypass() {
934        let cache = Arc::new(
935            CacheBuilder::new(KeyExtractorTest)
936                .max_item_size(MAX_ITEM_SIZE)
937                .build()
938                .unwrap(),
939        );
940
941        let mut app = Router::new()
942            .route("/", post(handler))
943            .route("/", get(handler))
944            .route("/too_big", get(handler_too_big))
945            .layer(from_fn_with_state(Arc::clone(&cache), middleware));
946
947        // Test only GET requests are cached.
948        let req = Request::post("/").body(Body::from("")).unwrap();
949        let result = app.call(req).await.unwrap();
950        assert_eq!(result.status(), StatusCode::OK);
951        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
952        assert_eq!(cache.len(), 0);
953        assert_eq!(
954            cache_status,
955            CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable)
956        );
957
958        // Test non-2xx response are not cached
959        let req = Request::get("/non_existing_path")
960            .body(Body::from("foobar"))
961            .unwrap();
962        let result = app.call(req).await.unwrap();
963        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
964        assert_eq!(result.status(), StatusCode::NOT_FOUND);
965        assert_eq!(
966            cache_status,
967            CacheStatus::Bypass(CacheBypassReason::HTTPError)
968        );
969        assert_eq!(cache.len(), 0);
970
971        // Test body too big
972        let req = Request::get("/too_big").body(Body::from("foobar")).unwrap();
973        let result = app.call(req).await.unwrap();
974        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
975        assert_eq!(
976            cache_status,
977            CacheStatus::Bypass(CacheBypassReason::BodyTooBig)
978        );
979        assert_eq!(result.status(), StatusCode::OK);
980        assert_eq!(cache.len(), 0);
981    }
982
983    #[tokio::test]
984    async fn test_cache_hit() {
985        let ttl = Duration::from_millis(1500);
986
987        let cache = Arc::new(
988            CacheBuilder::new(KeyExtractorTest)
989                .cache_size(MAX_CACHE_SIZE)
990                .max_item_size(MAX_ITEM_SIZE)
991                .ttl(ttl)
992                .build()
993                .unwrap(),
994        );
995
996        let mut app = Router::new()
997            .route("/{key}", get(handler))
998            .layer(from_fn_with_state(Arc::clone(&cache), middleware));
999
1000        // First request doesn't hit the cache, but is stored in the cache
1001        let req = Request::get("/1").body(Body::from("")).unwrap();
1002        let result = app.call(req).await.unwrap();
1003        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1004        assert_eq!(result.status(), StatusCode::OK);
1005        assert_eq!(cache_status, CacheStatus::Miss);
1006        cache.housekeep();
1007        assert_eq!(cache.len(), 1);
1008
1009        // Next request doesn't hit the cache, but is stored in the cache
1010        let req = Request::get("/2").body(Body::from("")).unwrap();
1011        let result = app.call(req).await.unwrap();
1012        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1013        assert_eq!(result.status(), StatusCode::OK);
1014        assert_eq!(cache_status, CacheStatus::Miss);
1015        cache.housekeep();
1016        assert_eq!(cache.len(), 2);
1017
1018        // Next request hits the cache
1019        let req = Request::get("/1").body(Body::from("")).unwrap();
1020        let result = app.call(req).await.unwrap();
1021        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1022        assert_eq!(result.status(), StatusCode::OK);
1023        assert_eq!(cache_status, CacheStatus::Hit);
1024        let (_, body) = result.into_parts();
1025        let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1026        let body = String::from_utf8_lossy(&body);
1027        assert_eq!("test_body", body);
1028        cache.housekeep();
1029        assert_eq!(cache.len(), 2);
1030
1031        // Next request hits again
1032        let req = Request::get("/2").body(Body::from("")).unwrap();
1033        let result = app.call(req).await.unwrap();
1034        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1035        assert_eq!(result.status(), StatusCode::OK);
1036        assert_eq!(cache_status, CacheStatus::Hit);
1037        let (_, body) = result.into_parts();
1038        let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1039        let body = String::from_utf8_lossy(&body);
1040        assert_eq!("test_body", body);
1041        cache.housekeep();
1042        assert_eq!(cache.len(), 2);
1043
1044        // After ttl, request doesn't hit the cache anymore
1045        sleep(ttl + Duration::from_millis(300)).await;
1046        cache.housekeep();
1047        let req = Request::get("/1").body(Body::from("")).unwrap();
1048        let result = app.call(req).await.unwrap();
1049        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1050        assert_eq!(result.status(), StatusCode::OK);
1051        assert_eq!(cache_status, CacheStatus::Miss);
1052
1053        // Before cache_size limit is reached all requests should be stored in cache.
1054        cache.clear();
1055        let req_count = 50;
1056        // First dispatch round, all requests miss cache.
1057        for idx in 0..req_count {
1058            let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1059            assert_eq!(status.unwrap(), CacheStatus::Miss);
1060        }
1061        // Second dispatch round, all requests hit the cache.
1062        for idx in 0..req_count {
1063            let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1064            assert_eq!(status.unwrap(), CacheStatus::Hit);
1065        }
1066
1067        // Once cache_size limit is reached some requests should be evicted.
1068        cache.clear();
1069        let req_count = 500;
1070        // First dispatch round, all cache misses.
1071        for idx in 0..req_count {
1072            let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1073            assert_eq!(status.unwrap(), CacheStatus::Miss);
1074        }
1075
1076        // Second dispatch round, some requests hit the cache, some don't
1077        let mut count_misses = 0;
1078        let mut count_hits = 0;
1079        for idx in 0..req_count {
1080            let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1081            if status == Some(CacheStatus::Miss) {
1082                count_misses += 1;
1083            } else if status == Some(CacheStatus::Hit) {
1084                count_hits += 1;
1085            }
1086        }
1087        assert!(count_misses > 0);
1088        assert!(count_hits > 0);
1089        cache.housekeep();
1090        let entry_size = cache.size() / cache.len();
1091
1092        // Make sure cache size limit was reached.
1093        // Check that adding one more entry to the cache would overflow its max capacity.
1094        assert!(MAX_CACHE_SIZE > cache.size());
1095        assert!(MAX_CACHE_SIZE < cache.size() + entry_size);
1096    }
1097
1098    #[tokio::test]
1099    async fn test_cache_control() {
1100        let cache = Arc::new(
1101            CacheBuilder::new(KeyExtractorTest)
1102                .obey_cache_control(true)
1103                .build()
1104                .unwrap(),
1105        );
1106
1107        let mut app = Router::new()
1108            .route("/", get(handler))
1109            .route(
1110                "/cache_control_no_store",
1111                get(handler_cache_control_no_store),
1112            )
1113            .route(
1114                "/cache_control_no_cache",
1115                get(handler_cache_control_no_cache),
1116            )
1117            .route(
1118                "/cache_control_max_age_1d",
1119                get(handler_cache_control_max_age_1d),
1120            )
1121            .route(
1122                "/cache_control_max_age_7d",
1123                get(handler_cache_control_max_age_7d),
1124            )
1125            .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1126
1127        // Cache-Control no-cache
1128        let req = Request::get("/cache_control_no_cache")
1129            .body(Body::from("foobar"))
1130            .unwrap();
1131        let result = app.call(req).await.unwrap();
1132        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1133        assert_eq!(
1134            cache_status,
1135            CacheStatus::Bypass(CacheBypassReason::CacheControl)
1136        );
1137        assert_eq!(result.status(), StatusCode::OK);
1138        cache.housekeep();
1139        assert_eq!(cache.len(), 0);
1140
1141        // Cache-Control no-store
1142        let req = Request::get("/cache_control_no_store")
1143            .body(Body::from("foobar"))
1144            .unwrap();
1145        let result = app.call(req).await.unwrap();
1146        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1147        assert_eq!(
1148            cache_status,
1149            CacheStatus::Bypass(CacheBypassReason::CacheControl)
1150        );
1151        assert_eq!(result.status(), StatusCode::OK);
1152        cache.housekeep();
1153        assert_eq!(cache.len(), 0);
1154
1155        // Cache-Control max-age 1 day
1156        let req = Request::get("/cache_control_max_age_1d")
1157            .body(Body::from("foobar"))
1158            .unwrap();
1159        let result = app.call(req).await.unwrap();
1160        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1161        let ttl = result
1162            .headers()
1163            .get(X_CACHE_TTL)
1164            .unwrap()
1165            .to_str()
1166            .unwrap()
1167            .parse::<u64>()
1168            .unwrap();
1169        assert_eq!(cache_status, CacheStatus::Miss);
1170        assert_eq!(ttl, 86400);
1171        assert_eq!(result.status(), StatusCode::OK);
1172        cache.housekeep();
1173        assert_eq!(cache.len(), 1);
1174
1175        // Cache-Control max-age 7 days should still be capped to 1 day
1176        let req = Request::get("/cache_control_max_age_7d")
1177            .body(Body::from("foobar"))
1178            .unwrap();
1179        let result = app.call(req).await.unwrap();
1180        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1181        let ttl = result
1182            .headers()
1183            .get(X_CACHE_TTL)
1184            .unwrap()
1185            .to_str()
1186            .unwrap()
1187            .parse::<u64>()
1188            .unwrap();
1189        assert_eq!(cache_status, CacheStatus::Miss);
1190        assert_eq!(ttl, 86400);
1191        assert_eq!(result.status(), StatusCode::OK);
1192        cache.housekeep();
1193        assert_eq!(cache.len(), 2);
1194
1195        // w/o Cache-Control we should get a default 10s TTL
1196        let req = Request::get("/").body(Body::from("foobar")).unwrap();
1197        let result = app.call(req).await.unwrap();
1198        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1199        let ttl = result
1200            .headers()
1201            .get(X_CACHE_TTL)
1202            .unwrap()
1203            .to_str()
1204            .unwrap()
1205            .parse::<u64>()
1206            .unwrap();
1207        assert_eq!(cache_status, CacheStatus::Miss);
1208        assert_eq!(ttl, 10);
1209        assert_eq!(result.status(), StatusCode::OK);
1210        cache.housekeep();
1211        assert_eq!(cache.len(), 3);
1212
1213        // Test when we do not obey
1214        let cache = Arc::new(
1215            CacheBuilder::new(KeyExtractorTest)
1216                .obey_cache_control(false)
1217                .build()
1218                .unwrap(),
1219        );
1220
1221        let mut app = Router::new()
1222            .route("/", get(handler))
1223            .route(
1224                "/cache_control_no_store",
1225                get(handler_cache_control_no_store),
1226            )
1227            .route(
1228                "/cache_control_no_cache",
1229                get(handler_cache_control_no_cache),
1230            )
1231            .route(
1232                "/cache_control_max_age_1d",
1233                get(handler_cache_control_max_age_1d),
1234            )
1235            .route(
1236                "/cache_control_max_age_7d",
1237                get(handler_cache_control_max_age_7d),
1238            )
1239            .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1240
1241        // Cache-Control no-cache
1242        let req = Request::get("/cache_control_no_cache")
1243            .body(Body::from("foobar"))
1244            .unwrap();
1245        let result = app.call(req).await.unwrap();
1246        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1247        assert_eq!(cache_status, CacheStatus::Miss);
1248        assert_eq!(result.status(), StatusCode::OK);
1249        cache.housekeep();
1250        assert_eq!(cache.len(), 1);
1251
1252        // Cache-Control no-store
1253        let req = Request::get("/cache_control_no_store")
1254            .body(Body::from("foobar"))
1255            .unwrap();
1256        let result = app.call(req).await.unwrap();
1257        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1258        assert_eq!(cache_status, CacheStatus::Miss,);
1259        assert_eq!(result.status(), StatusCode::OK);
1260        cache.housekeep();
1261        assert_eq!(cache.len(), 2);
1262
1263        // Cache-Control max-age 1 day
1264        let req = Request::get("/cache_control_max_age_1d")
1265            .body(Body::from("foobar"))
1266            .unwrap();
1267        let result = app.call(req).await.unwrap();
1268        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1269        let ttl = result
1270            .headers()
1271            .get(X_CACHE_TTL)
1272            .unwrap()
1273            .to_str()
1274            .unwrap()
1275            .parse::<u64>()
1276            .unwrap();
1277        assert_eq!(cache_status, CacheStatus::Miss);
1278        assert_eq!(ttl, 10);
1279        assert_eq!(result.status(), StatusCode::OK);
1280        cache.housekeep();
1281        assert_eq!(cache.len(), 3);
1282
1283        // w/o Cache-Control we should get a default 10s TTL
1284        let req = Request::get("/").body(Body::from("foobar")).unwrap();
1285        let result = app.call(req).await.unwrap();
1286        let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1287        let ttl = result
1288            .headers()
1289            .get(X_CACHE_TTL)
1290            .unwrap()
1291            .to_str()
1292            .unwrap()
1293            .parse::<u64>()
1294            .unwrap();
1295        assert_eq!(cache_status, CacheStatus::Miss);
1296        assert_eq!(ttl, 10);
1297        assert_eq!(result.status(), StatusCode::OK);
1298        cache.housekeep();
1299        assert_eq!(cache.len(), 4);
1300    }
1301
1302    #[tokio::test]
1303    async fn test_proxy_cache_lock() {
1304        let cache = Arc::new(
1305            CacheBuilder::new(KeyExtractorTest)
1306                .lock_timeout(PROXY_LOCK_TIMEOUT)
1307                .build()
1308                .unwrap(),
1309        );
1310
1311        let app = Router::new()
1312            .route("/{key}", get(handler_proxy_cache_lock))
1313            .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1314
1315        let req_count = 50;
1316        // Expected cache misses/hits for fast/slow responses, respectively.
1317        let expected_misses = [1, req_count];
1318        let expected_hits = [req_count - 1, 0];
1319        for (idx, uri) in ["/fast_response", "/slow_response"].iter().enumerate() {
1320            let mut tasks = vec![];
1321            // Dispatch requests simultaneously.
1322            for _ in 0..req_count {
1323                let app = app.clone();
1324                tasks.push(tokio::spawn(async move {
1325                    let req = Request::get(*uri).body(Body::from("")).unwrap();
1326                    let result = app.oneshot(req).await.unwrap();
1327                    assert_eq!(result.status(), StatusCode::OK);
1328                    result.extensions().get::<CacheStatus>().cloned()
1329                }));
1330            }
1331            let mut count_hits = 0;
1332            let mut count_misses = 0;
1333            for task in tasks {
1334                task.await
1335                    .map(|res| match res {
1336                        Some(CacheStatus::Hit) => count_hits += 1,
1337                        Some(CacheStatus::Miss) => count_misses += 1,
1338                        _ => panic!("Unexpected cache status"),
1339                    })
1340                    .expect("failed to complete task");
1341            }
1342            assert_eq!(count_hits, expected_hits[idx]);
1343            assert_eq!(count_misses, expected_misses[idx]);
1344            cache.housekeep();
1345            cache.clear();
1346        }
1347    }
1348
1349    #[test]
1350    fn test_xfetch() {
1351        let now = Instant::now();
1352        let reqs = 10000;
1353
1354        let entry = Entry {
1355            response: Response::builder().body(Bytes::new()).unwrap(),
1356            delta: 0.5,
1357            expires: now + Duration::from_secs(60),
1358        };
1359
1360        // Check close to expiration
1361        let now2 = now + Duration::from_secs(58);
1362        let mut refresh = 0;
1363        for _ in 0..reqs {
1364            if entry.need_to_refresh(now2, 1.5) {
1365                refresh += 1;
1366            }
1367        }
1368
1369        assert!(refresh > 550 && refresh < 800);
1370
1371        // Check mid-expiration with small beta
1372        let now2 = now + Duration::from_secs(30);
1373        let mut refresh = 0;
1374        for _ in 0..reqs {
1375            if entry.need_to_refresh(now2, 1.0) {
1376                refresh += 1;
1377            }
1378        }
1379
1380        assert_eq!(refresh, 0);
1381
1382        // Check mid-expiration with high beta
1383        let now2 = now + Duration::from_secs(30);
1384        let mut refresh = 0;
1385        for _ in 0..reqs {
1386            if entry.need_to_refresh(now2, 10.0) {
1387                refresh += 1;
1388            }
1389        }
1390
1391        assert!(refresh > 9);
1392    }
1393}