1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use url::Url;
8
9const BEARER_PREFIX_LEN: usize = "Bearer ".len();
10
11#[must_use]
13pub fn generate_token() -> String {
14 uuid::Uuid::new_v4().to_string()
15}
16
17#[derive(Clone)]
19pub struct AuthState {
20 pub token: Option<String>,
22}
23
24pub async fn require_auth(
30 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
31 request: Request,
32 next: Next,
33) -> Result<Response, StatusCode> {
34 let Some(expected) = &auth.token else {
35 return Ok(next.run(request).await);
36 };
37
38 let provided = request
39 .headers()
40 .get("authorization")
41 .and_then(|v| v.to_str().ok())
42 .and_then(|v| {
43 let lower = v.to_lowercase();
44 if lower.starts_with("bearer ") {
45 Some(v[BEARER_PREFIX_LEN..].to_string())
46 } else {
47 None
48 }
49 });
50
51 match provided {
52 Some(ref token) if token == expected => Ok(next.run(request).await),
53 _ => Err(StatusCode::UNAUTHORIZED),
54 }
55}
56
57pub struct RateLimiterState {
61 tokens: AtomicU64,
62 max_tokens: u64,
63 last_refill_ms: AtomicU64,
64 refill_rate_per_sec: u64,
65}
66
67fn now_ms() -> u64 {
68 std::time::SystemTime::now()
69 .duration_since(std::time::UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_millis() as u64
72}
73
74impl RateLimiterState {
75 #[must_use]
77 pub fn new(max_requests_per_sec: u64) -> Self {
78 Self {
79 tokens: AtomicU64::new(max_requests_per_sec),
80 max_tokens: max_requests_per_sec,
81 last_refill_ms: AtomicU64::new(now_ms()),
82 refill_rate_per_sec: max_requests_per_sec,
83 }
84 }
85
86 pub fn try_acquire(&self) -> bool {
88 self.refill();
89 loop {
90 let current = self.tokens.load(Ordering::Relaxed);
91 if current == 0 {
92 return false;
93 }
94 if self
95 .tokens
96 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
97 .is_ok()
98 {
99 return true;
100 }
101 }
102 }
103
104 fn refill(&self) {
105 let now = now_ms();
106 let last = self.last_refill_ms.load(Ordering::Relaxed);
107 let elapsed_ms = now.saturating_sub(last);
108 if elapsed_ms == 0 {
109 return;
110 }
111 let add = elapsed_ms * self.refill_rate_per_sec / 1000;
112 if add == 0 {
113 return;
114 }
115 if self
116 .last_refill_ms
117 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
118 .is_ok()
119 {
120 loop {
121 let current = self.tokens.load(Ordering::Relaxed);
122 let new_val = (current + add).min(self.max_tokens);
123 if self
124 .tokens
125 .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
126 .is_ok()
127 {
128 break;
129 }
130 }
131 }
132 }
133}
134
135pub async fn rate_limit(
141 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
142 request: Request,
143 next: Next,
144) -> Result<Response, StatusCode> {
145 if limiter.try_acquire() {
146 Ok(next.run(request).await)
147 } else {
148 Err(StatusCode::TOO_MANY_REQUESTS)
149 }
150}
151
152const DEFAULT_RATE_LIMIT: u64 = 1000;
153
154#[must_use]
156pub fn default_rate_limiter() -> Arc<RateLimiterState> {
157 Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
158}
159
160pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
170 let host = request
171 .headers()
172 .get("host")
173 .and_then(|v| v.to_str().ok())
174 .unwrap_or("");
175 let host_name = if host.starts_with('[') {
176 host.split(']').next().map_or(host, |s| &s[1..])
177 } else {
178 host.split(':').next().unwrap_or(host)
179 };
180 let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
181 if !is_allowed {
182 tracing::warn!("DNS rebinding attempt blocked: Host={host}");
183 return Err(StatusCode::FORBIDDEN);
184 }
185 Ok(next.run(request).await)
186}
187
188pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
195 if let Some(origin) = request
196 .headers()
197 .get("origin")
198 .and_then(|v| v.to_str().ok())
199 && !is_allowed_origin(origin)
200 {
201 tracing::warn!("Cross-origin request blocked: Origin={origin}");
202 return Err(StatusCode::FORBIDDEN);
203 }
204 Ok(next.run(request).await)
205}
206
207fn is_allowed_origin(origin: &str) -> bool {
208 if origin.starts_with("tauri://") {
209 return true;
210 }
211 let Ok(parsed) = Url::parse(origin) else {
212 return false;
213 };
214 matches!(parsed.scheme(), "http" | "https")
215 && matches!(
216 parsed.host_str(),
217 Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
218 )
219}
220
221pub async fn security_headers(request: Request, next: Next) -> Response {
223 let mut response = next.run(request).await;
224 let headers = response.headers_mut();
225 headers.insert(
226 axum::http::header::X_CONTENT_TYPE_OPTIONS,
227 axum::http::HeaderValue::from_static("nosniff"),
228 );
229 headers.insert(
230 axum::http::header::CACHE_CONTROL,
231 axum::http::HeaderValue::from_static("no-store"),
232 );
233 headers.insert(
234 axum::http::header::HeaderName::from_static("x-frame-options"),
235 axum::http::HeaderValue::from_static("DENY"),
236 );
237 response
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use axum::Router;
244 use axum::body::Body;
245 use axum::middleware;
246 use axum::routing::get;
247 use tower::ServiceExt; async fn ok_handler() -> &'static str {
250 "ok"
251 }
252
253 #[test]
254 fn token_generation_is_unique() {
255 let t1 = generate_token();
256 let t2 = generate_token();
257 assert_ne!(t1, t2);
258 assert_eq!(t1.len(), 36); }
260
261 #[test]
262 fn token_is_valid_uuid() {
263 let token = generate_token();
264 assert!(uuid::Uuid::parse_str(&token).is_ok());
265 }
266
267 #[test]
268 fn rate_limiter_allows_within_budget() {
269 let limiter = RateLimiterState::new(10);
270 for _ in 0..10 {
271 assert!(limiter.try_acquire());
272 }
273 }
274
275 #[test]
276 fn rate_limiter_denies_when_exhausted() {
277 let limiter = RateLimiterState::new(5);
278 for _ in 0..5 {
279 assert!(limiter.try_acquire());
280 }
281 assert!(!limiter.try_acquire());
282 }
283
284 #[test]
285 fn rate_limiter_initial_tokens_match_max() {
286 let limiter = RateLimiterState::new(42);
287 assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
288 assert_eq!(limiter.max_tokens, 42);
289 }
290
291 #[test]
292 fn rate_limiter_concurrent_acquire() {
293 let limiter = Arc::new(RateLimiterState::new(1000));
295 let mut handles = vec![];
296 for _ in 0..10 {
297 let l = limiter.clone();
298 handles.push(std::thread::spawn(move || {
299 let mut acquired = 0;
300 for _ in 0..200 {
301 if l.try_acquire() {
302 acquired += 1;
303 }
304 }
305 acquired
306 }));
307 }
308 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
309 assert!((1000..=1010).contains(&total));
311 }
312
313 #[test]
314 fn default_rate_limiter_has_expected_tokens() {
315 let limiter = default_rate_limiter();
316 assert_eq!(limiter.max_tokens, 1000);
317 }
318
319 #[test]
320 fn rate_limiter_zero_capacity() {
321 let limiter = RateLimiterState::new(0);
322 assert!(!limiter.try_acquire());
323 }
324
325 fn dns_rebinding_router() -> Router {
328 Router::new()
329 .route("/test", get(ok_handler))
330 .layer(middleware::from_fn(dns_rebinding_guard))
331 }
332
333 fn dns_request(host: Option<&str>) -> Request<Body> {
334 let mut builder = Request::builder().uri("/test");
335 if let Some(h) = host {
336 builder = builder.header("host", h);
337 }
338 builder.body(Body::empty()).unwrap()
339 }
340
341 #[tokio::test]
342 async fn dns_rebinding_allows_localhost() {
343 let app = dns_rebinding_router();
344 let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
345 assert_eq!(resp.status(), StatusCode::OK);
346 }
347
348 #[tokio::test]
349 async fn dns_rebinding_allows_127_0_0_1() {
350 let app = dns_rebinding_router();
351 let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
352 assert_eq!(resp.status(), StatusCode::OK);
353 }
354
355 #[tokio::test]
356 async fn dns_rebinding_allows_ipv6_bracketed() {
357 let app = dns_rebinding_router();
358 let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
359 assert_eq!(resp.status(), StatusCode::OK);
360 }
361
362 #[tokio::test]
363 async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
364 let app = dns_rebinding_router();
365 let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
366 assert_eq!(resp.status(), StatusCode::OK);
367 }
368
369 #[tokio::test]
370 async fn dns_rebinding_allows_ipv6_bare() {
371 let app = dns_rebinding_router();
372 let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
373 assert_eq!(resp.status(), StatusCode::OK);
374 }
375
376 #[tokio::test]
377 async fn dns_rebinding_allows_empty_host() {
378 let app = dns_rebinding_router();
379 let resp = app.oneshot(dns_request(None)).await.unwrap();
380 assert_eq!(resp.status(), StatusCode::OK);
381 }
382
383 #[tokio::test]
384 async fn dns_rebinding_blocks_evil_com() {
385 let app = dns_rebinding_router();
386 let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
387 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
388 }
389
390 #[tokio::test]
391 async fn dns_rebinding_blocks_localhost_subdomain() {
392 let app = dns_rebinding_router();
393 let resp = app
394 .oneshot(dns_request(Some("localhost.evil.com")))
395 .await
396 .unwrap();
397 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
398 }
399
400 #[tokio::test]
401 async fn dns_rebinding_blocks_ip_subdomain() {
402 let app = dns_rebinding_router();
403 let resp = app
404 .oneshot(dns_request(Some("127.0.0.1.evil.com")))
405 .await
406 .unwrap();
407 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
408 }
409
410 fn origin_router() -> Router {
413 Router::new()
414 .route("/test", get(ok_handler))
415 .layer(middleware::from_fn(origin_guard))
416 }
417
418 fn origin_request(origin: Option<&str>) -> Request<Body> {
419 let mut builder = Request::builder().uri("/test");
420 if let Some(o) = origin {
421 builder = builder.header("origin", o);
422 }
423 builder.body(Body::empty()).unwrap()
424 }
425
426 #[tokio::test]
427 async fn origin_allows_no_origin() {
428 let app = origin_router();
429 let resp = app.oneshot(origin_request(None)).await.unwrap();
430 assert_eq!(resp.status(), StatusCode::OK);
431 }
432
433 #[tokio::test]
434 async fn origin_allows_localhost_http() {
435 let app = origin_router();
436 let resp = app
437 .oneshot(origin_request(Some("http://localhost:3000")))
438 .await
439 .unwrap();
440 assert_eq!(resp.status(), StatusCode::OK);
441 }
442
443 #[tokio::test]
444 async fn origin_allows_127_0_0_1_https() {
445 let app = origin_router();
446 let resp = app
447 .oneshot(origin_request(Some("https://127.0.0.1:8080")))
448 .await
449 .unwrap();
450 assert_eq!(resp.status(), StatusCode::OK);
451 }
452
453 #[tokio::test]
454 async fn origin_allows_tauri_scheme() {
455 let app = origin_router();
456 let resp = app
457 .oneshot(origin_request(Some("tauri://localhost")))
458 .await
459 .unwrap();
460 assert_eq!(resp.status(), StatusCode::OK);
461 }
462
463 #[tokio::test]
464 async fn origin_blocks_null() {
465 let app = origin_router();
466 let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
467 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
468 }
469
470 #[tokio::test]
471 async fn origin_blocks_evil_com() {
472 let app = origin_router();
473 let resp = app
474 .oneshot(origin_request(Some("http://evil.com")))
475 .await
476 .unwrap();
477 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
478 }
479
480 fn security_headers_router() -> Router {
483 Router::new()
484 .route("/test", get(ok_handler))
485 .layer(middleware::from_fn(security_headers))
486 }
487
488 #[tokio::test]
489 async fn security_headers_x_content_type_options() {
490 let app = security_headers_router();
491 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
492 let resp = app.oneshot(req).await.unwrap();
493 assert_eq!(resp.status(), StatusCode::OK);
494 assert_eq!(
495 resp.headers().get("x-content-type-options").unwrap(),
496 "nosniff"
497 );
498 }
499
500 #[tokio::test]
501 async fn security_headers_cache_control() {
502 let app = security_headers_router();
503 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
504 let resp = app.oneshot(req).await.unwrap();
505 assert_eq!(resp.status(), StatusCode::OK);
506 assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
507 }
508
509 #[tokio::test]
510 async fn security_headers_x_frame_options() {
511 let app = security_headers_router();
512 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
513 let resp = app.oneshot(req).await.unwrap();
514 assert_eq!(resp.status(), StatusCode::OK);
515 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
516 }
517
518 fn auth_router(token: Option<&str>) -> Router {
521 let state = Arc::new(AuthState {
522 token: token.map(String::from),
523 });
524 Router::new()
525 .route("/test", get(ok_handler))
526 .layer(middleware::from_fn_with_state(state, require_auth))
527 }
528
529 fn auth_request(token: Option<&str>) -> Request<Body> {
530 let mut builder = Request::builder().uri("/test");
531 if let Some(t) = token {
532 builder = builder.header("authorization", format!("Bearer {t}"));
533 }
534 builder.body(Body::empty()).unwrap()
535 }
536
537 #[tokio::test]
538 async fn auth_allows_correct_token() {
539 let app = auth_router(Some("secret-123"));
540 let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
541 assert_eq!(resp.status(), StatusCode::OK);
542 }
543
544 #[tokio::test]
545 async fn auth_rejects_wrong_token() {
546 let app = auth_router(Some("secret-123"));
547 let resp = app
548 .oneshot(auth_request(Some("wrong-token")))
549 .await
550 .unwrap();
551 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
552 }
553
554 #[tokio::test]
555 async fn auth_rejects_missing_token() {
556 let app = auth_router(Some("secret-123"));
557 let resp = app.oneshot(auth_request(None)).await.unwrap();
558 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
559 }
560
561 #[tokio::test]
562 async fn auth_allows_any_when_disabled() {
563 let app = auth_router(None);
564 let resp = app.oneshot(auth_request(None)).await.unwrap();
565 assert_eq!(resp.status(), StatusCode::OK);
566 }
567
568 #[tokio::test]
569 async fn auth_case_insensitive_bearer_prefix() {
570 let state = Arc::new(AuthState {
571 token: Some("my-token".into()),
572 });
573 let app = Router::new()
574 .route("/test", get(ok_handler))
575 .layer(middleware::from_fn_with_state(state, require_auth));
576
577 let req = Request::builder()
578 .uri("/test")
579 .header("authorization", "BEARER my-token")
580 .body(Body::empty())
581 .unwrap();
582 let resp = app.oneshot(req).await.unwrap();
583 assert_eq!(resp.status(), StatusCode::OK);
584 }
585
586 #[tokio::test]
587 async fn auth_rejects_non_bearer_scheme() {
588 let app = auth_router(Some("secret"));
589 let req = Request::builder()
590 .uri("/test")
591 .header("authorization", "Basic c2VjcmV0")
592 .body(Body::empty())
593 .unwrap();
594 let resp = app.oneshot(req).await.unwrap();
595 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
596 }
597
598 #[tokio::test]
601 async fn rate_limiter_returns_429_when_exhausted() {
602 let limiter = Arc::new(RateLimiterState::new(2));
603 let app = Router::new()
604 .route("/test", get(ok_handler))
605 .layer(middleware::from_fn_with_state(limiter, rate_limit));
606
607 let app2 = app.clone();
608 let app3 = app2.clone();
609
610 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
611 assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
612
613 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
614 assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
615
616 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
617 assert_eq!(
618 app3.oneshot(req).await.unwrap().status(),
619 StatusCode::TOO_MANY_REQUESTS
620 );
621 }
622
623 #[tokio::test]
626 async fn combined_layers_enforce_all_guards() {
627 let auth_state = Arc::new(AuthState {
628 token: Some("tok-123".into()),
629 });
630 let limiter = Arc::new(RateLimiterState::new(100));
631
632 let app = Router::new()
633 .route("/test", get(ok_handler))
634 .layer(middleware::from_fn_with_state(auth_state, require_auth))
635 .layer(middleware::from_fn_with_state(limiter, rate_limit))
636 .layer(middleware::from_fn(security_headers))
637 .layer(middleware::from_fn(origin_guard))
638 .layer(middleware::from_fn(dns_rebinding_guard));
639
640 let req = Request::builder()
642 .uri("/test")
643 .header("authorization", "Bearer tok-123")
644 .header("host", "127.0.0.1:7373")
645 .body(Body::empty())
646 .unwrap();
647 let resp = app.clone().oneshot(req).await.unwrap();
648 assert_eq!(resp.status(), StatusCode::OK);
649 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
650
651 let req = Request::builder()
653 .uri("/test")
654 .header("authorization", "Bearer tok-123")
655 .header("host", "evil.com")
656 .body(Body::empty())
657 .unwrap();
658 let resp = app.clone().oneshot(req).await.unwrap();
659 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
660
661 let req = Request::builder()
663 .uri("/test")
664 .header("authorization", "Bearer tok-123")
665 .header("host", "localhost")
666 .header("origin", "https://evil.com")
667 .body(Body::empty())
668 .unwrap();
669 let resp = app.clone().oneshot(req).await.unwrap();
670 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
671
672 let req = Request::builder()
674 .uri("/test")
675 .header("host", "localhost")
676 .body(Body::empty())
677 .unwrap();
678 let resp = app.oneshot(req).await.unwrap();
679 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
680 }
681
682 #[test]
683 fn origin_guard_allows_localhost_variants() {
684 assert!(is_allowed_origin("http://localhost"));
685 assert!(is_allowed_origin("http://localhost:7373"));
686 assert!(is_allowed_origin("https://localhost"));
687 assert!(is_allowed_origin("https://localhost:443"));
688 assert!(is_allowed_origin("http://127.0.0.1"));
689 assert!(is_allowed_origin("http://127.0.0.1:8080"));
690 assert!(is_allowed_origin("https://127.0.0.1"));
691 assert!(is_allowed_origin("http://[::1]"));
692 assert!(is_allowed_origin("http://[::1]:7373"));
693 assert!(is_allowed_origin("tauri://localhost"));
694 assert!(is_allowed_origin("tauri://some-app"));
695 }
696
697 #[test]
698 fn origin_guard_rejects_prefix_smuggling() {
699 assert!(!is_allowed_origin("http://localhost.evil.com"));
700 assert!(!is_allowed_origin("https://localhost.evil.com"));
701 assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
702 assert!(!is_allowed_origin("http://[::1].evil.com"));
703 }
704
705 #[test]
706 fn origin_guard_rejects_userinfo_trick() {
707 assert!(!is_allowed_origin("http://localhost@evil.com"));
708 assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
709 }
710
711 #[test]
712 fn origin_guard_rejects_foreign_and_malformed() {
713 assert!(!is_allowed_origin("http://evil.com"));
714 assert!(!is_allowed_origin("https://attacker.io"));
715 assert!(!is_allowed_origin("not-a-url"));
716 assert!(!is_allowed_origin(""));
717 assert!(!is_allowed_origin("ftp://localhost"));
718 }
719}