fraiseql_server/auth/
state_encryption.rs1use chacha20poly1305::{
5 ChaCha20Poly1305, Nonce,
6 aead::{Aead, KeyInit, Payload},
7};
8use rand::RngCore;
9
10use crate::auth::{AuthError, error::Result};
11
12#[derive(Debug, Clone)]
14pub struct EncryptedState {
15 pub ciphertext: Vec<u8>,
17 pub nonce: [u8; 12],
19}
20
21impl EncryptedState {
22 pub fn new(ciphertext: Vec<u8>, nonce: [u8; 12]) -> Self {
24 Self { ciphertext, nonce }
25 }
26
27 pub fn to_bytes(&self) -> Vec<u8> {
30 let mut bytes = Vec::with_capacity(12 + self.ciphertext.len());
31 bytes.extend_from_slice(&self.nonce);
32 bytes.extend_from_slice(&self.ciphertext);
33 bytes
34 }
35
36 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
38 if bytes.len() < 12 {
39 return Err(AuthError::InvalidState);
40 }
41
42 let mut nonce = [0u8; 12];
43 nonce.copy_from_slice(&bytes[0..12]);
44 let ciphertext = bytes[12..].to_vec();
45
46 Ok(Self::new(ciphertext, nonce))
47 }
48}
49
50pub struct StateEncryption {
62 cipher: ChaCha20Poly1305,
63}
64
65impl StateEncryption {
66 pub fn new(key_bytes: &[u8; 32]) -> Result<Self> {
74 let cipher =
75 ChaCha20Poly1305::new_from_slice(key_bytes).map_err(|_| AuthError::ConfigError {
76 message: "Invalid state encryption key".to_string(),
77 })?;
78
79 Ok(Self { cipher })
80 }
81
82 pub fn encrypt(&self, state: &str) -> Result<EncryptedState> {
96 let mut nonce_bytes = [0u8; 12];
98 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
99 let nonce = Nonce::from(nonce_bytes);
100
101 let ciphertext =
103 self.cipher.encrypt(&nonce, Payload::from(state.as_bytes())).map_err(|_| {
104 AuthError::Internal {
105 message: "State encryption failed".to_string(),
106 }
107 })?;
108
109 Ok(EncryptedState::new(ciphertext, nonce_bytes))
110 }
111
112 pub fn decrypt(&self, encrypted: &EncryptedState) -> Result<String> {
129 let nonce = Nonce::from(encrypted.nonce);
130
131 let plaintext = self
133 .cipher
134 .decrypt(&nonce, Payload::from(encrypted.ciphertext.as_slice()))
135 .map_err(|_| AuthError::InvalidState)?;
136
137 String::from_utf8(plaintext).map_err(|_| AuthError::InvalidState)
139 }
140
141 pub fn encrypt_to_bytes(&self, state: &str) -> Result<Vec<u8>> {
143 let encrypted = self.encrypt(state)?;
144 Ok(encrypted.to_bytes())
145 }
146
147 pub fn decrypt_from_bytes(&self, bytes: &[u8]) -> Result<String> {
149 let encrypted = EncryptedState::from_bytes(bytes)?;
150 self.decrypt(&encrypted)
151 }
152}
153
154pub fn generate_state_encryption_key() -> [u8; 32] {
156 let mut key = [0u8; 32];
157 rand::rngs::OsRng.fill_bytes(&mut key);
158 key
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 fn test_key() -> [u8; 32] {
166 [42u8; 32]
168 }
169
170 #[test]
171 fn test_encrypt_decrypt() {
172 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
173 let state = "oauth_state_test_value";
174
175 let encrypted = encryption.encrypt(state).expect("Encryption failed");
176 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
177
178 assert_eq!(decrypted, state);
179 }
180
181 #[test]
182 fn test_encrypt_produces_ciphertext() {
183 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
184 let state = "test_state";
185
186 let encrypted = encryption.encrypt(state).expect("Encryption failed");
187
188 assert_ne!(encrypted.ciphertext, state.as_bytes());
191 }
192
193 #[test]
194 fn test_empty_state() {
195 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
196 let state = "";
197
198 let encrypted = encryption.encrypt(state).expect("Encryption failed");
199 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
200
201 assert_eq!(decrypted, state);
202 }
203
204 #[test]
205 fn test_different_keys_fail_decryption() {
206 let key1 = [42u8; 32];
207 let key2 = [99u8; 32];
208 let state = "secret_state";
209
210 let encryption1 = StateEncryption::new(&key1).expect("Init 1 failed");
211 let encrypted = encryption1.encrypt(state).expect("Encryption failed");
212
213 let encryption2 = StateEncryption::new(&key2).expect("Init 2 failed");
214 let result = encryption2.decrypt(&encrypted);
215
216 assert!(result.is_err());
218 }
219
220 #[test]
221 fn test_tampered_ciphertext_fails() {
222 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
223 let state = "tamper_test";
224
225 let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
226
227 if !encrypted.ciphertext.is_empty() {
229 encrypted.ciphertext[0] ^= 0xFF;
230 }
231
232 let result = encryption.decrypt(&encrypted);
234 assert!(result.is_err());
235 }
236
237 #[test]
238 fn test_tampered_nonce_fails() {
239 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
240 let state = "nonce_tamper";
241
242 let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
243
244 encrypted.nonce[0] ^= 0xFF;
246
247 let result = encryption.decrypt(&encrypted);
249 assert!(result.is_err());
250 }
251
252 #[test]
253 fn test_truncated_ciphertext_fails() {
254 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
255 let state = "truncation_test";
256
257 let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
258
259 if encrypted.ciphertext.len() > 1 {
261 encrypted.ciphertext.truncate(encrypted.ciphertext.len() - 1);
262 }
263
264 let result = encryption.decrypt(&encrypted);
266 assert!(result.is_err());
267 }
268
269 #[test]
270 fn test_serialization() {
271 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
272 let state = "serialization_test";
273
274 let bytes = encryption.encrypt_to_bytes(state).expect("Encryption failed");
276
277 let decrypted = encryption.decrypt_from_bytes(&bytes).expect("Decryption failed");
279
280 assert_eq!(decrypted, state);
281 }
282
283 #[test]
284 fn test_random_nonces() {
285 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
286 let state = "random_nonce_test";
287
288 let encrypted1 = encryption.encrypt(state).expect("Encryption 1 failed");
289 let encrypted2 = encryption.encrypt(state).expect("Encryption 2 failed");
290
291 assert_ne!(encrypted1.nonce, encrypted2.nonce);
293
294 let decrypted1 = encryption.decrypt(&encrypted1).expect("Decryption 1 failed");
296 let decrypted2 = encryption.decrypt(&encrypted2).expect("Decryption 2 failed");
297
298 assert_eq!(decrypted1, state);
299 assert_eq!(decrypted2, state);
300 }
301
302 #[test]
303 fn test_long_state() {
304 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
305 let state = "a".repeat(10_000);
306
307 let encrypted = encryption.encrypt(&state).expect("Encryption failed");
308 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
309
310 assert_eq!(decrypted, state);
311 }
312
313 #[test]
314 fn test_special_characters() {
315 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
316 let state = "state:with-special_chars.and/symbols!@#$%^&*()";
317
318 let encrypted = encryption.encrypt(state).expect("Encryption failed");
319 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
320
321 assert_eq!(decrypted, state);
322 }
323
324 #[test]
325 fn test_unicode_state() {
326 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
327 let state = "state_with_emoji_🔐_🔒_🔓_and_emoji";
328
329 let encrypted = encryption.encrypt(state).expect("Encryption failed");
330 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
331
332 assert_eq!(decrypted, state);
333 }
334
335 #[test]
336 fn test_null_bytes_in_state() {
337 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
338 let state = "state_with\x00null\x00bytes\x00";
339
340 let encrypted = encryption.encrypt(state).expect("Encryption failed");
341 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
342
343 assert_eq!(decrypted, state);
344 }
345
346 #[test]
347 fn test_key_generation() {
348 let key1 = generate_state_encryption_key();
349 let key2 = generate_state_encryption_key();
350
351 assert_ne!(key1, key2);
353
354 assert_eq!(key1.len(), 32);
356 assert_eq!(key2.len(), 32);
357
358 let enc1 = StateEncryption::new(&key1).expect("Init 1 failed");
360 let enc2 = StateEncryption::new(&key2).expect("Init 2 failed");
361
362 let state = "test";
363 let encrypted1 = enc1.encrypt(state).expect("Encryption 1 failed");
364 let encrypted2 = enc2.encrypt(state).expect("Encryption 2 failed");
365
366 assert_eq!(enc1.decrypt(&encrypted1).expect("Decryption 1 failed"), state);
367 assert_eq!(enc2.decrypt(&encrypted2).expect("Decryption 2 failed"), state);
368 }
369
370 #[test]
371 fn test_large_ciphertext() {
372 let encryption = StateEncryption::new(&test_key()).expect("Init failed");
373 let state = "x".repeat(100_000);
374
375 let encrypted = encryption.encrypt(&state).expect("Encryption failed");
376 let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
377
378 assert_eq!(decrypted, state);
379 }
380}