1use std::{
2 fmt::Debug,
3 marker::PhantomData,
4 mem::size_of,
5 sync::Arc,
6 time::{Duration, Instant},
7};
8
9use ahash::RandomState;
10use async_trait::async_trait;
11use axum::{body::Body, extract::Request, middleware::Next, response::Response};
12use bytes::Bytes;
13use http::{
14 Method,
15 header::{CACHE_CONTROL, RANGE},
16};
17use http_body::Body as _;
18use ic_bn_lib_common::{
19 traits::{
20 Run,
21 http::{Bypasser, CustomBypassReason, KeyExtractor},
22 },
23 types::http::{CacheBypassReason, CacheError, Error as HttpError},
24};
25use moka::{
26 Expiry,
27 sync::{Cache as MokaCache, CacheBuilder as MokaCacheBuilder},
28};
29use prometheus::{
30 Counter, CounterVec, Histogram, HistogramVec, IntGauge, Registry,
31 register_counter_vec_with_registry, register_counter_with_registry,
32 register_histogram_vec_with_registry, register_histogram_with_registry,
33 register_int_gauge_with_registry,
34};
35use sha1::{Digest, Sha1};
36use strum_macros::{Display, IntoStaticStr};
37use tokio::{select, sync::Mutex, time::sleep};
38use tokio_util::sync::CancellationToken;
39
40use super::{body::buffer_body, calc_headers_size, extract_authority};
41use crate::http::headers::X_CACHE_TTL;
42
43#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
44pub enum CustomBypassReasonDummy {}
45impl CustomBypassReason for CustomBypassReasonDummy {}
46
47#[derive(Debug, Clone, Display, PartialEq, Eq, Default, IntoStaticStr)]
49#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
50pub enum CacheStatus<R: CustomBypassReason = CustomBypassReasonDummy> {
51 #[default]
52 Disabled,
53 Bypass(CacheBypassReason<R>),
54 Hit,
55 Miss,
56}
57
58impl<B: CustomBypassReason> CacheStatus<B> {
59 pub fn with_response<T>(self, mut resp: Response<T>) -> Response<T> {
61 resp.extensions_mut().insert(self);
62 resp
63 }
64}
65
66enum ResponseType<R: CustomBypassReason> {
67 Fetched(Response<Bytes>, Duration),
68 Streamed(Response, CacheBypassReason<R>),
69}
70
71#[derive(Clone)]
73struct Entry {
74 response: Response<Bytes>,
75 delta: f64,
78 expires: Instant,
79}
80
81impl Entry {
82 fn need_to_refresh(&self, now: Instant, beta: f64) -> bool {
86 if beta == 0.0 {
88 return false;
89 }
90
91 let rnd = rand::random::<f64>();
92 let xfetch = -(self.delta * beta * rnd.ln());
93 let ttl_left = (self.expires - now).as_secs_f64();
94
95 xfetch > ttl_left
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct NoopBypasser;
102
103impl Bypasser for NoopBypasser {
104 type BypassReason = CustomBypassReasonDummy;
105
106 fn bypass<T>(&self, _req: &Request<T>) -> Result<Option<Self::BypassReason>, CacheError> {
107 Ok(None)
108 }
109}
110
111#[derive(Clone)]
113pub struct Metrics {
114 lock_await: HistogramVec,
115 requests_count: CounterVec,
116 requests_duration: HistogramVec,
117 ttl: Histogram,
118 x_fetch: Counter,
119 memory: IntGauge,
120 entries: IntGauge,
121}
122
123impl Metrics {
124 pub fn new(registry: &Registry) -> Self {
126 let lbls = &["cache_status", "cache_bypass_reason"];
127
128 Self {
129 lock_await: register_histogram_vec_with_registry!(
130 "cache_proxy_lock_await",
131 "Time spent waiting for the proxy cache lock",
132 &["lock_obtained"],
133 registry,
134 )
135 .unwrap(),
136
137 requests_count: register_counter_vec_with_registry!(
138 "cache_requests_count",
139 "Cache requests count",
140 lbls,
141 registry,
142 )
143 .unwrap(),
144
145 requests_duration: register_histogram_vec_with_registry!(
146 "cache_requests_duration",
147 "Time it took to execute the request",
148 lbls,
149 registry,
150 )
151 .unwrap(),
152
153 ttl: register_histogram_with_registry!(
154 "cache_ttl",
155 "TTL that was set when storing the response",
156 vec![1.0, 10.0, 100.0, 1000.0, 10000.0, 86400.0],
157 registry,
158 )
159 .unwrap(),
160
161 x_fetch: register_counter_with_registry!(
162 "cache_xfetch_count",
163 "Number of requests that x-fetch refreshed",
164 registry,
165 )
166 .unwrap(),
167
168 memory: register_int_gauge_with_registry!(
169 "cache_memory",
170 "Memory usage by the cache in bytes",
171 registry,
172 )
173 .unwrap(),
174
175 entries: register_int_gauge_with_registry!(
176 "cache_entries",
177 "Count of entries in the cache",
178 registry,
179 )
180 .unwrap(),
181 }
182 }
183}
184
185pub struct Opts {
187 pub cache_size: u64,
188 pub max_item_size: usize,
189 pub ttl: Duration,
190 pub max_ttl: Duration,
191 pub obey_cache_control: bool,
192 pub lock_timeout: Duration,
193 pub body_timeout: Duration,
194 pub xfetch_beta: f64,
195 pub methods: Vec<Method>,
196}
197
198impl Default for Opts {
199 fn default() -> Self {
200 Self {
201 cache_size: 128 * 1024 * 1024,
202 max_item_size: 16 * 1024 * 1024,
203 ttl: Duration::from_secs(10),
204 max_ttl: Duration::from_secs(86400),
205 obey_cache_control: false,
206 lock_timeout: Duration::from_secs(5),
207 body_timeout: Duration::from_secs(60),
208 xfetch_beta: 0.0,
209 methods: vec![Method::GET],
210 }
211 }
212}
213
214#[derive(Debug, PartialEq, Eq)]
215enum CacheControl {
216 NoCache,
217 MaxAge(Duration),
218}
219
220fn infer_ttl<T>(req: &Response<T>) -> Option<CacheControl> {
222 let hdr = req
224 .headers()
225 .get(CACHE_CONTROL)
226 .and_then(|x| x.to_str().ok())?;
227
228 hdr.split(',').find_map(|x| {
230 let (k, v) = {
231 let mut split = x.split('=').map(|s| s.trim());
232 (split.next().unwrap(), split.next())
233 };
234
235 if ["no-cache", "no-store"].contains(&k) {
236 Some(CacheControl::NoCache)
237 } else if k == "max-age" {
238 let v = v.and_then(|x| x.parse::<u64>().ok());
239 if v == Some(0) {
240 Some(CacheControl::NoCache)
241 } else {
242 v.map(|x| CacheControl::MaxAge(Duration::from_secs(x)))
243 }
244 } else {
245 None
246 }
247 })
248}
249
250struct Expirer<K: KeyExtractor>(PhantomData<K>);
252
253impl<K: KeyExtractor> Expiry<K::Key, Arc<Entry>> for Expirer<K> {
254 fn expire_after_create(
255 &self,
256 _key: &K::Key,
257 value: &Arc<Entry>,
258 created_at: Instant,
259 ) -> Option<Duration> {
260 Some(value.expires - created_at)
261 }
262}
263
264pub struct CacheBuilder<K: KeyExtractor, B: Bypasser> {
266 key_extractor: K,
267 bypasser: Option<B>,
268 opts: Opts,
269 registry: Registry,
270}
271
272impl<K: KeyExtractor> CacheBuilder<K, NoopBypasser> {
273 pub fn new(key_extractor: K) -> Self {
275 Self {
276 key_extractor,
277 bypasser: None,
278 opts: Opts::default(),
279 registry: Registry::new(),
280 }
281 }
282}
283
284impl<K: KeyExtractor, B: Bypasser> CacheBuilder<K, B> {
285 pub fn new_with_bypasser(key_extractor: K, bypasser: B) -> Self {
287 Self {
288 key_extractor,
289 bypasser: Some(bypasser),
290 opts: Opts::default(),
291 registry: Registry::new(),
292 }
293 }
294
295 pub const fn cache_size(mut self, v: u64) -> Self {
297 self.opts.cache_size = v;
298 self
299 }
300
301 pub const fn max_item_size(mut self, v: usize) -> Self {
303 self.opts.max_item_size = v;
304 self
305 }
306
307 pub const fn ttl(mut self, v: Duration) -> Self {
309 self.opts.ttl = v;
310 self
311 }
312
313 pub const fn max_ttl(mut self, v: Duration) -> Self {
315 self.opts.max_ttl = v;
316 self
317 }
318
319 pub const fn lock_timeout(mut self, v: Duration) -> Self {
321 self.opts.lock_timeout = v;
322 self
323 }
324
325 pub const fn body_timeout(mut self, v: Duration) -> Self {
327 self.opts.body_timeout = v;
328 self
329 }
330
331 pub const fn xfetch_beta(mut self, v: f64) -> Self {
333 self.opts.xfetch_beta = v;
334 self
335 }
336
337 pub fn methods(mut self, v: &[Method]) -> Self {
339 self.opts.methods = v.into();
340 self
341 }
342
343 pub const fn obey_cache_control(mut self, v: bool) -> Self {
345 self.opts.obey_cache_control = v;
346 self
347 }
348
349 pub fn registry(mut self, v: &Registry) -> Self {
351 self.registry = v.clone();
352 self
353 }
354
355 pub fn build(self) -> Result<Cache<K, B>, CacheError> {
357 Cache::new(self.opts, self.key_extractor, self.bypasser, &self.registry)
358 }
359}
360
361pub struct Cache<K: KeyExtractor, B: Bypasser = NoopBypasser> {
363 store: MokaCache<K::Key, Arc<Entry>, RandomState>,
364 locks: MokaCache<K::Key, Arc<Mutex<()>>, RandomState>,
365 key_extractor: K,
366 bypasser: Option<B>,
367 metrics: Metrics,
368 opts: Opts,
369}
370
371fn weigh_entry<K: KeyExtractor>(_k: &K::Key, v: &Arc<Entry>) -> u32 {
372 let mut size = size_of::<K::Key>() + size_of::<Arc<Entry>>();
373
374 size += calc_headers_size(v.response.headers());
375 size += v.response.body().len();
376
377 size as u32
378}
379
380impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
381 pub fn new(
383 opts: Opts,
384 key_extractor: K,
385 bypasser: Option<B>,
386 registry: &Registry,
387 ) -> Result<Self, CacheError> {
388 if opts.max_item_size as u64 >= opts.cache_size {
389 return Err(CacheError::Other(
390 "Cache item size should be less than whole cache size".into(),
391 ));
392 }
393
394 if opts.ttl > opts.max_ttl {
395 return Err(CacheError::Other("TTL should be <= max TTL".into()));
396 }
397
398 Ok(Self {
399 store: MokaCacheBuilder::new(opts.cache_size)
400 .expire_after(Expirer::<K>(PhantomData))
401 .weigher(weigh_entry::<K>)
402 .build_with_hasher(RandomState::default()),
403
404 locks: MokaCacheBuilder::new(32768)
406 .time_to_idle(Duration::from_secs(60))
407 .build_with_hasher(RandomState::default()),
408
409 key_extractor,
410 bypasser,
411 metrics: Metrics::new(registry),
412
413 opts,
414 })
415 }
416
417 pub fn get(&self, key: &K::Key, now: Instant, beta: f64) -> Option<Response> {
419 let val = self.store.get(key)?;
420
421 if val.need_to_refresh(now, beta) {
423 self.metrics.x_fetch.inc();
424 return None;
425 }
426
427 let (parts, body) = val.response.clone().into_parts();
428 Some(Response::from_parts(parts, Body::from(body)))
429 }
430
431 pub fn insert(
433 &self,
434 key: K::Key,
435 now: Instant,
436 ttl: Duration,
437 delta: Duration,
438 response: Response<Bytes>,
439 ) {
440 self.metrics.ttl.observe(ttl.as_secs_f64());
441
442 self.store.insert(
443 key,
444 Arc::new(Entry {
445 response,
446 delta: delta.as_secs_f64(),
447 expires: now + ttl,
448 }),
449 );
450 }
451
452 pub async fn process_request(
454 &self,
455 request: Request,
456 next: Next,
457 ) -> Result<Response, CacheError> {
458 let now = Instant::now();
459 let (cache_status, response) = self.process_inner(now, request, next).await?;
460
461 let cache_status_str: &'static str = (&cache_status).into();
463 let cache_bypass_reason_str: &'static str = match cache_status.clone() {
464 CacheStatus::Bypass(v) => v.into_str(),
465 _ => "none",
466 };
467
468 let labels = &[cache_status_str, cache_bypass_reason_str];
469
470 self.metrics.requests_count.with_label_values(labels).inc();
471 self.metrics
472 .requests_duration
473 .with_label_values(labels)
474 .observe(now.elapsed().as_secs_f64());
475
476 Ok(cache_status.with_response(response))
477 }
478
479 async fn process_inner(
480 &self,
481 now: Instant,
482 request: Request,
483 next: Next,
484 ) -> Result<(CacheStatus<B::BypassReason>, Response), CacheError> {
485 if let Some(b) = &self.bypasser {
487 if let Ok(v) = b.bypass(&request) {
489 if let Some(r) = v {
491 return Ok((
492 CacheStatus::Bypass(CacheBypassReason::Custom(r)),
493 next.run(request).await,
494 ));
495 }
496 } else {
497 return Ok((
498 CacheStatus::Bypass(CacheBypassReason::UnableToRunBypasser),
499 next.run(request).await,
500 ));
501 }
502 }
503
504 if !self.opts.methods.contains(request.method()) {
506 return Ok((
507 CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable),
508 next.run(request).await,
509 ));
510 }
511
512 let Ok(key) = self.key_extractor.extract(&request) else {
514 return Ok((
515 CacheStatus::Bypass(CacheBypassReason::UnableToExtractKey),
516 next.run(request).await,
517 ));
518 };
519
520 if let Some(v) = self.get(&key, now, self.opts.xfetch_beta) {
521 return Ok((CacheStatus::Hit, v));
522 }
523
524 let lock = self
526 .locks
527 .get_with_by_ref(&key, || Arc::new(Mutex::new(())));
528
529 let mut lock_obtained = false;
530 select! {
531 _ = lock.lock() => {
534 lock_obtained = true;
535 }
536
537 _ = sleep(self.opts.lock_timeout) => {}
539 }
540
541 self.metrics
543 .lock_await
544 .with_label_values(&[if lock_obtained { "yes" } else { "no" }])
545 .observe(now.elapsed().as_secs_f64());
546
547 if let Some(v) = self.get(&key, now, 0.0) {
550 return Ok((CacheStatus::Hit, v));
551 }
552
553 let now = Instant::now();
555 Ok(match self.pass_request(request, next).await? {
556 ResponseType::Fetched(v, ttl) => {
558 let delta = now.elapsed();
559 self.insert(key, now + delta, ttl, delta, v.clone());
560
561 let (mut parts, body) = v.into_parts();
562 parts.headers.insert(X_CACHE_TTL, ttl.as_secs().into());
563 let response = Response::from_parts(parts, Body::from(body));
564 (CacheStatus::Miss, response)
565 }
566
567 ResponseType::Streamed(v, reason) => (CacheStatus::Bypass(reason), v),
569 })
570 }
571
572 async fn pass_request(
574 &self,
575 request: Request,
576 next: Next,
577 ) -> Result<ResponseType<B::BypassReason>, CacheError> {
578 let response = next.run(request).await;
580
581 if !response.status().is_success() {
583 return Ok(ResponseType::Streamed(
584 response,
585 CacheBypassReason::HTTPError,
586 ));
587 }
588
589 let body_size = response.body().size_hint().exact().map(|x| x as usize);
591
592 let Some(body_size) = body_size else {
594 return Ok(ResponseType::Streamed(
595 response,
596 CacheBypassReason::SizeUnknown,
597 ));
598 };
599
600 if body_size > self.opts.max_item_size {
602 return Ok(ResponseType::Streamed(
603 response,
604 CacheBypassReason::BodyTooBig,
605 ));
606 }
607
608 let ttl = if self.opts.obey_cache_control {
610 let ttl = infer_ttl(&response);
611
612 match ttl {
613 Some(CacheControl::NoCache) => {
615 return Ok(ResponseType::Streamed(
616 response,
617 CacheBypassReason::CacheControl,
618 ));
619 }
620
621 Some(CacheControl::MaxAge(v)) => v.min(self.opts.max_ttl),
623
624 None => self.opts.ttl,
626 }
627 } else {
628 self.opts.ttl
629 };
630
631 let (parts, body) = response.into_parts();
633 let body = buffer_body(body, body_size, self.opts.body_timeout)
634 .await
635 .map_err(|e| match e {
636 HttpError::BodyTooBig => CacheError::FetchBodyTooBig,
637 HttpError::BodyTimedOut => CacheError::FetchBodyTimeout,
638 _ => CacheError::FetchBody(e.to_string()),
639 })?;
640
641 Ok(ResponseType::Fetched(
642 Response::from_parts(parts, body),
643 ttl,
644 ))
645 }
646}
647
648#[async_trait]
649impl<K: KeyExtractor, B: Bypasser> Run for Cache<K, B> {
650 async fn run(&self, _: CancellationToken) -> Result<(), anyhow::Error> {
651 self.store.run_pending_tasks();
652 self.metrics.memory.set(self.store.weighted_size() as i64);
653 self.metrics.entries.set(self.store.entry_count() as i64);
654 Ok(())
655 }
656}
657
658#[cfg(test)]
659impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
660 pub fn housekeep(&self) {
661 self.store.run_pending_tasks();
662 self.locks.run_pending_tasks();
663 }
664
665 pub fn size(&self) -> u64 {
666 self.store.weighted_size()
667 }
668
669 #[allow(clippy::len_without_is_empty)]
670 pub fn len(&self) -> u64 {
671 self.store.entry_count()
672 }
673
674 pub fn clear(&self) {
675 self.store.invalidate_all();
676 self.locks.invalidate_all();
677 self.housekeep();
678 }
679}
680
681#[derive(Clone, Debug)]
683pub struct KeyExtractorUriRange;
684
685impl KeyExtractor for KeyExtractorUriRange {
686 type Key = [u8; 20];
687
688 fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
689 let authority = extract_authority(request)
690 .ok_or_else(|| CacheError::ExtractKey("no authority found".into()))?
691 .as_bytes();
692 let paq = request
693 .uri()
694 .path_and_query()
695 .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
696 .as_str()
697 .as_bytes();
698
699 let mut hash = Sha1::new().chain_update(authority).chain_update(paq);
701 if let Some(v) = request.headers().get(RANGE) {
702 hash = hash.chain_update(v.as_bytes());
703 }
704
705 Ok(hash.finalize().into())
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use crate::hval;
712
713 use super::*;
714
715 use axum::{
716 Router,
717 body::to_bytes,
718 extract::State,
719 middleware::from_fn_with_state,
720 response::IntoResponse,
721 routing::{get, post},
722 };
723 use http::{Request, Response, StatusCode, Uri};
724 use sha1::Digest;
725 use tower::{Service, ServiceExt};
726
727 #[derive(Clone, Debug)]
728 pub struct KeyExtractorTest;
729
730 impl KeyExtractor for KeyExtractorTest {
731 type Key = [u8; 20];
732
733 fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
734 let paq = request
735 .uri()
736 .path_and_query()
737 .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
738 .as_str()
739 .as_bytes();
740
741 let hash: [u8; 20] = sha1::Sha1::new().chain_update(paq).finalize().into();
742 Ok(hash)
743 }
744 }
745
746 const MAX_ITEM_SIZE: usize = 1024;
747 const MAX_CACHE_SIZE: u64 = 32768;
748 const PROXY_LOCK_TIMEOUT: Duration = Duration::from_secs(1);
749
750 async fn dispatch_get_request(router: &mut Router, uri: String) -> Option<CacheStatus> {
751 let req = Request::get(uri).body(Body::from("")).unwrap();
752 let result = router.call(req).await.unwrap();
753 assert_eq!(result.status(), StatusCode::OK);
754 result.extensions().get::<CacheStatus>().cloned()
755 }
756
757 async fn handler(_request: Request<Body>) -> impl IntoResponse {
758 "test_body"
759 }
760
761 async fn handler_proxy_cache_lock(request: Request<Body>) -> impl IntoResponse {
762 if request.uri().path().contains("slow_response") {
763 sleep(2 * PROXY_LOCK_TIMEOUT).await;
764 }
765
766 "test_body"
767 }
768
769 async fn handler_too_big(_request: Request<Body>) -> impl IntoResponse {
770 "a".repeat(MAX_ITEM_SIZE + 1)
771 }
772
773 async fn handler_cache_control_max_age_1d(_request: Request<Body>) -> impl IntoResponse {
774 [(CACHE_CONTROL, "max-age=86400")]
775 }
776
777 async fn handler_cache_control_max_age_7d(_request: Request<Body>) -> impl IntoResponse {
778 [(CACHE_CONTROL, "max-age=604800")]
779 }
780
781 async fn handler_cache_control_no_cache(_request: Request<Body>) -> impl IntoResponse {
782 [(CACHE_CONTROL, "no-cache")]
783 }
784
785 async fn handler_cache_control_no_store(_request: Request<Body>) -> impl IntoResponse {
786 [(CACHE_CONTROL, "no-store")]
787 }
788
789 async fn middleware(
790 State(cache): State<Arc<Cache<KeyExtractorTest>>>,
791 request: Request<Body>,
792 next: Next,
793 ) -> impl IntoResponse {
794 cache
795 .process_request(request, next)
796 .await
797 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
798 }
799
800 #[test]
801 fn test_bypass_reason_serialize() {
802 #[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
803 #[strum(serialize_all = "snake_case")]
804 enum CustomReasonTest {
805 Bar,
806 }
807 impl CustomBypassReason for CustomReasonTest {}
808
809 let a: CacheBypassReason<CustomReasonTest> =
810 CacheBypassReason::Custom(CustomReasonTest::Bar);
811 let txt = a.into_str();
812 assert_eq!(txt, "bar");
813
814 let a: CacheBypassReason<CustomReasonTest> = CacheBypassReason::BodyTooBig;
815 let txt = a.into_str();
816 assert_eq!(txt, "body_too_big");
817 }
818
819 #[test]
820 fn test_key_extractor_uri_range() {
821 let x = KeyExtractorUriRange;
822
823 let mut req = Request::new("foo");
825 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=1");
826 let key1 = x.extract(&req).unwrap();
827
828 let mut req = Request::new("foo");
830 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=2");
831 let key2 = x.extract(&req).unwrap();
832 assert_ne!(key1, key2);
833
834 let mut req = Request::new("foo");
835 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/ba?abc=1");
836 let key2 = x.extract(&req).unwrap();
837 assert_ne!(key1, key2);
838
839 let mut req = Request::new("foo");
840 *req.uri_mut() = Uri::from_static("http://foo.bar.ba:80/foo/bar?abc=1");
841 let key2 = x.extract(&req).unwrap();
842 assert_ne!(key1, key2);
843
844 let mut req = Request::new("foo");
846 *req.uri_mut() = Uri::from_static("https://foo.bar.baz:80/foo/bar?abc=1");
847 let key2 = x.extract(&req).unwrap();
848 assert_eq!(key1, key2);
849
850 let mut req = Request::new("foo");
852 *req.uri_mut() = Uri::from_static("http://foo.bar.bar:80/foo/bar?abc=1");
853 (*req.headers_mut()).insert(RANGE, hval!("1000-2000"));
854 let key2 = x.extract(&req).unwrap();
855 assert_ne!(key1, key2);
856 }
857
858 #[test]
859 fn test_infer_ttl() {
860 let mut req = Response::new(());
861
862 assert_eq!(infer_ttl(&req), None);
863
864 req.headers_mut().insert(CACHE_CONTROL, hval!("no-cache"));
866 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
867
868 req.headers_mut().insert(CACHE_CONTROL, hval!("no-store"));
869 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
870
871 req.headers_mut()
872 .insert(CACHE_CONTROL, hval!("no-store, no-cache"));
873 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
874
875 req.headers_mut()
877 .insert(CACHE_CONTROL, hval!("no-store, no-cache, max-age=1"));
878 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
879
880 req.headers_mut()
881 .insert(CACHE_CONTROL, hval!("max-age=1, no-store, no-cache"));
882 assert_eq!(
883 infer_ttl(&req),
884 Some(CacheControl::MaxAge(Duration::from_secs(1)))
885 );
886
887 req.headers_mut()
889 .insert(CACHE_CONTROL, hval!("max-age=86400"));
890 assert_eq!(
891 infer_ttl(&req),
892 Some(CacheControl::MaxAge(Duration::from_secs(86400)))
893 );
894 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=0"));
895 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
896
897 req.headers_mut()
898 .insert(CACHE_CONTROL, hval!("max-age=foo"));
899 assert_eq!(infer_ttl(&req), None);
900
901 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age="));
902 assert_eq!(infer_ttl(&req), None);
903
904 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=-1"));
905 assert_eq!(infer_ttl(&req), None);
906
907 req.headers_mut().insert(CACHE_CONTROL, hval!(""));
909 assert_eq!(infer_ttl(&req), None);
910
911 req.headers_mut()
913 .insert(CACHE_CONTROL, hval!(", =foobar, "));
914 assert_eq!(infer_ttl(&req), None);
915 }
916
917 #[test]
918 fn test_cache_creation_errors() {
919 let cache = CacheBuilder::new(KeyExtractorTest)
920 .cache_size(1)
921 .max_item_size(2)
922 .build();
923 assert!(cache.is_err());
924
925 let cache = CacheBuilder::new(KeyExtractorTest)
926 .ttl(Duration::from_secs(2))
927 .max_ttl(Duration::from_secs(1))
928 .build();
929 assert!(cache.is_err());
930 }
931
932 #[tokio::test]
933 async fn test_cache_bypass() {
934 let cache = Arc::new(
935 CacheBuilder::new(KeyExtractorTest)
936 .max_item_size(MAX_ITEM_SIZE)
937 .build()
938 .unwrap(),
939 );
940
941 let mut app = Router::new()
942 .route("/", post(handler))
943 .route("/", get(handler))
944 .route("/too_big", get(handler_too_big))
945 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
946
947 let req = Request::post("/").body(Body::from("")).unwrap();
949 let result = app.call(req).await.unwrap();
950 assert_eq!(result.status(), StatusCode::OK);
951 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
952 assert_eq!(cache.len(), 0);
953 assert_eq!(
954 cache_status,
955 CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable)
956 );
957
958 let req = Request::get("/non_existing_path")
960 .body(Body::from("foobar"))
961 .unwrap();
962 let result = app.call(req).await.unwrap();
963 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
964 assert_eq!(result.status(), StatusCode::NOT_FOUND);
965 assert_eq!(
966 cache_status,
967 CacheStatus::Bypass(CacheBypassReason::HTTPError)
968 );
969 assert_eq!(cache.len(), 0);
970
971 let req = Request::get("/too_big").body(Body::from("foobar")).unwrap();
973 let result = app.call(req).await.unwrap();
974 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
975 assert_eq!(
976 cache_status,
977 CacheStatus::Bypass(CacheBypassReason::BodyTooBig)
978 );
979 assert_eq!(result.status(), StatusCode::OK);
980 assert_eq!(cache.len(), 0);
981 }
982
983 #[tokio::test]
984 async fn test_cache_hit() {
985 let ttl = Duration::from_millis(1500);
986
987 let cache = Arc::new(
988 CacheBuilder::new(KeyExtractorTest)
989 .cache_size(MAX_CACHE_SIZE)
990 .max_item_size(MAX_ITEM_SIZE)
991 .ttl(ttl)
992 .build()
993 .unwrap(),
994 );
995
996 let mut app = Router::new()
997 .route("/{key}", get(handler))
998 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
999
1000 let req = Request::get("/1").body(Body::from("")).unwrap();
1002 let result = app.call(req).await.unwrap();
1003 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1004 assert_eq!(result.status(), StatusCode::OK);
1005 assert_eq!(cache_status, CacheStatus::Miss);
1006 cache.housekeep();
1007 assert_eq!(cache.len(), 1);
1008
1009 let req = Request::get("/2").body(Body::from("")).unwrap();
1011 let result = app.call(req).await.unwrap();
1012 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1013 assert_eq!(result.status(), StatusCode::OK);
1014 assert_eq!(cache_status, CacheStatus::Miss);
1015 cache.housekeep();
1016 assert_eq!(cache.len(), 2);
1017
1018 let req = Request::get("/1").body(Body::from("")).unwrap();
1020 let result = app.call(req).await.unwrap();
1021 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1022 assert_eq!(result.status(), StatusCode::OK);
1023 assert_eq!(cache_status, CacheStatus::Hit);
1024 let (_, body) = result.into_parts();
1025 let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1026 let body = String::from_utf8_lossy(&body);
1027 assert_eq!("test_body", body);
1028 cache.housekeep();
1029 assert_eq!(cache.len(), 2);
1030
1031 let req = Request::get("/2").body(Body::from("")).unwrap();
1033 let result = app.call(req).await.unwrap();
1034 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1035 assert_eq!(result.status(), StatusCode::OK);
1036 assert_eq!(cache_status, CacheStatus::Hit);
1037 let (_, body) = result.into_parts();
1038 let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1039 let body = String::from_utf8_lossy(&body);
1040 assert_eq!("test_body", body);
1041 cache.housekeep();
1042 assert_eq!(cache.len(), 2);
1043
1044 sleep(ttl + Duration::from_millis(300)).await;
1046 cache.housekeep();
1047 let req = Request::get("/1").body(Body::from("")).unwrap();
1048 let result = app.call(req).await.unwrap();
1049 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1050 assert_eq!(result.status(), StatusCode::OK);
1051 assert_eq!(cache_status, CacheStatus::Miss);
1052
1053 cache.clear();
1055 let req_count = 50;
1056 for idx in 0..req_count {
1058 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1059 assert_eq!(status.unwrap(), CacheStatus::Miss);
1060 }
1061 for idx in 0..req_count {
1063 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1064 assert_eq!(status.unwrap(), CacheStatus::Hit);
1065 }
1066
1067 cache.clear();
1069 let req_count = 500;
1070 for idx in 0..req_count {
1072 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1073 assert_eq!(status.unwrap(), CacheStatus::Miss);
1074 }
1075
1076 let mut count_misses = 0;
1078 let mut count_hits = 0;
1079 for idx in 0..req_count {
1080 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1081 if status == Some(CacheStatus::Miss) {
1082 count_misses += 1;
1083 } else if status == Some(CacheStatus::Hit) {
1084 count_hits += 1;
1085 }
1086 }
1087 assert!(count_misses > 0);
1088 assert!(count_hits > 0);
1089 cache.housekeep();
1090 let entry_size = cache.size() / cache.len();
1091
1092 assert!(MAX_CACHE_SIZE > cache.size());
1095 assert!(MAX_CACHE_SIZE < cache.size() + entry_size);
1096 }
1097
1098 #[tokio::test]
1099 async fn test_cache_control() {
1100 let cache = Arc::new(
1101 CacheBuilder::new(KeyExtractorTest)
1102 .obey_cache_control(true)
1103 .build()
1104 .unwrap(),
1105 );
1106
1107 let mut app = Router::new()
1108 .route("/", get(handler))
1109 .route(
1110 "/cache_control_no_store",
1111 get(handler_cache_control_no_store),
1112 )
1113 .route(
1114 "/cache_control_no_cache",
1115 get(handler_cache_control_no_cache),
1116 )
1117 .route(
1118 "/cache_control_max_age_1d",
1119 get(handler_cache_control_max_age_1d),
1120 )
1121 .route(
1122 "/cache_control_max_age_7d",
1123 get(handler_cache_control_max_age_7d),
1124 )
1125 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1126
1127 let req = Request::get("/cache_control_no_cache")
1129 .body(Body::from("foobar"))
1130 .unwrap();
1131 let result = app.call(req).await.unwrap();
1132 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1133 assert_eq!(
1134 cache_status,
1135 CacheStatus::Bypass(CacheBypassReason::CacheControl)
1136 );
1137 assert_eq!(result.status(), StatusCode::OK);
1138 cache.housekeep();
1139 assert_eq!(cache.len(), 0);
1140
1141 let req = Request::get("/cache_control_no_store")
1143 .body(Body::from("foobar"))
1144 .unwrap();
1145 let result = app.call(req).await.unwrap();
1146 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1147 assert_eq!(
1148 cache_status,
1149 CacheStatus::Bypass(CacheBypassReason::CacheControl)
1150 );
1151 assert_eq!(result.status(), StatusCode::OK);
1152 cache.housekeep();
1153 assert_eq!(cache.len(), 0);
1154
1155 let req = Request::get("/cache_control_max_age_1d")
1157 .body(Body::from("foobar"))
1158 .unwrap();
1159 let result = app.call(req).await.unwrap();
1160 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1161 let ttl = result
1162 .headers()
1163 .get(X_CACHE_TTL)
1164 .unwrap()
1165 .to_str()
1166 .unwrap()
1167 .parse::<u64>()
1168 .unwrap();
1169 assert_eq!(cache_status, CacheStatus::Miss);
1170 assert_eq!(ttl, 86400);
1171 assert_eq!(result.status(), StatusCode::OK);
1172 cache.housekeep();
1173 assert_eq!(cache.len(), 1);
1174
1175 let req = Request::get("/cache_control_max_age_7d")
1177 .body(Body::from("foobar"))
1178 .unwrap();
1179 let result = app.call(req).await.unwrap();
1180 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1181 let ttl = result
1182 .headers()
1183 .get(X_CACHE_TTL)
1184 .unwrap()
1185 .to_str()
1186 .unwrap()
1187 .parse::<u64>()
1188 .unwrap();
1189 assert_eq!(cache_status, CacheStatus::Miss);
1190 assert_eq!(ttl, 86400);
1191 assert_eq!(result.status(), StatusCode::OK);
1192 cache.housekeep();
1193 assert_eq!(cache.len(), 2);
1194
1195 let req = Request::get("/").body(Body::from("foobar")).unwrap();
1197 let result = app.call(req).await.unwrap();
1198 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1199 let ttl = result
1200 .headers()
1201 .get(X_CACHE_TTL)
1202 .unwrap()
1203 .to_str()
1204 .unwrap()
1205 .parse::<u64>()
1206 .unwrap();
1207 assert_eq!(cache_status, CacheStatus::Miss);
1208 assert_eq!(ttl, 10);
1209 assert_eq!(result.status(), StatusCode::OK);
1210 cache.housekeep();
1211 assert_eq!(cache.len(), 3);
1212
1213 let cache = Arc::new(
1215 CacheBuilder::new(KeyExtractorTest)
1216 .obey_cache_control(false)
1217 .build()
1218 .unwrap(),
1219 );
1220
1221 let mut app = Router::new()
1222 .route("/", get(handler))
1223 .route(
1224 "/cache_control_no_store",
1225 get(handler_cache_control_no_store),
1226 )
1227 .route(
1228 "/cache_control_no_cache",
1229 get(handler_cache_control_no_cache),
1230 )
1231 .route(
1232 "/cache_control_max_age_1d",
1233 get(handler_cache_control_max_age_1d),
1234 )
1235 .route(
1236 "/cache_control_max_age_7d",
1237 get(handler_cache_control_max_age_7d),
1238 )
1239 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1240
1241 let req = Request::get("/cache_control_no_cache")
1243 .body(Body::from("foobar"))
1244 .unwrap();
1245 let result = app.call(req).await.unwrap();
1246 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1247 assert_eq!(cache_status, CacheStatus::Miss);
1248 assert_eq!(result.status(), StatusCode::OK);
1249 cache.housekeep();
1250 assert_eq!(cache.len(), 1);
1251
1252 let req = Request::get("/cache_control_no_store")
1254 .body(Body::from("foobar"))
1255 .unwrap();
1256 let result = app.call(req).await.unwrap();
1257 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1258 assert_eq!(cache_status, CacheStatus::Miss,);
1259 assert_eq!(result.status(), StatusCode::OK);
1260 cache.housekeep();
1261 assert_eq!(cache.len(), 2);
1262
1263 let req = Request::get("/cache_control_max_age_1d")
1265 .body(Body::from("foobar"))
1266 .unwrap();
1267 let result = app.call(req).await.unwrap();
1268 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1269 let ttl = result
1270 .headers()
1271 .get(X_CACHE_TTL)
1272 .unwrap()
1273 .to_str()
1274 .unwrap()
1275 .parse::<u64>()
1276 .unwrap();
1277 assert_eq!(cache_status, CacheStatus::Miss);
1278 assert_eq!(ttl, 10);
1279 assert_eq!(result.status(), StatusCode::OK);
1280 cache.housekeep();
1281 assert_eq!(cache.len(), 3);
1282
1283 let req = Request::get("/").body(Body::from("foobar")).unwrap();
1285 let result = app.call(req).await.unwrap();
1286 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1287 let ttl = result
1288 .headers()
1289 .get(X_CACHE_TTL)
1290 .unwrap()
1291 .to_str()
1292 .unwrap()
1293 .parse::<u64>()
1294 .unwrap();
1295 assert_eq!(cache_status, CacheStatus::Miss);
1296 assert_eq!(ttl, 10);
1297 assert_eq!(result.status(), StatusCode::OK);
1298 cache.housekeep();
1299 assert_eq!(cache.len(), 4);
1300 }
1301
1302 #[tokio::test]
1303 async fn test_proxy_cache_lock() {
1304 let cache = Arc::new(
1305 CacheBuilder::new(KeyExtractorTest)
1306 .lock_timeout(PROXY_LOCK_TIMEOUT)
1307 .build()
1308 .unwrap(),
1309 );
1310
1311 let app = Router::new()
1312 .route("/{key}", get(handler_proxy_cache_lock))
1313 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1314
1315 let req_count = 50;
1316 let expected_misses = [1, req_count];
1318 let expected_hits = [req_count - 1, 0];
1319 for (idx, uri) in ["/fast_response", "/slow_response"].iter().enumerate() {
1320 let mut tasks = vec![];
1321 for _ in 0..req_count {
1323 let app = app.clone();
1324 tasks.push(tokio::spawn(async move {
1325 let req = Request::get(*uri).body(Body::from("")).unwrap();
1326 let result = app.oneshot(req).await.unwrap();
1327 assert_eq!(result.status(), StatusCode::OK);
1328 result.extensions().get::<CacheStatus>().cloned()
1329 }));
1330 }
1331 let mut count_hits = 0;
1332 let mut count_misses = 0;
1333 for task in tasks {
1334 task.await
1335 .map(|res| match res {
1336 Some(CacheStatus::Hit) => count_hits += 1,
1337 Some(CacheStatus::Miss) => count_misses += 1,
1338 _ => panic!("Unexpected cache status"),
1339 })
1340 .expect("failed to complete task");
1341 }
1342 assert_eq!(count_hits, expected_hits[idx]);
1343 assert_eq!(count_misses, expected_misses[idx]);
1344 cache.housekeep();
1345 cache.clear();
1346 }
1347 }
1348
1349 #[test]
1350 fn test_xfetch() {
1351 let now = Instant::now();
1352 let reqs = 10000;
1353
1354 let entry = Entry {
1355 response: Response::builder().body(Bytes::new()).unwrap(),
1356 delta: 0.5,
1357 expires: now + Duration::from_secs(60),
1358 };
1359
1360 let now2 = now + Duration::from_secs(58);
1362 let mut refresh = 0;
1363 for _ in 0..reqs {
1364 if entry.need_to_refresh(now2, 1.5) {
1365 refresh += 1;
1366 }
1367 }
1368
1369 assert!(refresh > 550 && refresh < 800);
1370
1371 let now2 = now + Duration::from_secs(30);
1373 let mut refresh = 0;
1374 for _ in 0..reqs {
1375 if entry.need_to_refresh(now2, 1.0) {
1376 refresh += 1;
1377 }
1378 }
1379
1380 assert_eq!(refresh, 0);
1381
1382 let now2 = now + Duration::from_secs(30);
1384 let mut refresh = 0;
1385 for _ in 0..reqs {
1386 if entry.need_to_refresh(now2, 10.0) {
1387 refresh += 1;
1388 }
1389 }
1390
1391 assert!(refresh > 9);
1392 }
1393}