Skip to main content

fraiseql_server/auth/
constant_time.rs

1// Constant-time comparison utilities
2// Prevents timing attacks on token validation
3//
4// ## Integration Points
5//
6// This module provides utilities for constant-time token comparison to prevent
7// timing attacks. Key integration points:
8//
9// 1. **JWT Validation**: Already handled by `jsonwebtoken` crate (uses `subtle` internally)
10// 2. **Session Token Comparison**: Use `compare_session_token()` or `compare_hmac()` when comparing
11//    session token hashes in session_postgres.rs
12// 3. **CSRF State Validation**: Use `compare_state_token()` in state_store retrieve()
13// 4. **PKCE Verifier**: Use `compare_pkce_verifier()` in auth_callback()
14// 5. **Refresh Token Hashes**: Use `compare_refresh_token()` or `compare_hmac()`
15//
16// See constant_time_refactor_notes.md for detailed integration guide.
17
18use subtle::ConstantTimeEq;
19
20/// Constant-time comparison utilities for security tokens
21/// Uses subtle crate to ensure comparisons take the same time regardless of where differences occur
22pub struct ConstantTimeOps;
23
24impl ConstantTimeOps {
25    /// Compare two byte slices in constant time
26    ///
27    /// Returns true if equal, false otherwise.
28    /// Time is independent of where the difference occurs, preventing timing attacks.
29    ///
30    /// # Arguments
31    /// * `expected` - The expected (correct/known) value
32    /// * `actual` - The actual (untrusted) value from the user/attacker
33    ///
34    /// # Examples
35    /// ```ignore
36    /// let stored_token = b"secret_token_value";
37    /// let user_token = b"user_provided_token";
38    /// assert!(!ConstantTimeOps::compare(stored_token, user_token));
39    /// ```
40    pub fn compare(expected: &[u8], actual: &[u8]) -> bool {
41        expected.ct_eq(actual).into()
42    }
43
44    /// Compare two strings in constant time
45    ///
46    /// Converts strings to bytes and performs constant-time comparison.
47    /// Useful for comparing JWT tokens, session tokens, or other string-based secrets.
48    ///
49    /// # Arguments
50    /// * `expected` - The expected (correct/known) string value
51    /// * `actual` - The actual (untrusted) string value from the user/attacker
52    pub fn compare_str(expected: &str, actual: &str) -> bool {
53        Self::compare(expected.as_bytes(), actual.as_bytes())
54    }
55
56    /// Compare two slices with different lengths in constant time
57    ///
58    /// If lengths differ, still compares as much as possible to avoid leaking
59    /// length information through timing.
60    pub fn compare_len_safe(expected: &[u8], actual: &[u8]) -> bool {
61        // If lengths differ, still compare constant-time
62        // First compare what we can, then check length
63        let min_len = expected.len().min(actual.len());
64        let prefix_equal = expected[..min_len].ct_eq(&actual[..min_len]);
65        let length_equal = (expected.len() == actual.len()) as u8;
66
67        (prefix_equal.unwrap_u8() & length_equal) != 0
68    }
69
70    /// Compare JWT tokens in constant time
71    /// Handles the common case of JWT with header.payload.signature format
72    pub fn compare_jwt(expected: &str, actual: &str) -> bool {
73        Self::compare_str(expected, actual)
74    }
75
76    /// Compare session tokens in constant time
77    /// Handles session_id:signature format
78    pub fn compare_session_token(expected: &str, actual: &str) -> bool {
79        Self::compare_str(expected, actual)
80    }
81
82    /// Compare CSRF tokens in constant time
83    pub fn compare_csrf_token(expected: &str, actual: &str) -> bool {
84        Self::compare_str(expected, actual)
85    }
86
87    /// Compare HMAC signatures in constant time
88    /// Used for verifying webhook signatures and other HMAC-based authenticity
89    pub fn compare_hmac(expected: &[u8], actual: &[u8]) -> bool {
90        Self::compare(expected, actual)
91    }
92
93    /// Compare refresh tokens in constant time
94    pub fn compare_refresh_token(expected: &str, actual: &str) -> bool {
95        Self::compare_str(expected, actual)
96    }
97
98    /// Compare authorization codes in constant time (used in OAuth flows)
99    pub fn compare_auth_code(expected: &str, actual: &str) -> bool {
100        Self::compare_str(expected, actual)
101    }
102
103    /// Compare PKCE code verifier in constant time
104    pub fn compare_pkce_verifier(expected: &str, actual: &str) -> bool {
105        Self::compare_str(expected, actual)
106    }
107
108    /// Compare state tokens in constant time (CSRF protection in OAuth)
109    pub fn compare_state_token(expected: &str, actual: &str) -> bool {
110        Self::compare_str(expected, actual)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_compare_equal_bytes() {
120        let token1 = b"equal_token_value";
121        let token2 = b"equal_token_value";
122        assert!(ConstantTimeOps::compare(token1, token2));
123    }
124
125    #[test]
126    fn test_compare_different_bytes() {
127        let token1 = b"expected_token";
128        let token2 = b"actual_token_x";
129        assert!(!ConstantTimeOps::compare(token1, token2));
130    }
131
132    #[test]
133    fn test_compare_equal_strings() {
134        let token1 = "equal_token_value";
135        let token2 = "equal_token_value";
136        assert!(ConstantTimeOps::compare_str(token1, token2));
137    }
138
139    #[test]
140    fn test_compare_different_strings() {
141        let token1 = "expected_token";
142        let token2 = "actual_token_x";
143        assert!(!ConstantTimeOps::compare_str(token1, token2));
144    }
145
146    #[test]
147    fn test_compare_empty() {
148        let token1 = b"";
149        let token2 = b"";
150        assert!(ConstantTimeOps::compare(token1, token2));
151    }
152
153    #[test]
154    fn test_compare_different_lengths() {
155        let token1 = b"short";
156        let token2 = b"much_longer_token";
157        assert!(!ConstantTimeOps::compare(token1, token2));
158    }
159
160    #[test]
161    fn test_compare_len_safe() {
162        let expected = b"abcdefghij";
163        let actual = b"abcdefghij";
164        assert!(ConstantTimeOps::compare_len_safe(expected, actual));
165
166        let different = b"abcdefghix";
167        assert!(!ConstantTimeOps::compare_len_safe(expected, different));
168
169        let shorter = b"abcdefgh";
170        assert!(!ConstantTimeOps::compare_len_safe(expected, shorter));
171    }
172
173    #[test]
174    fn test_jwt_comparison() {
175        let jwt1 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature123";
176        let jwt2 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature123";
177        assert!(ConstantTimeOps::compare_jwt(jwt1, jwt2));
178
179        let different = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.signature999";
180        assert!(!ConstantTimeOps::compare_jwt(jwt1, different));
181    }
182
183    #[test]
184    fn test_session_token_comparison() {
185        let token1 = "sess_abc123:hmac_sig_xyz";
186        let token2 = "sess_abc123:hmac_sig_xyz";
187        assert!(ConstantTimeOps::compare_session_token(token1, token2));
188
189        let different = "sess_abc123:hmac_sig_abc";
190        assert!(!ConstantTimeOps::compare_session_token(token1, different));
191    }
192
193    #[test]
194    fn test_csrf_token_comparison() {
195        let token1 = "csrf_token_xyz123abc";
196        let token2 = "csrf_token_xyz123abc";
197        assert!(ConstantTimeOps::compare_csrf_token(token1, token2));
198
199        let different = "csrf_token_abc123xyz";
200        assert!(!ConstantTimeOps::compare_csrf_token(token1, different));
201    }
202
203    #[test]
204    fn test_hmac_comparison() {
205        let sig1 = b"\x48\x6d\x61\x63\x5f\x73\x69\x67\x6e\x61\x74\x75\x72\x65";
206        let sig2 = b"\x48\x6d\x61\x63\x5f\x73\x69\x67\x6e\x61\x74\x75\x72\x65";
207        assert!(ConstantTimeOps::compare_hmac(sig1, sig2));
208
209        let different = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
210        assert!(!ConstantTimeOps::compare_hmac(sig1, different));
211    }
212
213    #[test]
214    fn test_refresh_token_comparison() {
215        let token1 = "refresh_token_long_value_xyz";
216        let token2 = "refresh_token_long_value_xyz";
217        assert!(ConstantTimeOps::compare_refresh_token(token1, token2));
218
219        let different = "refresh_token_long_value_abc";
220        assert!(!ConstantTimeOps::compare_refresh_token(token1, different));
221    }
222
223    #[test]
224    fn test_auth_code_comparison() {
225        let code1 = "auth_code_xyz_123_abc";
226        let code2 = "auth_code_xyz_123_abc";
227        assert!(ConstantTimeOps::compare_auth_code(code1, code2));
228
229        let different = "auth_code_xyz_123_xyz";
230        assert!(!ConstantTimeOps::compare_auth_code(code1, different));
231    }
232
233    #[test]
234    fn test_state_token_comparison() {
235        let state1 = "state_token_xyz123abc";
236        let state2 = "state_token_xyz123abc";
237        assert!(ConstantTimeOps::compare_state_token(state1, state2));
238
239        let different = "state_token_abc123xyz";
240        assert!(!ConstantTimeOps::compare_state_token(state1, different));
241    }
242
243    #[test]
244    fn test_null_bytes_comparison() {
245        let token1 = b"token\x00with\x00nulls";
246        let token2 = b"token\x00with\x00nulls";
247        assert!(ConstantTimeOps::compare(token1, token2));
248
249        let different = b"token\x00with\x00other";
250        assert!(!ConstantTimeOps::compare(token1, different));
251    }
252
253    #[test]
254    fn test_all_byte_values() {
255        let mut token1 = vec![0u8; 256];
256        let mut token2 = vec![0u8; 256];
257        for i in 0..256 {
258            token1[i] = i as u8;
259            token2[i] = i as u8;
260        }
261
262        assert!(ConstantTimeOps::compare(&token1, &token2));
263
264        token2[127] = token2[127].wrapping_add(1);
265        assert!(!ConstantTimeOps::compare(&token1, &token2));
266    }
267
268    #[test]
269    fn test_very_long_tokens() {
270        let token1: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
271        let token2 = token1.clone();
272        assert!(ConstantTimeOps::compare(&token1, &token2));
273
274        let mut token3 = token1.clone();
275        token3[5_000] = token3[5_000].wrapping_add(1);
276        assert!(!ConstantTimeOps::compare(&token1, &token3));
277    }
278}