axum_security/headers/
hsts.rs1use std::time::Duration;
2
3use axum::http::{HeaderValue, header::STRICT_TRANSPORT_SECURITY};
4use tower::Layer;
5
6use crate::utils::headers::InsertHeadersService;
7
8const PRELOAD_MIN_MAX_AGE: u64 = 365 * 24 * 60 * 60;
9
10#[derive(Clone)]
11pub struct StrictTransportSecurity {
12 header_value: HeaderValue,
13}
14
15impl StrictTransportSecurity {
16 pub fn builder() -> HstsBuilder {
17 HstsBuilder {
18 max_age: None,
19 include_subdomains: false,
20 preload: false,
21 }
22 }
23}
24
25pub struct HstsBuilder {
26 max_age: Option<u64>,
27 include_subdomains: bool,
28 preload: bool,
29}
30
31impl HstsBuilder {
32 pub fn max_age(mut self, duration: Duration) -> Self {
33 self.max_age = Some(duration.as_secs());
34 self
35 }
36
37 pub fn max_age_seconds(mut self, max_age: u64) -> Self {
38 self.max_age = Some(max_age);
39 self
40 }
41
42 pub fn max_age_days(self, max_age: u64) -> Self {
44 self.max_age_seconds(max_age * 24 * 60 * 60)
45 }
46
47 pub fn max_age_years(self, max_age: u64) -> Self {
49 self.max_age_days(max_age * 365)
50 }
51
52 pub fn include_subdomains(mut self) -> Self {
53 self.include_subdomains = true;
54 self
55 }
56
57 pub fn preload(mut self) -> Self {
58 self.preload = true;
59 self
60 }
61
62 pub fn try_build(self) -> Result<StrictTransportSecurity, StrictTransportSecurityBuilderError> {
63 let Some(max_age) = self.max_age else {
64 return Err(StrictTransportSecurityBuilderError::NoMaxAge);
65 };
66
67 let mut header = format!("max-age={max_age}");
68
69 if self.include_subdomains {
70 header.push_str("; includeSubDomains");
71 }
72
73 if self.preload {
74 if max_age < PRELOAD_MIN_MAX_AGE {
75 return Err(StrictTransportSecurityBuilderError::InvalidMaxAge);
76 } else if !self.include_subdomains {
77 return Err(StrictTransportSecurityBuilderError::IncludeSubdomainsRequired);
78 }
79
80 header.push_str("; preload");
81 }
82
83 let header_value =
84 HeaderValue::from_str(&header).expect("Hsts header does not contain invalid bytes");
85
86 Ok(StrictTransportSecurity { header_value })
87 }
88
89 pub fn build(self) -> StrictTransportSecurity {
90 self.try_build().unwrap()
91 }
92}
93
94#[derive(Debug)]
95pub enum StrictTransportSecurityBuilderError {
96 NoMaxAge,
97 InvalidMaxAge,
98 IncludeSubdomainsRequired,
99}
100
101impl<S> Layer<S> for StrictTransportSecurity {
102 type Service = InsertHeadersService<S>;
103
104 fn layer(&self, inner: S) -> Self::Service {
105 InsertHeadersService {
106 inner,
107 header_name: STRICT_TRANSPORT_SECURITY,
108 header_value: self.header_value.clone(),
109 }
110 }
111}
112
113#[cfg(test)]
114mod hsts_tests {
115 use axum::{Router, body::Body, extract::Request, http::header::STRICT_TRANSPORT_SECURITY};
116
117 use crate::headers::{StrictTransportSecurity, StrictTransportSecurityBuilderError};
118 use tower::ServiceExt;
119
120 #[test]
121 fn builder() {
122 let hsts = StrictTransportSecurity::builder().try_build();
123 assert!(matches!(
124 hsts,
125 Err(StrictTransportSecurityBuilderError::NoMaxAge)
126 ));
127
128 let hsts = StrictTransportSecurity::builder()
129 .max_age_days(364)
130 .preload()
131 .try_build();
132 assert!(matches!(
133 hsts,
134 Err(StrictTransportSecurityBuilderError::InvalidMaxAge)
135 ));
136
137 let hsts = StrictTransportSecurity::builder()
138 .max_age_years(1)
139 .preload()
140 .try_build();
141 assert!(matches!(
142 hsts,
143 Err(StrictTransportSecurityBuilderError::IncludeSubdomainsRequired)
144 ));
145 }
146
147 #[test]
148 fn header() {
149 let hsts = StrictTransportSecurity::builder()
150 .max_age_seconds(1)
151 .build();
152 assert!(hsts.header_value == "max-age=1");
153
154 let hsts = StrictTransportSecurity::builder()
155 .max_age_seconds(1)
156 .include_subdomains()
157 .build();
158 assert!(hsts.header_value == "max-age=1; includeSubDomains");
159
160 let hsts = StrictTransportSecurity::builder()
161 .max_age_years(1)
162 .include_subdomains()
163 .preload()
164 .build();
165 assert!(hsts.header_value == "max-age=31536000; includeSubDomains; preload");
166 }
167
168 #[tokio::test]
169 async fn basic() {
170 let hsts = StrictTransportSecurity::builder()
171 .max_age_years(1)
172 .include_subdomains()
173 .preload()
174 .build();
175
176 let router = Router::<()>::new().layer(hsts);
177
178 let res = router
179 .oneshot(Request::get("/").body(Body::empty()).unwrap())
180 .await
181 .unwrap();
182
183 assert_eq!(
184 res.headers()[STRICT_TRANSPORT_SECURITY],
185 "max-age=31536000; includeSubDomains; preload"
186 );
187 }
188}