1use crate::{
29 CacheKey, CacheOptions, CachePolicy, CacheStorage, PutHandle, StoredEntry,
30 tee::TeeingReader,
31 validation::{AfterResponse, BeforeRequest},
32};
33use futures_lite::{AsyncReadExt, AsyncWriteExt};
34use std::{sync::Arc, time::SystemTime};
35use trillium_client::{
36 Body, Client, ClientHandler, Conn, ConnExt, Headers, KnownHeaderName, Method, ResponseBody,
37 Result, Url,
38};
39
40const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
41
42#[derive(Debug)]
48pub struct Cache<S: CacheStorage> {
49 storage: Arc<S>,
50 options: CacheOptions,
51 max_cacheable_size: u64,
52}
53
54impl<S: CacheStorage> Clone for Cache<S> {
55 fn clone(&self) -> Self {
56 Self {
57 storage: Arc::clone(&self.storage),
58 options: self.options,
59 max_cacheable_size: self.max_cacheable_size,
60 }
61 }
62}
63
64impl<S: CacheStorage> Cache<S> {
65 pub fn new(storage: S) -> Self {
68 Self {
69 storage: Arc::new(storage),
70 options: CacheOptions::default(),
71 max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
72 }
73 }
74
75 pub fn with_options(mut self, options: CacheOptions) -> Self {
77 self.options = options;
78 self
79 }
80
81 pub fn shared(mut self) -> Self {
84 self.options.shared = true;
85 self
86 }
87
88 pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
93 self.max_cacheable_size = max;
94 self
95 }
96
97 pub fn storage(&self) -> &S {
99 &self.storage
100 }
101}
102
103enum CacheCtx<E: StoredEntry> {
106 Hit,
109 Revalidation { stored: E, key: CacheKey },
113 Miss { key: CacheKey },
116 Unsafe { url: Url },
119}
120
121impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Self::Hit => f.write_str("Hit"),
125 Self::Revalidation { key, .. } => f
126 .debug_struct("Revalidation")
127 .field("key", key)
128 .finish_non_exhaustive(),
129 Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
130 Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
131 }
132 }
133}
134
135impl<S: CacheStorage> ClientHandler for Cache<S> {
136 async fn run(&self, conn: &mut Conn) -> Result<()> {
137 let method = conn.method();
138 let key = CacheKey::new(method, conn.url().clone());
139 log::trace!("cache: run {method} {}", conn.url());
140
141 if !method.is_safe() {
144 log::trace!("cache: unsafe method {method}, bypassing cache read");
145 conn.insert_state(CacheCtx::<S::StoredEntry>::Unsafe {
146 url: conn.url().clone(),
147 });
148 return Ok(());
149 }
150
151 let now = SystemTime::now();
152 let entries = self.storage.get(&key).await;
153 log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
154
155 for entry in entries {
156 match entry.policy().before_request(conn.request_headers(), now) {
157 BeforeRequest::Fresh(cached) => {
158 log::trace!("cache: hit (fresh) for {key}, serving cached response");
159 *conn.response_headers_mut() = cached.headers;
160 let body = match entry.open().await {
161 Ok(b) => b,
162 Err(e) => {
163 log::warn!(
164 "cache: open for hit failed for {key}: {e}, passing through"
165 );
166 return Ok(());
168 }
169 };
170 conn.set_status(cached.status)
171 .set_response_body(body)
172 .halt()
173 .insert_state(CacheCtx::<S::StoredEntry>::Hit);
174 return Ok(());
175 }
176
177 BeforeRequest::NotModified(cached) => {
178 log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
179 *conn.response_headers_mut() = cached.headers;
180 conn.set_status(cached.status)
181 .set_response_body(b"" as &[u8])
182 .halt()
183 .insert_state(CacheCtx::<S::StoredEntry>::Hit);
184 return Ok(());
185 }
186
187 BeforeRequest::Stale {
188 request_headers,
189 matches: true,
190 } => {
191 if entry.policy().is_swr_eligible(now) {
195 log::trace!(
196 "cache: stale-while-revalidate for {key}, serving stale + spawning \
197 background revalidation"
198 );
199 let entry_for_bg = entry.clone();
200 self.spawn_background_revalidation(
201 conn,
202 entry_for_bg,
203 key.clone(),
204 request_headers,
205 );
206 match self.serve_stale(conn, entry, now).await {
207 Ok(()) => {
208 conn.halt();
209 conn.insert_state(CacheCtx::<S::StoredEntry>::Hit);
210 }
211 Err(e) => {
212 log::warn!(
213 "cache: open for stale serve failed for {key}: {e}, passing \
214 through"
215 );
216 }
217 }
218 return Ok(());
219 }
220 log::trace!("cache: stale for {key}, sending conditional revalidation request");
222 *conn.request_headers_mut() = request_headers;
223 conn.insert_state(CacheCtx::Revalidation { stored: entry, key });
224 return Ok(());
225 }
226
227 BeforeRequest::Stale { matches: false, .. } => {
228 log::trace!("cache: candidate vary-mismatch for {key}, trying next");
229 continue;
230 }
231 }
232 }
233
234 log::trace!("cache: miss for {key}, forwarding to origin");
235 conn.insert_state(CacheCtx::<S::StoredEntry>::Miss { key });
236 Ok(())
237 }
238
239 async fn after_response(&self, conn: &mut Conn) -> Result<()> {
240 let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
241 log::trace!("cache: after_response with no CacheCtx, nothing to do");
242 return Ok(());
243 };
244
245 if let CacheCtx::Revalidation { ref stored, .. } = ctx {
248 let now = SystemTime::now();
249 let origin_failed =
250 conn.error().is_some() || conn.status().is_some_and(|s| s.is_server_error());
251 if origin_failed && stored.policy().is_sie_eligible(now) {
252 log::trace!(
253 "cache: stale-if-error recovery for {} (origin error/{:?}), serving stale",
254 conn.url(),
255 conn.status()
256 );
257 if let Err(e) = self.serve_stale(conn, stored.clone(), now).await {
258 log::warn!(
259 "cache: open for stale serve failed for {}: {e}, propagating error",
260 conn.url()
261 );
262 return Ok(());
263 }
264 conn.take_error();
265 return Ok(());
266 }
267 }
268
269 if conn.status().is_none() {
270 log::trace!(
271 "cache: transport error with no SIE recovery for {}, propagating",
272 conn.url()
273 );
274 return Ok(());
275 }
276
277 match ctx {
278 CacheCtx::Hit => {
279 log::trace!("cache: hit confirmed in after_response for {}", conn.url());
280 Ok(())
281 }
282 CacheCtx::Revalidation { stored, key } => {
283 self.handle_revalidation(conn, stored, key).await
284 }
285 CacheCtx::Miss { key } => self.handle_miss(conn, key).await,
286 CacheCtx::Unsafe { url } => {
287 let status = conn.status().expect("checked above");
288 if status.is_success() || status.is_redirection() {
289 log::trace!(
290 "cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
291 conn.method(),
292 status
293 );
294 self.invalidate_url(&url).await;
295
296 for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
297 let Some(value) = conn.response_headers().get_str(header) else {
298 continue;
299 };
300 let Ok(target) = url.join(value) else {
301 log::trace!(
302 "cache: unsafe method secondary invalidation: {header} value \
303 {value:?} did not resolve, skipping"
304 );
305 continue;
306 };
307 if target.host_str() != url.host_str() {
308 log::trace!(
309 "cache: unsafe method secondary invalidation: {header} target \
310 {target} differs in host from request URL, skipping (§4.4 DoS \
311 guard)"
312 );
313 continue;
314 }
315 log::trace!(
316 "cache: unsafe method secondary invalidation via {header}: {target}"
317 );
318 self.invalidate_url(&target).await;
319 }
320 } else {
321 log::trace!(
322 "cache: unsafe method {} → {} for {url}, no invalidation",
323 conn.method(),
324 status
325 );
326 }
327 Ok(())
328 }
329 }
330 }
331}
332
333impl<S: CacheStorage> Cache<S> {
334 async fn invalidate_url(&self, url: &Url) {
337 self.storage
338 .invalidate(&CacheKey::new(Method::Get, url.clone()))
339 .await;
340 self.storage
341 .invalidate(&CacheKey::new(Method::Head, url.clone()))
342 .await;
343 }
344
345 async fn serve_stale(
349 &self,
350 conn: &mut Conn,
351 stored: S::StoredEntry,
352 now: SystemTime,
353 ) -> std::io::Result<()> {
354 let cached = stored.policy().cached_response(now);
355 let body = stored.open().await?;
356 conn.set_status(cached.status);
357 *conn.response_headers_mut() = cached.headers;
358 conn.set_response_body(body);
359 Ok(())
360 }
361
362 fn spawn_background_revalidation(
370 &self,
371 conn: &Conn,
372 stored: S::StoredEntry,
373 key: CacheKey,
374 request_headers: Headers,
375 ) {
376 let runtime = conn.client().connector().runtime();
377 let bypass_client = conn.client().clone().with_handler(());
378 let cache = self.clone();
379 let method = conn.method();
380 let url = conn.url().clone();
381 log::trace!("cache: spawning background revalidation for {key}");
382
383 let _detached = runtime.spawn(async move {
384 cache
385 .background_revalidation(bypass_client, method, url, request_headers, stored, key)
386 .await;
387 });
388 }
389
390 async fn background_revalidation(
391 self,
392 client: Client,
393 method: Method,
394 url: Url,
395 request_headers: Headers,
396 mut stored: S::StoredEntry,
397 key: CacheKey,
398 ) {
399 let mut new_conn = client.build_conn(method, url);
400 *new_conn.request_headers_mut() = request_headers;
401
402 if let Err(e) = (&mut new_conn).await {
403 log::trace!(
404 "cache: background revalidation transport error for {key} ({e}), leaving stored \
405 entry"
406 );
407 return;
408 }
409
410 let now = SystemTime::now();
411 let new_status = new_conn
412 .status()
413 .expect("background revalidation: response not yet received");
414 match stored.policy().after_response(
415 new_conn.request_headers(),
416 new_status,
417 new_conn.response_headers(),
418 now,
419 ) {
420 AfterResponse::NotModified(new_policy, _) => {
421 log::trace!("cache: background revalidation 304 for {key}, refreshing entry");
422 if let Err(e) = stored.refresh_policy(new_policy).await {
423 log::warn!("cache: background refresh_policy failed for {key}: {e}");
424 }
425 }
426 AfterResponse::Modified => {
427 let new_request_method = new_conn.method();
428 let new_request_headers = new_conn.request_headers().clone();
429 let new_response_headers = new_conn.response_headers().clone();
430 if !CachePolicy::is_storable(
431 new_request_method,
432 &new_request_headers,
433 new_status,
434 &new_response_headers,
435 &self.options,
436 ) {
437 log::trace!(
438 "cache: background revalidation 200 for {key}, response not storable, \
439 dropping"
440 );
441 return;
442 }
443 let new_policy = CachePolicy::new(
444 new_request_method,
445 &new_request_headers,
446 new_status,
447 new_response_headers,
448 now,
449 self.options,
450 );
451 let put_handle = match self.storage.put(key.clone(), new_policy).await {
452 Ok(h) => h,
453 Err(e) => {
454 log::warn!(
455 "cache: background put({key}) failed: {e}, leaving stored entry"
456 );
457 return;
458 }
459 };
460 let Some(body) = new_conn.take_response_body() else {
461 log::trace!(
462 "cache: background revalidation 200 for {key}, no body, leaving stored \
463 entry"
464 );
465 return;
466 };
467 if let Err(e) = copy_into_storage(body, put_handle, self.max_cacheable_size).await {
468 log::warn!(
469 "cache: background copy into storage failed for {key}: {e}, leaving \
470 stored entry"
471 );
472 }
473 }
474 }
475 }
476
477 async fn handle_revalidation(
478 &self,
479 conn: &mut Conn,
480 mut stored: S::StoredEntry,
481 key: CacheKey,
482 ) -> Result<()> {
483 let now = SystemTime::now();
484 let new_status = conn.status().expect("checked above");
485 match stored.policy().after_response(
486 conn.request_headers(),
487 new_status,
488 conn.response_headers(),
489 now,
490 ) {
491 AfterResponse::NotModified(new_policy, cached_response) => {
492 log::trace!(
493 "cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
494 );
495 if let Err(e) = stored.refresh_policy(new_policy).await {
496 log::warn!("cache: refresh_policy failed for {key}: {e}");
497 }
498 let body = match stored.open().await {
499 Ok(b) => b,
500 Err(e) => {
501 log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
502 return Ok(());
503 }
504 };
505 conn.set_status(cached_response.status);
506 *conn.response_headers_mut() = cached_response.headers;
507 conn.set_response_body(body);
508 Ok(())
509 }
510 AfterResponse::Modified => {
511 drop(stored);
514 self.handle_miss(conn, key).await
515 }
516 }
517 }
518
519 async fn handle_miss(&self, conn: &mut Conn, key: CacheKey) -> Result<()> {
520 let status = conn.status().expect("checked above");
521 if !CachePolicy::is_storable(
522 conn.method(),
523 conn.request_headers(),
524 status,
525 conn.response_headers(),
526 &self.options,
527 ) {
528 log::trace!("cache: miss for {key}, response not storable, passing through");
529 return Ok(());
530 }
531
532 if let Some(len) = conn
534 .response_headers()
535 .get_str(KnownHeaderName::ContentLength)
536 .and_then(|s| s.parse::<u64>().ok())
537 && len > self.max_cacheable_size
538 {
539 log::trace!(
540 "cache: miss for {key}, body {len} > max {}, not caching",
541 self.max_cacheable_size
542 );
543 return Ok(());
544 }
545
546 let policy = CachePolicy::new(
547 conn.method(),
548 conn.request_headers(),
549 status,
550 conn.response_headers().clone(),
551 SystemTime::now(),
552 self.options,
553 );
554 let put_handle = match self.storage.put(key.clone(), policy).await {
555 Ok(h) => h,
556 Err(e) => {
557 log::warn!("cache: put({key}) failed: {e}, passing through");
558 return Ok(());
559 }
560 };
561
562 let Some(response_body) = conn.take_response_body() else {
563 log::trace!("cache: miss for {key}, no body, passing through");
564 return Ok(());
565 };
566 let len = response_body.content_length();
567 let upstream = Body::new_with_trailers(response_body, len);
568 log::trace!("cache: miss for {key}, streaming through tee");
569 let tee = TeeingReader::new(upstream, put_handle, self.max_cacheable_size);
570 conn.set_response_body(Body::new_with_trailers(tee, len));
571 Ok(())
572 }
573}
574
575async fn copy_into_storage<P: PutHandle>(
579 body: ResponseBody<'static>,
580 mut put: P,
581 cap: u64,
582) -> std::io::Result<()> {
583 let len = body.content_length();
584 let mut body = Body::new_with_trailers(body, len);
585 let mut buf = [0u8; 8192];
586 let mut total: u64 = 0;
587 loop {
588 let n = body.read(&mut buf).await?;
589 if n == 0 {
590 break;
591 }
592 total = total.saturating_add(n as u64);
593 if total > cap {
594 drop(put);
596 log::trace!("cache: background copy exceeded cap {cap}, aborting cache write");
597 return Ok(());
598 }
599 put.write_all(&buf[..n]).await?;
600 }
601 let trailers = body.trailers();
602 put.finalize(trailers).await
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use crate::InMemoryStorage;
609 use std::sync::{
610 Arc,
611 atomic::{AtomicUsize, Ordering},
612 };
613 use trillium::{Conn as ServerConn, Handler as ServerHandler, KnownHeaderName, Status};
614 use trillium_client::Client;
615 use trillium_testing::{ServerConnector, TestResult, harness, test};
616
617 #[derive(Debug, Clone)]
618 struct CountingServer {
619 counter: Arc<AtomicUsize>,
620 cache_control: &'static str,
621 etag: Option<&'static str>,
622 }
623
624 impl CountingServer {
625 fn new(cache_control: &'static str) -> Self {
626 Self {
627 counter: Arc::new(AtomicUsize::new(0)),
628 cache_control,
629 etag: None,
630 }
631 }
632
633 fn with_etag(mut self, etag: &'static str) -> Self {
634 self.etag = Some(etag);
635 self
636 }
637 }
638
639 impl ServerHandler for CountingServer {
640 async fn run(&self, conn: ServerConn) -> ServerConn {
641 let n = self.counter.fetch_add(1, Ordering::SeqCst);
642
643 if let Some(etag) = self.etag {
644 if conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag) {
645 return conn
646 .with_status(Status::NotModified)
647 .with_response_header(KnownHeaderName::Etag, etag)
648 .halt();
649 }
650 }
651
652 let mut conn = conn
653 .with_response_header(KnownHeaderName::CacheControl, self.cache_control)
654 .ok(format!("body-{n}"));
655 if let Some(etag) = self.etag {
656 conn.response_headers_mut()
657 .insert(KnownHeaderName::Etag, etag);
658 }
659 conn
660 }
661 }
662
663 fn cache_client(server: CountingServer) -> (Client, Arc<AtomicUsize>) {
664 let counter = server.counter.clone();
665 let client = Client::new(ServerConnector::new(server))
666 .with_handler(Cache::new(InMemoryStorage::new()));
667 (client, counter)
668 }
669
670 #[test(harness)]
671 async fn first_request_misses_subsequent_request_hits() -> TestResult {
672 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
673
674 let mut r1 = client.get("http://example.com/x").await?;
675 assert_eq!(r1.status(), Some(Status::Ok));
676 assert_eq!(r1.response_body().read_string().await?, "body-0");
677
678 let mut r2 = client.get("http://example.com/x").await?;
679 assert_eq!(r2.status(), Some(Status::Ok));
680 assert_eq!(r2.response_body().read_string().await?, "body-0");
681 assert_eq!(counter.load(Ordering::SeqCst), 1, "server only hit once");
682 Ok(())
683 }
684
685 #[test(harness)]
686 async fn different_urls_dont_collide() -> TestResult {
687 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
688
689 let mut r1 = client.get("http://example.com/a").await?;
690 let mut r2 = client.get("http://example.com/b").await?;
691 assert_eq!(r1.response_body().read_string().await?, "body-0");
692 assert_eq!(r2.response_body().read_string().await?, "body-1");
693 assert_eq!(counter.load(Ordering::SeqCst), 2);
694 Ok(())
695 }
696
697 #[test(harness)]
698 async fn no_store_response_is_not_cached() -> TestResult {
699 let (client, counter) = cache_client(CountingServer::new("no-store"));
700
701 let mut r1 = client.get("http://example.com/x").await?;
702 assert_eq!(r1.response_body().read_string().await?, "body-0");
703
704 let mut r2 = client.get("http://example.com/x").await?;
705 assert_eq!(r2.response_body().read_string().await?, "body-1");
706 assert_eq!(counter.load(Ordering::SeqCst), 2);
707 Ok(())
708 }
709
710 #[test(harness)]
711 async fn post_invalidates_existing_entry() -> TestResult {
712 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
713
714 let mut r1 = client.get("http://example.com/x").await?;
715 assert_eq!(r1.response_body().read_string().await?, "body-0");
716
717 let _ = client.post("http://example.com/x").await?;
718
719 let mut r3 = client.get("http://example.com/x").await?;
720 assert_eq!(r3.response_body().read_string().await?, "body-2");
721 assert_eq!(counter.load(Ordering::SeqCst), 3);
722 Ok(())
723 }
724
725 #[test(harness)]
726 async fn post_invalidates_location_and_content_location_targets() -> TestResult {
727 #[derive(Debug, Clone, Default)]
728 struct LclServer(Arc<AtomicUsize>);
729 impl ServerHandler for LclServer {
730 async fn run(&self, conn: ServerConn) -> ServerConn {
731 let n = self.0.fetch_add(1, Ordering::SeqCst);
732 if conn.method() == Method::Post {
733 conn.with_response_header(KnownHeaderName::Location, "/loc")
734 .with_response_header(KnownHeaderName::ContentLocation, "/cl")
735 .ok(format!("post-body-{n}"))
736 } else {
737 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
738 .ok(format!("get-body-{n}"))
739 }
740 }
741 }
742
743 let server = LclServer::default();
744 let counter = Arc::clone(&server.0);
745 let client = Client::new(ServerConnector::new(server))
746 .with_handler(Cache::new(InMemoryStorage::new()));
747
748 let mut loc = client.get("http://example.com/loc").await?;
750 let _ = loc.response_body().read_string().await?;
751 let mut cl = client.get("http://example.com/cl").await?;
752 let _ = cl.response_body().read_string().await?;
753 assert_eq!(counter.load(Ordering::SeqCst), 2);
754
755 let _ = client.post("http://example.com/anything").await?;
756
757 let _ = client.get("http://example.com/loc").await?;
758 let _ = client.get("http://example.com/cl").await?;
759 assert_eq!(
760 counter.load(Ordering::SeqCst),
761 5,
762 "POST + 2 re-fetches should hit the origin again"
763 );
764 Ok(())
765 }
766
767 #[test(harness)]
768 async fn cross_host_location_does_not_invalidate() -> TestResult {
769 #[derive(Debug, Clone, Default)]
770 struct CrossHostServer(Arc<AtomicUsize>);
771 impl ServerHandler for CrossHostServer {
772 async fn run(&self, conn: ServerConn) -> ServerConn {
773 let n = self.0.fetch_add(1, Ordering::SeqCst);
774 if conn.method() == Method::Post {
775 conn.with_response_header(KnownHeaderName::Location, "http://other.example/loc")
776 .ok(format!("post-{n}"))
777 } else {
778 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
779 .ok(format!("get-{n}"))
780 }
781 }
782 }
783
784 let server = CrossHostServer::default();
785 let counter = Arc::clone(&server.0);
786 let client = Client::new(ServerConnector::new(server))
787 .with_handler(Cache::new(InMemoryStorage::new()));
788
789 let mut populating = client.get("http://other.example/loc").await?;
792 let _ = populating.response_body().read_string().await?;
793 assert_eq!(counter.load(Ordering::SeqCst), 1);
794
795 let _ = client.post("http://example.com/anything").await?;
796
797 let mut r = client.get("http://other.example/loc").await?;
798 assert_eq!(r.response_body().read_string().await?, "get-0");
799 assert_eq!(
800 counter.load(Ordering::SeqCst),
801 2,
802 "no extra GET to other.example"
803 );
804 Ok(())
805 }
806
807 #[test(harness)]
808 async fn stale_with_etag_revalidates_to_304() -> TestResult {
809 let (client, counter) = cache_client(CountingServer::new("max-age=0").with_etag(r#""v1""#));
810
811 let mut r1 = client.get("http://example.com/x").await?;
812 assert_eq!(r1.response_body().read_string().await?, "body-0");
813 assert_eq!(counter.load(Ordering::SeqCst), 1);
814
815 let mut r2 = client.get("http://example.com/x").await?;
816 assert_eq!(r2.status(), Some(Status::Ok));
817 assert_eq!(r2.response_body().read_string().await?, "body-0");
818 assert_eq!(counter.load(Ordering::SeqCst), 2);
819 Ok(())
820 }
821
822 #[test(harness)]
823 async fn stale_with_mismatching_etag_replaces_body() -> TestResult {
824 #[derive(Debug, Clone)]
825 struct AlwaysFresh {
826 counter: Arc<AtomicUsize>,
827 }
828 impl ServerHandler for AlwaysFresh {
829 async fn run(&self, conn: ServerConn) -> ServerConn {
830 let n = self.counter.fetch_add(1, Ordering::SeqCst);
831 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=0")
832 .with_response_header(KnownHeaderName::Etag, r#""rolling""#)
833 .ok(format!("body-{n}"))
834 }
835 }
836 let counter = Arc::new(AtomicUsize::new(0));
837 let server = AlwaysFresh {
838 counter: counter.clone(),
839 };
840 let client = Client::new(ServerConnector::new(server))
841 .with_handler(Cache::new(InMemoryStorage::new()));
842
843 let mut r1 = client.get("http://example.com/x").await?;
844 assert_eq!(r1.response_body().read_string().await?, "body-0");
845
846 let mut r2 = client.get("http://example.com/x").await?;
847 assert_eq!(r2.response_body().read_string().await?, "body-1");
848 assert_eq!(counter.load(Ordering::SeqCst), 2);
849 Ok(())
850 }
851
852 #[test(harness)]
853 async fn vary_isolates_entries_by_request_header() -> TestResult {
854 #[derive(Debug, Clone)]
855 struct VaryServer {
856 counter: Arc<AtomicUsize>,
857 }
858 impl ServerHandler for VaryServer {
859 async fn run(&self, conn: ServerConn) -> ServerConn {
860 self.counter.fetch_add(1, Ordering::SeqCst);
861 let ae = conn
862 .request_headers()
863 .get_str(KnownHeaderName::AcceptEncoding)
864 .unwrap_or("none")
865 .to_string();
866 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
867 .with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
868 .ok(format!("body-for-{ae}"))
869 }
870 }
871 let counter = Arc::new(AtomicUsize::new(0));
872 let server = VaryServer {
873 counter: counter.clone(),
874 };
875 let client = Client::new(ServerConnector::new(server))
876 .with_handler(Cache::new(InMemoryStorage::new()));
877
878 let mut r1 = client
879 .get("http://example.com/x")
880 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
881 .await?;
882 assert_eq!(r1.response_body().read_string().await?, "body-for-gzip");
883
884 let mut r2 = client
885 .get("http://example.com/x")
886 .with_request_header(KnownHeaderName::AcceptEncoding, "br")
887 .await?;
888 assert_eq!(r2.response_body().read_string().await?, "body-for-br");
889
890 let mut r3 = client
891 .get("http://example.com/x")
892 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
893 .await?;
894 assert_eq!(r3.response_body().read_string().await?, "body-for-gzip");
895
896 assert_eq!(counter.load(Ordering::SeqCst), 2);
897 Ok(())
898 }
899
900 #[test(harness)]
901 async fn oversized_body_is_served_but_not_cached() -> TestResult {
902 let server = CountingServer::new("max-age=600");
903 let counter = server.counter.clone();
904 let client = Client::new(ServerConnector::new(server))
905 .with_handler(Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3));
906
907 let mut r1 = client.get("http://example.com/x").await?;
908 assert_eq!(r1.response_body().read_string().await?, "body-0");
909
910 let mut r2 = client.get("http://example.com/x").await?;
911 assert_eq!(r2.response_body().read_string().await?, "body-1");
912 assert_eq!(counter.load(Ordering::SeqCst), 2);
913 Ok(())
914 }
915
916 use crate::test_helpers::exchange;
919 use std::{io, net::SocketAddr};
920 use trillium_client::{Connector, Url};
921
922 #[derive(Debug)]
925 struct FailingConnector {
926 inner: ServerConnector<Status>,
927 }
928
929 impl FailingConnector {
930 fn new() -> Self {
931 Self {
932 inner: ServerConnector::new(Status::Ok),
933 }
934 }
935 }
936
937 impl Connector for FailingConnector {
938 type Runtime = <ServerConnector<Status> as Connector>::Runtime;
939 type Transport = <ServerConnector<Status> as Connector>::Transport;
940 type Udp = <ServerConnector<Status> as Connector>::Udp;
941
942 async fn connect(&self, _url: &Url) -> io::Result<Self::Transport> {
943 Err(io::Error::new(
944 io::ErrorKind::ConnectionRefused,
945 "test failure",
946 ))
947 }
948
949 fn runtime(&self) -> Self::Runtime {
950 self.inner.runtime().clone()
951 }
952
953 async fn resolve(&self, host: &str, port: u16) -> io::Result<Vec<SocketAddr>> {
954 self.inner.resolve(host, port).await
955 }
956 }
957
958 async fn populate_stale_entry(
961 storage: &InMemoryStorage,
962 cache_control: &'static str,
963 body: &'static [u8],
964 ) -> CacheKey {
965 let conn = exchange(
966 Method::Get,
967 &[],
968 Status::Ok,
969 &[(KnownHeaderName::CacheControl, cache_control)],
970 );
971 let policy =
972 crate::test_helpers::policy_from(&conn, SystemTime::now(), CacheOptions::default());
973 let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
974 let mut handle = storage.put(key.clone(), policy).await.unwrap();
975 use futures_lite::AsyncWriteExt;
976 handle.write_all(body).await.unwrap();
977 handle.finalize(None).await.unwrap();
978 key
979 }
980
981 #[test(harness)]
982 async fn sie_serves_stale_on_transport_error() -> TestResult {
983 let storage = InMemoryStorage::new();
984 let _ =
985 populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
986 let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
987
988 let mut conn = client.get("http://example.com/x").await?;
989 assert_eq!(conn.status(), Some(Status::Ok));
990 assert_eq!(conn.response_body().read_string().await?, "stale body");
991 Ok(())
992 }
993
994 #[test(harness)]
995 async fn no_sie_propagates_transport_error() -> TestResult {
996 let storage = InMemoryStorage::new();
997 let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
998 let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
999
1000 let result = client.get("http://example.com/x").await;
1001 assert!(
1002 result.is_err(),
1003 "expected transport error to propagate, got {result:?}"
1004 );
1005 Ok(())
1006 }
1007
1008 #[test(harness)]
1009 async fn sie_serves_stale_on_5xx() -> TestResult {
1010 let storage = InMemoryStorage::new();
1011 let _ =
1012 populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
1013 let server = ServerConnector::new(Status::ServiceUnavailable);
1014 let client = Client::new(server).with_handler(Cache::new(storage));
1015
1016 let mut conn = client.get("http://example.com/x").await?;
1017 assert_eq!(conn.status(), Some(Status::Ok));
1018 assert_eq!(conn.response_body().read_string().await?, "stale body");
1019 Ok(())
1020 }
1021
1022 #[test(harness)]
1023 async fn no_sie_serves_5xx_as_received() -> TestResult {
1024 let storage = InMemoryStorage::new();
1025 let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
1026 let server = ServerConnector::new(Status::ServiceUnavailable);
1027 let client = Client::new(server).with_handler(Cache::new(storage));
1028
1029 let conn = client.get("http://example.com/x").await?;
1030 assert_eq!(conn.status(), Some(Status::ServiceUnavailable));
1031 Ok(())
1032 }
1033
1034 use std::time::Duration;
1037
1038 #[test(harness)]
1039 async fn swr_serves_stale_immediately_and_revalidates_in_background() -> TestResult {
1040 let storage = InMemoryStorage::new();
1041 let _ = populate_stale_entry(
1042 &storage,
1043 "max-age=0, stale-while-revalidate=3600",
1044 b"stale-body",
1045 )
1046 .await;
1047
1048 let server = CountingServer::new("max-age=600");
1049 let counter = server.counter.clone();
1050 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1051
1052 let mut conn = client.get("http://example.com/x").await?;
1053 assert_eq!(conn.status(), Some(Status::Ok));
1054 assert_eq!(conn.response_body().read_string().await?, "stale-body");
1055
1056 let runtime = client.connector().runtime();
1057 for _ in 0..100 {
1058 if counter.load(Ordering::SeqCst) > 0 {
1059 break;
1060 }
1061 runtime.delay(Duration::from_millis(10)).await;
1062 }
1063 assert_eq!(
1064 counter.load(Ordering::SeqCst),
1065 1,
1066 "background revalidation should hit the origin"
1067 );
1068
1069 let cache = client
1070 .downcast_handler::<Cache<InMemoryStorage>>()
1071 .expect("cache handler installed");
1072 let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
1073 for _ in 0..100 {
1075 if !cache.storage().get(&key).await.is_empty() {
1076 break;
1077 }
1078 runtime.delay(Duration::from_millis(10)).await;
1079 }
1080 let entries = cache.storage().get(&key).await;
1081 assert_eq!(entries.len(), 1);
1082 let body = entries[0].clone().open().await.unwrap();
1083 use futures_lite::AsyncReadExt;
1084 let mut buf = Vec::new();
1085 let mut body = body;
1086 body.read_to_end(&mut buf).await.unwrap();
1087 assert_eq!(&buf, b"body-0");
1088 Ok(())
1089 }
1090
1091 #[test(harness)]
1092 async fn no_swr_falls_back_to_synchronous_revalidation() -> TestResult {
1093 let storage = InMemoryStorage::new();
1094 let _ = populate_stale_entry(&storage, "max-age=0", b"stale-body").await;
1095
1096 let server = CountingServer::new("max-age=600");
1097 let counter = server.counter.clone();
1098 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1099
1100 let mut conn = client.get("http://example.com/x").await?;
1101 assert_eq!(conn.response_body().read_string().await?, "body-0");
1102 assert_eq!(counter.load(Ordering::SeqCst), 1);
1103 Ok(())
1104 }
1105
1106 #[test(harness)]
1107 async fn must_revalidate_disables_swr() -> TestResult {
1108 let storage = InMemoryStorage::new();
1109 let _ = populate_stale_entry(
1110 &storage,
1111 "max-age=0, must-revalidate, stale-while-revalidate=3600",
1112 b"stale-body",
1113 )
1114 .await;
1115
1116 let server = CountingServer::new("max-age=600");
1117 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1118
1119 let mut conn = client.get("http://example.com/x").await?;
1120 assert_eq!(conn.response_body().read_string().await?, "body-0");
1121 Ok(())
1122 }
1123}