Skip to main content

victauri_browser/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
11    if a.len() != b.len() {
12        return false;
13    }
14    a.iter()
15        .zip(b.iter())
16        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
17        == 0
18}
19
20/// Generate a random UUID v4 token for Bearer authentication.
21#[must_use]
22pub fn generate_token() -> String {
23    uuid::Uuid::new_v4().to_string()
24}
25
26#[derive(Clone)]
27pub struct AuthState {
28    pub token: Option<String>,
29}
30
31/// Axum middleware that validates Bearer token authentication.
32///
33/// # Errors
34///
35/// Returns `401 Unauthorized` if the token is missing or invalid.
36pub async fn require_auth(
37    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
38    request: Request,
39    next: Next,
40) -> Result<Response, StatusCode> {
41    let Some(expected) = &auth.token else {
42        return Ok(next.run(request).await);
43    };
44
45    let provided = request
46        .headers()
47        .get("authorization")
48        .and_then(|v| v.to_str().ok())
49        .and_then(|v| {
50            let lower = v.to_lowercase();
51            if lower.starts_with("bearer ") {
52                Some(v[BEARER_PREFIX_LEN..].to_string())
53            } else {
54                None
55            }
56        });
57
58    match provided {
59        Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
60            Ok(next.run(request).await)
61        }
62        _ => {
63            tracing::warn!("victauri-browser: rejected request — invalid or missing auth token");
64            Err(StatusCode::UNAUTHORIZED)
65        }
66    }
67}
68
69fn now_ms() -> u64 {
70    std::time::SystemTime::now()
71        .duration_since(std::time::UNIX_EPOCH)
72        .unwrap_or_default()
73        .as_millis() as u64
74}
75
76pub struct RateLimiterState {
77    tokens: AtomicU64,
78    max_tokens: u64,
79    last_refill_ms: AtomicU64,
80    refill_rate_per_sec: u64,
81}
82
83impl RateLimiterState {
84    #[must_use]
85    pub fn new(max_requests_per_sec: u64) -> Self {
86        Self {
87            tokens: AtomicU64::new(max_requests_per_sec),
88            max_tokens: max_requests_per_sec,
89            last_refill_ms: AtomicU64::new(now_ms()),
90            refill_rate_per_sec: max_requests_per_sec,
91        }
92    }
93
94    /// Try to consume one token. Returns `true` if allowed.
95    pub fn try_acquire(&self) -> bool {
96        self.refill();
97        loop {
98            let current = self.tokens.load(Ordering::Relaxed);
99            if current == 0 {
100                return false;
101            }
102            if self
103                .tokens
104                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
105                .is_ok()
106            {
107                return true;
108            }
109        }
110    }
111
112    fn refill(&self) {
113        let now = now_ms();
114        let last = self.last_refill_ms.load(Ordering::Relaxed);
115        let elapsed_ms = now.saturating_sub(last);
116        if elapsed_ms < 10 {
117            return;
118        }
119        let new_tokens = (elapsed_ms * self.refill_rate_per_sec) / 1000;
120        if new_tokens == 0 {
121            return;
122        }
123        if self
124            .last_refill_ms
125            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
126            .is_ok()
127        {
128            let current = self.tokens.load(Ordering::Relaxed);
129            let capped = (current + new_tokens).min(self.max_tokens);
130            self.tokens.store(capped, Ordering::Relaxed);
131        }
132    }
133}
134
135/// Default rate limiter: 1000 requests per second.
136#[must_use]
137pub fn default_rate_limiter() -> Arc<RateLimiterState> {
138    Arc::new(RateLimiterState::new(1000))
139}
140
141/// Axum middleware for rate limiting.
142///
143/// # Errors
144///
145/// Returns `429 Too Many Requests` with `Retry-After: 1` header when the rate
146/// limit is exceeded.
147pub async fn rate_limit(
148    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
149    request: Request,
150    next: Next,
151) -> Result<
152    Response,
153    (
154        StatusCode,
155        [(axum::http::HeaderName, axum::http::HeaderValue); 1],
156    ),
157> {
158    if limiter.try_acquire() {
159        Ok(next.run(request).await)
160    } else {
161        Err((
162            StatusCode::TOO_MANY_REQUESTS,
163            [(
164                axum::http::header::RETRY_AFTER,
165                axum::http::HeaderValue::from_static("1"),
166            )],
167        ))
168    }
169}
170
171/// Axum middleware that blocks DNS rebinding attacks.
172///
173/// Rejects any request where the Host header is not a localhost address.
174///
175/// # Errors
176///
177/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`,
178/// `127.0.0.1`, or `::1`.
179pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
180    let host = request
181        .headers()
182        .get("host")
183        .and_then(|v| v.to_str().ok())
184        .unwrap_or("");
185    let host_name = if host.starts_with('[') {
186        host.split(']').next().map_or(host, |s| &s[1..])
187    } else if host.contains("::") {
188        host
189    } else {
190        host.split(':').next().unwrap_or(host)
191    };
192    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1");
193    if !is_allowed {
194        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
195        return Err(StatusCode::FORBIDDEN);
196    }
197    Ok(next.run(request).await)
198}
199
200/// Axum middleware that sets security-hardening response headers on every response.
201pub async fn security_headers(request: Request, next: Next) -> Response {
202    let mut response = next.run(request).await;
203    let headers = response.headers_mut();
204    headers.insert(
205        axum::http::header::X_CONTENT_TYPE_OPTIONS,
206        axum::http::HeaderValue::from_static("nosniff"),
207    );
208    headers.insert(
209        axum::http::header::CACHE_CONTROL,
210        axum::http::HeaderValue::from_static("no-store"),
211    );
212    headers.insert(
213        axum::http::header::HeaderName::from_static("x-frame-options"),
214        axum::http::HeaderValue::from_static("DENY"),
215    );
216    headers.insert(
217        axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
218        axum::http::HeaderValue::from_static("null"),
219    );
220    headers.insert(
221        axum::http::header::HeaderName::from_static("content-security-policy"),
222        axum::http::HeaderValue::from_static("default-src 'none'"),
223    );
224    response
225}
226
227/// Localhost origin guard: rejects requests with non-localhost Origin header.
228///
229/// Parses the origin as a URL and checks the host component directly,
230/// preventing bypass via subdomains like "localhost.evil.com".
231///
232/// # Errors
233/// Returns `403 Forbidden` if the Origin header contains a non-localhost host.
234pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
235    if let Some(origin) = request
236        .headers()
237        .get("origin")
238        .and_then(|v| v.to_str().ok())
239    {
240        let is_local = is_localhost_origin(origin);
241        if !is_local {
242            tracing::warn!("rejected non-local origin: {origin}");
243            return Err(StatusCode::FORBIDDEN);
244        }
245    }
246    Ok(next.run(request).await)
247}
248
249fn is_localhost_origin(origin: &str) -> bool {
250    // Extract the host from scheme://host[:port]
251    let after_scheme = match origin.find("://") {
252        Some(i) => &origin[i + 3..],
253        None => origin,
254    };
255    // Strip port if present
256    let host = if after_scheme.starts_with('[') {
257        // IPv6: [::1]:port
258        match after_scheme.find(']') {
259            Some(i) => &after_scheme[..=i],
260            None => after_scheme,
261        }
262    } else {
263        after_scheme.split(':').next().unwrap_or(after_scheme)
264    };
265    // Strip trailing path if any
266    let host = host.split('/').next().unwrap_or(host);
267
268    host == "127.0.0.1" || host == "localhost" || host == "[::1]"
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn token_generation() {
277        let t1 = generate_token();
278        let t2 = generate_token();
279        assert_ne!(t1, t2);
280        assert_eq!(t1.len(), 36);
281    }
282
283    #[test]
284    fn rate_limiter_allows_within_budget() {
285        let limiter = RateLimiterState::new(10);
286        for _ in 0..10 {
287            assert!(limiter.try_acquire());
288        }
289        assert!(!limiter.try_acquire());
290    }
291
292    #[test]
293    fn constant_time_eq_works() {
294        assert!(constant_time_eq(b"hello", b"hello"));
295        assert!(!constant_time_eq(b"hello", b"world"));
296        assert!(!constant_time_eq(b"hello", b"hell"));
297    }
298
299    #[test]
300    fn constant_time_eq_empty_strings() {
301        assert!(constant_time_eq(b"", b""));
302        assert!(!constant_time_eq(b"", b"x"));
303    }
304
305    #[test]
306    fn constant_time_eq_single_bit_diff() {
307        assert!(!constant_time_eq(b"\x00", b"\x01"));
308        assert!(!constant_time_eq(b"\xff", b"\xfe"));
309    }
310
311    #[test]
312    fn rate_limiter_single_token() {
313        let limiter = RateLimiterState::new(1);
314        assert!(limiter.try_acquire());
315        assert!(!limiter.try_acquire());
316    }
317
318    #[test]
319    fn token_format_is_uuid() {
320        let token = generate_token();
321        assert_eq!(token.len(), 36);
322        assert_eq!(token.chars().filter(|c| *c == '-').count(), 4);
323    }
324
325    #[test]
326    fn default_rate_limiter_has_budget() {
327        let limiter = default_rate_limiter();
328        assert!(limiter.try_acquire());
329    }
330
331    // --- Adversarial stress tests ---
332
333    #[test]
334    fn rate_limiter_exact_boundary() {
335        let limiter = RateLimiterState::new(100);
336        for i in 0..100 {
337            assert!(limiter.try_acquire(), "failed at iteration {i}");
338        }
339        assert!(!limiter.try_acquire());
340        assert!(!limiter.try_acquire());
341        assert!(!limiter.try_acquire());
342    }
343
344    #[test]
345    fn rate_limiter_concurrent_contention() {
346        use std::sync::Arc;
347        use std::thread;
348
349        let limiter = Arc::new(RateLimiterState::new(50));
350        let mut handles = vec![];
351
352        for _ in 0..10 {
353            let l = Arc::clone(&limiter);
354            handles.push(thread::spawn(move || {
355                let mut acquired = 0u32;
356                for _ in 0..20 {
357                    if l.try_acquire() {
358                        acquired += 1;
359                    }
360                }
361                acquired
362            }));
363        }
364
365        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
366        // With 50 tokens and 10 threads each trying 20 times, at most 50 succeed
367        assert!(total <= 50, "acquired {total} but budget was 50");
368        assert!(total >= 45, "should acquire most tokens, got {total}");
369    }
370
371    #[test]
372    fn constant_time_eq_long_strings() {
373        let a = "a".repeat(10_000);
374        let b = "a".repeat(10_000);
375        assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
376
377        let mut c = "a".repeat(10_000);
378        c.push('b');
379        assert!(!constant_time_eq(a.as_bytes(), c.as_bytes()));
380    }
381
382    #[test]
383    fn constant_time_eq_timing_consistency() {
384        let token = "8f14e45f-ceea-367f-a27f-c790e5a0fdc4";
385        let wrong1 = "0000000f-ceea-367f-a27f-c790e5a0fdc4";
386        let wrong2 = "8f14e45f-ceea-367f-a27f-c790e5a0fd00";
387
388        // Both should fail regardless of where the mismatch is
389        assert!(!constant_time_eq(token.as_bytes(), wrong1.as_bytes()));
390        assert!(!constant_time_eq(token.as_bytes(), wrong2.as_bytes()));
391    }
392
393    #[test]
394    fn token_uniqueness_over_1000_generations() {
395        let mut tokens = std::collections::HashSet::new();
396        for _ in 0..1000 {
397            let t = generate_token();
398            assert!(tokens.insert(t), "duplicate token generated");
399        }
400    }
401
402    #[test]
403    fn rate_limiter_zero_budget() {
404        let limiter = RateLimiterState::new(0);
405        assert!(!limiter.try_acquire());
406    }
407
408    #[test]
409    fn constant_time_eq_all_byte_values() {
410        for b in 0..=255u8 {
411            let a = [b];
412            assert!(constant_time_eq(&a, &a));
413            if b < 255 {
414                let c = [b + 1];
415                assert!(!constant_time_eq(&a, &c));
416            }
417        }
418    }
419
420    // --- DNS rebinding guard tests ---
421
422    #[test]
423    fn dns_rebinding_guard_allows_localhost() {
424        let host = "localhost";
425        let host_name = host.split(':').next().unwrap_or(host);
426        assert!(matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
427    }
428
429    #[test]
430    fn dns_rebinding_guard_allows_127() {
431        let host = "127.0.0.1:7474";
432        let host_name = host.split(':').next().unwrap_or(host);
433        assert!(matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
434    }
435
436    #[test]
437    fn dns_rebinding_guard_blocks_evil() {
438        let host = "evil.com";
439        let host_name = host.split(':').next().unwrap_or(host);
440        assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
441    }
442
443    #[test]
444    fn dns_rebinding_guard_blocks_localhost_subdomain() {
445        let host = "localhost.evil.com";
446        let host_name = host.split(':').next().unwrap_or(host);
447        assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
448    }
449
450    #[test]
451    fn dns_rebinding_guard_blocks_empty() {
452        let host = "";
453        let host_name = host.split(':').next().unwrap_or(host);
454        assert!(!matches!(host_name, "localhost" | "127.0.0.1" | "::1"));
455    }
456
457    // --- Origin guard tests ---
458
459    #[test]
460    fn localhost_origin_accepted() {
461        assert!(is_localhost_origin("http://localhost:3000"));
462        assert!(is_localhost_origin("http://localhost"));
463        assert!(is_localhost_origin("https://localhost:7474"));
464    }
465
466    #[test]
467    fn ipv4_loopback_accepted() {
468        assert!(is_localhost_origin("http://127.0.0.1:7474"));
469        assert!(is_localhost_origin("http://127.0.0.1"));
470        assert!(is_localhost_origin("https://127.0.0.1:443"));
471    }
472
473    #[test]
474    fn ipv6_loopback_accepted() {
475        assert!(is_localhost_origin("http://[::1]:7474"));
476        assert!(is_localhost_origin("http://[::1]"));
477    }
478
479    #[test]
480    fn subdomain_bypass_rejected() {
481        assert!(!is_localhost_origin("https://localhost.evil.com"));
482        assert!(!is_localhost_origin("https://127.0.0.1.evil.com"));
483        assert!(!is_localhost_origin("https://evil-localhost.com"));
484    }
485
486    #[test]
487    fn path_bypass_rejected() {
488        assert!(!is_localhost_origin("https://evil.com/localhost"));
489        assert!(!is_localhost_origin("https://evil.com/127.0.0.1"));
490    }
491
492    #[test]
493    fn external_origins_rejected() {
494        assert!(!is_localhost_origin("https://google.com"));
495        assert!(!is_localhost_origin("https://example.com:443"));
496        assert!(!is_localhost_origin("http://attacker.com"));
497    }
498}