Skip to main content

fraiseql_server/auth/
state_encryption.rs

1// State encryption for PKCE protection
2// Encrypts OAuth state parameters using ChaCha20-Poly1305 AEAD
3
4use chacha20poly1305::{
5    ChaCha20Poly1305, Nonce,
6    aead::{Aead, KeyInit, Payload},
7};
8use rand::RngCore;
9
10use crate::auth::{AuthError, error::Result};
11
12/// Encrypted state container with nonce
13#[derive(Debug, Clone)]
14pub struct EncryptedState {
15    /// Ciphertext with authentication tag appended
16    pub ciphertext: Vec<u8>,
17    /// 96-bit nonce used for encryption
18    pub nonce:      [u8; 12],
19}
20
21impl EncryptedState {
22    /// Create new encrypted state
23    pub fn new(ciphertext: Vec<u8>, nonce: [u8; 12]) -> Self {
24        Self { ciphertext, nonce }
25    }
26
27    /// Serialize to bytes for storage
28    /// Format: [12-byte nonce][ciphertext with auth tag]
29    pub fn to_bytes(&self) -> Vec<u8> {
30        let mut bytes = Vec::with_capacity(12 + self.ciphertext.len());
31        bytes.extend_from_slice(&self.nonce);
32        bytes.extend_from_slice(&self.ciphertext);
33        bytes
34    }
35
36    /// Deserialize from bytes
37    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
38        if bytes.len() < 12 {
39            return Err(AuthError::InvalidState);
40        }
41
42        let mut nonce = [0u8; 12];
43        nonce.copy_from_slice(&bytes[0..12]);
44        let ciphertext = bytes[12..].to_vec();
45
46        Ok(Self::new(ciphertext, nonce))
47    }
48}
49
50/// State encryption using ChaCha20-Poly1305 AEAD
51///
52/// Provides authenticated encryption for OAuth state parameters.
53/// Uses a fixed encryption key for the deployment lifetime.
54/// Each encryption uses a random nonce for security.
55///
56/// # Security Properties
57/// - **Confidentiality**: State values are encrypted with ChaCha20
58/// - **Authenticity**: Authentication tag prevents tampering detection
59/// - **Replay Prevention**: Random nonce in each encryption
60/// - **Key Isolation**: Separate from signing keys, used only for state
61pub struct StateEncryption {
62    cipher: ChaCha20Poly1305,
63}
64
65impl StateEncryption {
66    /// Create a new state encryption instance
67    ///
68    /// # Arguments
69    /// * `key` - 32-byte encryption key (must be cryptographically random)
70    ///
71    /// # Errors
72    /// Returns error if key is invalid
73    pub fn new(key_bytes: &[u8; 32]) -> Result<Self> {
74        let cipher =
75            ChaCha20Poly1305::new_from_slice(key_bytes).map_err(|_| AuthError::ConfigError {
76                message: "Invalid state encryption key".to_string(),
77            })?;
78
79        Ok(Self { cipher })
80    }
81
82    /// Encrypt a state value
83    ///
84    /// Generates a random 96-bit nonce and encrypts the state using ChaCha20-Poly1305.
85    /// The authentication tag is appended to the ciphertext.
86    ///
87    /// # Arguments
88    /// * `state` - The plaintext state value to encrypt
89    ///
90    /// # Returns
91    /// EncryptedState containing ciphertext and nonce
92    ///
93    /// # Errors
94    /// Returns error if encryption fails (should be rare)
95    pub fn encrypt(&self, state: &str) -> Result<EncryptedState> {
96        // Generate random 96-bit nonce
97        let mut nonce_bytes = [0u8; 12];
98        rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
99        let nonce = Nonce::from(nonce_bytes);
100
101        // Encrypt with AEAD (includes authentication tag)
102        let ciphertext =
103            self.cipher.encrypt(&nonce, Payload::from(state.as_bytes())).map_err(|_| {
104                AuthError::Internal {
105                    message: "State encryption failed".to_string(),
106                }
107            })?;
108
109        Ok(EncryptedState::new(ciphertext, nonce_bytes))
110    }
111
112    /// Decrypt and verify a state value
113    ///
114    /// Uses the nonce from EncryptedState to decrypt the ciphertext.
115    /// Authentication tag verification is automatic - tampering is detected.
116    ///
117    /// # Arguments
118    /// * `encrypted` - The encrypted state to decrypt
119    ///
120    /// # Returns
121    /// The decrypted plaintext state value
122    ///
123    /// # Errors
124    /// Returns error if:
125    /// - Authentication tag verification fails (tampering detected)
126    /// - Decryption fails
127    /// - Result is not valid UTF-8
128    pub fn decrypt(&self, encrypted: &EncryptedState) -> Result<String> {
129        let nonce = Nonce::from(encrypted.nonce);
130
131        // Decrypt and verify authentication tag
132        let plaintext = self
133            .cipher
134            .decrypt(&nonce, Payload::from(encrypted.ciphertext.as_slice()))
135            .map_err(|_| AuthError::InvalidState)?;
136
137        // Convert bytes to UTF-8 string
138        String::from_utf8(plaintext).map_err(|_| AuthError::InvalidState)
139    }
140
141    /// Encrypt state and serialize to bytes
142    pub fn encrypt_to_bytes(&self, state: &str) -> Result<Vec<u8>> {
143        let encrypted = self.encrypt(state)?;
144        Ok(encrypted.to_bytes())
145    }
146
147    /// Decrypt state from serialized bytes
148    pub fn decrypt_from_bytes(&self, bytes: &[u8]) -> Result<String> {
149        let encrypted = EncryptedState::from_bytes(bytes)?;
150        self.decrypt(&encrypted)
151    }
152}
153
154/// Generate a cryptographically random encryption key
155pub fn generate_state_encryption_key() -> [u8; 32] {
156    let mut key = [0u8; 32];
157    rand::rngs::OsRng.fill_bytes(&mut key);
158    key
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    fn test_key() -> [u8; 32] {
166        // Use deterministic test key
167        [42u8; 32]
168    }
169
170    #[test]
171    fn test_encrypt_decrypt() {
172        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
173        let state = "oauth_state_test_value";
174
175        let encrypted = encryption.encrypt(state).expect("Encryption failed");
176        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
177
178        assert_eq!(decrypted, state);
179    }
180
181    #[test]
182    fn test_encrypt_produces_ciphertext() {
183        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
184        let state = "test_state";
185
186        let encrypted = encryption.encrypt(state).expect("Encryption failed");
187
188        // Ciphertext should be different from plaintext (due to ChaCha20 encryption)
189        // Ciphertext should include auth tag, so typically longer than plaintext
190        assert_ne!(encrypted.ciphertext, state.as_bytes());
191    }
192
193    #[test]
194    fn test_empty_state() {
195        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
196        let state = "";
197
198        let encrypted = encryption.encrypt(state).expect("Encryption failed");
199        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
200
201        assert_eq!(decrypted, state);
202    }
203
204    #[test]
205    fn test_different_keys_fail_decryption() {
206        let key1 = [42u8; 32];
207        let key2 = [99u8; 32];
208        let state = "secret_state";
209
210        let encryption1 = StateEncryption::new(&key1).expect("Init 1 failed");
211        let encrypted = encryption1.encrypt(state).expect("Encryption failed");
212
213        let encryption2 = StateEncryption::new(&key2).expect("Init 2 failed");
214        let result = encryption2.decrypt(&encrypted);
215
216        // Different key should fail due to authentication tag mismatch
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_tampered_ciphertext_fails() {
222        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
223        let state = "tamper_test";
224
225        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
226
227        // Tamper with ciphertext
228        if !encrypted.ciphertext.is_empty() {
229            encrypted.ciphertext[0] ^= 0xFF;
230        }
231
232        // Should fail due to authentication tag verification
233        let result = encryption.decrypt(&encrypted);
234        assert!(result.is_err());
235    }
236
237    #[test]
238    fn test_tampered_nonce_fails() {
239        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
240        let state = "nonce_tamper";
241
242        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
243
244        // Tamper with nonce
245        encrypted.nonce[0] ^= 0xFF;
246
247        // Should fail due to authentication tag verification
248        let result = encryption.decrypt(&encrypted);
249        assert!(result.is_err());
250    }
251
252    #[test]
253    fn test_truncated_ciphertext_fails() {
254        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
255        let state = "truncation_test";
256
257        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
258
259        // Truncate (removes auth tag)
260        if encrypted.ciphertext.len() > 1 {
261            encrypted.ciphertext.truncate(encrypted.ciphertext.len() - 1);
262        }
263
264        // Should fail
265        let result = encryption.decrypt(&encrypted);
266        assert!(result.is_err());
267    }
268
269    #[test]
270    fn test_serialization() {
271        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
272        let state = "serialization_test";
273
274        // Encrypt and serialize
275        let bytes = encryption.encrypt_to_bytes(state).expect("Encryption failed");
276
277        // Deserialize and decrypt
278        let decrypted = encryption.decrypt_from_bytes(&bytes).expect("Decryption failed");
279
280        assert_eq!(decrypted, state);
281    }
282
283    #[test]
284    fn test_random_nonces() {
285        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
286        let state = "random_nonce_test";
287
288        let encrypted1 = encryption.encrypt(state).expect("Encryption 1 failed");
289        let encrypted2 = encryption.encrypt(state).expect("Encryption 2 failed");
290
291        // Nonces should be different (extremely unlikely to collide)
292        assert_ne!(encrypted1.nonce, encrypted2.nonce);
293
294        // Both should decrypt correctly
295        let decrypted1 = encryption.decrypt(&encrypted1).expect("Decryption 1 failed");
296        let decrypted2 = encryption.decrypt(&encrypted2).expect("Decryption 2 failed");
297
298        assert_eq!(decrypted1, state);
299        assert_eq!(decrypted2, state);
300    }
301
302    #[test]
303    fn test_long_state() {
304        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
305        let state = "a".repeat(10_000);
306
307        let encrypted = encryption.encrypt(&state).expect("Encryption failed");
308        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
309
310        assert_eq!(decrypted, state);
311    }
312
313    #[test]
314    fn test_special_characters() {
315        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
316        let state = "state:with-special_chars.and/symbols!@#$%^&*()";
317
318        let encrypted = encryption.encrypt(state).expect("Encryption failed");
319        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
320
321        assert_eq!(decrypted, state);
322    }
323
324    #[test]
325    fn test_unicode_state() {
326        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
327        let state = "state_with_emoji_🔐_🔒_🔓_and_emoji";
328
329        let encrypted = encryption.encrypt(state).expect("Encryption failed");
330        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
331
332        assert_eq!(decrypted, state);
333    }
334
335    #[test]
336    fn test_null_bytes_in_state() {
337        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
338        let state = "state_with\x00null\x00bytes\x00";
339
340        let encrypted = encryption.encrypt(state).expect("Encryption failed");
341        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
342
343        assert_eq!(decrypted, state);
344    }
345
346    #[test]
347    fn test_key_generation() {
348        let key1 = generate_state_encryption_key();
349        let key2 = generate_state_encryption_key();
350
351        // Keys should be different
352        assert_ne!(key1, key2);
353
354        // Both should be valid 32-byte keys
355        assert_eq!(key1.len(), 32);
356        assert_eq!(key2.len(), 32);
357
358        // Both should work
359        let enc1 = StateEncryption::new(&key1).expect("Init 1 failed");
360        let enc2 = StateEncryption::new(&key2).expect("Init 2 failed");
361
362        let state = "test";
363        let encrypted1 = enc1.encrypt(state).expect("Encryption 1 failed");
364        let encrypted2 = enc2.encrypt(state).expect("Encryption 2 failed");
365
366        assert_eq!(enc1.decrypt(&encrypted1).expect("Decryption 1 failed"), state);
367        assert_eq!(enc2.decrypt(&encrypted2).expect("Decryption 2 failed"), state);
368    }
369
370    #[test]
371    fn test_large_ciphertext() {
372        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
373        let state = "x".repeat(100_000);
374
375        let encrypted = encryption.encrypt(&state).expect("Encryption failed");
376        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
377
378        assert_eq!(decrypted, state);
379    }
380}