Skip to main content

axum_security/headers/
hsts.rs

1use std::time::Duration;
2
3use axum::http::{HeaderValue, header::STRICT_TRANSPORT_SECURITY};
4use tower::Layer;
5
6use crate::utils::headers::InsertHeadersService;
7
8const PRELOAD_MIN_MAX_AGE: u64 = 365 * 24 * 60 * 60;
9
10#[derive(Clone)]
11pub struct StrictTransportSecurity {
12    header_value: HeaderValue,
13}
14
15impl StrictTransportSecurity {
16    pub fn builder() -> HstsBuilder {
17        HstsBuilder {
18            max_age: None,
19            include_subdomains: false,
20            preload: false,
21        }
22    }
23}
24
25pub struct HstsBuilder {
26    max_age: Option<u64>,
27    include_subdomains: bool,
28    preload: bool,
29}
30
31impl HstsBuilder {
32    pub fn max_age(mut self, duration: Duration) -> Self {
33        self.max_age = Some(duration.as_secs());
34        self
35    }
36
37    pub fn max_age_seconds(mut self, max_age: u64) -> Self {
38        self.max_age = Some(max_age);
39        self
40    }
41
42    /// 24h in a day
43    pub fn max_age_days(self, max_age: u64) -> Self {
44        self.max_age_seconds(max_age * 24 * 60 * 60)
45    }
46
47    /// 365 days in a year
48    pub fn max_age_years(self, max_age: u64) -> Self {
49        self.max_age_days(max_age * 365)
50    }
51
52    pub fn include_subdomains(mut self) -> Self {
53        self.include_subdomains = true;
54        self
55    }
56
57    pub fn preload(mut self) -> Self {
58        self.preload = true;
59        self
60    }
61
62    pub fn try_build(self) -> Result<StrictTransportSecurity, StrictTransportSecurityBuilderError> {
63        let Some(max_age) = self.max_age else {
64            return Err(StrictTransportSecurityBuilderError::NoMaxAge);
65        };
66
67        let mut header = format!("max-age={max_age}");
68
69        if self.include_subdomains {
70            header.push_str("; includeSubDomains");
71        }
72
73        if self.preload {
74            if max_age < PRELOAD_MIN_MAX_AGE {
75                return Err(StrictTransportSecurityBuilderError::InvalidMaxAge);
76            } else if !self.include_subdomains {
77                return Err(StrictTransportSecurityBuilderError::IncludeSubdomainsRequired);
78            }
79
80            header.push_str("; preload");
81        }
82
83        let header_value =
84            HeaderValue::from_str(&header).expect("Hsts header does not contain invalid bytes");
85
86        Ok(StrictTransportSecurity { header_value })
87    }
88
89    pub fn build(self) -> StrictTransportSecurity {
90        self.try_build().unwrap()
91    }
92}
93
94#[derive(Debug)]
95pub enum StrictTransportSecurityBuilderError {
96    NoMaxAge,
97    InvalidMaxAge,
98    IncludeSubdomainsRequired,
99}
100
101impl<S> Layer<S> for StrictTransportSecurity {
102    type Service = InsertHeadersService<S>;
103
104    fn layer(&self, inner: S) -> Self::Service {
105        InsertHeadersService {
106            inner,
107            header_name: STRICT_TRANSPORT_SECURITY,
108            header_value: self.header_value.clone(),
109        }
110    }
111}
112
113#[cfg(test)]
114mod hsts_tests {
115    use axum::{Router, body::Body, extract::Request, http::header::STRICT_TRANSPORT_SECURITY};
116
117    use crate::headers::{StrictTransportSecurity, StrictTransportSecurityBuilderError};
118    use tower::ServiceExt;
119
120    #[test]
121    fn builder() {
122        let hsts = StrictTransportSecurity::builder().try_build();
123        assert!(matches!(
124            hsts,
125            Err(StrictTransportSecurityBuilderError::NoMaxAge)
126        ));
127
128        let hsts = StrictTransportSecurity::builder()
129            .max_age_days(364)
130            .preload()
131            .try_build();
132        assert!(matches!(
133            hsts,
134            Err(StrictTransportSecurityBuilderError::InvalidMaxAge)
135        ));
136
137        let hsts = StrictTransportSecurity::builder()
138            .max_age_years(1)
139            .preload()
140            .try_build();
141        assert!(matches!(
142            hsts,
143            Err(StrictTransportSecurityBuilderError::IncludeSubdomainsRequired)
144        ));
145    }
146
147    #[test]
148    fn header() {
149        let hsts = StrictTransportSecurity::builder()
150            .max_age_seconds(1)
151            .build();
152        assert!(hsts.header_value == "max-age=1");
153
154        let hsts = StrictTransportSecurity::builder()
155            .max_age_seconds(1)
156            .include_subdomains()
157            .build();
158        assert!(hsts.header_value == "max-age=1; includeSubDomains");
159
160        let hsts = StrictTransportSecurity::builder()
161            .max_age_years(1)
162            .include_subdomains()
163            .preload()
164            .build();
165        assert!(hsts.header_value == "max-age=31536000; includeSubDomains; preload");
166    }
167
168    #[tokio::test]
169    async fn basic() {
170        let hsts = StrictTransportSecurity::builder()
171            .max_age_years(1)
172            .include_subdomains()
173            .preload()
174            .build();
175
176        let router = Router::<()>::new().layer(hsts);
177
178        let res = router
179            .oneshot(Request::get("/").body(Body::empty()).unwrap())
180            .await
181            .unwrap();
182
183        assert_eq!(
184            res.headers()[STRICT_TRANSPORT_SECURITY],
185            "max-age=31536000; includeSubDomains; preload"
186        );
187    }
188}