1use std::{
2 fmt::Debug,
3 marker::PhantomData,
4 mem::size_of,
5 sync::Arc,
6 time::{Duration, Instant},
7};
8
9use ahash::RandomState;
10use async_trait::async_trait;
11use axum::{body::Body, extract::Request, middleware::Next, response::Response};
12use bytes::Bytes;
13use http::{
14 Method,
15 header::{CACHE_CONTROL, RANGE},
16};
17use http_body::Body as _;
18use ic_bn_lib_common::{
19 traits::{
20 Run,
21 http::{Bypasser, CustomBypassReason, KeyExtractor},
22 },
23 types::http::{CacheBypassReason, CacheError, Error as HttpError},
24};
25use moka::{
26 Expiry,
27 sync::{Cache as MokaCache, CacheBuilder as MokaCacheBuilder},
28};
29use prometheus::{
30 Counter, CounterVec, Histogram, HistogramVec, IntGauge, Registry,
31 register_counter_vec_with_registry, register_counter_with_registry,
32 register_histogram_vec_with_registry, register_histogram_with_registry,
33 register_int_gauge_with_registry,
34};
35use sha1::{Digest, Sha1};
36use strum_macros::{Display, IntoStaticStr};
37use tokio::{select, sync::Mutex, time::sleep};
38use tokio_util::sync::CancellationToken;
39
40use super::{body::buffer_body, calc_headers_size, extract_authority};
41use crate::http::headers::X_CACHE_TTL;
42
43#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
44pub enum CustomBypassReasonDummy {}
45impl CustomBypassReason for CustomBypassReasonDummy {}
46
47#[derive(Debug, Clone, Display, PartialEq, Eq, Default, IntoStaticStr)]
49#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
50pub enum CacheStatus<R: CustomBypassReason = CustomBypassReasonDummy> {
51 #[default]
52 Disabled,
53 Bypass(CacheBypassReason<R>),
54 Hit(i64),
55 Miss(i64),
56}
57
58impl<B: CustomBypassReason> CacheStatus<B> {
59 pub fn with_response<T>(self, mut resp: Response<T>) -> Response<T> {
61 resp.extensions_mut().insert(self);
62 resp
63 }
64}
65
66enum ResponseType<R: CustomBypassReason> {
67 Fetched(Response<Bytes>, Duration),
68 Streamed(Response, CacheBypassReason<R>),
69}
70
71#[derive(Clone)]
73struct Entry {
74 response: Response<Bytes>,
75 delta: f64,
78 expires: Instant,
79}
80
81impl Entry {
82 fn need_to_refresh(&self, now: Instant, beta: f64) -> bool {
86 if beta == 0.0 {
88 return false;
89 }
90
91 let rnd = rand::random::<f64>();
92 let xfetch = -(self.delta * beta * rnd.ln());
93 let ttl_left = (self.expires - now).as_secs_f64();
94
95 xfetch > ttl_left
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct NoopBypasser;
102
103impl Bypasser for NoopBypasser {
104 type BypassReason = CustomBypassReasonDummy;
105
106 fn bypass<T>(&self, _req: &Request<T>) -> Result<Option<Self::BypassReason>, CacheError> {
107 Ok(None)
108 }
109}
110
111#[derive(Clone)]
113pub struct Metrics {
114 lock_await: HistogramVec,
115 requests_count: CounterVec,
116 requests_duration: HistogramVec,
117 ttl: Histogram,
118 x_fetch: Counter,
119 memory: IntGauge,
120 entries: IntGauge,
121}
122
123impl Metrics {
124 pub fn new(registry: &Registry) -> Self {
126 let lbls = &["cache_status", "cache_bypass_reason"];
127
128 Self {
129 lock_await: register_histogram_vec_with_registry!(
130 "cache_proxy_lock_await",
131 "Time spent waiting for the proxy cache lock",
132 &["lock_obtained"],
133 registry,
134 )
135 .unwrap(),
136
137 requests_count: register_counter_vec_with_registry!(
138 "cache_requests_count",
139 "Cache requests count",
140 lbls,
141 registry,
142 )
143 .unwrap(),
144
145 requests_duration: register_histogram_vec_with_registry!(
146 "cache_requests_duration",
147 "Time it took to execute the request",
148 lbls,
149 registry,
150 )
151 .unwrap(),
152
153 ttl: register_histogram_with_registry!(
154 "cache_ttl",
155 "TTL that was set when storing the response",
156 vec![1.0, 10.0, 100.0, 1000.0, 10000.0, 86400.0],
157 registry,
158 )
159 .unwrap(),
160
161 x_fetch: register_counter_with_registry!(
162 "cache_xfetch_count",
163 "Number of requests that x-fetch refreshed",
164 registry,
165 )
166 .unwrap(),
167
168 memory: register_int_gauge_with_registry!(
169 "cache_memory",
170 "Memory usage by the cache in bytes",
171 registry,
172 )
173 .unwrap(),
174
175 entries: register_int_gauge_with_registry!(
176 "cache_entries",
177 "Count of entries in the cache",
178 registry,
179 )
180 .unwrap(),
181 }
182 }
183}
184
185pub struct Opts {
187 pub cache_size: u64,
188 pub max_item_size: usize,
189 pub ttl: Duration,
190 pub max_ttl: Duration,
191 pub obey_cache_control: bool,
192 pub lock_timeout: Duration,
193 pub body_timeout: Duration,
194 pub xfetch_beta: f64,
195 pub methods: Vec<Method>,
196}
197
198impl Default for Opts {
199 fn default() -> Self {
200 Self {
201 cache_size: 128 * 1024 * 1024,
202 max_item_size: 16 * 1024 * 1024,
203 ttl: Duration::from_secs(10),
204 max_ttl: Duration::from_secs(86400),
205 obey_cache_control: false,
206 lock_timeout: Duration::from_secs(5),
207 body_timeout: Duration::from_secs(60),
208 xfetch_beta: 0.0,
209 methods: vec![Method::GET],
210 }
211 }
212}
213
214#[derive(Debug, PartialEq, Eq)]
215enum CacheControl {
216 NoCache,
217 MaxAge(Duration),
218}
219
220fn infer_ttl<T>(req: &Response<T>) -> Option<CacheControl> {
222 let hdr = req
224 .headers()
225 .get(CACHE_CONTROL)
226 .and_then(|x| x.to_str().ok())?;
227
228 hdr.split(',').find_map(|x| {
230 let (k, v) = {
231 let mut split = x.split('=').map(|s| s.trim());
232 (split.next().unwrap(), split.next())
233 };
234
235 if ["no-cache", "no-store"].contains(&k) {
236 Some(CacheControl::NoCache)
237 } else if k == "max-age" {
238 let v = v.and_then(|x| x.parse::<u64>().ok());
239 if v == Some(0) {
240 Some(CacheControl::NoCache)
241 } else {
242 v.map(|x| CacheControl::MaxAge(Duration::from_secs(x)))
243 }
244 } else {
245 None
246 }
247 })
248}
249
250struct Expirer<K: KeyExtractor>(PhantomData<K>);
252
253impl<K: KeyExtractor> Expiry<K::Key, Arc<Entry>> for Expirer<K> {
254 fn expire_after_create(
255 &self,
256 _key: &K::Key,
257 value: &Arc<Entry>,
258 created_at: Instant,
259 ) -> Option<Duration> {
260 Some(value.expires - created_at)
261 }
262}
263
264pub struct CacheBuilder<K: KeyExtractor, B: Bypasser> {
266 key_extractor: K,
267 bypasser: Option<B>,
268 opts: Opts,
269 registry: Registry,
270}
271
272impl<K: KeyExtractor> CacheBuilder<K, NoopBypasser> {
273 pub fn new(key_extractor: K) -> Self {
275 Self {
276 key_extractor,
277 bypasser: None,
278 opts: Opts::default(),
279 registry: Registry::new(),
280 }
281 }
282}
283
284impl<K: KeyExtractor, B: Bypasser> CacheBuilder<K, B> {
285 pub fn new_with_bypasser(key_extractor: K, bypasser: B) -> Self {
287 Self {
288 key_extractor,
289 bypasser: Some(bypasser),
290 opts: Opts::default(),
291 registry: Registry::new(),
292 }
293 }
294
295 pub const fn cache_size(mut self, v: u64) -> Self {
297 self.opts.cache_size = v;
298 self
299 }
300
301 pub const fn max_item_size(mut self, v: usize) -> Self {
303 self.opts.max_item_size = v;
304 self
305 }
306
307 pub const fn ttl(mut self, v: Duration) -> Self {
309 self.opts.ttl = v;
310 self
311 }
312
313 pub const fn max_ttl(mut self, v: Duration) -> Self {
315 self.opts.max_ttl = v;
316 self
317 }
318
319 pub const fn lock_timeout(mut self, v: Duration) -> Self {
321 self.opts.lock_timeout = v;
322 self
323 }
324
325 pub const fn body_timeout(mut self, v: Duration) -> Self {
327 self.opts.body_timeout = v;
328 self
329 }
330
331 pub const fn xfetch_beta(mut self, v: f64) -> Self {
333 self.opts.xfetch_beta = v;
334 self
335 }
336
337 pub fn methods(mut self, v: &[Method]) -> Self {
339 self.opts.methods = v.into();
340 self
341 }
342
343 pub const fn obey_cache_control(mut self, v: bool) -> Self {
345 self.opts.obey_cache_control = v;
346 self
347 }
348
349 pub fn registry(mut self, v: &Registry) -> Self {
351 self.registry = v.clone();
352 self
353 }
354
355 pub fn build(self) -> Result<Cache<K, B>, CacheError> {
357 Cache::new(self.opts, self.key_extractor, self.bypasser, &self.registry)
358 }
359}
360
361pub struct Cache<K: KeyExtractor, B: Bypasser = NoopBypasser> {
363 store: MokaCache<K::Key, Arc<Entry>, RandomState>,
364 locks: MokaCache<K::Key, Arc<Mutex<()>>, RandomState>,
365 key_extractor: K,
366 bypasser: Option<B>,
367 metrics: Metrics,
368 opts: Opts,
369}
370
371fn weigh_entry<K: KeyExtractor>(_k: &K::Key, v: &Arc<Entry>) -> u32 {
372 let mut size = size_of::<K::Key>() + size_of::<Arc<Entry>>();
373
374 size += calc_headers_size(v.response.headers());
375 size += v.response.body().len();
376
377 size as u32
378}
379
380impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
381 pub fn new(
383 opts: Opts,
384 key_extractor: K,
385 bypasser: Option<B>,
386 registry: &Registry,
387 ) -> Result<Self, CacheError> {
388 if opts.max_item_size as u64 >= opts.cache_size {
389 return Err(CacheError::Other(
390 "Cache item size should be less than whole cache size".into(),
391 ));
392 }
393
394 if opts.ttl > opts.max_ttl {
395 return Err(CacheError::Other("TTL should be <= max TTL".into()));
396 }
397
398 Ok(Self {
399 store: MokaCacheBuilder::new(opts.cache_size)
400 .expire_after(Expirer::<K>(PhantomData))
401 .weigher(weigh_entry::<K>)
402 .build_with_hasher(RandomState::default()),
403
404 locks: MokaCacheBuilder::new(32768)
406 .time_to_idle(Duration::from_secs(60))
407 .build_with_hasher(RandomState::default()),
408
409 key_extractor,
410 bypasser,
411 metrics: Metrics::new(registry),
412
413 opts,
414 })
415 }
416
417 pub fn get(&self, key: &K::Key, now: Instant, beta: f64) -> Option<(Response, i64)> {
419 let val = self.store.get(key)?;
420
421 if val.need_to_refresh(now, beta) {
423 self.metrics.x_fetch.inc();
424 return None;
425 }
426
427 let (mut parts, body) = val.response.clone().into_parts();
428 let ttl_left = if now < val.expires {
429 (val.expires - now).as_secs() as i64
430 } else {
431 -((now - val.expires).as_secs() as i64)
432 };
433
434 if ttl_left >= 0 {
435 parts.headers.insert(X_CACHE_TTL, ttl_left.into());
436 }
437
438 Some((Response::from_parts(parts, Body::from(body)), ttl_left))
439 }
440
441 pub fn insert(
443 &self,
444 key: K::Key,
445 now: Instant,
446 ttl: Duration,
447 delta: Duration,
448 response: Response<Bytes>,
449 ) {
450 self.metrics.ttl.observe(ttl.as_secs_f64());
451
452 self.store.insert(
453 key,
454 Arc::new(Entry {
455 response,
456 delta: delta.as_secs_f64(),
457 expires: now + ttl,
458 }),
459 );
460 }
461
462 pub async fn process_request(
464 &self,
465 request: Request,
466 next: Next,
467 ) -> Result<Response, CacheError> {
468 let now = Instant::now();
469 let (cache_status, response) = self.process_inner(now, request, next).await?;
470
471 let cache_status_str: &'static str = (&cache_status).into();
473 let cache_bypass_reason_str: &'static str = match cache_status.clone() {
474 CacheStatus::Bypass(v) => v.into_str(),
475 _ => "none",
476 };
477
478 let labels = &[cache_status_str, cache_bypass_reason_str];
479
480 self.metrics.requests_count.with_label_values(labels).inc();
481 self.metrics
482 .requests_duration
483 .with_label_values(labels)
484 .observe(now.elapsed().as_secs_f64());
485
486 Ok(cache_status.with_response(response))
487 }
488
489 async fn process_inner(
490 &self,
491 now: Instant,
492 request: Request,
493 next: Next,
494 ) -> Result<(CacheStatus<B::BypassReason>, Response), CacheError> {
495 if let Some(b) = &self.bypasser {
497 if let Ok(v) = b.bypass(&request) {
499 if let Some(r) = v {
501 return Ok((
502 CacheStatus::Bypass(CacheBypassReason::Custom(r)),
503 next.run(request).await,
504 ));
505 }
506 } else {
507 return Ok((
508 CacheStatus::Bypass(CacheBypassReason::UnableToRunBypasser),
509 next.run(request).await,
510 ));
511 }
512 }
513
514 if !self.opts.methods.contains(request.method()) {
516 return Ok((
517 CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable),
518 next.run(request).await,
519 ));
520 }
521
522 let Ok(key) = self.key_extractor.extract(&request) else {
524 return Ok((
525 CacheStatus::Bypass(CacheBypassReason::UnableToExtractKey),
526 next.run(request).await,
527 ));
528 };
529
530 if let Some((v, ttl_left)) = self.get(&key, now, self.opts.xfetch_beta) {
531 return Ok((CacheStatus::Hit(ttl_left), v));
532 }
533
534 let lock = self
536 .locks
537 .get_with_by_ref(&key, || Arc::new(Mutex::new(())));
538
539 let mut lock_obtained = false;
540 select! {
541 _ = lock.lock() => {
544 lock_obtained = true;
545 }
546
547 _ = sleep(self.opts.lock_timeout) => {}
549 }
550
551 self.metrics
553 .lock_await
554 .with_label_values(&[if lock_obtained { "yes" } else { "no" }])
555 .observe(now.elapsed().as_secs_f64());
556
557 if let Some((v, ttl_left)) = self.get(&key, now, 0.0) {
560 return Ok((CacheStatus::Hit(ttl_left), v));
561 }
562
563 let now = Instant::now();
565 Ok(match self.pass_request(request, next).await? {
566 ResponseType::Fetched(v, ttl) => {
568 let delta = now.elapsed();
569 self.insert(key, now + delta, ttl, delta, v.clone());
570
571 let ttl = ttl.as_secs();
572 let (mut parts, body) = v.into_parts();
573 parts.headers.insert(X_CACHE_TTL, ttl.into());
574 let response = Response::from_parts(parts, Body::from(body));
575 (CacheStatus::Miss(ttl as i64), response)
576 }
577
578 ResponseType::Streamed(v, reason) => (CacheStatus::Bypass(reason), v),
580 })
581 }
582
583 async fn pass_request(
585 &self,
586 request: Request,
587 next: Next,
588 ) -> Result<ResponseType<B::BypassReason>, CacheError> {
589 let response = next.run(request).await;
591
592 if !response.status().is_success() {
594 return Ok(ResponseType::Streamed(
595 response,
596 CacheBypassReason::HTTPError,
597 ));
598 }
599
600 let body_size = response.body().size_hint().exact().map(|x| x as usize);
602
603 let Some(body_size) = body_size else {
605 return Ok(ResponseType::Streamed(
606 response,
607 CacheBypassReason::SizeUnknown,
608 ));
609 };
610
611 if body_size > self.opts.max_item_size {
613 return Ok(ResponseType::Streamed(
614 response,
615 CacheBypassReason::BodyTooBig,
616 ));
617 }
618
619 let ttl = if self.opts.obey_cache_control {
621 let ttl = infer_ttl(&response);
622
623 match ttl {
624 Some(CacheControl::NoCache) => {
626 return Ok(ResponseType::Streamed(
627 response,
628 CacheBypassReason::CacheControl,
629 ));
630 }
631
632 Some(CacheControl::MaxAge(v)) => v.min(self.opts.max_ttl),
634
635 None => self.opts.ttl,
637 }
638 } else {
639 self.opts.ttl
640 };
641
642 let (parts, body) = response.into_parts();
644 let body = buffer_body(body, body_size, self.opts.body_timeout)
645 .await
646 .map_err(|e| match e {
647 HttpError::BodyTooBig => CacheError::FetchBodyTooBig,
648 HttpError::BodyTimedOut => CacheError::FetchBodyTimeout,
649 _ => CacheError::FetchBody(e.to_string()),
650 })?;
651
652 Ok(ResponseType::Fetched(
653 Response::from_parts(parts, body),
654 ttl,
655 ))
656 }
657}
658
659#[async_trait]
660impl<K: KeyExtractor, B: Bypasser> Run for Cache<K, B> {
661 async fn run(&self, _: CancellationToken) -> Result<(), anyhow::Error> {
662 self.store.run_pending_tasks();
663 self.metrics.memory.set(self.store.weighted_size() as i64);
664 self.metrics.entries.set(self.store.entry_count() as i64);
665 Ok(())
666 }
667}
668
669#[cfg(test)]
670impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
671 pub fn housekeep(&self) {
672 self.store.run_pending_tasks();
673 self.locks.run_pending_tasks();
674 }
675
676 pub fn size(&self) -> u64 {
677 self.store.weighted_size()
678 }
679
680 #[allow(clippy::len_without_is_empty)]
681 pub fn len(&self) -> u64 {
682 self.store.entry_count()
683 }
684
685 pub fn clear(&self) {
686 self.store.invalidate_all();
687 self.locks.invalidate_all();
688 self.housekeep();
689 }
690}
691
692#[derive(Clone, Debug)]
694pub struct KeyExtractorUriRange;
695
696impl KeyExtractor for KeyExtractorUriRange {
697 type Key = [u8; 20];
698
699 fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
700 let authority = extract_authority(request)
701 .ok_or_else(|| CacheError::ExtractKey("no authority found".into()))?
702 .as_bytes();
703 let paq = request
704 .uri()
705 .path_and_query()
706 .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
707 .as_str()
708 .as_bytes();
709
710 let mut hash = Sha1::new().chain_update(authority).chain_update(paq);
712 if let Some(v) = request.headers().get(RANGE) {
713 hash = hash.chain_update(v.as_bytes());
714 }
715
716 Ok(hash.finalize().into())
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use crate::hval;
723
724 use super::*;
725
726 use axum::{
727 Router,
728 body::to_bytes,
729 extract::State,
730 middleware::from_fn_with_state,
731 response::IntoResponse,
732 routing::{get, post},
733 };
734 use http::{Request, Response, StatusCode, Uri};
735 use sha1::Digest;
736 use tower::{Service, ServiceExt};
737
738 #[derive(Clone, Debug)]
739 pub struct KeyExtractorTest;
740
741 impl KeyExtractor for KeyExtractorTest {
742 type Key = [u8; 20];
743
744 fn extract<T>(&self, request: &Request<T>) -> Result<Self::Key, CacheError> {
745 let paq = request
746 .uri()
747 .path_and_query()
748 .ok_or_else(|| CacheError::ExtractKey("no path_and_query found".into()))?
749 .as_str()
750 .as_bytes();
751
752 let hash: [u8; 20] = sha1::Sha1::new().chain_update(paq).finalize().into();
753 Ok(hash)
754 }
755 }
756
757 const MAX_ITEM_SIZE: usize = 1024;
758 const MAX_CACHE_SIZE: u64 = 32768;
759 const PROXY_LOCK_TIMEOUT: Duration = Duration::from_secs(1);
760
761 async fn dispatch_get_request(router: &mut Router, uri: String) -> Option<CacheStatus> {
762 let req = Request::get(uri).body(Body::from("")).unwrap();
763 let result = router.call(req).await.unwrap();
764 assert_eq!(result.status(), StatusCode::OK);
765 result.extensions().get::<CacheStatus>().cloned()
766 }
767
768 async fn handler(_request: Request<Body>) -> impl IntoResponse {
769 "test_body"
770 }
771
772 async fn handler_proxy_cache_lock(request: Request<Body>) -> impl IntoResponse {
773 if request.uri().path().contains("slow_response") {
774 sleep(2 * PROXY_LOCK_TIMEOUT).await;
775 }
776
777 "test_body"
778 }
779
780 async fn handler_too_big(_request: Request<Body>) -> impl IntoResponse {
781 "a".repeat(MAX_ITEM_SIZE + 1)
782 }
783
784 async fn handler_cache_control_max_age_1d(_request: Request<Body>) -> impl IntoResponse {
785 [(CACHE_CONTROL, "max-age=86400")]
786 }
787
788 async fn handler_cache_control_max_age_7d(_request: Request<Body>) -> impl IntoResponse {
789 [(CACHE_CONTROL, "max-age=604800")]
790 }
791
792 async fn handler_cache_control_no_cache(_request: Request<Body>) -> impl IntoResponse {
793 [(CACHE_CONTROL, "no-cache")]
794 }
795
796 async fn handler_cache_control_no_store(_request: Request<Body>) -> impl IntoResponse {
797 [(CACHE_CONTROL, "no-store")]
798 }
799
800 async fn middleware(
801 State(cache): State<Arc<Cache<KeyExtractorTest>>>,
802 request: Request<Body>,
803 next: Next,
804 ) -> impl IntoResponse {
805 cache
806 .process_request(request, next)
807 .await
808 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
809 }
810
811 #[test]
812 fn test_bypass_reason_serialize() {
813 #[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
814 #[strum(serialize_all = "snake_case")]
815 enum CustomReasonTest {
816 Bar,
817 }
818 impl CustomBypassReason for CustomReasonTest {}
819
820 let a: CacheBypassReason<CustomReasonTest> =
821 CacheBypassReason::Custom(CustomReasonTest::Bar);
822 let txt = a.into_str();
823 assert_eq!(txt, "bar");
824
825 let a: CacheBypassReason<CustomReasonTest> = CacheBypassReason::BodyTooBig;
826 let txt = a.into_str();
827 assert_eq!(txt, "body_too_big");
828 }
829
830 #[test]
831 fn test_key_extractor_uri_range() {
832 let x = KeyExtractorUriRange;
833
834 let mut req = Request::new("foo");
836 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=1");
837 let key1 = x.extract(&req).unwrap();
838
839 let mut req = Request::new("foo");
841 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/bar?abc=2");
842 let key2 = x.extract(&req).unwrap();
843 assert_ne!(key1, key2);
844
845 let mut req = Request::new("foo");
846 *req.uri_mut() = Uri::from_static("http://foo.bar.baz:80/foo/ba?abc=1");
847 let key2 = x.extract(&req).unwrap();
848 assert_ne!(key1, key2);
849
850 let mut req = Request::new("foo");
851 *req.uri_mut() = Uri::from_static("http://foo.bar.ba:80/foo/bar?abc=1");
852 let key2 = x.extract(&req).unwrap();
853 assert_ne!(key1, key2);
854
855 let mut req = Request::new("foo");
857 *req.uri_mut() = Uri::from_static("https://foo.bar.baz:80/foo/bar?abc=1");
858 let key2 = x.extract(&req).unwrap();
859 assert_eq!(key1, key2);
860
861 let mut req = Request::new("foo");
863 *req.uri_mut() = Uri::from_static("http://foo.bar.bar:80/foo/bar?abc=1");
864 (*req.headers_mut()).insert(RANGE, hval!("1000-2000"));
865 let key2 = x.extract(&req).unwrap();
866 assert_ne!(key1, key2);
867 }
868
869 #[test]
870 fn test_infer_ttl() {
871 let mut req = Response::new(());
872
873 assert_eq!(infer_ttl(&req), None);
874
875 req.headers_mut().insert(CACHE_CONTROL, hval!("no-cache"));
877 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
878
879 req.headers_mut().insert(CACHE_CONTROL, hval!("no-store"));
880 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
881
882 req.headers_mut()
883 .insert(CACHE_CONTROL, hval!("no-store, no-cache"));
884 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
885
886 req.headers_mut()
888 .insert(CACHE_CONTROL, hval!("no-store, no-cache, max-age=1"));
889 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
890
891 req.headers_mut()
892 .insert(CACHE_CONTROL, hval!("max-age=1, no-store, no-cache"));
893 assert_eq!(
894 infer_ttl(&req),
895 Some(CacheControl::MaxAge(Duration::from_secs(1)))
896 );
897
898 req.headers_mut()
900 .insert(CACHE_CONTROL, hval!("max-age=86400"));
901 assert_eq!(
902 infer_ttl(&req),
903 Some(CacheControl::MaxAge(Duration::from_secs(86400)))
904 );
905 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=0"));
906 assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
907
908 req.headers_mut()
909 .insert(CACHE_CONTROL, hval!("max-age=foo"));
910 assert_eq!(infer_ttl(&req), None);
911
912 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age="));
913 assert_eq!(infer_ttl(&req), None);
914
915 req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=-1"));
916 assert_eq!(infer_ttl(&req), None);
917
918 req.headers_mut().insert(CACHE_CONTROL, hval!(""));
920 assert_eq!(infer_ttl(&req), None);
921
922 req.headers_mut()
924 .insert(CACHE_CONTROL, hval!(", =foobar, "));
925 assert_eq!(infer_ttl(&req), None);
926 }
927
928 #[test]
929 fn test_cache_creation_errors() {
930 let cache = CacheBuilder::new(KeyExtractorTest)
931 .cache_size(1)
932 .max_item_size(2)
933 .build();
934 assert!(cache.is_err());
935
936 let cache = CacheBuilder::new(KeyExtractorTest)
937 .ttl(Duration::from_secs(2))
938 .max_ttl(Duration::from_secs(1))
939 .build();
940 assert!(cache.is_err());
941 }
942
943 #[tokio::test]
944 async fn test_cache_bypass() {
945 let cache = Arc::new(
946 CacheBuilder::new(KeyExtractorTest)
947 .max_item_size(MAX_ITEM_SIZE)
948 .build()
949 .unwrap(),
950 );
951
952 let mut app = Router::new()
953 .route("/", post(handler))
954 .route("/", get(handler))
955 .route("/too_big", get(handler_too_big))
956 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
957
958 let req = Request::post("/").body(Body::from("")).unwrap();
960 let result = app.call(req).await.unwrap();
961 assert_eq!(result.status(), StatusCode::OK);
962 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
963 assert_eq!(cache.len(), 0);
964 assert_eq!(
965 cache_status,
966 CacheStatus::Bypass(CacheBypassReason::MethodNotCacheable)
967 );
968
969 let req = Request::get("/non_existing_path")
971 .body(Body::from("foobar"))
972 .unwrap();
973 let result = app.call(req).await.unwrap();
974 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
975 assert_eq!(result.status(), StatusCode::NOT_FOUND);
976 assert_eq!(
977 cache_status,
978 CacheStatus::Bypass(CacheBypassReason::HTTPError)
979 );
980 assert_eq!(cache.len(), 0);
981
982 let req = Request::get("/too_big").body(Body::from("foobar")).unwrap();
984 let result = app.call(req).await.unwrap();
985 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
986 assert_eq!(
987 cache_status,
988 CacheStatus::Bypass(CacheBypassReason::BodyTooBig)
989 );
990 assert_eq!(result.status(), StatusCode::OK);
991 assert_eq!(cache.len(), 0);
992 }
993
994 #[tokio::test]
995 async fn test_cache_hit() {
996 let ttl = Duration::from_millis(1500);
997
998 let cache = Arc::new(
999 CacheBuilder::new(KeyExtractorTest)
1000 .cache_size(MAX_CACHE_SIZE)
1001 .max_item_size(MAX_ITEM_SIZE)
1002 .ttl(ttl)
1003 .build()
1004 .unwrap(),
1005 );
1006
1007 let mut app = Router::new()
1008 .route("/{key}", get(handler))
1009 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1010
1011 let req = Request::get("/1").body(Body::from("")).unwrap();
1013 let result = app.call(req).await.unwrap();
1014 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1015 assert_eq!(result.status(), StatusCode::OK);
1016 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1017 cache.housekeep();
1018 assert_eq!(cache.len(), 1);
1019
1020 let req = Request::get("/2").body(Body::from("")).unwrap();
1022 let result = app.call(req).await.unwrap();
1023 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1024 assert_eq!(result.status(), StatusCode::OK);
1025 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1026 cache.housekeep();
1027 assert_eq!(cache.len(), 2);
1028
1029 let req = Request::get("/1").body(Body::from("")).unwrap();
1031 let result = app.call(req).await.unwrap();
1032 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1033 assert_eq!(result.status(), StatusCode::OK);
1034 assert!(matches!(cache_status, CacheStatus::Hit(_)));
1035 let (_, body) = result.into_parts();
1036 let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1037 let body = String::from_utf8_lossy(&body);
1038 assert_eq!("test_body", body);
1039 cache.housekeep();
1040 assert_eq!(cache.len(), 2);
1041
1042 let req = Request::get("/2").body(Body::from("")).unwrap();
1044 let result = app.call(req).await.unwrap();
1045 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1046 assert_eq!(result.status(), StatusCode::OK);
1047 assert!(matches!(cache_status, CacheStatus::Hit(_)));
1048 let (_, body) = result.into_parts();
1049 let body = to_bytes(body, usize::MAX).await.unwrap().to_vec();
1050 let body = String::from_utf8_lossy(&body);
1051 assert_eq!("test_body", body);
1052 cache.housekeep();
1053 assert_eq!(cache.len(), 2);
1054
1055 sleep(ttl + Duration::from_millis(300)).await;
1057 cache.housekeep();
1058 let req = Request::get("/1").body(Body::from("")).unwrap();
1059 let result = app.call(req).await.unwrap();
1060 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1061 assert_eq!(result.status(), StatusCode::OK);
1062 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1063
1064 cache.clear();
1066 let req_count = 50;
1067 for idx in 0..req_count {
1069 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1070 assert!(matches!(status, Some(CacheStatus::Miss(_))));
1071 }
1072 for idx in 0..req_count {
1074 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1075 assert!(matches!(status, Some(CacheStatus::Hit(_))));
1076 }
1077
1078 cache.clear();
1080 let req_count = 500;
1081 for idx in 0..req_count {
1083 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1084 assert!(matches!(status, Some(CacheStatus::Miss(_))));
1085 }
1086
1087 let mut count_misses = 0;
1089 let mut count_hits = 0;
1090 for idx in 0..req_count {
1091 let status = dispatch_get_request(&mut app, format!("/{idx}")).await;
1092 if matches!(status, Some(CacheStatus::Miss(_))) {
1093 count_misses += 1;
1094 } else if matches!(status, Some(CacheStatus::Hit(_))) {
1095 count_hits += 1;
1096 }
1097 }
1098 assert!(count_misses > 0);
1099 assert!(count_hits > 0);
1100 cache.housekeep();
1101 let entry_size = cache.size() / cache.len();
1102
1103 assert!(MAX_CACHE_SIZE > cache.size());
1106 assert!(MAX_CACHE_SIZE < cache.size() + entry_size);
1107 }
1108
1109 #[tokio::test]
1110 async fn test_cache_control() {
1111 let cache = Arc::new(
1112 CacheBuilder::new(KeyExtractorTest)
1113 .obey_cache_control(true)
1114 .build()
1115 .unwrap(),
1116 );
1117
1118 let mut app = Router::new()
1119 .route("/", get(handler))
1120 .route(
1121 "/cache_control_no_store",
1122 get(handler_cache_control_no_store),
1123 )
1124 .route(
1125 "/cache_control_no_cache",
1126 get(handler_cache_control_no_cache),
1127 )
1128 .route(
1129 "/cache_control_max_age_1d",
1130 get(handler_cache_control_max_age_1d),
1131 )
1132 .route(
1133 "/cache_control_max_age_7d",
1134 get(handler_cache_control_max_age_7d),
1135 )
1136 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1137
1138 let req = Request::get("/cache_control_no_cache")
1140 .body(Body::from("foobar"))
1141 .unwrap();
1142 let result = app.call(req).await.unwrap();
1143 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1144 assert_eq!(
1145 cache_status,
1146 CacheStatus::Bypass(CacheBypassReason::CacheControl)
1147 );
1148 assert_eq!(result.status(), StatusCode::OK);
1149 cache.housekeep();
1150 assert_eq!(cache.len(), 0);
1151
1152 let req = Request::get("/cache_control_no_store")
1154 .body(Body::from("foobar"))
1155 .unwrap();
1156 let result = app.call(req).await.unwrap();
1157 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1158 assert_eq!(
1159 cache_status,
1160 CacheStatus::Bypass(CacheBypassReason::CacheControl)
1161 );
1162 assert_eq!(result.status(), StatusCode::OK);
1163 cache.housekeep();
1164 assert_eq!(cache.len(), 0);
1165
1166 let req = Request::get("/cache_control_max_age_1d")
1168 .body(Body::from("foobar"))
1169 .unwrap();
1170 let result = app.call(req).await.unwrap();
1171 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1172 let ttl = result
1173 .headers()
1174 .get(X_CACHE_TTL)
1175 .unwrap()
1176 .to_str()
1177 .unwrap()
1178 .parse::<u64>()
1179 .unwrap();
1180 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1181 assert_eq!(ttl, 86400);
1182 assert_eq!(result.status(), StatusCode::OK);
1183 cache.housekeep();
1184 assert_eq!(cache.len(), 1);
1185
1186 let req = Request::get("/cache_control_max_age_7d")
1188 .body(Body::from("foobar"))
1189 .unwrap();
1190 let result = app.call(req).await.unwrap();
1191 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1192 let ttl = result
1193 .headers()
1194 .get(X_CACHE_TTL)
1195 .unwrap()
1196 .to_str()
1197 .unwrap()
1198 .parse::<u64>()
1199 .unwrap();
1200 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1201 assert_eq!(ttl, 86400);
1202 assert_eq!(result.status(), StatusCode::OK);
1203 cache.housekeep();
1204 assert_eq!(cache.len(), 2);
1205
1206 let req = Request::get("/").body(Body::from("foobar")).unwrap();
1208 let result = app.call(req).await.unwrap();
1209 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1210 let ttl = result
1211 .headers()
1212 .get(X_CACHE_TTL)
1213 .unwrap()
1214 .to_str()
1215 .unwrap()
1216 .parse::<u64>()
1217 .unwrap();
1218 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1219 assert_eq!(ttl, 10);
1220 assert_eq!(result.status(), StatusCode::OK);
1221 cache.housekeep();
1222 assert_eq!(cache.len(), 3);
1223
1224 let cache = Arc::new(
1226 CacheBuilder::new(KeyExtractorTest)
1227 .obey_cache_control(false)
1228 .build()
1229 .unwrap(),
1230 );
1231
1232 let mut app = Router::new()
1233 .route("/", get(handler))
1234 .route(
1235 "/cache_control_no_store",
1236 get(handler_cache_control_no_store),
1237 )
1238 .route(
1239 "/cache_control_no_cache",
1240 get(handler_cache_control_no_cache),
1241 )
1242 .route(
1243 "/cache_control_max_age_1d",
1244 get(handler_cache_control_max_age_1d),
1245 )
1246 .route(
1247 "/cache_control_max_age_7d",
1248 get(handler_cache_control_max_age_7d),
1249 )
1250 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1251
1252 let req = Request::get("/cache_control_no_cache")
1254 .body(Body::from("foobar"))
1255 .unwrap();
1256 let result = app.call(req).await.unwrap();
1257 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1258 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1259 assert_eq!(result.status(), StatusCode::OK);
1260 cache.housekeep();
1261 assert_eq!(cache.len(), 1);
1262
1263 let req = Request::get("/cache_control_no_store")
1265 .body(Body::from("foobar"))
1266 .unwrap();
1267 let result = app.call(req).await.unwrap();
1268 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1269 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1270 assert_eq!(result.status(), StatusCode::OK);
1271 cache.housekeep();
1272 assert_eq!(cache.len(), 2);
1273
1274 let req = Request::get("/cache_control_max_age_1d")
1276 .body(Body::from("foobar"))
1277 .unwrap();
1278 let result = app.call(req).await.unwrap();
1279 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1280 let ttl = result
1281 .headers()
1282 .get(X_CACHE_TTL)
1283 .unwrap()
1284 .to_str()
1285 .unwrap()
1286 .parse::<u64>()
1287 .unwrap();
1288 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1289 assert_eq!(ttl, 10);
1290 assert_eq!(result.status(), StatusCode::OK);
1291 cache.housekeep();
1292 assert_eq!(cache.len(), 3);
1293
1294 let req = Request::get("/").body(Body::from("foobar")).unwrap();
1296 let result = app.call(req).await.unwrap();
1297 let cache_status = result.extensions().get::<CacheStatus>().cloned().unwrap();
1298 let ttl = result
1299 .headers()
1300 .get(X_CACHE_TTL)
1301 .unwrap()
1302 .to_str()
1303 .unwrap()
1304 .parse::<u64>()
1305 .unwrap();
1306 assert!(matches!(cache_status, CacheStatus::Miss(_)));
1307 assert_eq!(ttl, 10);
1308 assert_eq!(result.status(), StatusCode::OK);
1309 cache.housekeep();
1310 assert_eq!(cache.len(), 4);
1311 }
1312
1313 #[tokio::test]
1314 async fn test_proxy_cache_lock() {
1315 let cache = Arc::new(
1316 CacheBuilder::new(KeyExtractorTest)
1317 .lock_timeout(PROXY_LOCK_TIMEOUT)
1318 .build()
1319 .unwrap(),
1320 );
1321
1322 let app = Router::new()
1323 .route("/{key}", get(handler_proxy_cache_lock))
1324 .layer(from_fn_with_state(Arc::clone(&cache), middleware));
1325
1326 let req_count = 50;
1327 let expected_misses = [1, req_count];
1329 let expected_hits = [req_count - 1, 0];
1330 for (idx, uri) in ["/fast_response", "/slow_response"].iter().enumerate() {
1331 let mut tasks = vec![];
1332 for _ in 0..req_count {
1334 let app = app.clone();
1335 tasks.push(tokio::spawn(async move {
1336 let req = Request::get(*uri).body(Body::from("")).unwrap();
1337 let result = app.oneshot(req).await.unwrap();
1338 assert_eq!(result.status(), StatusCode::OK);
1339 result.extensions().get::<CacheStatus>().cloned()
1340 }));
1341 }
1342 let mut count_hits = 0;
1343 let mut count_misses = 0;
1344 for task in tasks {
1345 task.await
1346 .map(|res| match res {
1347 Some(CacheStatus::Hit(_)) => count_hits += 1,
1348 Some(CacheStatus::Miss(_)) => count_misses += 1,
1349 _ => panic!("Unexpected cache status"),
1350 })
1351 .expect("failed to complete task");
1352 }
1353 assert_eq!(count_hits, expected_hits[idx]);
1354 assert_eq!(count_misses, expected_misses[idx]);
1355 cache.housekeep();
1356 cache.clear();
1357 }
1358 }
1359
1360 #[test]
1361 fn test_xfetch() {
1362 let now = Instant::now();
1363 let reqs = 10000;
1364
1365 let entry = Entry {
1366 response: Response::builder().body(Bytes::new()).unwrap(),
1367 delta: 0.5,
1368 expires: now + Duration::from_secs(60),
1369 };
1370
1371 let now2 = now + Duration::from_secs(58);
1373 let mut refresh = 0;
1374 for _ in 0..reqs {
1375 if entry.need_to_refresh(now2, 1.5) {
1376 refresh += 1;
1377 }
1378 }
1379
1380 assert!(refresh > 550 && refresh < 800);
1381
1382 let now2 = now + Duration::from_secs(30);
1384 let mut refresh = 0;
1385 for _ in 0..reqs {
1386 if entry.need_to_refresh(now2, 1.0) {
1387 refresh += 1;
1388 }
1389 }
1390
1391 assert_eq!(refresh, 0);
1392
1393 let now2 = now + Duration::from_secs(30);
1395 let mut refresh = 0;
1396 for _ in 0..reqs {
1397 if entry.need_to_refresh(now2, 10.0) {
1398 refresh += 1;
1399 }
1400 }
1401
1402 assert!(refresh > 9);
1403 }
1404}