http_cache_tower_server/lib.rs
1//! Server-side HTTP response caching middleware for Tower.
2//!
3//! This crate provides Tower middleware for caching HTTP responses on the server side.
4//! Unlike client-side caching, this middleware caches your own application's responses
5//! to reduce load and improve performance.
6//!
7//! # Key Features
8//!
9//! - Response-first architecture: Caches based on response headers, not requests
10//! - Preserves request context: Maintains all request extensions (path params, state, etc.)
11//! - Handler-centric: Calls the handler first, then decides whether to cache
12//! - RFC 7234 compliant: Respects Cache-Control, Vary, and other standard headers
13//! - Reuses existing infrastructure: Leverages `CacheManager` trait from `http-cache`
14//!
15//! # Example
16//!
17//! ```rust
18//! use http::{Request, Response};
19//! use http_body_util::Full;
20//! use bytes::Bytes;
21//! use http_cache_tower_server::ServerCacheLayer;
22//! use tower::{Service, Layer};
23//! # use http_cache::{CacheManager, HttpResponse, HttpVersion};
24//! # use http_cache_semantics::CachePolicy;
25//! # use std::collections::HashMap;
26//! # use std::sync::{Arc, Mutex};
27//! #
28//! # #[derive(Clone)]
29//! # struct MemoryCacheManager {
30//! # store: Arc<Mutex<HashMap<String, (HttpResponse, CachePolicy)>>>,
31//! # }
32//! #
33//! # impl MemoryCacheManager {
34//! # fn new() -> Self {
35//! # Self { store: Arc::new(Mutex::new(HashMap::new())) }
36//! # }
37//! # }
38//! #
39//! # #[async_trait::async_trait]
40//! # impl CacheManager for MemoryCacheManager {
41//! # async fn get(&self, cache_key: &str) -> http_cache::Result<Option<(HttpResponse, CachePolicy)>> {
42//! # Ok(self.store.lock().unwrap().get(cache_key).cloned())
43//! # }
44//! # async fn put(&self, cache_key: String, res: HttpResponse, policy: CachePolicy) -> http_cache::Result<HttpResponse> {
45//! # self.store.lock().unwrap().insert(cache_key, (res.clone(), policy));
46//! # Ok(res)
47//! # }
48//! # async fn delete(&self, cache_key: &str) -> http_cache::Result<()> {
49//! # self.store.lock().unwrap().remove(cache_key);
50//! # Ok(())
51//! # }
52//! # }
53//!
54//! # tokio_test::block_on(async {
55//! let manager = MemoryCacheManager::new();
56//! let layer = ServerCacheLayer::new(manager);
57//!
58//! // Apply the layer to your Tower service
59//! let service = tower::service_fn(|_req: Request<Full<Bytes>>| async {
60//! Ok::<_, std::io::Error>(
61//! Response::builder()
62//! .header("cache-control", "max-age=60")
63//! .body(Full::new(Bytes::from("Hello, World!")))
64//! .unwrap()
65//! )
66//! });
67//!
68//! let mut cached_service = layer.layer(service);
69//! # });
70//! ```
71//!
72//! # Vary Header Support
73//!
74//! This cache enforces `Vary` headers using `http-cache-semantics`. When a response includes
75//! a `Vary` header, subsequent requests must have matching header values to receive the cached
76//! response. Requests with different header values will result in cache misses.
77//!
78//! For example, if a response has `Vary: Accept-Language`, a cached English response won't be
79//! served to a request with `Accept-Language: de`.
80//!
81//! # Security Warnings
82//!
83//! This is a **shared cache** - cached responses are served to ALL users. Improper configuration
84//! can leak user-specific data between different users.
85//!
86//! ## Authorization and Authentication
87//!
88//! This cache does not check for `Authorization` headers or session cookies in requests.
89//! Caching authenticated endpoints without proper cache key differentiation will cause
90//! user A's response to be served to user B.
91//!
92//! **Do NOT cache authenticated endpoints** unless you use a `CustomKeyer` that includes
93//! the user or session identifier in the cache key:
94//!
95//! ```rust
96//! # use http_cache_tower_server::CustomKeyer;
97//! # use http::Request;
98//! // Example: Include session ID in cache key
99//! let keyer = CustomKeyer::new(|req: &Request<()>| {
100//! let session = req.headers()
101//! .get("cookie")
102//! .and_then(|v| v.to_str().ok())
103//! .and_then(|c| extract_session_id(c))
104//! .unwrap_or("anonymous");
105//! format!("{} {} session:{}", req.method(), req.uri().path(), session)
106//! });
107//! # fn extract_session_id(cookie: &str) -> Option<&str> { None }
108//! ```
109//!
110//! ## General Security Considerations
111//!
112//! - Never cache responses containing user-specific data without user-specific cache keys
113//! - Validate cache keys to prevent cache poisoning attacks
114//! - Be careful with header-based caching due to header injection risks
115//! - Consider the `private` Cache-Control directive for user-specific responses (automatically rejected by this cache)
116
117#![warn(missing_docs)]
118#![deny(unsafe_code)]
119
120use bytes::Bytes;
121use http::{header::HeaderValue, Request, Response};
122use http_body::{Body as HttpBody, Frame};
123use http_body_util::BodyExt;
124use http_cache::{CacheManager, HttpResponse, HttpVersion};
125use http_cache_semantics::{BeforeRequest, CachePolicy};
126use serde::{Deserialize, Serialize};
127use std::collections::HashMap;
128use std::error::Error as StdError;
129use std::pin::Pin;
130use std::sync::atomic::{AtomicU64, Ordering};
131use std::sync::Arc;
132use std::task::{Context, Poll};
133use std::time::{Duration, SystemTime};
134use tower::{Layer, Service};
135
136type BoxError = Box<dyn StdError + Send + Sync>;
137
138/// Cache performance metrics.
139///
140/// Tracks hits, misses, and stores for monitoring cache effectiveness.
141#[derive(Debug, Default)]
142pub struct CacheMetrics {
143 /// Number of cache hits.
144 pub hits: AtomicU64,
145 /// Number of cache misses.
146 pub misses: AtomicU64,
147 /// Number of responses stored in cache.
148 pub stores: AtomicU64,
149 /// Number of responses skipped (too large, not cacheable, etc.).
150 pub skipped: AtomicU64,
151}
152
153impl CacheMetrics {
154 /// Create new metrics instance.
155 pub fn new() -> Self {
156 Self::default()
157 }
158
159 /// Calculate cache hit rate as a percentage (0.0 to 1.0).
160 pub fn hit_rate(&self) -> f64 {
161 let hits = self.hits.load(Ordering::Relaxed);
162 let total = hits + self.misses.load(Ordering::Relaxed);
163 if total == 0 {
164 0.0
165 } else {
166 hits as f64 / total as f64
167 }
168 }
169
170 /// Reset all metrics to zero.
171 pub fn reset(&self) {
172 self.hits.store(0, Ordering::Relaxed);
173 self.misses.store(0, Ordering::Relaxed);
174 self.stores.store(0, Ordering::Relaxed);
175 self.skipped.store(0, Ordering::Relaxed);
176 }
177}
178
179/// A trait for generating cache keys from HTTP requests.
180pub trait Keyer: Clone + Send + Sync + 'static {
181 /// Generate a cache key for the given request.
182 fn cache_key<B>(&self, req: &Request<B>) -> String;
183}
184
185/// Default keyer that uses HTTP method and path.
186///
187/// Generates keys in the format: `{METHOD} {path}`
188///
189/// # Example
190///
191/// ```
192/// # use http::Request;
193/// # use http_cache_tower_server::{Keyer, DefaultKeyer};
194/// let keyer = DefaultKeyer;
195/// let req = Request::get("/users/123").body(()).unwrap();
196/// let key = keyer.cache_key(&req);
197/// assert_eq!(key, "GET /users/123");
198/// ```
199#[derive(Debug, Clone, Copy, Default)]
200pub struct DefaultKeyer;
201
202impl Keyer for DefaultKeyer {
203 fn cache_key<B>(&self, req: &Request<B>) -> String {
204 format!("{} {}", req.method(), req.uri().path())
205 }
206}
207
208/// Keyer that includes query parameters in the cache key.
209///
210/// Generates keys in the format: `{METHOD} {path}?{query}`
211///
212/// # Example
213///
214/// ```
215/// # use http::Request;
216/// # use http_cache_tower_server::{Keyer, QueryKeyer};
217/// let keyer = QueryKeyer;
218/// let req = Request::get("/users?page=1").body(()).unwrap();
219/// let key = keyer.cache_key(&req);
220/// assert_eq!(key, "GET /users?page=1");
221/// ```
222#[derive(Debug, Clone, Copy, Default)]
223pub struct QueryKeyer;
224
225impl Keyer for QueryKeyer {
226 fn cache_key<B>(&self, req: &Request<B>) -> String {
227 format!("{} {}", req.method(), req.uri())
228 }
229}
230
231/// Custom keyer that uses a user-provided function.
232///
233/// Use this when the default method+path keying is insufficient, such as:
234/// - Content negotiation based on request headers (Accept-Language, Accept-Encoding)
235/// - User-specific or session-specific caching
236/// - Query parameter normalization
237///
238/// # Examples
239///
240/// Basic custom format:
241///
242/// ```
243/// # use http::Request;
244/// # use http_cache_tower_server::{Keyer, CustomKeyer};
245/// let keyer = CustomKeyer::new(|req: &Request<()>| {
246/// format!("custom-{}-{}", req.method(), req.uri().path())
247/// });
248/// let req = Request::get("/users").body(()).unwrap();
249/// let key = keyer.cache_key(&req);
250/// assert_eq!(key, "custom-GET-/users");
251/// ```
252///
253/// Content negotiation (Accept-Language):
254///
255/// ```
256/// # use http::Request;
257/// # use http_cache_tower_server::{Keyer, CustomKeyer};
258/// let keyer = CustomKeyer::new(|req: &Request<()>| {
259/// let lang = req.headers()
260/// .get("accept-language")
261/// .and_then(|v| v.to_str().ok())
262/// .and_then(|s| s.split(',').next())
263/// .unwrap_or("en");
264/// format!("{} {} lang:{}", req.method(), req.uri().path(), lang)
265/// });
266/// ```
267///
268/// User-specific caching (session-based):
269///
270/// ```
271/// # use http::Request;
272/// # use http_cache_tower_server::{Keyer, CustomKeyer};
273/// let keyer = CustomKeyer::new(|req: &Request<()>| {
274/// let user_id = req.headers()
275/// .get("x-user-id")
276/// .and_then(|v| v.to_str().ok())
277/// .unwrap_or("anonymous");
278/// format!("{} {} user:{}", req.method(), req.uri().path(), user_id)
279/// });
280/// ```
281///
282/// # Security Warning
283///
284/// When caching user-specific or session-specific data, ensure the user/session identifier
285/// is included in the cache key. Failure to do so will cause responses from one user to be
286/// served to other users.
287#[derive(Clone)]
288pub struct CustomKeyer<F> {
289 func: F,
290}
291
292impl<F> CustomKeyer<F> {
293 /// Create a new custom keyer with the given function.
294 pub fn new(func: F) -> Self {
295 Self { func }
296 }
297}
298
299impl<F> Keyer for CustomKeyer<F>
300where
301 F: Fn(&Request<()>) -> String + Clone + Send + Sync + 'static,
302{
303 fn cache_key<B>(&self, req: &Request<B>) -> String {
304 // Create a temporary request with the same parts but () body
305 let mut temp_req = Request::builder()
306 .method(req.method())
307 .uri(req.uri())
308 .version(req.version())
309 .body(())
310 .unwrap();
311
312 // Copy headers for content negotiation support
313 *temp_req.headers_mut() = req.headers().clone();
314
315 (self.func)(&temp_req)
316 }
317}
318
319/// Configuration options for server-side caching.
320#[derive(Debug, Clone)]
321pub struct ServerCacheOptions {
322 /// Default TTL when response has no Cache-Control header.
323 pub default_ttl: Option<Duration>,
324
325 /// Maximum TTL, even if response specifies longer.
326 pub max_ttl: Option<Duration>,
327
328 /// Minimum TTL, even if response specifies shorter.
329 pub min_ttl: Option<Duration>,
330
331 /// Whether to add X-Cache headers (HIT/MISS).
332 pub cache_status_headers: bool,
333
334 /// Maximum response body size to cache (in bytes).
335 pub max_body_size: usize,
336
337 /// Whether to cache responses without explicit Cache-Control.
338 pub cache_by_default: bool,
339
340 /// Whether to respect Vary header for content negotiation.
341 ///
342 /// When true (default), cached responses are only served if the request's
343 /// headers match those specified in the response's Vary header. This is
344 /// enforced via `http-cache-semantics`.
345 pub respect_vary: bool,
346
347 /// Whether to respect Authorization headers per RFC 9111 §3.5.
348 ///
349 /// When true (default), requests with `Authorization` headers are not cached
350 /// unless the response explicitly permits it via `public`, `s-maxage`, or
351 /// `must-revalidate` directives.
352 ///
353 /// This prevents accidental caching of authenticated responses that could
354 /// leak user-specific data to other users.
355 pub respect_authorization: bool,
356}
357
358impl Default for ServerCacheOptions {
359 fn default() -> Self {
360 Self {
361 default_ttl: Some(Duration::from_secs(60)),
362 max_ttl: Some(Duration::from_secs(3600)),
363 min_ttl: None,
364 cache_status_headers: true,
365 max_body_size: 128 * 1024 * 1024,
366 cache_by_default: false,
367 respect_vary: true,
368 respect_authorization: true,
369 }
370 }
371}
372
373/// A cached HTTP response with metadata.
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct CachedResponse {
376 /// Response status code.
377 pub status: u16,
378
379 /// Response headers.
380 pub headers: HashMap<String, String>,
381
382 /// Response body bytes.
383 pub body: Vec<u8>,
384
385 /// When this response was cached.
386 pub cached_at: SystemTime,
387
388 /// Time-to-live duration.
389 pub ttl: Duration,
390
391 /// Optional vary headers for content negotiation.
392 pub vary: Option<Vec<String>>,
393}
394
395impl CachedResponse {
396 /// Check if this cached response is stale.
397 pub fn is_stale(&self) -> bool {
398 SystemTime::now()
399 .duration_since(self.cached_at)
400 .unwrap_or(Duration::MAX)
401 > self.ttl
402 }
403
404 /// Convert to an HTTP response.
405 pub fn into_response(self) -> Response<Bytes> {
406 let mut builder = Response::builder().status(self.status);
407
408 for (key, value) in self.headers {
409 if let Ok(header_value) = HeaderValue::from_str(&value) {
410 builder = builder.header(key, header_value);
411 }
412 }
413
414 builder.body(Bytes::from(self.body)).unwrap()
415 }
416}
417
418/// Response body types.
419#[derive(Debug)]
420pub enum ResponseBody {
421 /// Cached response body.
422 Cached(Bytes),
423 /// Fresh response body.
424 Fresh(Bytes),
425 /// Uncacheable response body.
426 Uncacheable(Bytes),
427}
428
429impl HttpBody for ResponseBody {
430 type Data = Bytes;
431 type Error = BoxError;
432
433 fn poll_frame(
434 mut self: Pin<&mut Self>,
435 _cx: &mut Context<'_>,
436 ) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
437 let bytes = match &mut *self {
438 ResponseBody::Cached(b)
439 | ResponseBody::Fresh(b)
440 | ResponseBody::Uncacheable(b) => {
441 std::mem::replace(b, Bytes::new())
442 }
443 };
444
445 if bytes.is_empty() {
446 Poll::Ready(None)
447 } else {
448 Poll::Ready(Some(Ok(Frame::data(bytes))))
449 }
450 }
451
452 fn is_end_stream(&self) -> bool {
453 match self {
454 ResponseBody::Cached(b)
455 | ResponseBody::Fresh(b)
456 | ResponseBody::Uncacheable(b) => b.is_empty(),
457 }
458 }
459}
460
461/// Tower layer for server-side HTTP response caching.
462///
463/// This layer should be placed AFTER routing to ensure request
464/// extensions (like path parameters) are preserved.
465///
466/// # Shared Cache Behavior
467///
468/// This implements a **shared cache** as defined in RFC 9111. Responses cached by this layer
469/// are served to all users making requests with matching cache keys. The cache automatically
470/// rejects responses with the `private` directive, but does not inspect `Authorization` headers
471/// or session cookies.
472///
473/// For authenticated or user-specific endpoints, either:
474/// - Set `Cache-Control: private` in responses (prevents caching)
475/// - Use a `CustomKeyer` that includes user/session identifiers in the cache key
476#[derive(Clone)]
477pub struct ServerCacheLayer<M, K = DefaultKeyer>
478where
479 M: CacheManager,
480 K: Keyer,
481{
482 manager: M,
483 keyer: K,
484 options: ServerCacheOptions,
485 metrics: Arc<CacheMetrics>,
486}
487
488impl<M> ServerCacheLayer<M, DefaultKeyer>
489where
490 M: CacheManager,
491{
492 /// Create a new cache layer with default options.
493 pub fn new(manager: M) -> Self {
494 Self {
495 manager,
496 keyer: DefaultKeyer,
497 options: ServerCacheOptions::default(),
498 metrics: Arc::new(CacheMetrics::new()),
499 }
500 }
501}
502
503impl<M, K> ServerCacheLayer<M, K>
504where
505 M: CacheManager,
506 K: Keyer,
507{
508 /// Create a cache layer with a custom keyer.
509 pub fn with_keyer(manager: M, keyer: K) -> Self {
510 Self {
511 manager,
512 keyer,
513 options: ServerCacheOptions::default(),
514 metrics: Arc::new(CacheMetrics::new()),
515 }
516 }
517
518 /// Set custom options.
519 pub fn with_options(mut self, options: ServerCacheOptions) -> Self {
520 self.options = options;
521 self
522 }
523
524 /// Get a reference to the cache metrics.
525 pub fn metrics(&self) -> &Arc<CacheMetrics> {
526 &self.metrics
527 }
528
529 /// Invalidate a specific cache entry by its key.
530 pub async fn invalidate(&self, cache_key: &str) -> Result<(), BoxError> {
531 self.manager.delete(cache_key).await
532 }
533
534 /// Invalidate cache entry for a specific request.
535 ///
536 /// Uses the configured keyer to generate the cache key from the request.
537 pub async fn invalidate_request<B>(
538 &self,
539 req: &Request<B>,
540 ) -> Result<(), BoxError> {
541 let cache_key = self.keyer.cache_key(req);
542 self.invalidate(&cache_key).await
543 }
544}
545
546impl<S, M, K> Layer<S> for ServerCacheLayer<M, K>
547where
548 M: CacheManager + Clone,
549 K: Keyer,
550{
551 type Service = ServerCacheService<S, M, K>;
552
553 fn layer(&self, inner: S) -> Self::Service {
554 ServerCacheService {
555 inner,
556 manager: self.manager.clone(),
557 keyer: self.keyer.clone(),
558 options: self.options.clone(),
559 metrics: self.metrics.clone(),
560 }
561 }
562}
563
564/// Tower service that implements response caching.
565#[derive(Clone)]
566pub struct ServerCacheService<S, M, K>
567where
568 M: CacheManager,
569 K: Keyer,
570{
571 inner: S,
572 manager: M,
573 keyer: K,
574 options: ServerCacheOptions,
575 metrics: Arc<CacheMetrics>,
576}
577
578impl<S, ReqBody, ResBody, M, K> Service<Request<ReqBody>>
579 for ServerCacheService<S, M, K>
580where
581 S: Service<Request<ReqBody>, Response = Response<ResBody>>
582 + Clone
583 + Send
584 + 'static,
585 S::Error: Into<BoxError>,
586 S::Future: Send + 'static,
587 M: CacheManager + Clone,
588 K: Keyer,
589 ReqBody: Send + 'static,
590 ResBody: HttpBody + Send + 'static,
591 ResBody::Data: Send,
592 ResBody::Error: Into<BoxError>,
593{
594 type Response = Response<ResponseBody>;
595 type Error = BoxError;
596 type Future = Pin<
597 Box<
598 dyn std::future::Future<
599 Output = std::result::Result<Self::Response, Self::Error>,
600 > + Send,
601 >,
602 >;
603
604 fn poll_ready(
605 &mut self,
606 cx: &mut Context<'_>,
607 ) -> Poll<std::result::Result<(), Self::Error>> {
608 self.inner.poll_ready(cx).map_err(Into::into)
609 }
610
611 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
612 let manager = self.manager.clone();
613 let keyer = self.keyer.clone();
614 let options = self.options.clone();
615 let metrics = self.metrics.clone();
616 let mut inner = self.inner.clone();
617
618 Box::pin(async move {
619 // Store request parts for later use in should_cache
620 let (req_parts, req_body) = req.into_parts();
621
622 // Generate cache key from request parts
623 let temp_req = Request::from_parts(req_parts.clone(), ());
624 let cache_key = keyer.cache_key(&temp_req);
625
626 // Try to get from cache
627 if let Ok(Some((cached_resp, policy))) =
628 manager.get(&cache_key).await
629 {
630 // Deserialize cached response first
631 if let Ok(cached) =
632 serde_json::from_slice::<CachedResponse>(&cached_resp.body)
633 {
634 // Check freshness using both CachePolicy and our TTL tracking.
635 // CachePolicy handles Vary header matching.
636 // Our is_stale() handles the TTL we assigned (especially for cache_by_default).
637 let before_req =
638 policy.before_request(&req_parts, SystemTime::now());
639
640 // Determine if response had explicit freshness directives
641 // (max-age or s-maxage). If it only has "public" or other directives
642 // without explicit TTL, we use our own TTL tracking.
643 let has_explicit_ttl =
644 cached.headers.get("cache-control").is_some_and(|cc| {
645 cc.contains("max-age") || cc.contains("s-maxage")
646 });
647
648 let is_fresh = match before_req {
649 BeforeRequest::Fresh(_) => {
650 // CachePolicy says fresh - use it
651 true
652 }
653 BeforeRequest::Stale { .. } => {
654 // CachePolicy says stale. This could be due to:
655 // 1. Vary header mismatch
656 // 2. Time-based staleness per cache headers
657 // 3. No explicit TTL (cache_by_default or public-only)
658 //
659 // For case 3, our TTL tracking is authoritative.
660 // For cases 1-2, we should respect CachePolicy.
661 if has_explicit_ttl {
662 // Had explicit TTL - trust CachePolicy
663 false
664 } else {
665 // No explicit TTL - use our TTL
666 !cached.is_stale()
667 }
668 }
669 };
670
671 if is_fresh {
672 // Cache hit
673 metrics.hits.fetch_add(1, Ordering::Relaxed);
674 let mut response = cached.into_response();
675
676 if options.cache_status_headers {
677 response.headers_mut().insert(
678 "x-cache",
679 HeaderValue::from_static("HIT"),
680 );
681 }
682
683 return Ok(response.map(ResponseBody::Cached));
684 }
685 }
686 }
687
688 // Reconstruct request for handler
689 let req = Request::from_parts(req_parts.clone(), req_body);
690
691 // Cache miss or stale - call the handler
692 metrics.misses.fetch_add(1, Ordering::Relaxed);
693 let response = inner.call(req).await.map_err(Into::into)?;
694
695 // Split response to check if we should cache
696 let (res_parts, body) = response.into_parts();
697
698 // Check if we should cache this response
699 if let Some(ttl) = should_cache(&req_parts, &res_parts, &options) {
700 // Buffer the response body
701 let body_bytes = match collect_body(body).await {
702 Ok(bytes) => bytes,
703 Err(e) => {
704 // If we can't collect the body, return an error response
705 return Err(e);
706 }
707 };
708
709 // Check size limit
710 if body_bytes.len() <= options.max_body_size {
711 metrics.stores.fetch_add(1, Ordering::Relaxed);
712 // Create cached response
713 let cached = CachedResponse {
714 status: res_parts.status.as_u16(),
715 headers: res_parts
716 .headers
717 .iter()
718 .filter_map(|(k, v)| {
719 v.to_str()
720 .ok()
721 .map(|s| (k.to_string(), s.to_string()))
722 })
723 .collect(),
724 body: body_bytes.to_vec(),
725 cached_at: SystemTime::now(),
726 ttl,
727 vary: extract_vary_headers(&res_parts),
728 };
729
730 // Store in cache (fire and forget)
731 let cached_json = serde_json::to_vec(&cached)
732 .map_err(|e| Box::new(e) as BoxError)?;
733 let http_response = HttpResponse {
734 body: cached_json,
735 headers: Default::default(),
736 status: 200,
737 url: cache_key.clone().parse().unwrap_or_else(|_| {
738 "http://localhost/".parse().unwrap()
739 }),
740 version: HttpVersion::Http11,
741 };
742
743 // Create CachePolicy from actual request/response for Vary support
744 let policy_req = Request::from_parts(req_parts.clone(), ());
745 let policy_res =
746 Response::from_parts(res_parts.clone(), ());
747 let policy = CachePolicy::new(&policy_req, &policy_res);
748
749 // Spawn cache write asynchronously
750 let manager_clone = manager.clone();
751 tokio::spawn(async move {
752 let _ = manager_clone
753 .put(cache_key, http_response, policy)
754 .await;
755 });
756 } else {
757 // Body too large
758 metrics.skipped.fetch_add(1, Ordering::Relaxed);
759 }
760
761 // Return response with MISS header
762 let mut response = Response::from_parts(res_parts, body_bytes);
763 if options.cache_status_headers {
764 response
765 .headers_mut()
766 .insert("x-cache", HeaderValue::from_static("MISS"));
767 }
768 return Ok(response.map(ResponseBody::Fresh));
769 }
770
771 // Don't cache - just return
772 metrics.skipped.fetch_add(1, Ordering::Relaxed);
773 let body_bytes = collect_body(body).await?;
774 Ok(Response::from_parts(res_parts, body_bytes)
775 .map(ResponseBody::Uncacheable))
776 })
777 }
778}
779
780/// Collect a body into bytes.
781async fn collect_body<B>(body: B) -> std::result::Result<Bytes, BoxError>
782where
783 B: HttpBody,
784 B::Error: Into<BoxError>,
785{
786 body.collect()
787 .await
788 .map(|collected| collected.to_bytes())
789 .map_err(Into::into)
790}
791
792/// Extract Vary headers from response parts.
793fn extract_vary_headers(parts: &http::response::Parts) -> Option<Vec<String>> {
794 parts
795 .headers
796 .get(http::header::VARY)
797 .and_then(|v| v.to_str().ok())
798 .map(|s| s.split(',').map(|h| h.trim().to_string()).collect())
799}
800
801/// Determine if a response should be cached based on its headers.
802/// Implements RFC 7234/9111 requirements for shared caches.
803/// Helper function to check if a Cache-Control directive is present.
804/// This properly parses directives by splitting on commas and matching exact names.
805fn has_directive(cache_control: &str, directive: &str) -> bool {
806 cache_control
807 .split(',')
808 .map(|d| d.trim())
809 .any(|d| d == directive || d.starts_with(&format!("{}=", directive)))
810}
811
812/// Check if response explicitly permits caching of authorized requests per RFC 9111 §3.5.
813///
814/// Returns true if the response contains directives that allow caching despite
815/// the request having an Authorization header.
816fn response_permits_authorized_caching(cc_str: &str) -> bool {
817 has_directive(cc_str, "public")
818 || has_directive(cc_str, "s-maxage")
819 || has_directive(cc_str, "must-revalidate")
820}
821
822fn should_cache(
823 req_parts: &http::request::Parts,
824 res_parts: &http::response::Parts,
825 options: &ServerCacheOptions,
826) -> Option<Duration> {
827 // RFC 7234: Only cache successful responses (2xx)
828 if !res_parts.status.is_success() {
829 return None;
830 }
831
832 // RFC 9111 §3.5: Check Authorization header
833 let has_authorization =
834 req_parts.headers.contains_key(http::header::AUTHORIZATION);
835
836 // RFC 7234: Check Cache-Control directives
837 if let Some(cc) = res_parts.headers.get(http::header::CACHE_CONTROL) {
838 let cc_str = cc.to_str().ok()?;
839
840 // RFC 9111 §3.5: If request has Authorization header, only cache if
841 // response explicitly permits it
842 if has_authorization
843 && options.respect_authorization
844 && !response_permits_authorized_caching(cc_str)
845 {
846 return None;
847 }
848
849 // RFC 7234: MUST NOT store if no-store directive present
850 if has_directive(cc_str, "no-store") {
851 return None;
852 }
853
854 // RFC 7234: MUST NOT store if no-cache
855 // Note: Per RFC, no-cache means "cache but always revalidate". However,
856 // without conditional request support (ETag/If-None-Match), we cannot
857 // revalidate, so we skip caching entirely.
858 if has_directive(cc_str, "no-cache") {
859 return None;
860 }
861
862 // RFC 7234: Shared caches MUST NOT store responses with private directive
863 if has_directive(cc_str, "private") {
864 return None;
865 }
866
867 // RFC 7234: s-maxage directive overrides max-age for shared caches
868 if let Some(s_maxage) = parse_s_maxage(cc_str) {
869 let ttl = Duration::from_secs(s_maxage);
870 let ttl = apply_ttl_constraints(ttl, options);
871 return Some(ttl);
872 }
873
874 // RFC 7234: Extract max-age for cache lifetime
875 if let Some(max_age) = parse_max_age(cc_str) {
876 let ttl = Duration::from_secs(max_age);
877 let ttl = apply_ttl_constraints(ttl, options);
878 return Some(ttl);
879 }
880
881 // RFC 7234: public directive makes response cacheable
882 if has_directive(cc_str, "public") {
883 return options.default_ttl;
884 }
885 } else {
886 // No Cache-Control header
887 // RFC 9111 §3.5: Don't cache authorized requests without explicit permission
888 if has_authorization && options.respect_authorization {
889 return None;
890 }
891 }
892
893 // RFC 7234: Check for Expires header if no Cache-Control
894 if let Some(expires) = res_parts.headers.get(http::header::EXPIRES) {
895 if let Ok(expires_str) = expires.to_str() {
896 if let Some(ttl) = parse_expires(expires_str) {
897 let ttl = apply_ttl_constraints(ttl, options);
898 return Some(ttl);
899 }
900 }
901 }
902
903 // No explicit caching directive
904 if options.cache_by_default {
905 options.default_ttl
906 } else {
907 None
908 }
909}
910
911/// Apply min/max TTL constraints from options.
912fn apply_ttl_constraints(
913 ttl: Duration,
914 options: &ServerCacheOptions,
915) -> Duration {
916 let mut result = ttl;
917
918 if let Some(max) = options.max_ttl {
919 result = result.min(max);
920 }
921
922 if let Some(min) = options.min_ttl {
923 result = result.max(min);
924 }
925
926 result
927}
928
929/// Parse max-age from Cache-Control header.
930fn parse_max_age(cache_control: &str) -> Option<u64> {
931 for directive in cache_control.split(',') {
932 let directive = directive.trim();
933 if let Some(value) = directive.strip_prefix("max-age=") {
934 return value.parse().ok();
935 }
936 }
937 None
938}
939
940/// Parse s-maxage from Cache-Control header (shared cache specific).
941fn parse_s_maxage(cache_control: &str) -> Option<u64> {
942 for directive in cache_control.split(',') {
943 let directive = directive.trim();
944 if let Some(value) = directive.strip_prefix("s-maxage=") {
945 return value.parse().ok();
946 }
947 }
948 None
949}
950
951/// Parse Expires header to calculate TTL.
952///
953/// Returns the duration until expiration, or None if the date is invalid or in the past.
954fn parse_expires(expires: &str) -> Option<Duration> {
955 let expires_time = httpdate::parse_http_date(expires).ok()?;
956 let now = SystemTime::now();
957
958 expires_time.duration_since(now).ok()
959}
960
961#[cfg(test)]
962mod tests {
963 use super::*;
964
965 #[test]
966 fn test_default_keyer() {
967 let keyer = DefaultKeyer;
968 let req = Request::get("/users/123").body(()).unwrap();
969 let key = keyer.cache_key(&req);
970 assert_eq!(key, "GET /users/123");
971 }
972
973 #[test]
974 fn test_query_keyer() {
975 let keyer = QueryKeyer;
976 let req = Request::get("/users?page=1").body(()).unwrap();
977 let key = keyer.cache_key(&req);
978 assert_eq!(key, "GET /users?page=1");
979 }
980
981 #[test]
982 fn test_parse_max_age() {
983 assert_eq!(parse_max_age("max-age=3600"), Some(3600));
984 assert_eq!(parse_max_age("public, max-age=3600"), Some(3600));
985 assert_eq!(parse_max_age("max-age=3600, public"), Some(3600));
986 assert_eq!(parse_max_age("public"), None);
987 }
988
989 #[test]
990 fn test_parse_s_maxage() {
991 assert_eq!(parse_s_maxage("s-maxage=7200"), Some(7200));
992 assert_eq!(parse_s_maxage("public, s-maxage=7200"), Some(7200));
993 assert_eq!(parse_s_maxage("s-maxage=7200, max-age=3600"), Some(7200));
994 assert_eq!(parse_s_maxage("public"), None);
995 }
996
997 #[test]
998 fn test_apply_ttl_constraints() {
999 let options = ServerCacheOptions {
1000 min_ttl: Some(Duration::from_secs(10)),
1001 max_ttl: Some(Duration::from_secs(100)),
1002 ..Default::default()
1003 };
1004
1005 assert_eq!(
1006 apply_ttl_constraints(Duration::from_secs(5), &options),
1007 Duration::from_secs(10)
1008 );
1009 assert_eq!(
1010 apply_ttl_constraints(Duration::from_secs(50), &options),
1011 Duration::from_secs(50)
1012 );
1013 assert_eq!(
1014 apply_ttl_constraints(Duration::from_secs(200), &options),
1015 Duration::from_secs(100)
1016 );
1017 }
1018}