Skip to main content

victauri_plugin/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8/// Generate a random UUID v4 token suitable for Bearer authentication.
9pub fn generate_token() -> String {
10    uuid::Uuid::new_v4().to_string()
11}
12
13/// Shared authentication state holding the optional Bearer token for the MCP server.
14#[derive(Clone)]
15pub struct AuthState {
16    /// The expected Bearer token, or `None` if authentication is disabled.
17    pub token: Option<String>,
18}
19
20/// Axum middleware that validates the `Authorization: Bearer <token>` header against [`AuthState`].
21pub async fn require_auth(
22    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
23    request: Request,
24    next: Next,
25) -> Result<Response, StatusCode> {
26    let expected = match &auth.token {
27        Some(t) => t,
28        None => return Ok(next.run(request).await),
29    };
30
31    let provided = request
32        .headers()
33        .get("authorization")
34        .and_then(|v| v.to_str().ok())
35        .and_then(|v| {
36            let lower = v.to_lowercase();
37            if lower.starts_with("bearer ") {
38                Some(v[7..].to_string())
39            } else {
40                None
41            }
42        });
43
44    match provided {
45        Some(ref token) if token == expected => Ok(next.run(request).await),
46        _ => Err(StatusCode::UNAUTHORIZED),
47    }
48}
49
50// ── Rate Limiter ───────────────────────────────────────────────────────────
51
52/// Lock-free token-bucket rate limiter using millisecond-precision timestamps for smooth refill.
53pub struct RateLimiterState {
54    tokens: AtomicU64,
55    max_tokens: u64,
56    last_refill_ms: AtomicU64,
57    refill_rate_per_sec: u64,
58}
59
60fn now_ms() -> u64 {
61    std::time::SystemTime::now()
62        .duration_since(std::time::UNIX_EPOCH)
63        .unwrap_or_default()
64        .as_millis() as u64
65}
66
67impl RateLimiterState {
68    /// Create a rate limiter with the given maximum requests per second.
69    pub fn new(max_requests_per_sec: u64) -> Self {
70        Self {
71            tokens: AtomicU64::new(max_requests_per_sec),
72            max_tokens: max_requests_per_sec,
73            last_refill_ms: AtomicU64::new(now_ms()),
74            refill_rate_per_sec: max_requests_per_sec,
75        }
76    }
77
78    /// Atomically consume one token, returning `true` if the request is allowed.
79    pub fn try_acquire(&self) -> bool {
80        self.refill();
81        loop {
82            let current = self.tokens.load(Ordering::Relaxed);
83            if current == 0 {
84                return false;
85            }
86            if self
87                .tokens
88                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
89                .is_ok()
90            {
91                return true;
92            }
93        }
94    }
95
96    fn refill(&self) {
97        let now = now_ms();
98        let last = self.last_refill_ms.load(Ordering::Relaxed);
99        let elapsed_ms = now.saturating_sub(last);
100        if elapsed_ms == 0 {
101            return;
102        }
103        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
104        if add == 0 {
105            return;
106        }
107        if self
108            .last_refill_ms
109            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
110            .is_ok()
111        {
112            loop {
113                let current = self.tokens.load(Ordering::Relaxed);
114                let new_val = (current + add).min(self.max_tokens);
115                if self
116                    .tokens
117                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
118                    .is_ok()
119                {
120                    break;
121                }
122            }
123        }
124    }
125}
126
127/// Axum middleware that rejects requests with 429 when the token bucket is exhausted.
128pub async fn rate_limit(
129    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
130    request: Request,
131    next: Next,
132) -> Result<Response, StatusCode> {
133    if limiter.try_acquire() {
134        Ok(next.run(request).await)
135    } else {
136        Err(StatusCode::TOO_MANY_REQUESTS)
137    }
138}
139
140const DEFAULT_RATE_LIMIT: u64 = 100;
141
142/// Create a rate limiter with the default capacity of 100 requests per second.
143pub fn default_rate_limiter() -> Arc<RateLimiterState> {
144    Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
145}
146
147// ── Security Middlewares ──────────────────────────────────────────────────
148
149/// Axum middleware that blocks DNS rebinding attacks.
150///
151/// Rejects any request where the Host header is not a localhost address.
152pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
153    let host = request
154        .headers()
155        .get("host")
156        .and_then(|v| v.to_str().ok())
157        .unwrap_or("");
158    let host_name = if host.starts_with('[') {
159        host.split(']').next().map(|s| &s[1..]).unwrap_or(host)
160    } else {
161        host.split(':').next().unwrap_or(host)
162    };
163    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
164    if !is_allowed {
165        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
166        return Err(StatusCode::FORBIDDEN);
167    }
168    Ok(next.run(request).await)
169}
170
171/// Axum middleware that blocks cross-origin requests from browsers.
172pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
173    if let Some(origin) = request
174        .headers()
175        .get("origin")
176        .and_then(|v| v.to_str().ok())
177    {
178        let allowed = origin.starts_with("http://localhost")
179            || origin.starts_with("https://localhost")
180            || origin.starts_with("http://127.0.0.1")
181            || origin.starts_with("https://127.0.0.1")
182            || origin.starts_with("http://[::1]")
183            || origin.starts_with("https://[::1]")
184            || origin.starts_with("tauri://")
185            || origin == "null";
186        if !allowed {
187            tracing::warn!("Cross-origin request blocked: Origin={origin}");
188            return Err(StatusCode::FORBIDDEN);
189        }
190    }
191    Ok(next.run(request).await)
192}
193
194/// Axum middleware that sets security-hardening response headers on every response.
195pub async fn security_headers(request: Request, next: Next) -> Response {
196    let mut response = next.run(request).await;
197    let headers = response.headers_mut();
198    headers.insert(
199        axum::http::header::X_CONTENT_TYPE_OPTIONS,
200        axum::http::HeaderValue::from_static("nosniff"),
201    );
202    headers.insert(
203        axum::http::header::CACHE_CONTROL,
204        axum::http::HeaderValue::from_static("no-store"),
205    );
206    headers.insert(
207        axum::http::header::HeaderName::from_static("x-frame-options"),
208        axum::http::HeaderValue::from_static("DENY"),
209    );
210    response
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use axum::Router;
217    use axum::body::Body;
218    use axum::middleware;
219    use axum::routing::get;
220    use tower::ServiceExt; // for oneshot
221
222    async fn ok_handler() -> &'static str {
223        "ok"
224    }
225
226    #[test]
227    fn token_generation_is_unique() {
228        let t1 = generate_token();
229        let t2 = generate_token();
230        assert_ne!(t1, t2);
231        assert_eq!(t1.len(), 36); // UUID v4 format
232    }
233
234    #[test]
235    fn token_is_valid_uuid() {
236        let token = generate_token();
237        assert!(uuid::Uuid::parse_str(&token).is_ok());
238    }
239
240    #[test]
241    fn rate_limiter_allows_within_budget() {
242        let limiter = RateLimiterState::new(10);
243        for _ in 0..10 {
244            assert!(limiter.try_acquire());
245        }
246    }
247
248    #[test]
249    fn rate_limiter_denies_when_exhausted() {
250        let limiter = RateLimiterState::new(5);
251        for _ in 0..5 {
252            assert!(limiter.try_acquire());
253        }
254        assert!(!limiter.try_acquire());
255    }
256
257    #[test]
258    fn rate_limiter_initial_tokens_match_max() {
259        let limiter = RateLimiterState::new(42);
260        assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
261        assert_eq!(limiter.max_tokens, 42);
262    }
263
264    #[test]
265    fn rate_limiter_concurrent_acquire() {
266        // Use a large bucket so time-based refills (1 per second) are negligible
267        let limiter = Arc::new(RateLimiterState::new(1000));
268        let mut handles = vec![];
269        for _ in 0..10 {
270            let l = limiter.clone();
271            handles.push(std::thread::spawn(move || {
272                let mut acquired = 0;
273                for _ in 0..200 {
274                    if l.try_acquire() {
275                        acquired += 1;
276                    }
277                }
278                acquired
279            }));
280        }
281        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
282        // All 1000 tokens should be dispensed; a time-based refill may add a few
283        assert!((1000..=1010).contains(&total));
284    }
285
286    #[test]
287    fn default_rate_limiter_has_100_tokens() {
288        let limiter = default_rate_limiter();
289        assert_eq!(limiter.max_tokens, 100);
290    }
291
292    #[test]
293    fn rate_limiter_zero_capacity() {
294        let limiter = RateLimiterState::new(0);
295        assert!(!limiter.try_acquire());
296    }
297
298    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
299
300    fn dns_rebinding_router() -> Router {
301        Router::new()
302            .route("/test", get(ok_handler))
303            .layer(middleware::from_fn(dns_rebinding_guard))
304    }
305
306    fn dns_request(host: Option<&str>) -> Request<Body> {
307        let mut builder = Request::builder().uri("/test");
308        if let Some(h) = host {
309            builder = builder.header("host", h);
310        }
311        builder.body(Body::empty()).unwrap()
312    }
313
314    #[tokio::test]
315    async fn dns_rebinding_allows_localhost() {
316        let app = dns_rebinding_router();
317        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
318        assert_eq!(resp.status(), StatusCode::OK);
319    }
320
321    #[tokio::test]
322    async fn dns_rebinding_allows_127_0_0_1() {
323        let app = dns_rebinding_router();
324        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
325        assert_eq!(resp.status(), StatusCode::OK);
326    }
327
328    #[tokio::test]
329    async fn dns_rebinding_allows_ipv6_bracketed() {
330        let app = dns_rebinding_router();
331        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
332        assert_eq!(resp.status(), StatusCode::OK);
333    }
334
335    #[tokio::test]
336    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
337        let app = dns_rebinding_router();
338        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
339        assert_eq!(resp.status(), StatusCode::OK);
340    }
341
342    #[tokio::test]
343    async fn dns_rebinding_allows_ipv6_bare() {
344        let app = dns_rebinding_router();
345        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
346        assert_eq!(resp.status(), StatusCode::OK);
347    }
348
349    #[tokio::test]
350    async fn dns_rebinding_allows_empty_host() {
351        let app = dns_rebinding_router();
352        let resp = app.oneshot(dns_request(None)).await.unwrap();
353        assert_eq!(resp.status(), StatusCode::OK);
354    }
355
356    #[tokio::test]
357    async fn dns_rebinding_blocks_evil_com() {
358        let app = dns_rebinding_router();
359        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
360        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
361    }
362
363    #[tokio::test]
364    async fn dns_rebinding_blocks_localhost_subdomain() {
365        let app = dns_rebinding_router();
366        let resp = app
367            .oneshot(dns_request(Some("localhost.evil.com")))
368            .await
369            .unwrap();
370        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
371    }
372
373    #[tokio::test]
374    async fn dns_rebinding_blocks_ip_subdomain() {
375        let app = dns_rebinding_router();
376        let resp = app
377            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
378            .await
379            .unwrap();
380        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
381    }
382
383    // ── Origin Guard tests ────────────────────────────────────────────────
384
385    fn origin_router() -> Router {
386        Router::new()
387            .route("/test", get(ok_handler))
388            .layer(middleware::from_fn(origin_guard))
389    }
390
391    fn origin_request(origin: Option<&str>) -> Request<Body> {
392        let mut builder = Request::builder().uri("/test");
393        if let Some(o) = origin {
394            builder = builder.header("origin", o);
395        }
396        builder.body(Body::empty()).unwrap()
397    }
398
399    #[tokio::test]
400    async fn origin_allows_no_origin() {
401        let app = origin_router();
402        let resp = app.oneshot(origin_request(None)).await.unwrap();
403        assert_eq!(resp.status(), StatusCode::OK);
404    }
405
406    #[tokio::test]
407    async fn origin_allows_localhost_http() {
408        let app = origin_router();
409        let resp = app
410            .oneshot(origin_request(Some("http://localhost:3000")))
411            .await
412            .unwrap();
413        assert_eq!(resp.status(), StatusCode::OK);
414    }
415
416    #[tokio::test]
417    async fn origin_allows_127_0_0_1_https() {
418        let app = origin_router();
419        let resp = app
420            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
421            .await
422            .unwrap();
423        assert_eq!(resp.status(), StatusCode::OK);
424    }
425
426    #[tokio::test]
427    async fn origin_allows_tauri_scheme() {
428        let app = origin_router();
429        let resp = app
430            .oneshot(origin_request(Some("tauri://localhost")))
431            .await
432            .unwrap();
433        assert_eq!(resp.status(), StatusCode::OK);
434    }
435
436    #[tokio::test]
437    async fn origin_allows_null() {
438        let app = origin_router();
439        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
440        assert_eq!(resp.status(), StatusCode::OK);
441    }
442
443    #[tokio::test]
444    async fn origin_blocks_evil_com() {
445        let app = origin_router();
446        let resp = app
447            .oneshot(origin_request(Some("http://evil.com")))
448            .await
449            .unwrap();
450        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
451    }
452
453    // ── Security Headers tests ────────────────────────────────────────────
454
455    fn security_headers_router() -> Router {
456        Router::new()
457            .route("/test", get(ok_handler))
458            .layer(middleware::from_fn(security_headers))
459    }
460
461    #[tokio::test]
462    async fn security_headers_x_content_type_options() {
463        let app = security_headers_router();
464        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
465        let resp = app.oneshot(req).await.unwrap();
466        assert_eq!(resp.status(), StatusCode::OK);
467        assert_eq!(
468            resp.headers().get("x-content-type-options").unwrap(),
469            "nosniff"
470        );
471    }
472
473    #[tokio::test]
474    async fn security_headers_cache_control() {
475        let app = security_headers_router();
476        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
477        let resp = app.oneshot(req).await.unwrap();
478        assert_eq!(resp.status(), StatusCode::OK);
479        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
480    }
481
482    #[tokio::test]
483    async fn security_headers_x_frame_options() {
484        let app = security_headers_router();
485        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
486        let resp = app.oneshot(req).await.unwrap();
487        assert_eq!(resp.status(), StatusCode::OK);
488        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
489    }
490}