tower_cache_control/
lib.rs1#![doc = include_str!("../README.md")]
2
3use std::{task, time::Duration};
4
5use futures::future::BoxFuture;
6pub use headers::CacheControl;
7use headers::HeaderMapExt;
8use http::{Request, Response, StatusCode};
9use tower_layer::Layer;
10use tower_service::Service;
11
12#[derive(Clone, Debug)]
14pub struct CacheControlLayer {
15 default: Option<CacheControl>,
16}
17
18impl CacheControlLayer {
19 pub fn new(header: CacheControl) -> Self {
20 Self {
21 default: Some(header),
22 }
23 }
24}
25
26impl Default for CacheControlLayer {
27 fn default() -> Self {
28 Self { default: None }
29 }
30}
31
32impl<S> Layer<S> for CacheControlLayer {
33 type Service = CacheControlService<S>;
34 fn layer(&self, inner: S) -> Self::Service {
35 CacheControlService {
36 inner,
37 default: self.default.clone(),
38 }
39 }
40}
41
42#[derive(Clone, Debug)]
53pub struct CacheControlService<S> {
54 inner: S,
55 default: Option<CacheControl>,
56}
57
58impl<B, D, S> Service<Request<B>> for CacheControlService<S>
59where
60 S: Service<Request<B>, Response = Response<D>> + Send + 'static,
61 S::Future: Send + 'static,
62{
63 type Response = S::Response;
64 type Error = S::Error;
65 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
68 self.inner.poll_ready(cx)
69 }
70
71 fn call(&mut self, req: Request<B>) -> Self::Future {
72 let default = self
73 .default
74 .clone()
75 .unwrap_or(CacheControl::new().with_max_age(Duration::from_secs(5)));
76 let header = req
77 .headers()
78 .typed_get::<CacheControl>()
79 .and_then(|header| header.ne(&CacheControl::new()).then_some(header));
80 let fut = self.inner.call(req);
81 Box::pin(async move {
82 let mut res = fut.await?;
83 if res.headers().typed_get::<CacheControl>().is_some() {
84 return Ok(res);
85 };
86 let header = match res.status() {
87 StatusCode::MOVED_PERMANENTLY => default
88 .with_max_age(Duration::from_secs(86_400))
89 .with_public(),
90 s if s.is_success() => header.unwrap_or(default),
91 s if s.is_redirection() => header.unwrap_or(default).with_private(),
92 s if s.is_client_error() => default.with_no_cache().with_private(),
93 _ => default
94 .with_max_age(Duration::from_secs(1_800))
95 .with_public(),
96 };
97 res.headers_mut().typed_insert(header);
98 Ok(res)
99 })
100 }
101}