rama_http_headers/common/
cache_control.rs

1use std::fmt;
2use std::iter::FromIterator;
3use std::str::FromStr;
4use std::time::Duration;
5
6use rama_http_types::{HeaderName, HeaderValue};
7
8use crate::util::{self, Seconds, csv};
9use crate::{Error, Header};
10
11/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2)
12/// with extensions in [RFC8246](https://www.rfc-editor.org/rfc/rfc8246)
13///
14/// The `Cache-Control` header field is used to specify directives for
15/// caches along the request/response chain.  Such cache directives are
16/// unidirectional in that the presence of a directive in a request does
17/// not imply that the same directive is to be given in the response.
18///
19/// ## ABNF
20///
21/// ```text
22/// Cache-Control   = 1#cache-directive
23/// cache-directive = token [ "=" ( token / quoted-string ) ]
24/// ```
25///
26/// ## Example values
27///
28/// * `no-cache`
29/// * `private, community="UCI"`
30/// * `max-age=30`
31///
32/// # Example
33///
34/// ```
35/// use rama_http_headers::CacheControl;
36///
37/// let cc = CacheControl::new();
38/// ```
39#[derive(PartialEq, Clone, Debug)]
40pub struct CacheControl {
41    flags: Flags,
42    max_age: Option<Seconds>,
43    max_stale: Option<Seconds>,
44    min_fresh: Option<Seconds>,
45    s_max_age: Option<Seconds>,
46}
47
48#[derive(Debug, Clone, PartialEq)]
49struct Flags {
50    bits: u64,
51}
52
53impl Flags {
54    const NO_CACHE: Self = Self { bits: 0b000000001 };
55    const NO_STORE: Self = Self { bits: 0b000000010 };
56    const NO_TRANSFORM: Self = Self { bits: 0b000000100 };
57    const ONLY_IF_CACHED: Self = Self { bits: 0b000001000 };
58    const MUST_REVALIDATE: Self = Self { bits: 0b000010000 };
59    const PUBLIC: Self = Self { bits: 0b000100000 };
60    const PRIVATE: Self = Self { bits: 0b001000000 };
61    const PROXY_REVALIDATE: Self = Self { bits: 0b010000000 };
62    const IMMUTABLE: Self = Self { bits: 0b100000000 };
63    const MUST_UNDERSTAND: Self = Self { bits: 0b1000000000 };
64
65    fn empty() -> Self {
66        Self { bits: 0 }
67    }
68
69    fn contains(&self, flag: Self) -> bool {
70        (self.bits & flag.bits) != 0
71    }
72
73    fn insert(&mut self, flag: Self) {
74        self.bits |= flag.bits;
75    }
76}
77
78impl Default for CacheControl {
79    #[inline]
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl CacheControl {
86    /// Construct a new empty `CacheControl` header.
87    pub fn new() -> Self {
88        CacheControl {
89            flags: Flags::empty(),
90            max_age: None,
91            max_stale: None,
92            min_fresh: None,
93            s_max_age: None,
94        }
95    }
96
97    // getters
98
99    /// Check if the `no-cache` directive is set.
100    pub fn no_cache(&self) -> bool {
101        self.flags.contains(Flags::NO_CACHE)
102    }
103
104    /// Check if the `no-store` directive is set.
105    pub fn no_store(&self) -> bool {
106        self.flags.contains(Flags::NO_STORE)
107    }
108
109    /// Check if the `no-transform` directive is set.
110    pub fn no_transform(&self) -> bool {
111        self.flags.contains(Flags::NO_TRANSFORM)
112    }
113
114    /// Check if the `only-if-cached` directive is set.
115    pub fn only_if_cached(&self) -> bool {
116        self.flags.contains(Flags::ONLY_IF_CACHED)
117    }
118
119    /// Check if the `public` directive is set.
120    pub fn public(&self) -> bool {
121        self.flags.contains(Flags::PUBLIC)
122    }
123
124    /// Check if the `private` directive is set.
125    pub fn private(&self) -> bool {
126        self.flags.contains(Flags::PRIVATE)
127    }
128
129    /// Check if the `immutable` directive is set.
130    pub fn immutable(&self) -> bool {
131        self.flags.contains(Flags::IMMUTABLE)
132    }
133
134    /// Check if the `must-revalidate` directive is set.
135    pub fn must_revalidate(&self) -> bool {
136        self.flags.contains(Flags::MUST_REVALIDATE)
137    }
138
139    /// Check if the `must-understand` directive is set.
140    pub fn must_understand(&self) -> bool {
141        self.flags.contains(Flags::MUST_UNDERSTAND)
142    }
143
144    /// Get the value of the `max-age` directive if set.
145    pub fn max_age(&self) -> Option<Duration> {
146        self.max_age.map(Into::into)
147    }
148
149    /// Get the value of the `max-stale` directive if set.
150    pub fn max_stale(&self) -> Option<Duration> {
151        self.max_stale.map(Into::into)
152    }
153
154    /// Get the value of the `min-fresh` directive if set.
155    pub fn min_fresh(&self) -> Option<Duration> {
156        self.min_fresh.map(Into::into)
157    }
158
159    /// Get the value of the `s-maxage` directive if set.
160    pub fn s_max_age(&self) -> Option<Duration> {
161        self.s_max_age.map(Into::into)
162    }
163
164    // setters
165
166    /// Set the `no-cache` directive.
167    pub fn with_no_cache(mut self) -> Self {
168        self.flags.insert(Flags::NO_CACHE);
169        self
170    }
171
172    /// Set the `no-store` directive.
173    pub fn with_no_store(mut self) -> Self {
174        self.flags.insert(Flags::NO_STORE);
175        self
176    }
177
178    /// Set the `no-transform` directive.
179    pub fn with_no_transform(mut self) -> Self {
180        self.flags.insert(Flags::NO_TRANSFORM);
181        self
182    }
183
184    /// Set the `only-if-cached` directive.
185    pub fn with_only_if_cached(mut self) -> Self {
186        self.flags.insert(Flags::ONLY_IF_CACHED);
187        self
188    }
189
190    /// Set the `private` directive.
191    pub fn with_private(mut self) -> Self {
192        self.flags.insert(Flags::PRIVATE);
193        self
194    }
195
196    /// Set the `public` directive.
197    pub fn with_public(mut self) -> Self {
198        self.flags.insert(Flags::PUBLIC);
199        self
200    }
201
202    /// Set the `immutable` directive.
203    pub fn with_immutable(mut self) -> Self {
204        self.flags.insert(Flags::IMMUTABLE);
205        self
206    }
207
208    /// Set the `must-revalidate` directive.
209    pub fn with_must_revalidate(mut self) -> Self {
210        self.flags.insert(Flags::MUST_REVALIDATE);
211        self
212    }
213
214    /// Set the `must-understand` directive.
215    pub fn with_must_understand(mut self) -> Self {
216        self.flags.insert(Flags::MUST_UNDERSTAND);
217        self
218    }
219
220    /// Set the `max-age` directive.
221    pub fn with_max_age(mut self, duration: Duration) -> Self {
222        self.max_age = Some(duration.into());
223        self
224    }
225
226    /// Set the `max-stale` directive.
227    pub fn with_max_stale(mut self, duration: Duration) -> Self {
228        self.max_stale = Some(duration.into());
229        self
230    }
231
232    /// Set the `min-fresh` directive.
233    pub fn with_min_fresh(mut self, duration: Duration) -> Self {
234        self.min_fresh = Some(duration.into());
235        self
236    }
237
238    /// Set the `s-maxage` directive.
239    pub fn with_s_max_age(mut self, duration: Duration) -> Self {
240        self.s_max_age = Some(duration.into());
241        self
242    }
243}
244
245impl Header for CacheControl {
246    fn name() -> &'static HeaderName {
247        &::rama_http_types::header::CACHE_CONTROL
248    }
249
250    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> {
251        csv::from_comma_delimited(values).map(|FromIter(cc)| cc)
252    }
253
254    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
255        values.extend(::std::iter::once(util::fmt(Fmt(self))));
256    }
257}
258
259// Adapter to be used in Header::decode
260struct FromIter(CacheControl);
261
262impl FromIterator<KnownDirective> for FromIter {
263    fn from_iter<I>(iter: I) -> Self
264    where
265        I: IntoIterator<Item = KnownDirective>,
266    {
267        let mut cc = CacheControl::new();
268
269        // ignore all unknown directives
270        let iter = iter.into_iter().filter_map(|dir| match dir {
271            KnownDirective::Known(dir) => Some(dir),
272            KnownDirective::Unknown => None,
273        });
274
275        for directive in iter {
276            match directive {
277                Directive::NoCache => {
278                    cc.flags.insert(Flags::NO_CACHE);
279                }
280                Directive::NoStore => {
281                    cc.flags.insert(Flags::NO_STORE);
282                }
283                Directive::NoTransform => {
284                    cc.flags.insert(Flags::NO_TRANSFORM);
285                }
286                Directive::OnlyIfCached => {
287                    cc.flags.insert(Flags::ONLY_IF_CACHED);
288                }
289                Directive::MustRevalidate => {
290                    cc.flags.insert(Flags::MUST_REVALIDATE);
291                }
292                Directive::MustUnderstand => {
293                    cc.flags.insert(Flags::MUST_UNDERSTAND);
294                }
295                Directive::Public => {
296                    cc.flags.insert(Flags::PUBLIC);
297                }
298                Directive::Private => {
299                    cc.flags.insert(Flags::PRIVATE);
300                }
301                Directive::Immutable => {
302                    cc.flags.insert(Flags::IMMUTABLE);
303                }
304                Directive::ProxyRevalidate => {
305                    cc.flags.insert(Flags::PROXY_REVALIDATE);
306                }
307                Directive::MaxAge(secs) => {
308                    cc.max_age = Some(Duration::from_secs(secs).into());
309                }
310                Directive::MaxStale(secs) => {
311                    cc.max_stale = Some(Duration::from_secs(secs).into());
312                }
313                Directive::MinFresh(secs) => {
314                    cc.min_fresh = Some(Duration::from_secs(secs).into());
315                }
316                Directive::SMaxAge(secs) => {
317                    cc.s_max_age = Some(Duration::from_secs(secs).into());
318                }
319            }
320        }
321
322        FromIter(cc)
323    }
324}
325
326struct Fmt<'a>(&'a CacheControl);
327
328impl fmt::Display for Fmt<'_> {
329    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330        let if_flag = |f: Flags, dir: Directive| {
331            if self.0.flags.contains(f) {
332                Some(dir)
333            } else {
334                None
335            }
336        };
337
338        let slice = &[
339            if_flag(Flags::NO_CACHE, Directive::NoCache),
340            if_flag(Flags::NO_STORE, Directive::NoStore),
341            if_flag(Flags::NO_TRANSFORM, Directive::NoTransform),
342            if_flag(Flags::ONLY_IF_CACHED, Directive::OnlyIfCached),
343            if_flag(Flags::MUST_REVALIDATE, Directive::MustRevalidate),
344            if_flag(Flags::PUBLIC, Directive::Public),
345            if_flag(Flags::PRIVATE, Directive::Private),
346            if_flag(Flags::IMMUTABLE, Directive::Immutable),
347            if_flag(Flags::MUST_UNDERSTAND, Directive::MustUnderstand),
348            if_flag(Flags::PROXY_REVALIDATE, Directive::ProxyRevalidate),
349            self.0
350                .max_age
351                .as_ref()
352                .map(|s| Directive::MaxAge(s.as_u64())),
353            self.0
354                .max_stale
355                .as_ref()
356                .map(|s| Directive::MaxStale(s.as_u64())),
357            self.0
358                .min_fresh
359                .as_ref()
360                .map(|s| Directive::MinFresh(s.as_u64())),
361            self.0
362                .s_max_age
363                .as_ref()
364                .map(|s| Directive::SMaxAge(s.as_u64())),
365        ];
366
367        let iter = slice.iter().filter_map(|o| *o);
368
369        csv::fmt_comma_delimited(f, iter)
370    }
371}
372
373#[derive(Clone, Copy)]
374enum KnownDirective {
375    Known(Directive),
376    Unknown,
377}
378
379#[derive(Clone, Copy)]
380enum Directive {
381    NoCache,
382    NoStore,
383    NoTransform,
384    OnlyIfCached,
385
386    // request directives
387    MaxAge(u64),
388    MaxStale(u64),
389    MinFresh(u64),
390
391    // response directives
392    MustRevalidate,
393    MustUnderstand,
394    Public,
395    Private,
396    Immutable,
397    ProxyRevalidate,
398    SMaxAge(u64),
399}
400
401impl fmt::Display for Directive {
402    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
403        fmt::Display::fmt(
404            match *self {
405                Directive::NoCache => "no-cache",
406                Directive::NoStore => "no-store",
407                Directive::NoTransform => "no-transform",
408                Directive::OnlyIfCached => "only-if-cached",
409
410                Directive::MaxAge(secs) => return write!(f, "max-age={}", secs),
411                Directive::MaxStale(secs) => return write!(f, "max-stale={}", secs),
412                Directive::MinFresh(secs) => return write!(f, "min-fresh={}", secs),
413
414                Directive::MustRevalidate => "must-revalidate",
415                Directive::MustUnderstand => "must-understand",
416                Directive::Public => "public",
417                Directive::Private => "private",
418                Directive::Immutable => "immutable",
419                Directive::ProxyRevalidate => "proxy-revalidate",
420                Directive::SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
421            },
422            f,
423        )
424    }
425}
426
427impl FromStr for KnownDirective {
428    type Err = ();
429    fn from_str(s: &str) -> Result<Self, Self::Err> {
430        Ok(KnownDirective::Known(match s {
431            "no-cache" => Directive::NoCache,
432            "no-store" => Directive::NoStore,
433            "no-transform" => Directive::NoTransform,
434            "only-if-cached" => Directive::OnlyIfCached,
435            "must-revalidate" => Directive::MustRevalidate,
436            "public" => Directive::Public,
437            "private" => Directive::Private,
438            "immutable" => Directive::Immutable,
439            "must-understand" => Directive::MustUnderstand,
440            "proxy-revalidate" => Directive::ProxyRevalidate,
441            "" => return Err(()),
442            _ => match s.find('=') {
443                Some(idx) if idx + 1 < s.len() => {
444                    match (&s[..idx], (s[idx + 1..]).trim_matches('"')) {
445                        ("max-age", secs) => secs.parse().map(Directive::MaxAge).map_err(|_| ())?,
446                        ("max-stale", secs) => {
447                            secs.parse().map(Directive::MaxStale).map_err(|_| ())?
448                        }
449                        ("min-fresh", secs) => {
450                            secs.parse().map(Directive::MinFresh).map_err(|_| ())?
451                        }
452                        ("s-maxage", secs) => {
453                            secs.parse().map(Directive::SMaxAge).map_err(|_| ())?
454                        }
455                        _unknown => return Ok(KnownDirective::Unknown),
456                    }
457                }
458                Some(_) | None => return Ok(KnownDirective::Unknown),
459            },
460        }))
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::super::{test_decode, test_encode};
467    use super::*;
468
469    #[test]
470    fn test_parse_multiple_headers() {
471        assert_eq!(
472            test_decode::<CacheControl>(&["no-cache", "private"]).unwrap(),
473            CacheControl::new().with_no_cache().with_private(),
474        );
475    }
476
477    #[test]
478    fn test_parse_argument() {
479        assert_eq!(
480            test_decode::<CacheControl>(&["max-age=100, private"]).unwrap(),
481            CacheControl::new()
482                .with_max_age(Duration::from_secs(100))
483                .with_private(),
484        );
485    }
486
487    #[test]
488    fn test_parse_quote_form() {
489        assert_eq!(
490            test_decode::<CacheControl>(&["max-age=\"200\""]).unwrap(),
491            CacheControl::new().with_max_age(Duration::from_secs(200)),
492        );
493    }
494
495    #[test]
496    fn test_parse_extension() {
497        assert_eq!(
498            test_decode::<CacheControl>(&["foo, no-cache, bar=baz"]).unwrap(),
499            CacheControl::new().with_no_cache(),
500            "unknown extensions are ignored but shouldn't fail parsing",
501        );
502    }
503
504    #[test]
505    fn test_immutable() {
506        let cc = CacheControl::new().with_immutable();
507        let headers = test_encode(cc.clone());
508        assert_eq!(headers["cache-control"], "immutable");
509        assert_eq!(test_decode::<CacheControl>(&["immutable"]).unwrap(), cc);
510        assert!(cc.immutable());
511    }
512
513    #[test]
514    fn test_must_revalidate() {
515        let cc = CacheControl::new().with_must_revalidate();
516        let headers = test_encode(cc.clone());
517        assert_eq!(headers["cache-control"], "must-revalidate");
518        assert_eq!(
519            test_decode::<CacheControl>(&["must-revalidate"]).unwrap(),
520            cc
521        );
522        assert!(cc.must_revalidate());
523    }
524
525    #[test]
526    fn test_must_understand() {
527        let cc = CacheControl::new().with_must_understand();
528        let headers = test_encode(cc.clone());
529        assert_eq!(headers["cache-control"], "must-understand");
530        assert_eq!(
531            test_decode::<CacheControl>(&["must-understand"]).unwrap(),
532            cc
533        );
534        assert!(cc.must_understand());
535    }
536
537    #[test]
538    fn test_parse_bad_syntax() {
539        assert_eq!(test_decode::<CacheControl>(&["max-age=lolz"]), None);
540    }
541
542    #[test]
543    fn encode_one_flag_directive() {
544        let cc = CacheControl::new().with_no_cache();
545
546        let headers = test_encode(cc);
547        assert_eq!(headers["cache-control"], "no-cache");
548    }
549
550    #[test]
551    fn encode_one_param_directive() {
552        let cc = CacheControl::new().with_max_age(Duration::from_secs(300));
553
554        let headers = test_encode(cc);
555        assert_eq!(headers["cache-control"], "max-age=300");
556    }
557
558    #[test]
559    fn encode_two_directive() {
560        let headers = test_encode(CacheControl::new().with_no_cache().with_private());
561        assert_eq!(headers["cache-control"], "no-cache, private");
562
563        let headers = test_encode(
564            CacheControl::new()
565                .with_no_cache()
566                .with_max_age(Duration::from_secs(100)),
567        );
568        assert_eq!(headers["cache-control"], "no-cache, max-age=100");
569    }
570}