Skip to main content

tower_cache_control/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{task, time::Duration};
4
5use futures::future::BoxFuture;
6pub use headers::CacheControl;
7use headers::HeaderMapExt;
8use http::{Request, Response, StatusCode};
9use tower_layer::Layer;
10use tower_service::Service;
11
12/// Middleware [Layer] for the [CacheControlService] service.
13#[derive(Clone, Debug)]
14pub struct CacheControlLayer {
15    default: Option<CacheControl>,
16}
17
18impl CacheControlLayer {
19    pub fn new(header: CacheControl) -> Self {
20        Self {
21            default: Some(header),
22        }
23    }
24}
25
26impl Default for CacheControlLayer {
27    fn default() -> Self {
28        Self { default: None }
29    }
30}
31
32impl<S> Layer<S> for CacheControlLayer {
33    type Service = CacheControlService<S>;
34    fn layer(&self, inner: S) -> Self::Service {
35        CacheControlService {
36            inner,
37            default: self.default.clone(),
38        }
39    }
40}
41
42/// # `Cache-Control` setter [Service].
43///
44/// Assigns a value based on a response status:
45/// * on `1xx` and `2xx` takes a `no-cache` request header directive or falls back to a default one;
46/// * on `301`, likely a permanent move, sets a day *TTL* and asks *CDN* to cache the response, too;
47/// * on any other `3xx` takes the default and prevents *CDN* from caching the response;
48/// * on `4xx` caching is disabled;
49/// * on `5xx` 30 min *TTL* is set.
50///
51/// *TTL* defaults to `5` seconds.
52#[derive(Clone, Debug)]
53pub struct CacheControlService<S> {
54    inner: S,
55    default: Option<CacheControl>,
56}
57
58impl<B, D, S> Service<Request<B>> for CacheControlService<S>
59where
60    S: Service<Request<B>, Response = Response<D>> + Send + 'static,
61    S::Future: Send + 'static,
62{
63    type Response = S::Response;
64    type Error = S::Error;
65    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
68        self.inner.poll_ready(cx)
69    }
70
71    fn call(&mut self, req: Request<B>) -> Self::Future {
72        let default = self
73            .default
74            .clone()
75            .unwrap_or(CacheControl::new().with_max_age(Duration::from_secs(5)));
76        let header = req
77            .headers()
78            .typed_get::<CacheControl>()
79            .and_then(|header| header.ne(&CacheControl::new()).then_some(header));
80        let fut = self.inner.call(req);
81        Box::pin(async move {
82            let mut res = fut.await?;
83            if res.headers().typed_get::<CacheControl>().is_some() {
84                return Ok(res);
85            };
86            let header = match res.status() {
87                StatusCode::MOVED_PERMANENTLY => default
88                    .with_max_age(Duration::from_secs(86_400))
89                    .with_public(),
90                s if s.is_success() => header.unwrap_or(default),
91                s if s.is_redirection() => header.unwrap_or(default).with_private(),
92                s if s.is_client_error() => default.with_no_cache().with_private(),
93                _ => default
94                    .with_max_age(Duration::from_secs(1_800))
95                    .with_public(),
96            };
97            res.headers_mut().typed_insert(header);
98            Ok(res)
99        })
100    }
101}