1pub use victauri_core::middleware::{
2 AuthState, default_rate_limiter, dns_rebinding_guard, origin_guard, rate_limit, require_auth,
3 security_headers,
4};
5pub use victauri_core::security::{
6 self, RateLimiter as RateLimiterState, constant_time_eq, generate_token, is_allowed_origin,
7 is_localhost_host,
8};
9
10#[cfg(test)]
11mod tests {
12 use std::sync::Arc;
13
14 use super::*;
15
16 #[test]
17 fn token_generation() {
18 let t1 = generate_token();
19 let t2 = generate_token();
20 assert_ne!(t1, t2);
21 assert_eq!(t1.len(), 36);
22 }
23
24 #[test]
25 fn rate_limiter_allows_within_budget() {
26 let limiter = RateLimiterState::new(10);
27 for _ in 0..10 {
28 assert!(limiter.try_acquire());
29 }
30 assert!(!limiter.try_acquire());
31 }
32
33 #[test]
34 fn constant_time_eq_works() {
35 assert!(constant_time_eq(b"hello", b"hello"));
36 assert!(!constant_time_eq(b"hello", b"world"));
37 assert!(!constant_time_eq(b"hello", b"hell"));
38 }
39
40 #[test]
41 fn constant_time_eq_empty_strings() {
42 assert!(constant_time_eq(b"", b""));
43 assert!(!constant_time_eq(b"", b"x"));
44 }
45
46 #[test]
47 fn constant_time_eq_single_bit_diff() {
48 assert!(!constant_time_eq(b"\x00", b"\x01"));
49 assert!(!constant_time_eq(b"\xff", b"\xfe"));
50 }
51
52 #[test]
53 fn rate_limiter_single_token() {
54 let limiter = RateLimiterState::new(1);
55 assert!(limiter.try_acquire());
56 assert!(!limiter.try_acquire());
57 }
58
59 #[test]
60 fn token_format_is_uuid() {
61 let token = generate_token();
62 assert_eq!(token.len(), 36);
63 assert_eq!(token.chars().filter(|c| *c == '-').count(), 4);
64 }
65
66 #[test]
67 fn default_rate_limiter_has_budget() {
68 let limiter = default_rate_limiter();
69 assert!(limiter.try_acquire());
70 }
71
72 #[test]
75 fn rate_limiter_exact_boundary() {
76 let limiter = RateLimiterState::new(100);
77 for i in 0..100 {
78 assert!(limiter.try_acquire(), "failed at iteration {i}");
79 }
80 assert!(!limiter.try_acquire());
81 assert!(!limiter.try_acquire());
82 assert!(!limiter.try_acquire());
83 }
84
85 #[test]
86 fn rate_limiter_concurrent_contention() {
87 let limiter = Arc::new(RateLimiterState::new(1000));
88 let mut handles = vec![];
89
90 for _ in 0..10 {
91 let l = Arc::clone(&limiter);
92 handles.push(std::thread::spawn(move || {
93 let mut acquired = 0u64;
94 for _ in 0..200 {
95 if l.try_acquire() {
96 acquired += 1;
97 }
98 }
99 acquired
100 }));
101 }
102
103 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
104 assert!(
105 total >= 1000,
106 "should dispense at least the initial budget, got {total}"
107 );
108 assert!(total <= 1200, "refill overshoot too high, got {total}");
109 }
110
111 #[test]
112 fn constant_time_eq_long_strings() {
113 let a = "a".repeat(10_000);
114 let b = "a".repeat(10_000);
115 assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
116
117 let mut c = "a".repeat(10_000);
118 c.push('b');
119 assert!(!constant_time_eq(a.as_bytes(), c.as_bytes()));
120 }
121
122 #[test]
123 fn constant_time_eq_timing_consistency() {
124 let token = "8f14e45f-ceea-367f-a27f-c790e5a0fdc4";
125 let wrong1 = "0000000f-ceea-367f-a27f-c790e5a0fdc4";
126 let wrong2 = "8f14e45f-ceea-367f-a27f-c790e5a0fd00";
127
128 assert!(!constant_time_eq(token.as_bytes(), wrong1.as_bytes()));
129 assert!(!constant_time_eq(token.as_bytes(), wrong2.as_bytes()));
130 }
131
132 #[test]
133 fn token_uniqueness_over_1000_generations() {
134 let mut tokens = std::collections::HashSet::new();
135 for _ in 0..1000 {
136 let t = generate_token();
137 assert!(tokens.insert(t), "duplicate token generated");
138 }
139 }
140
141 #[test]
142 fn rate_limiter_zero_budget() {
143 let limiter = RateLimiterState::new(0);
144 assert!(!limiter.try_acquire());
145 }
146
147 #[test]
148 fn constant_time_eq_all_byte_values() {
149 for b in 0..=255u8 {
150 let a = [b];
151 assert!(constant_time_eq(&a, &a));
152 if b < 255 {
153 let c = [b + 1];
154 assert!(!constant_time_eq(&a, &c));
155 }
156 }
157 }
158
159 #[test]
162 fn dns_rebinding_guard_allows_localhost() {
163 assert!(is_localhost_host("localhost"));
164 assert!(is_localhost_host("localhost:7474"));
165 }
166
167 #[test]
168 fn dns_rebinding_guard_allows_127() {
169 assert!(is_localhost_host("127.0.0.1"));
170 assert!(is_localhost_host("127.0.0.1:7474"));
171 }
172
173 #[test]
174 fn dns_rebinding_guard_allows_ipv6() {
175 assert!(is_localhost_host("[::1]"));
176 assert!(is_localhost_host("[::1]:7474"));
177 assert!(is_localhost_host("::1"));
178 }
179
180 #[test]
181 fn dns_rebinding_guard_blocks_evil() {
182 assert!(!is_localhost_host("evil.com"));
183 }
184
185 #[test]
186 fn dns_rebinding_guard_blocks_localhost_subdomain() {
187 assert!(!is_localhost_host("localhost.evil.com"));
188 }
189
190 #[test]
191 fn dns_rebinding_guard_blocks_empty() {
192 assert!(!is_localhost_host(""));
193 }
194
195 #[test]
198 fn localhost_origin_accepted() {
199 assert!(is_allowed_origin("http://localhost:3000"));
200 assert!(is_allowed_origin("http://localhost"));
201 assert!(is_allowed_origin("https://localhost:7474"));
202 }
203
204 #[test]
205 fn ipv4_loopback_accepted() {
206 assert!(is_allowed_origin("http://127.0.0.1:7474"));
207 assert!(is_allowed_origin("http://127.0.0.1"));
208 assert!(is_allowed_origin("https://127.0.0.1:443"));
209 }
210
211 #[test]
212 fn ipv6_loopback_accepted() {
213 assert!(is_allowed_origin("http://[::1]:7474"));
214 assert!(is_allowed_origin("http://[::1]"));
215 }
216
217 #[test]
218 fn subdomain_bypass_rejected() {
219 assert!(!is_allowed_origin("https://localhost.evil.com"));
220 assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
221 assert!(!is_allowed_origin("https://evil-localhost.com"));
222 }
223
224 #[test]
225 fn path_bypass_rejected() {
226 assert!(!is_allowed_origin("https://evil.com/localhost"));
227 assert!(!is_allowed_origin("https://evil.com/127.0.0.1"));
228 }
229
230 #[test]
231 fn external_origins_rejected() {
232 assert!(!is_allowed_origin("https://google.com"));
233 assert!(!is_allowed_origin("https://example.com:443"));
234 assert!(!is_allowed_origin("http://attacker.com"));
235 }
236}