1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
11 if a.len() != b.len() {
12 return false;
13 }
14 a.iter()
15 .zip(b.iter())
16 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
17 == 0
18}
19
20#[must_use]
22pub fn generate_token() -> String {
23 uuid::Uuid::new_v4().to_string()
24}
25
26#[derive(Clone)]
27pub struct AuthState {
28 pub token: Option<String>,
29}
30
31pub async fn require_auth(
37 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
38 request: Request,
39 next: Next,
40) -> Result<Response, StatusCode> {
41 let Some(expected) = &auth.token else {
42 return Ok(next.run(request).await);
43 };
44
45 let provided = request
46 .headers()
47 .get("authorization")
48 .and_then(|v| v.to_str().ok())
49 .and_then(|v| {
50 let lower = v.to_lowercase();
51 if lower.starts_with("bearer ") {
52 Some(v[BEARER_PREFIX_LEN..].to_string())
53 } else {
54 None
55 }
56 });
57
58 match provided {
59 Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
60 Ok(next.run(request).await)
61 }
62 _ => {
63 tracing::warn!("victauri-browser: rejected request — invalid or missing auth token");
64 Err(StatusCode::UNAUTHORIZED)
65 }
66 }
67}
68
69fn now_ms() -> u64 {
70 std::time::SystemTime::now()
71 .duration_since(std::time::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_millis() as u64
74}
75
76pub struct RateLimiterState {
77 tokens: AtomicU64,
78 max_tokens: u64,
79 last_refill_ms: AtomicU64,
80 refill_rate_per_sec: u64,
81}
82
83impl RateLimiterState {
84 #[must_use]
85 pub fn new(max_requests_per_sec: u64) -> Self {
86 Self {
87 tokens: AtomicU64::new(max_requests_per_sec),
88 max_tokens: max_requests_per_sec,
89 last_refill_ms: AtomicU64::new(now_ms()),
90 refill_rate_per_sec: max_requests_per_sec,
91 }
92 }
93
94 pub fn try_acquire(&self) -> bool {
96 self.refill();
97 loop {
98 let current = self.tokens.load(Ordering::Relaxed);
99 if current == 0 {
100 return false;
101 }
102 if self
103 .tokens
104 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
105 .is_ok()
106 {
107 return true;
108 }
109 }
110 }
111
112 fn refill(&self) {
113 let now = now_ms();
114 let last = self.last_refill_ms.load(Ordering::Relaxed);
115 let elapsed_ms = now.saturating_sub(last);
116 if elapsed_ms < 10 {
117 return;
118 }
119 let new_tokens = (elapsed_ms * self.refill_rate_per_sec) / 1000;
120 if new_tokens == 0 {
121 return;
122 }
123 if self
124 .last_refill_ms
125 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
126 .is_ok()
127 {
128 let current = self.tokens.load(Ordering::Relaxed);
129 let capped = (current + new_tokens).min(self.max_tokens);
130 self.tokens.store(capped, Ordering::Relaxed);
131 }
132 }
133}
134
135#[must_use]
137pub fn default_rate_limiter() -> Arc<RateLimiterState> {
138 Arc::new(RateLimiterState::new(1000))
139}
140
141pub async fn rate_limit(
148 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
149 request: Request,
150 next: Next,
151) -> Result<
152 Response,
153 (
154 StatusCode,
155 [(axum::http::HeaderName, axum::http::HeaderValue); 1],
156 ),
157> {
158 if limiter.try_acquire() {
159 Ok(next.run(request).await)
160 } else {
161 Err((
162 StatusCode::TOO_MANY_REQUESTS,
163 [(
164 axum::http::header::RETRY_AFTER,
165 axum::http::HeaderValue::from_static("1"),
166 )],
167 ))
168 }
169}
170
171pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
180 let host = request
181 .headers()
182 .get("host")
183 .and_then(|v| v.to_str().ok())
184 .unwrap_or("");
185 let host_name = if host.starts_with('[') {
186 host.split(']').next().map_or(host, |s| &s[1..])
187 } else if host.contains("::") {
188 host
189 } else {
190 host.split(':').next().unwrap_or(host)
191 };
192 let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1");
193 if !is_allowed {
194 tracing::warn!("DNS rebinding attempt blocked: Host={host}");
195 return Err(StatusCode::FORBIDDEN);
196 }
197 Ok(next.run(request).await)
198}
199
200pub async fn security_headers(request: Request, next: Next) -> Response {
202 let mut response = next.run(request).await;
203 let headers = response.headers_mut();
204 headers.insert(
205 axum::http::header::X_CONTENT_TYPE_OPTIONS,
206 axum::http::HeaderValue::from_static("nosniff"),
207 );
208 headers.insert(
209 axum::http::header::CACHE_CONTROL,
210 axum::http::HeaderValue::from_static("no-store"),
211 );
212 headers.insert(
213 axum::http::header::HeaderName::from_static("x-frame-options"),
214 axum::http::HeaderValue::from_static("DENY"),
215 );
216 headers.insert(
217 axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
218 axum::http::HeaderValue::from_static("null"),
219 );
220 headers.insert(
221 axum::http::header::HeaderName::from_static("content-security-policy"),
222 axum::http::HeaderValue::from_static("default-src 'none'"),
223 );
224 response
225}
226
227pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
235 if let Some(origin) = request
236 .headers()
237 .get("origin")
238 .and_then(|v| v.to_str().ok())
239 {
240 let is_local = is_localhost_origin(origin);
241 if !is_local {
242 tracing::warn!("rejected non-local origin: {origin}");
243 return Err(StatusCode::FORBIDDEN);
244 }
245 }
246 Ok(next.run(request).await)
247}
248
249fn is_localhost_origin(origin: &str) -> bool {
250 let after_scheme = match origin.find("://") {
252 Some(i) => &origin[i + 3..],
253 None => origin,
254 };
255 let host = if after_scheme.starts_with('[') {
257 match after_scheme.find(']') {
259 Some(i) => &after_scheme[..=i],
260 None => after_scheme,
261 }
262 } else {
263 after_scheme.split(':').next().unwrap_or(after_scheme)
264 };
265 let host = host.split('/').next().unwrap_or(host);
267
268 host == "127.0.0.1" || host == "localhost" || host == "[::1]"
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn token_generation() {
277 let t1 = generate_token();
278 let t2 = generate_token();
279 assert_ne!(t1, t2);
280 assert_eq!(t1.len(), 36);
281 }
282
283 #[test]
284 fn rate_limiter_allows_within_budget() {
285 let limiter = RateLimiterState::new(10);
286 for _ in 0..10 {
287 assert!(limiter.try_acquire());
288 }
289 assert!(!limiter.try_acquire());
290 }
291
292 #[test]
293 fn constant_time_eq_works() {
294 assert!(constant_time_eq(b"hello", b"hello"));
295 assert!(!constant_time_eq(b"hello", b"world"));
296 assert!(!constant_time_eq(b"hello", b"hell"));
297 }
298
299 #[test]
300 fn constant_time_eq_empty_strings() {
301 assert!(constant_time_eq(b"", b""));
302 assert!(!constant_time_eq(b"", b"x"));
303 }
304
305 #[test]
306 fn constant_time_eq_single_bit_diff() {
307 assert!(!constant_time_eq(b"\x00", b"\x01"));
308 assert!(!constant_time_eq(b"\xff", b"\xfe"));
309 }
310
311 #[test]
312 fn rate_limiter_single_token() {
313 let limiter = RateLimiterState::new(1);
314 assert!(limiter.try_acquire());
315 assert!(!limiter.try_acquire());
316 }
317
318 #[test]
319 fn token_format_is_uuid() {
320 let token = generate_token();
321 assert_eq!(token.len(), 36);
322 assert_eq!(token.chars().filter(|c| *c == '-').count(), 4);
323 }
324
325 #[test]
326 fn default_rate_limiter_has_budget() {
327 let limiter = default_rate_limiter();
328 assert!(limiter.try_acquire());
329 }
330
331 #[test]
334 fn rate_limiter_exact_boundary() {
335 let limiter = RateLimiterState::new(100);
336 for i in 0..100 {
337 assert!(limiter.try_acquire(), "failed at iteration {i}");
338 }
339 assert!(!limiter.try_acquire());
340 assert!(!limiter.try_acquire());
341 assert!(!limiter.try_acquire());
342 }
343
344 #[test]
345 fn rate_limiter_concurrent_contention() {
346 use std::sync::Arc;
347 use std::thread;
348
349 let limiter = Arc::new(RateLimiterState::new(50));
350 let mut handles = vec![];
351
352 for _ in 0..10 {
353 let l = Arc::clone(&limiter);
354 handles.push(thread::spawn(move || {
355 let mut acquired = 0u32;
356 for _ in 0..20 {
357 if l.try_acquire() {
358 acquired += 1;
359 }
360 }
361 acquired
362 }));
363 }
364
365 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
366 assert!(total <= 50, "acquired {total} but budget was 50");
368 assert!(total >= 45, "should acquire most tokens, got {total}");
369 }
370
371 #[test]
372 fn constant_time_eq_long_strings() {
373 let a = "a".repeat(10_000);
374 let b = "a".repeat(10_000);
375 assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
376
377 let mut c = "a".repeat(10_000);
378 c.push('b');
379 assert!(!constant_time_eq(a.as_bytes(), c.as_bytes()));
380 }
381
382 #[test]
383 fn constant_time_eq_timing_consistency() {
384 let token = "8f14e45f-ceea-367f-a27f-c790e5a0fdc4";
385 let wrong1 = "0000000f-ceea-367f-a27f-c790e5a0fdc4";
386 let wrong2 = "8f14e45f-ceea-367f-a27f-c790e5a0fd00";
387
388 assert!(!constant_time_eq(token.as_bytes(), wrong1.as_bytes()));
390 assert!(!constant_time_eq(token.as_bytes(), wrong2.as_bytes()));
391 }
392
393 #[test]
394 fn token_uniqueness_over_1000_generations() {
395 let mut tokens = std::collections::HashSet::new();
396 for _ in 0..1000 {
397 let t = generate_token();
398 assert!(tokens.insert(t), "duplicate token generated");
399 }
400 }
401
402 #[test]
403 fn rate_limiter_zero_budget() {
404 let limiter = RateLimiterState::new(0);
405 assert!(!limiter.try_acquire());
406 }
407
408 #[test]
409 fn constant_time_eq_all_byte_values() {
410 for b in 0..=255u8 {
411 let a = [b];
412 assert!(constant_time_eq(&a, &a));
413 if b < 255 {
414 let c = [b + 1];
415 assert!(!constant_time_eq(&a, &c));
416 }
417 }
418 }
419
420 #[test]
423 fn dns_rebinding_guard_allows_localhost() {
424 let host = "localhost";
425 let host_name = host.split(':').next().unwrap_or(host);
426 assert!(matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
427 }
428
429 #[test]
430 fn dns_rebinding_guard_allows_127() {
431 let host = "127.0.0.1:7474";
432 let host_name = host.split(':').next().unwrap_or(host);
433 assert!(matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
434 }
435
436 #[test]
437 fn dns_rebinding_guard_blocks_evil() {
438 let host = "evil.com";
439 let host_name = host.split(':').next().unwrap_or(host);
440 assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
441 }
442
443 #[test]
444 fn dns_rebinding_guard_blocks_localhost_subdomain() {
445 let host = "localhost.evil.com";
446 let host_name = host.split(':').next().unwrap_or(host);
447 assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
448 }
449
450 #[test]
451 fn dns_rebinding_guard_blocks_empty() {
452 let host = "";
453 let host_name = host.split(':').next().unwrap_or(host);
454 assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
455 }
456
457 #[test]
460 fn localhost_origin_accepted() {
461 assert!(is_localhost_origin("http://localhost:3000"));
462 assert!(is_localhost_origin("http://localhost"));
463 assert!(is_localhost_origin("https://localhost:7474"));
464 }
465
466 #[test]
467 fn ipv4_loopback_accepted() {
468 assert!(is_localhost_origin("http://127.0.0.1:7474"));
469 assert!(is_localhost_origin("http://127.0.0.1"));
470 assert!(is_localhost_origin("https://127.0.0.1:443"));
471 }
472
473 #[test]
474 fn ipv6_loopback_accepted() {
475 assert!(is_localhost_origin("http://[::1]:7474"));
476 assert!(is_localhost_origin("http://[::1]"));
477 }
478
479 #[test]
480 fn subdomain_bypass_rejected() {
481 assert!(!is_localhost_origin("https://localhost.evil.com"));
482 assert!(!is_localhost_origin("https://127.0.0.1.evil.com"));
483 assert!(!is_localhost_origin("https://evil-localhost.com"));
484 }
485
486 #[test]
487 fn path_bypass_rejected() {
488 assert!(!is_localhost_origin("https://evil.com/localhost"));
489 assert!(!is_localhost_origin("https://evil.com/127.0.0.1"));
490 }
491
492 #[test]
493 fn external_origins_rejected() {
494 assert!(!is_localhost_origin("https://google.com"));
495 assert!(!is_localhost_origin("https://example.com:443"));
496 assert!(!is_localhost_origin("http://attacker.com"));
497 }
498}