1pub use rune_axum_csrf::{CsrfConfig, CsrfLayer, CsrfService};
64pub use rune_axum_helmet::{
65 CrossOriginEmbedderPolicy, CrossOriginOpenerPolicy, CrossOriginResourcePolicy, Helmet,
66 HelmetLayer, Hsts, ReferrerPolicy, XFrameOptions,
67};
68pub use rune_axum_ipfilter::{FilterMode, IpFilterConfig, IpFilterLayer, IpFilterService, IpSource};
69pub use rune_axum_ratelimit::{KeyExtractor, RateLimitConfig, RateLimitLayer, RateLimitService};
70pub use rune_axum_redirect_https::{RedirectHttps, RedirectHttpsLayer};
71pub use rune_axum_size::{BodyLimit, BodyLimitLayer, BodyLimitService};
72
73#[derive(Clone, Debug)]
111pub struct SecurityStack {
112 helmet: Option<Helmet>,
113 csrf: Option<CsrfConfig>,
114 ratelimit: Option<RateLimitConfig>,
115 ipfilter: Option<IpFilterConfig>,
116 redirect_https: Option<RedirectHttps>,
117 body_limit: Option<BodyLimit>,
118}
119
120impl Default for SecurityStack {
121 fn default() -> Self {
122 Self {
123 helmet: Some(Helmet::new()),
124 csrf: Some(CsrfConfig::new()),
125 ratelimit: Some(RateLimitConfig::new()),
126 ipfilter: None,
127 redirect_https: Some(RedirectHttps::new()),
128 body_limit: Some(BodyLimit::new()),
129 }
130 }
131}
132
133impl SecurityStack {
134 pub fn new() -> Self {
147 Self::default()
148 }
149
150 pub fn helmet(mut self, config: Helmet) -> Self {
152 self.helmet = Some(config);
153 self
154 }
155
156 pub fn without_helmet(mut self) -> Self {
158 self.helmet = None;
159 self
160 }
161
162 pub fn csrf(mut self, config: CsrfConfig) -> Self {
164 self.csrf = Some(config);
165 self
166 }
167
168 pub fn without_csrf(mut self) -> Self {
173 self.csrf = None;
174 self
175 }
176
177 pub fn ratelimit(mut self, config: RateLimitConfig) -> Self {
179 self.ratelimit = Some(config);
180 self
181 }
182
183 pub fn without_ratelimit(mut self) -> Self {
185 self.ratelimit = None;
186 self
187 }
188
189 pub fn ipfilter(mut self, config: IpFilterConfig) -> Self {
193 self.ipfilter = Some(config);
194 self
195 }
196
197 pub fn without_ipfilter(mut self) -> Self {
199 self.ipfilter = None;
200 self
201 }
202
203 pub fn redirect_https(mut self, config: RedirectHttps) -> Self {
205 self.redirect_https = Some(config);
206 self
207 }
208
209 pub fn without_redirect_https(mut self) -> Self {
214 self.redirect_https = None;
215 self
216 }
217
218 pub fn body_limit(mut self, config: BodyLimit) -> Self {
220 self.body_limit = Some(config);
221 self
222 }
223
224 pub fn without_body_limit(mut self) -> Self {
226 self.body_limit = None;
227 self
228 }
229
230 pub fn apply<S>(self, router: axum::Router<S>) -> axum::Router<S>
246 where
247 S: Clone + Send + Sync + 'static,
248 {
249 let mut r = router;
250
251 if let Some(config) = self.csrf {
253 r = r.layer(CsrfLayer::new(config));
254 }
255 if let Some(config) = self.body_limit {
256 r = r.layer(BodyLimitLayer::new(config));
257 }
258 if let Some(config) = self.ratelimit {
259 r = r.layer(RateLimitLayer::new(config));
260 }
261 if let Some(config) = self.ipfilter {
262 r = r.layer(IpFilterLayer::new(config));
263 }
264 if let Some(config) = self.redirect_https {
265 r = r.layer(RedirectHttpsLayer::new(config));
266 }
267 if let Some(config) = self.helmet {
269 r = r.layer(HelmetLayer::new(config));
270 }
271
272 r
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use axum::{
280 body::Body,
281 routing::{get, post},
282 Router,
283 };
284 use http::StatusCode;
285 use tower::ServiceExt;
286
287 fn base_router() -> Router {
288 Router::new()
289 .route("/", get(|| async { "ok" }))
290 .route("/submit", post(|| async { "ok" }))
291 }
292
293 async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
294 app.oneshot(req).await.unwrap()
295 }
296
297 fn get_req() -> http::Request<Body> {
298 http::Request::builder()
299 .method("GET")
300 .uri("/")
301 .header("x-forwarded-for", "1.2.3.4")
302 .body(Body::empty())
303 .unwrap()
304 }
305
306 fn post_req() -> http::Request<Body> {
307 http::Request::builder()
308 .method("POST")
309 .uri("/submit")
310 .header("x-forwarded-for", "1.2.3.4")
311 .body(Body::empty())
312 .unwrap()
313 }
314
315 #[tokio::test]
316 async fn default_stack_allows_get() {
317 let app = SecurityStack::default().apply(base_router());
318 assert_eq!(send(app, get_req()).await.status(), StatusCode::OK);
319 }
320
321 #[tokio::test]
322 async fn default_stack_sets_security_headers() {
323 let app = SecurityStack::default().apply(base_router());
324 let resp = send(app, get_req()).await;
325 assert_eq!(resp.status(), StatusCode::OK);
326 assert!(resp.headers().contains_key("x-content-type-options"));
327 assert!(resp.headers().contains_key("x-frame-options"));
328 }
329
330 #[tokio::test]
331 async fn default_stack_blocks_post_without_csrf_token() {
332 let app = SecurityStack::default().apply(base_router());
333 assert_eq!(send(app, post_req()).await.status(), StatusCode::FORBIDDEN);
334 }
335
336 #[tokio::test]
337 async fn without_csrf_allows_post() {
338 let app = SecurityStack::default().without_csrf().apply(base_router());
339 assert_eq!(send(app, post_req()).await.status(), StatusCode::OK);
340 }
341
342 #[tokio::test]
343 async fn custom_ratelimit_zero_blocks_request() {
344 let app = SecurityStack::new()
345 .without_csrf()
346 .without_redirect_https()
347 .ratelimit(RateLimitConfig::new().requests(0))
348 .apply(base_router());
349 assert_eq!(send(app, get_req()).await.status(), StatusCode::TOO_MANY_REQUESTS);
350 }
351
352 #[tokio::test]
353 async fn without_ratelimit_passes() {
354 let app = SecurityStack::new()
355 .without_csrf()
356 .without_redirect_https()
357 .without_ratelimit()
358 .apply(base_router());
359 assert_eq!(send(app, get_req()).await.status(), StatusCode::OK);
360 }
361
362 #[tokio::test]
363 async fn body_limit_rejects_oversized_request() {
364 let app = SecurityStack::new()
365 .without_csrf()
366 .without_redirect_https()
367 .body_limit(BodyLimit::new().limit(10))
368 .apply(base_router());
369 let req = http::Request::builder()
370 .method("POST")
371 .uri("/submit")
372 .header("content-length", "100")
373 .body(Body::empty())
374 .unwrap();
375 assert_eq!(send(app, req).await.status(), StatusCode::PAYLOAD_TOO_LARGE);
376 }
377
378 #[tokio::test]
379 async fn custom_helmet_deny_frame_options() {
380 let app = SecurityStack::new()
381 .without_csrf()
382 .without_redirect_https()
383 .helmet(Helmet::new().frame_options(XFrameOptions::Deny))
384 .apply(base_router());
385 let resp = send(app, get_req()).await;
386 assert_eq!(
387 resp.headers().get("x-frame-options").unwrap().to_str().unwrap(),
388 "DENY"
389 );
390 }
391
392 #[tokio::test]
393 async fn without_helmet_no_security_headers() {
394 let app = SecurityStack::new()
395 .without_csrf()
396 .without_redirect_https()
397 .without_helmet()
398 .apply(base_router());
399 let resp = send(app, get_req()).await;
400 assert!(!resp.headers().contains_key("x-content-type-options"));
401 }
402
403 #[tokio::test]
404 async fn redirect_https_triggers_on_forwarded_proto() {
405 let app = SecurityStack::new()
406 .without_csrf()
407 .apply(base_router());
408 let req = http::Request::builder()
409 .method("GET")
410 .uri("http://example.com/")
411 .header("x-forwarded-proto", "http")
412 .header("host", "example.com")
413 .body(Body::empty())
414 .unwrap();
415 assert_eq!(send(app, req).await.status(), StatusCode::PERMANENT_REDIRECT);
416 }
417
418 #[tokio::test]
419 async fn without_redirect_https_no_redirect() {
420 let app = SecurityStack::new()
421 .without_csrf()
422 .without_redirect_https()
423 .apply(base_router());
424 let req = http::Request::builder()
425 .method("GET")
426 .uri("http://example.com/")
427 .header("x-forwarded-proto", "http")
428 .header("host", "example.com")
429 .body(Body::empty())
430 .unwrap();
431 assert_eq!(send(app, req).await.status(), StatusCode::OK);
432 }
433
434 #[tokio::test]
435 async fn ipfilter_blocks_configured_cidr() {
436 let app = SecurityStack::new()
437 .without_csrf()
438 .without_redirect_https()
439 .ipfilter(
440 IpFilterConfig::new()
441 .mode(FilterMode::Blocklist)
442 .cidr("1.2.3.4/32"),
443 )
444 .apply(base_router());
445 assert_eq!(send(app, get_req()).await.status(), StatusCode::FORBIDDEN);
446 }
447
448 #[tokio::test]
449 async fn security_headers_present_on_rate_limit_rejection() {
450 let app = SecurityStack::new()
451 .without_csrf()
452 .without_redirect_https()
453 .ratelimit(RateLimitConfig::new().requests(0))
454 .apply(base_router());
455 let resp = send(app, get_req()).await;
456 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
457 assert!(resp.headers().contains_key("x-content-type-options"));
458 }
459
460 #[tokio::test]
461 async fn all_layers_disabled_passes_through() {
462 let app = SecurityStack::new()
463 .without_helmet()
464 .without_csrf()
465 .without_ratelimit()
466 .without_ipfilter()
467 .without_redirect_https()
468 .without_body_limit()
469 .apply(base_router());
470 assert_eq!(send(app, get_req()).await.status(), StatusCode::OK);
471 }
472}