Skip to main content

victauri_plugin/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10/// Generate a random `UUID` v4 token suitable for Bearer authentication.
11#[must_use]
12pub fn generate_token() -> String {
13    uuid::Uuid::new_v4().to_string()
14}
15
16/// Shared authentication state holding the optional Bearer token for the MCP server.
17#[derive(Clone)]
18pub struct AuthState {
19    /// The expected Bearer token, or `None` if authentication is disabled.
20    pub token: Option<String>,
21}
22
23/// Axum middleware that validates the `Authorization: Bearer <token>` header against [`AuthState`].
24///
25/// # Errors
26///
27/// Returns [`StatusCode::UNAUTHORIZED`] if the token is missing or invalid.
28pub async fn require_auth(
29    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
30    request: Request,
31    next: Next,
32) -> Result<Response, StatusCode> {
33    let Some(expected) = &auth.token else {
34        return Ok(next.run(request).await);
35    };
36
37    let provided = request
38        .headers()
39        .get("authorization")
40        .and_then(|v| v.to_str().ok())
41        .and_then(|v| {
42            let lower = v.to_lowercase();
43            if lower.starts_with("bearer ") {
44                Some(v[BEARER_PREFIX_LEN..].to_string())
45            } else {
46                None
47            }
48        });
49
50    match provided {
51        Some(ref token) if token == expected => Ok(next.run(request).await),
52        _ => Err(StatusCode::UNAUTHORIZED),
53    }
54}
55
56// ── Rate Limiter ───────────────────────────────────────────────────────────
57
58/// Lock-free token-bucket rate limiter using millisecond-precision timestamps for smooth refill.
59pub struct RateLimiterState {
60    tokens: AtomicU64,
61    max_tokens: u64,
62    last_refill_ms: AtomicU64,
63    refill_rate_per_sec: u64,
64}
65
66fn now_ms() -> u64 {
67    std::time::SystemTime::now()
68        .duration_since(std::time::UNIX_EPOCH)
69        .unwrap_or_default()
70        .as_millis() as u64
71}
72
73impl RateLimiterState {
74    /// Create a rate limiter with the given maximum requests per second.
75    #[must_use]
76    pub fn new(max_requests_per_sec: u64) -> Self {
77        Self {
78            tokens: AtomicU64::new(max_requests_per_sec),
79            max_tokens: max_requests_per_sec,
80            last_refill_ms: AtomicU64::new(now_ms()),
81            refill_rate_per_sec: max_requests_per_sec,
82        }
83    }
84
85    /// Atomically consume one token, returning `true` if the request is allowed.
86    pub fn try_acquire(&self) -> bool {
87        self.refill();
88        loop {
89            let current = self.tokens.load(Ordering::Relaxed);
90            if current == 0 {
91                return false;
92            }
93            if self
94                .tokens
95                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
96                .is_ok()
97            {
98                return true;
99            }
100        }
101    }
102
103    fn refill(&self) {
104        let now = now_ms();
105        let last = self.last_refill_ms.load(Ordering::Relaxed);
106        let elapsed_ms = now.saturating_sub(last);
107        if elapsed_ms == 0 {
108            return;
109        }
110        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
111        if add == 0 {
112            return;
113        }
114        if self
115            .last_refill_ms
116            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
117            .is_ok()
118        {
119            loop {
120                let current = self.tokens.load(Ordering::Relaxed);
121                let new_val = (current + add).min(self.max_tokens);
122                if self
123                    .tokens
124                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
125                    .is_ok()
126                {
127                    break;
128                }
129            }
130        }
131    }
132}
133
134/// Axum middleware that rejects requests with 429 when the token bucket is exhausted.
135///
136/// # Errors
137///
138/// Returns [`StatusCode::TOO_MANY_REQUESTS`] if the token bucket has no remaining capacity.
139pub async fn rate_limit(
140    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
141    request: Request,
142    next: Next,
143) -> Result<Response, StatusCode> {
144    if limiter.try_acquire() {
145        Ok(next.run(request).await)
146    } else {
147        Err(StatusCode::TOO_MANY_REQUESTS)
148    }
149}
150
151const DEFAULT_RATE_LIMIT: u64 = 1000;
152
153/// Create a rate limiter with the default capacity of 1000 requests per second.
154#[must_use]
155pub fn default_rate_limiter() -> Arc<RateLimiterState> {
156    Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
157}
158
159// ── Security Middlewares ──────────────────────────────────────────────────
160
161/// Axum middleware that blocks DNS rebinding attacks.
162///
163/// Rejects any request where the Host header is not a localhost address.
164///
165/// # Errors
166///
167/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`, `127.0.0.1`, or `::1`.
168pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
169    let host = request
170        .headers()
171        .get("host")
172        .and_then(|v| v.to_str().ok())
173        .unwrap_or("");
174    let host_name = if host.starts_with('[') {
175        host.split(']').next().map_or(host, |s| &s[1..])
176    } else {
177        host.split(':').next().unwrap_or(host)
178    };
179    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
180    if !is_allowed {
181        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
182        return Err(StatusCode::FORBIDDEN);
183    }
184    Ok(next.run(request).await)
185}
186
187/// Axum middleware that blocks cross-origin requests from browsers.
188///
189/// # Errors
190///
191/// Returns [`StatusCode::FORBIDDEN`] if the `Origin` header is present and does not match a
192/// localhost or `tauri://` origin.
193pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
194    if let Some(origin) = request
195        .headers()
196        .get("origin")
197        .and_then(|v| v.to_str().ok())
198    {
199        let allowed = origin.starts_with("http://localhost")
200            || origin.starts_with("https://localhost")
201            || origin.starts_with("http://127.0.0.1")
202            || origin.starts_with("https://127.0.0.1")
203            || origin.starts_with("http://[::1]")
204            || origin.starts_with("https://[::1]")
205            || origin.starts_with("tauri://");
206        if !allowed {
207            tracing::warn!("Cross-origin request blocked: Origin={origin}");
208            return Err(StatusCode::FORBIDDEN);
209        }
210    }
211    Ok(next.run(request).await)
212}
213
214/// Axum middleware that sets security-hardening response headers on every response.
215pub async fn security_headers(request: Request, next: Next) -> Response {
216    let mut response = next.run(request).await;
217    let headers = response.headers_mut();
218    headers.insert(
219        axum::http::header::X_CONTENT_TYPE_OPTIONS,
220        axum::http::HeaderValue::from_static("nosniff"),
221    );
222    headers.insert(
223        axum::http::header::CACHE_CONTROL,
224        axum::http::HeaderValue::from_static("no-store"),
225    );
226    headers.insert(
227        axum::http::header::HeaderName::from_static("x-frame-options"),
228        axum::http::HeaderValue::from_static("DENY"),
229    );
230    response
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use axum::Router;
237    use axum::body::Body;
238    use axum::middleware;
239    use axum::routing::get;
240    use tower::ServiceExt; // for oneshot
241
242    async fn ok_handler() -> &'static str {
243        "ok"
244    }
245
246    #[test]
247    fn token_generation_is_unique() {
248        let t1 = generate_token();
249        let t2 = generate_token();
250        assert_ne!(t1, t2);
251        assert_eq!(t1.len(), 36); // UUID v4 format
252    }
253
254    #[test]
255    fn token_is_valid_uuid() {
256        let token = generate_token();
257        assert!(uuid::Uuid::parse_str(&token).is_ok());
258    }
259
260    #[test]
261    fn rate_limiter_allows_within_budget() {
262        let limiter = RateLimiterState::new(10);
263        for _ in 0..10 {
264            assert!(limiter.try_acquire());
265        }
266    }
267
268    #[test]
269    fn rate_limiter_denies_when_exhausted() {
270        let limiter = RateLimiterState::new(5);
271        for _ in 0..5 {
272            assert!(limiter.try_acquire());
273        }
274        assert!(!limiter.try_acquire());
275    }
276
277    #[test]
278    fn rate_limiter_initial_tokens_match_max() {
279        let limiter = RateLimiterState::new(42);
280        assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
281        assert_eq!(limiter.max_tokens, 42);
282    }
283
284    #[test]
285    fn rate_limiter_concurrent_acquire() {
286        // Use a large bucket so time-based refills (1 per second) are negligible
287        let limiter = Arc::new(RateLimiterState::new(1000));
288        let mut handles = vec![];
289        for _ in 0..10 {
290            let l = limiter.clone();
291            handles.push(std::thread::spawn(move || {
292                let mut acquired = 0;
293                for _ in 0..200 {
294                    if l.try_acquire() {
295                        acquired += 1;
296                    }
297                }
298                acquired
299            }));
300        }
301        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
302        // All 1000 tokens should be dispensed; a time-based refill may add a few
303        assert!((1000..=1010).contains(&total));
304    }
305
306    #[test]
307    fn default_rate_limiter_has_expected_tokens() {
308        let limiter = default_rate_limiter();
309        assert_eq!(limiter.max_tokens, 1000);
310    }
311
312    #[test]
313    fn rate_limiter_zero_capacity() {
314        let limiter = RateLimiterState::new(0);
315        assert!(!limiter.try_acquire());
316    }
317
318    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
319
320    fn dns_rebinding_router() -> Router {
321        Router::new()
322            .route("/test", get(ok_handler))
323            .layer(middleware::from_fn(dns_rebinding_guard))
324    }
325
326    fn dns_request(host: Option<&str>) -> Request<Body> {
327        let mut builder = Request::builder().uri("/test");
328        if let Some(h) = host {
329            builder = builder.header("host", h);
330        }
331        builder.body(Body::empty()).unwrap()
332    }
333
334    #[tokio::test]
335    async fn dns_rebinding_allows_localhost() {
336        let app = dns_rebinding_router();
337        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
338        assert_eq!(resp.status(), StatusCode::OK);
339    }
340
341    #[tokio::test]
342    async fn dns_rebinding_allows_127_0_0_1() {
343        let app = dns_rebinding_router();
344        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
345        assert_eq!(resp.status(), StatusCode::OK);
346    }
347
348    #[tokio::test]
349    async fn dns_rebinding_allows_ipv6_bracketed() {
350        let app = dns_rebinding_router();
351        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
352        assert_eq!(resp.status(), StatusCode::OK);
353    }
354
355    #[tokio::test]
356    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
357        let app = dns_rebinding_router();
358        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
359        assert_eq!(resp.status(), StatusCode::OK);
360    }
361
362    #[tokio::test]
363    async fn dns_rebinding_allows_ipv6_bare() {
364        let app = dns_rebinding_router();
365        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
366        assert_eq!(resp.status(), StatusCode::OK);
367    }
368
369    #[tokio::test]
370    async fn dns_rebinding_allows_empty_host() {
371        let app = dns_rebinding_router();
372        let resp = app.oneshot(dns_request(None)).await.unwrap();
373        assert_eq!(resp.status(), StatusCode::OK);
374    }
375
376    #[tokio::test]
377    async fn dns_rebinding_blocks_evil_com() {
378        let app = dns_rebinding_router();
379        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
380        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
381    }
382
383    #[tokio::test]
384    async fn dns_rebinding_blocks_localhost_subdomain() {
385        let app = dns_rebinding_router();
386        let resp = app
387            .oneshot(dns_request(Some("localhost.evil.com")))
388            .await
389            .unwrap();
390        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
391    }
392
393    #[tokio::test]
394    async fn dns_rebinding_blocks_ip_subdomain() {
395        let app = dns_rebinding_router();
396        let resp = app
397            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
398            .await
399            .unwrap();
400        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
401    }
402
403    // ── Origin Guard tests ────────────────────────────────────────────────
404
405    fn origin_router() -> Router {
406        Router::new()
407            .route("/test", get(ok_handler))
408            .layer(middleware::from_fn(origin_guard))
409    }
410
411    fn origin_request(origin: Option<&str>) -> Request<Body> {
412        let mut builder = Request::builder().uri("/test");
413        if let Some(o) = origin {
414            builder = builder.header("origin", o);
415        }
416        builder.body(Body::empty()).unwrap()
417    }
418
419    #[tokio::test]
420    async fn origin_allows_no_origin() {
421        let app = origin_router();
422        let resp = app.oneshot(origin_request(None)).await.unwrap();
423        assert_eq!(resp.status(), StatusCode::OK);
424    }
425
426    #[tokio::test]
427    async fn origin_allows_localhost_http() {
428        let app = origin_router();
429        let resp = app
430            .oneshot(origin_request(Some("http://localhost:3000")))
431            .await
432            .unwrap();
433        assert_eq!(resp.status(), StatusCode::OK);
434    }
435
436    #[tokio::test]
437    async fn origin_allows_127_0_0_1_https() {
438        let app = origin_router();
439        let resp = app
440            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
441            .await
442            .unwrap();
443        assert_eq!(resp.status(), StatusCode::OK);
444    }
445
446    #[tokio::test]
447    async fn origin_allows_tauri_scheme() {
448        let app = origin_router();
449        let resp = app
450            .oneshot(origin_request(Some("tauri://localhost")))
451            .await
452            .unwrap();
453        assert_eq!(resp.status(), StatusCode::OK);
454    }
455
456    #[tokio::test]
457    async fn origin_blocks_null() {
458        let app = origin_router();
459        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
460        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
461    }
462
463    #[tokio::test]
464    async fn origin_blocks_evil_com() {
465        let app = origin_router();
466        let resp = app
467            .oneshot(origin_request(Some("http://evil.com")))
468            .await
469            .unwrap();
470        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
471    }
472
473    // ── Security Headers tests ────────────────────────────────────────────
474
475    fn security_headers_router() -> Router {
476        Router::new()
477            .route("/test", get(ok_handler))
478            .layer(middleware::from_fn(security_headers))
479    }
480
481    #[tokio::test]
482    async fn security_headers_x_content_type_options() {
483        let app = security_headers_router();
484        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
485        let resp = app.oneshot(req).await.unwrap();
486        assert_eq!(resp.status(), StatusCode::OK);
487        assert_eq!(
488            resp.headers().get("x-content-type-options").unwrap(),
489            "nosniff"
490        );
491    }
492
493    #[tokio::test]
494    async fn security_headers_cache_control() {
495        let app = security_headers_router();
496        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
497        let resp = app.oneshot(req).await.unwrap();
498        assert_eq!(resp.status(), StatusCode::OK);
499        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
500    }
501
502    #[tokio::test]
503    async fn security_headers_x_frame_options() {
504        let app = security_headers_router();
505        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
506        let resp = app.oneshot(req).await.unwrap();
507        assert_eq!(resp.status(), StatusCode::OK);
508        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
509    }
510
511    // ── Auth middleware integration tests ─────────────────────────────────
512
513    fn auth_router(token: Option<&str>) -> Router {
514        let state = Arc::new(AuthState {
515            token: token.map(String::from),
516        });
517        Router::new()
518            .route("/test", get(ok_handler))
519            .layer(middleware::from_fn_with_state(state, require_auth))
520    }
521
522    fn auth_request(token: Option<&str>) -> Request<Body> {
523        let mut builder = Request::builder().uri("/test");
524        if let Some(t) = token {
525            builder = builder.header("authorization", format!("Bearer {t}"));
526        }
527        builder.body(Body::empty()).unwrap()
528    }
529
530    #[tokio::test]
531    async fn auth_allows_correct_token() {
532        let app = auth_router(Some("secret-123"));
533        let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
534        assert_eq!(resp.status(), StatusCode::OK);
535    }
536
537    #[tokio::test]
538    async fn auth_rejects_wrong_token() {
539        let app = auth_router(Some("secret-123"));
540        let resp = app
541            .oneshot(auth_request(Some("wrong-token")))
542            .await
543            .unwrap();
544        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
545    }
546
547    #[tokio::test]
548    async fn auth_rejects_missing_token() {
549        let app = auth_router(Some("secret-123"));
550        let resp = app.oneshot(auth_request(None)).await.unwrap();
551        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
552    }
553
554    #[tokio::test]
555    async fn auth_allows_any_when_disabled() {
556        let app = auth_router(None);
557        let resp = app.oneshot(auth_request(None)).await.unwrap();
558        assert_eq!(resp.status(), StatusCode::OK);
559    }
560
561    #[tokio::test]
562    async fn auth_case_insensitive_bearer_prefix() {
563        let state = Arc::new(AuthState {
564            token: Some("my-token".into()),
565        });
566        let app = Router::new()
567            .route("/test", get(ok_handler))
568            .layer(middleware::from_fn_with_state(state, require_auth));
569
570        let req = Request::builder()
571            .uri("/test")
572            .header("authorization", "BEARER my-token")
573            .body(Body::empty())
574            .unwrap();
575        let resp = app.oneshot(req).await.unwrap();
576        assert_eq!(resp.status(), StatusCode::OK);
577    }
578
579    #[tokio::test]
580    async fn auth_rejects_non_bearer_scheme() {
581        let app = auth_router(Some("secret"));
582        let req = Request::builder()
583            .uri("/test")
584            .header("authorization", "Basic c2VjcmV0")
585            .body(Body::empty())
586            .unwrap();
587        let resp = app.oneshot(req).await.unwrap();
588        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
589    }
590
591    // ── Rate limiter middleware integration test ──────────────────────────
592
593    #[tokio::test]
594    async fn rate_limiter_returns_429_when_exhausted() {
595        let limiter = Arc::new(RateLimiterState::new(2));
596        let app = Router::new()
597            .route("/test", get(ok_handler))
598            .layer(middleware::from_fn_with_state(limiter, rate_limit));
599
600        let app2 = app.clone();
601        let app3 = app2.clone();
602
603        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
604        assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
605
606        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
607        assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
608
609        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
610        assert_eq!(
611            app3.oneshot(req).await.unwrap().status(),
612            StatusCode::TOO_MANY_REQUESTS
613        );
614    }
615
616    // ── Combined security layer test ─────────────────────────────────────
617
618    #[tokio::test]
619    async fn combined_layers_enforce_all_guards() {
620        let auth_state = Arc::new(AuthState {
621            token: Some("tok-123".into()),
622        });
623        let limiter = Arc::new(RateLimiterState::new(100));
624
625        let app = Router::new()
626            .route("/test", get(ok_handler))
627            .layer(middleware::from_fn_with_state(auth_state, require_auth))
628            .layer(middleware::from_fn_with_state(limiter, rate_limit))
629            .layer(middleware::from_fn(security_headers))
630            .layer(middleware::from_fn(origin_guard))
631            .layer(middleware::from_fn(dns_rebinding_guard));
632
633        // Good request: all guards pass
634        let req = Request::builder()
635            .uri("/test")
636            .header("authorization", "Bearer tok-123")
637            .header("host", "127.0.0.1:7373")
638            .body(Body::empty())
639            .unwrap();
640        let resp = app.clone().oneshot(req).await.unwrap();
641        assert_eq!(resp.status(), StatusCode::OK);
642        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
643
644        // Bad host: DNS rebinding guard blocks
645        let req = Request::builder()
646            .uri("/test")
647            .header("authorization", "Bearer tok-123")
648            .header("host", "evil.com")
649            .body(Body::empty())
650            .unwrap();
651        let resp = app.clone().oneshot(req).await.unwrap();
652        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
653
654        // Bad origin: origin guard blocks
655        let req = Request::builder()
656            .uri("/test")
657            .header("authorization", "Bearer tok-123")
658            .header("host", "localhost")
659            .header("origin", "https://evil.com")
660            .body(Body::empty())
661            .unwrap();
662        let resp = app.clone().oneshot(req).await.unwrap();
663        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
664
665        // Missing auth: auth middleware blocks
666        let req = Request::builder()
667            .uri("/test")
668            .header("host", "localhost")
669            .body(Body::empty())
670            .unwrap();
671        let resp = app.oneshot(req).await.unwrap();
672        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
673    }
674}