Skip to main content

victauri_plugin/
auth.rs

1pub use victauri_core::middleware::{
2    AuthState, default_rate_limiter, dns_rebinding_guard, origin_guard, rate_limit, require_auth,
3    security_headers,
4};
5pub use victauri_core::security::{
6    self, RateLimiter as RateLimiterState, constant_time_eq, generate_token, is_allowed_origin,
7    is_localhost_host,
8};
9
10#[cfg(test)]
11mod tests {
12    use std::sync::Arc;
13
14    use super::*;
15    use axum::Router;
16    use axum::body::Body;
17    use axum::http::StatusCode;
18    use axum::middleware;
19    use axum::routing::get;
20    use tower::ServiceExt;
21
22    async fn ok_handler() -> &'static str {
23        "ok"
24    }
25
26    #[test]
27    fn token_generation_is_unique() {
28        let t1 = generate_token();
29        let t2 = generate_token();
30        assert_ne!(t1, t2);
31        assert_eq!(t1.len(), 36); // UUID v4 format
32    }
33
34    #[test]
35    fn token_is_valid_uuid() {
36        let token = generate_token();
37        assert!(uuid::Uuid::parse_str(&token).is_ok());
38    }
39
40    #[test]
41    fn rate_limiter_allows_within_budget() {
42        let limiter = RateLimiterState::new(10);
43        for _ in 0..10 {
44            assert!(limiter.try_acquire());
45        }
46    }
47
48    #[test]
49    fn rate_limiter_denies_when_exhausted() {
50        let limiter = RateLimiterState::new(5);
51        for _ in 0..5 {
52            assert!(limiter.try_acquire());
53        }
54        assert!(!limiter.try_acquire());
55    }
56
57    #[test]
58    fn rate_limiter_initial_tokens_match_max() {
59        let limiter = RateLimiterState::new(42);
60        assert_eq!(limiter.current_tokens(), 42);
61        assert_eq!(limiter.max_tokens(), 42);
62    }
63
64    #[test]
65    fn rate_limiter_concurrent_acquire() {
66        let limiter = Arc::new(RateLimiterState::new(1000));
67        let mut handles = vec![];
68        for _ in 0..10 {
69            let l = limiter.clone();
70            handles.push(std::thread::spawn(move || {
71                let mut acquired = 0;
72                for _ in 0..200 {
73                    if l.try_acquire() {
74                        acquired += 1;
75                    }
76                }
77                acquired
78            }));
79        }
80        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
81        assert!(
82            total >= 1000,
83            "should dispense at least the initial budget, got {total}"
84        );
85        assert!(total <= 1200, "refill overshoot too high, got {total}");
86    }
87
88    #[test]
89    fn default_rate_limiter_has_expected_tokens() {
90        let limiter = default_rate_limiter();
91        assert_eq!(limiter.max_tokens(), 1000);
92    }
93
94    #[test]
95    fn rate_limiter_zero_capacity() {
96        let limiter = RateLimiterState::new(0);
97        assert!(!limiter.try_acquire());
98    }
99
100    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
101
102    fn dns_rebinding_router() -> Router {
103        Router::new()
104            .route("/test", get(ok_handler))
105            .layer(middleware::from_fn(dns_rebinding_guard))
106    }
107
108    fn dns_request(host: Option<&str>) -> axum::extract::Request<Body> {
109        let mut builder = axum::extract::Request::builder().uri("/test");
110        if let Some(h) = host {
111            builder = builder.header("host", h);
112        }
113        builder.body(Body::empty()).unwrap()
114    }
115
116    #[tokio::test]
117    async fn dns_rebinding_allows_localhost() {
118        let app = dns_rebinding_router();
119        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
120        assert_eq!(resp.status(), StatusCode::OK);
121    }
122
123    #[tokio::test]
124    async fn dns_rebinding_allows_127_0_0_1() {
125        let app = dns_rebinding_router();
126        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
127        assert_eq!(resp.status(), StatusCode::OK);
128    }
129
130    #[tokio::test]
131    async fn dns_rebinding_allows_ipv6_bracketed() {
132        let app = dns_rebinding_router();
133        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
134        assert_eq!(resp.status(), StatusCode::OK);
135    }
136
137    #[tokio::test]
138    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
139        let app = dns_rebinding_router();
140        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
141        assert_eq!(resp.status(), StatusCode::OK);
142    }
143
144    #[tokio::test]
145    async fn dns_rebinding_allows_ipv6_bare() {
146        let app = dns_rebinding_router();
147        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
148        assert_eq!(resp.status(), StatusCode::OK);
149    }
150
151    #[tokio::test]
152    async fn dns_rebinding_blocks_empty_host() {
153        let app = dns_rebinding_router();
154        let resp = app.oneshot(dns_request(None)).await.unwrap();
155        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
156    }
157
158    #[tokio::test]
159    async fn dns_rebinding_blocks_evil_com() {
160        let app = dns_rebinding_router();
161        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
162        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
163    }
164
165    #[tokio::test]
166    async fn dns_rebinding_blocks_localhost_subdomain() {
167        let app = dns_rebinding_router();
168        let resp = app
169            .oneshot(dns_request(Some("localhost.evil.com")))
170            .await
171            .unwrap();
172        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
173    }
174
175    #[tokio::test]
176    async fn dns_rebinding_blocks_ip_subdomain() {
177        let app = dns_rebinding_router();
178        let resp = app
179            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
180            .await
181            .unwrap();
182        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
183    }
184
185    // ── Origin Guard tests ────────────────────────────────────────────────
186
187    fn origin_router() -> Router {
188        Router::new()
189            .route("/test", get(ok_handler))
190            .layer(middleware::from_fn(origin_guard))
191    }
192
193    fn origin_request(origin: Option<&str>) -> axum::extract::Request<Body> {
194        let mut builder = axum::extract::Request::builder().uri("/test");
195        if let Some(o) = origin {
196            builder = builder.header("origin", o);
197        }
198        builder.body(Body::empty()).unwrap()
199    }
200
201    #[tokio::test]
202    async fn origin_allows_no_origin() {
203        let app = origin_router();
204        let resp = app.oneshot(origin_request(None)).await.unwrap();
205        assert_eq!(resp.status(), StatusCode::OK);
206    }
207
208    #[tokio::test]
209    async fn origin_allows_localhost_http() {
210        let app = origin_router();
211        let resp = app
212            .oneshot(origin_request(Some("http://localhost:3000")))
213            .await
214            .unwrap();
215        assert_eq!(resp.status(), StatusCode::OK);
216    }
217
218    #[tokio::test]
219    async fn origin_allows_127_0_0_1_https() {
220        let app = origin_router();
221        let resp = app
222            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
223            .await
224            .unwrap();
225        assert_eq!(resp.status(), StatusCode::OK);
226    }
227
228    #[tokio::test]
229    async fn origin_allows_tauri_scheme() {
230        let app = origin_router();
231        let resp = app
232            .oneshot(origin_request(Some("tauri://localhost")))
233            .await
234            .unwrap();
235        assert_eq!(resp.status(), StatusCode::OK);
236    }
237
238    #[tokio::test]
239    async fn origin_blocks_null() {
240        let app = origin_router();
241        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
242        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
243    }
244
245    #[tokio::test]
246    async fn origin_blocks_evil_com() {
247        let app = origin_router();
248        let resp = app
249            .oneshot(origin_request(Some("http://evil.com")))
250            .await
251            .unwrap();
252        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
253    }
254
255    // ── Security Headers tests ────────────────────────────────────────────
256
257    fn security_headers_router() -> Router {
258        Router::new()
259            .route("/test", get(ok_handler))
260            .layer(middleware::from_fn(security_headers))
261    }
262
263    #[tokio::test]
264    async fn security_headers_x_content_type_options() {
265        let app = security_headers_router();
266        let req = axum::extract::Request::builder()
267            .uri("/test")
268            .body(Body::empty())
269            .unwrap();
270        let resp = app.oneshot(req).await.unwrap();
271        assert_eq!(resp.status(), StatusCode::OK);
272        assert_eq!(
273            resp.headers().get("x-content-type-options").unwrap(),
274            "nosniff"
275        );
276    }
277
278    #[tokio::test]
279    async fn security_headers_cache_control() {
280        let app = security_headers_router();
281        let req = axum::extract::Request::builder()
282            .uri("/test")
283            .body(Body::empty())
284            .unwrap();
285        let resp = app.oneshot(req).await.unwrap();
286        assert_eq!(resp.status(), StatusCode::OK);
287        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
288    }
289
290    #[tokio::test]
291    async fn security_headers_x_frame_options() {
292        let app = security_headers_router();
293        let req = axum::extract::Request::builder()
294            .uri("/test")
295            .body(Body::empty())
296            .unwrap();
297        let resp = app.oneshot(req).await.unwrap();
298        assert_eq!(resp.status(), StatusCode::OK);
299        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
300    }
301
302    // ── Auth middleware integration tests ─────────────────────────────────
303
304    fn auth_router(token: Option<&str>) -> Router {
305        let state = Arc::new(AuthState {
306            token: token.map(String::from),
307        });
308        Router::new()
309            .route("/test", get(ok_handler))
310            .layer(middleware::from_fn_with_state(state, require_auth))
311    }
312
313    fn auth_request(token: Option<&str>) -> axum::extract::Request<Body> {
314        let mut builder = axum::extract::Request::builder().uri("/test");
315        if let Some(t) = token {
316            builder = builder.header("authorization", format!("Bearer {t}"));
317        }
318        builder.body(Body::empty()).unwrap()
319    }
320
321    #[tokio::test]
322    async fn auth_allows_correct_token() {
323        let app = auth_router(Some("secret-123"));
324        let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
325        assert_eq!(resp.status(), StatusCode::OK);
326    }
327
328    #[tokio::test]
329    async fn auth_rejects_wrong_token() {
330        let app = auth_router(Some("secret-123"));
331        let resp = app
332            .oneshot(auth_request(Some("wrong-token")))
333            .await
334            .unwrap();
335        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
336    }
337
338    #[tokio::test]
339    async fn auth_rejects_missing_token() {
340        let app = auth_router(Some("secret-123"));
341        let resp = app.oneshot(auth_request(None)).await.unwrap();
342        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
343    }
344
345    #[tokio::test]
346    async fn auth_allows_any_when_disabled() {
347        let app = auth_router(None);
348        let resp = app.oneshot(auth_request(None)).await.unwrap();
349        assert_eq!(resp.status(), StatusCode::OK);
350    }
351
352    #[tokio::test]
353    async fn auth_case_insensitive_bearer_prefix() {
354        let state = Arc::new(AuthState {
355            token: Some("my-token".into()),
356        });
357        let app = Router::new()
358            .route("/test", get(ok_handler))
359            .layer(middleware::from_fn_with_state(state, require_auth));
360
361        let req = axum::extract::Request::builder()
362            .uri("/test")
363            .header("authorization", "BEARER my-token")
364            .body(Body::empty())
365            .unwrap();
366        let resp = app.oneshot(req).await.unwrap();
367        assert_eq!(resp.status(), StatusCode::OK);
368    }
369
370    #[tokio::test]
371    async fn auth_rejects_non_bearer_scheme() {
372        let app = auth_router(Some("secret"));
373        let req = axum::extract::Request::builder()
374            .uri("/test")
375            .header("authorization", "Basic c2VjcmV0")
376            .body(Body::empty())
377            .unwrap();
378        let resp = app.oneshot(req).await.unwrap();
379        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
380    }
381
382    // ── Rate limiter middleware integration test ──────────────────────────
383
384    #[tokio::test]
385    async fn rate_limiter_returns_429_when_exhausted() {
386        let limiter = Arc::new(RateLimiterState::new(2));
387        let app = Router::new()
388            .route("/test", get(ok_handler))
389            .layer(middleware::from_fn_with_state(limiter, rate_limit));
390
391        let app2 = app.clone();
392        let app3 = app2.clone();
393
394        let req = axum::extract::Request::builder()
395            .uri("/test")
396            .body(Body::empty())
397            .unwrap();
398        assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
399
400        let req = axum::extract::Request::builder()
401            .uri("/test")
402            .body(Body::empty())
403            .unwrap();
404        assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
405
406        let req = axum::extract::Request::builder()
407            .uri("/test")
408            .body(Body::empty())
409            .unwrap();
410        assert_eq!(
411            app3.oneshot(req).await.unwrap().status(),
412            StatusCode::TOO_MANY_REQUESTS
413        );
414    }
415
416    // ── Combined security layer test ─────────────────────────────────────
417
418    #[tokio::test]
419    async fn combined_layers_enforce_all_guards() {
420        let auth_state = Arc::new(AuthState {
421            token: Some("tok-123".into()),
422        });
423        let limiter = Arc::new(RateLimiterState::new(100));
424
425        let app = Router::new()
426            .route("/test", get(ok_handler))
427            .layer(middleware::from_fn_with_state(auth_state, require_auth))
428            .layer(middleware::from_fn_with_state(limiter, rate_limit))
429            .layer(middleware::from_fn(security_headers))
430            .layer(middleware::from_fn(origin_guard))
431            .layer(middleware::from_fn(dns_rebinding_guard));
432
433        // Good request: all guards pass
434        let req = axum::extract::Request::builder()
435            .uri("/test")
436            .header("authorization", "Bearer tok-123")
437            .header("host", "127.0.0.1:7373")
438            .body(Body::empty())
439            .unwrap();
440        let resp = app.clone().oneshot(req).await.unwrap();
441        assert_eq!(resp.status(), StatusCode::OK);
442        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
443
444        // Bad host: DNS rebinding guard blocks
445        let req = axum::extract::Request::builder()
446            .uri("/test")
447            .header("authorization", "Bearer tok-123")
448            .header("host", "evil.com")
449            .body(Body::empty())
450            .unwrap();
451        let resp = app.clone().oneshot(req).await.unwrap();
452        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
453
454        // Bad origin: origin guard blocks
455        let req = axum::extract::Request::builder()
456            .uri("/test")
457            .header("authorization", "Bearer tok-123")
458            .header("host", "localhost")
459            .header("origin", "https://evil.com")
460            .body(Body::empty())
461            .unwrap();
462        let resp = app.clone().oneshot(req).await.unwrap();
463        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
464
465        // Missing auth: auth middleware blocks
466        let req = axum::extract::Request::builder()
467            .uri("/test")
468            .header("host", "localhost")
469            .body(Body::empty())
470            .unwrap();
471        let resp = app.oneshot(req).await.unwrap();
472        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
473    }
474
475    #[test]
476    fn origin_guard_allows_localhost_variants() {
477        assert!(is_allowed_origin("http://localhost"));
478        assert!(is_allowed_origin("http://localhost:7373"));
479        assert!(is_allowed_origin("https://localhost"));
480        assert!(is_allowed_origin("https://localhost:443"));
481        assert!(is_allowed_origin("http://127.0.0.1"));
482        assert!(is_allowed_origin("http://127.0.0.1:8080"));
483        assert!(is_allowed_origin("https://127.0.0.1"));
484        assert!(is_allowed_origin("http://[::1]"));
485        assert!(is_allowed_origin("http://[::1]:7373"));
486        assert!(is_allowed_origin("tauri://localhost"));
487        assert!(is_allowed_origin("tauri://some-app"));
488    }
489
490    #[test]
491    fn origin_guard_rejects_prefix_smuggling() {
492        assert!(!is_allowed_origin("http://localhost.evil.com"));
493        assert!(!is_allowed_origin("https://localhost.evil.com"));
494        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
495        assert!(!is_allowed_origin("http://[::1].evil.com"));
496    }
497
498    #[test]
499    fn origin_guard_rejects_userinfo_trick() {
500        assert!(!is_allowed_origin("http://localhost@evil.com"));
501        assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
502    }
503
504    #[test]
505    fn origin_guard_rejects_foreign_and_malformed() {
506        assert!(!is_allowed_origin("http://evil.com"));
507        assert!(!is_allowed_origin("https://attacker.io"));
508        assert!(!is_allowed_origin("not-a-url"));
509        assert!(!is_allowed_origin(""));
510        assert!(!is_allowed_origin("ftp://localhost"));
511    }
512
513    // ── Constant-time comparison tests ───────────────────────────────────
514
515    #[test]
516    fn constant_time_eq_equal_strings() {
517        assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
518    }
519
520    #[test]
521    fn constant_time_eq_different_strings() {
522        assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
523    }
524
525    #[test]
526    fn constant_time_eq_different_lengths() {
527        assert!(!constant_time_eq(b"short", b"longer-string"));
528    }
529
530    #[test]
531    fn constant_time_eq_empty_strings() {
532        assert!(constant_time_eq(b"", b""));
533    }
534
535    #[test]
536    fn constant_time_eq_one_empty() {
537        assert!(!constant_time_eq(b"", b"notempty"));
538        assert!(!constant_time_eq(b"notempty", b""));
539    }
540
541    #[test]
542    fn constant_time_eq_single_bit_difference() {
543        assert!(!constant_time_eq(b"A", b"B"));
544    }
545
546    // ── Security headers: CORS + CSP tests ───────────────────────────────
547
548    #[tokio::test]
549    async fn security_headers_cors_deny() {
550        let app = security_headers_router();
551        let req = axum::extract::Request::builder()
552            .uri("/test")
553            .body(Body::empty())
554            .unwrap();
555        let resp = app.oneshot(req).await.unwrap();
556        assert_eq!(
557            resp.headers().get("access-control-allow-origin").unwrap(),
558            "null"
559        );
560    }
561
562    #[tokio::test]
563    async fn security_headers_csp() {
564        let app = security_headers_router();
565        let req = axum::extract::Request::builder()
566            .uri("/test")
567            .body(Body::empty())
568            .unwrap();
569        let resp = app.oneshot(req).await.unwrap();
570        assert_eq!(
571            resp.headers().get("content-security-policy").unwrap(),
572            "default-src 'none'"
573        );
574    }
575}