fraiseql_auth/
constant_time.rs1use subtle::ConstantTimeEq;
9
10pub struct ConstantTimeOps;
13
14impl ConstantTimeOps {
15 pub fn compare(expected: &[u8], actual: &[u8]) -> bool {
32 expected.ct_eq(actual).into()
33 }
34
35 pub fn compare_str(expected: &str, actual: &str) -> bool {
44 Self::compare(expected.as_bytes(), actual.as_bytes())
45 }
46
47 pub fn compare_len_safe(expected: &[u8], actual: &[u8]) -> bool {
57 let min_len = expected.len().min(actual.len());
60 let prefix_equal = expected[..min_len].ct_eq(&actual[..min_len]);
61 let length_equal = u8::from(expected.len() == actual.len());
62
63 (prefix_equal.unwrap_u8() & length_equal) != 0
64 }
65
66 pub fn compare_padded(expected: &[u8], actual: &[u8], fixed_len: usize) -> bool {
92 let mut expected_padded = vec![0u8; fixed_len];
96 let mut actual_padded = vec![0u8; fixed_len];
97
98 let copy_expected = expected.len().min(fixed_len);
99 expected_padded[..copy_expected].copy_from_slice(&expected[..copy_expected]);
100
101 let copy_actual = actual.len().min(fixed_len);
102 actual_padded[..copy_actual].copy_from_slice(&actual[..copy_actual]);
103
104 expected_padded.ct_eq(&actual_padded).into()
106 }
107
108 pub fn compare_jwt_constant(expected: &str, actual: &str) -> bool {
113 Self::compare_padded(expected.as_bytes(), actual.as_bytes(), 512)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 #[allow(clippy::wildcard_imports)]
121 use super::*;
123
124 #[test]
125 fn test_compare_equal_bytes() {
126 let token1 = b"equal_token_value";
127 let token2 = b"equal_token_value";
128 assert!(ConstantTimeOps::compare(token1, token2));
129 }
130
131 #[test]
132 fn test_compare_different_bytes() {
133 let token1 = b"expected_token";
134 let token2 = b"actual_token_x";
135 assert!(!ConstantTimeOps::compare(token1, token2));
136 }
137
138 #[test]
139 fn test_compare_equal_strings() {
140 let token1 = "equal_token_value";
141 let token2 = "equal_token_value";
142 assert!(ConstantTimeOps::compare_str(token1, token2));
143 }
144
145 #[test]
146 fn test_compare_different_strings() {
147 let token1 = "expected_token";
148 let token2 = "actual_token_x";
149 assert!(!ConstantTimeOps::compare_str(token1, token2));
150 }
151
152 #[test]
153 fn test_compare_empty() {
154 let token1 = b"";
155 let token2 = b"";
156 assert!(ConstantTimeOps::compare(token1, token2));
157 }
158
159 #[test]
160 fn test_compare_different_lengths() {
161 let token1 = b"short";
162 let token2 = b"much_longer_token";
163 assert!(!ConstantTimeOps::compare(token1, token2));
164 }
165
166 #[test]
167 fn test_compare_len_safe() {
168 let expected = b"abcdefghij";
169 let actual = b"abcdefghij";
170 assert!(ConstantTimeOps::compare_len_safe(expected, actual));
171
172 let different = b"abcdefghix";
173 assert!(!ConstantTimeOps::compare_len_safe(expected, different));
174
175 let shorter = b"abcdefgh";
176 assert!(!ConstantTimeOps::compare_len_safe(expected, shorter));
177 }
178
179 #[test]
180 fn test_null_bytes_comparison() {
181 let token1 = b"token\x00with\x00nulls";
182 let token2 = b"token\x00with\x00nulls";
183 assert!(ConstantTimeOps::compare(token1, token2));
184
185 let different = b"token\x00with\x00other";
186 assert!(!ConstantTimeOps::compare(token1, different));
187 }
188
189 #[test]
190 fn test_all_byte_values() {
191 let mut token1 = vec![0u8; 256];
192 let mut token2 = vec![0u8; 256];
193 for i in 0..256 {
194 #[allow(clippy::cast_possible_truncation)]
195 let byte = i as u8;
197 token1[i] = byte;
198 token2[i] = byte;
199 }
200
201 assert!(ConstantTimeOps::compare(&token1, &token2));
202
203 token2[127] = token2[127].wrapping_add(1);
204 assert!(!ConstantTimeOps::compare(&token1, &token2));
205 }
206
207 #[test]
208 fn test_very_long_tokens() {
209 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
210 let token1: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
212 let token2 = token1.clone();
213 assert!(ConstantTimeOps::compare(&token1, &token2));
214
215 let mut token3 = token1.clone();
216 token3[5_000] = token3[5_000].wrapping_add(1);
217 assert!(!ConstantTimeOps::compare(&token1, &token3));
218 }
219
220 #[test]
221 fn test_compare_padded_equal_length() {
222 let token1 = b"same_token_value";
223 let token2 = b"same_token_value";
224 assert!(ConstantTimeOps::compare_padded(token1, token2, 512));
225 }
226
227 #[test]
228 fn test_compare_padded_different_length_shorter_actual() {
229 let expected = b"this_is_expected_token_value";
230 let actual = b"short";
231 assert!(!ConstantTimeOps::compare_padded(expected, actual, 512));
233 }
234
235 #[test]
236 fn test_compare_padded_different_length_longer_actual() {
237 let expected = b"expected";
238 let actual = b"this_is_a_much_longer_actual_token_that_exceeds_expected";
239 assert!(!ConstantTimeOps::compare_padded(expected, actual, 512));
241 }
242
243 #[test]
244 fn test_compare_padded_timing_consistency() {
245 let short_token = b"short";
247 let long_token = b"this_is_a_much_longer_token_value_with_more_content";
248
249 let _ = ConstantTimeOps::compare_padded(short_token, short_token, 512);
252 let _ = ConstantTimeOps::compare_padded(long_token, long_token, 512);
253
254 assert!(ConstantTimeOps::compare_padded(short_token, short_token, 512));
256 assert!(ConstantTimeOps::compare_padded(long_token, long_token, 512));
257 }
258
259 #[test]
260 fn test_compare_jwt_constant() {
261 let jwt1 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature123";
262 let jwt2 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature123";
263 assert!(ConstantTimeOps::compare_jwt_constant(jwt1, jwt2));
264 }
265
266 #[test]
267 fn test_compare_jwt_constant_different() {
268 let jwt1 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature123";
269 let jwt2 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature999";
270 assert!(!ConstantTimeOps::compare_jwt_constant(jwt1, jwt2));
271 }
272
273 #[test]
274 fn test_compare_jwt_constant_prevents_length_attack() {
275 let short_invalid_jwt = "short";
277 let long_valid_jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.sig123";
278
279 assert!(!ConstantTimeOps::compare_jwt_constant(short_invalid_jwt, long_valid_jwt));
281
282 assert!(!ConstantTimeOps::compare_jwt_constant(short_invalid_jwt, long_valid_jwt));
285 }
286
287 #[test]
288 fn test_compare_padded_zero_length() {
289 let token1 = b"";
291 let token2 = b"";
292 assert!(ConstantTimeOps::compare_padded(token1, token2, 512));
293 }
294
295 #[test]
296 fn test_compare_padded_exact_fixed_length() {
297 let token = b"a".repeat(512);
299 assert!(ConstantTimeOps::compare_padded(&token, &token, 512));
300
301 let mut different = token.clone();
302 different[256] = different[256].wrapping_add(1);
303 assert!(!ConstantTimeOps::compare_padded(&token, &different, 512));
304 }
305
306 #[test]
307 fn test_compare_padded_large_fixed_len() {
308 let token1 = b"test";
310 let token2 = b"test";
311 assert!(ConstantTimeOps::compare_padded(token1, token2, 2048));
312
313 let long_a: Vec<u8> = b"prefix".iter().chain(b"AAAA".iter()).copied().collect();
315 let long_b: Vec<u8> = b"prefix".iter().chain(b"BBBB".iter()).copied().collect();
316 assert!(ConstantTimeOps::compare_padded(&long_a, &long_b, 6));
318 assert!(!ConstantTimeOps::compare_padded(&long_a, &long_b, 10));
320 }
321
322 #[test]
323 fn test_timing_attack_prevention_early_difference() {
324 let token1 = b"XXXXXXX_correct_token";
326 let token2 = b"YYYYYYY_correct_token";
327 let result = ConstantTimeOps::compare(token1, token2);
328 assert!(!result);
329 }
331
332 #[test]
333 fn test_timing_attack_prevention_late_difference() {
334 let token1 = b"correct_token_XXXXXXX";
336 let token2 = b"correct_token_YYYYYYY";
337 let result = ConstantTimeOps::compare(token1, token2);
338 assert!(!result);
339 }
341
342 #[test]
343 fn test_jwt_constant_padding() {
344 let short_jwt = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1c2VyIn0.abc";
346 let padded_jwt = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1c2VyIn0.abc";
347 assert!(ConstantTimeOps::compare_jwt_constant(short_jwt, padded_jwt));
348 }
349
350 #[test]
351 fn test_jwt_constant_different_lengths() {
352 let jwt1 = "short";
354 let jwt2 = "very_long_jwt_token_with_lots_of_data_making_it_much_longer";
355 let result = ConstantTimeOps::compare_jwt_constant(jwt1, jwt2);
356 assert!(!result);
357 }
359}