dbx_core/storage/encryption/
config.rs1use aes_gcm_siv::Aes256GcmSiv;
59use aes_gcm_siv::aead::generic_array::GenericArray;
60use aes_gcm_siv::aead::{Aead, KeyInit};
61use chacha20poly1305::ChaCha20Poly1305;
62use hkdf::Hkdf;
63use rand::RngCore;
64use sha2::Sha256;
65
66use crate::error::{DbxError, DbxResult};
67
68const NONCE_SIZE: usize = 12;
70
71const KEY_SIZE: usize = 32;
73
74const HKDF_INFO: &[u8] = b"dbx-encryption-v1";
76
77const HKDF_SALT: &[u8] = b"dbx-default-salt-v1";
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
90pub enum EncryptionAlgorithm {
91 #[default]
97 Aes256GcmSiv,
98
99 ChaCha20Poly1305,
105}
106
107impl std::fmt::Display for EncryptionAlgorithm {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 Self::Aes256GcmSiv => write!(f, "AES-256-GCM-SIV"),
111 Self::ChaCha20Poly1305 => write!(f, "ChaCha20-Poly1305"),
112 }
113 }
114}
115
116impl EncryptionAlgorithm {
117 pub const ALL: &'static [EncryptionAlgorithm] = &[
119 EncryptionAlgorithm::Aes256GcmSiv,
120 EncryptionAlgorithm::ChaCha20Poly1305,
121 ];
122}
123
124#[derive(Clone)]
159pub struct EncryptionConfig {
160 algorithm: EncryptionAlgorithm,
162 key: [u8; KEY_SIZE],
164}
165
166impl std::fmt::Debug for EncryptionConfig {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("EncryptionConfig")
169 .field("algorithm", &self.algorithm)
170 .field("key", &"[REDACTED]")
171 .finish()
172 }
173}
174
175impl EncryptionConfig {
176 pub fn from_password(password: &str) -> Self {
183 Self::from_password_with_algorithm(password, EncryptionAlgorithm::default())
184 }
185
186 pub fn from_password_with_algorithm(password: &str, algorithm: EncryptionAlgorithm) -> Self {
188 let key = Self::derive_key(password.as_bytes());
189 Self { algorithm, key }
190 }
191
192 pub fn from_key(key: [u8; KEY_SIZE]) -> Self {
200 Self {
201 algorithm: EncryptionAlgorithm::default(),
202 key,
203 }
204 }
205
206 pub fn from_key_with_algorithm(key: [u8; KEY_SIZE], algorithm: EncryptionAlgorithm) -> Self {
208 Self { algorithm, key }
209 }
210
211 pub fn with_algorithm(mut self, algorithm: EncryptionAlgorithm) -> Self {
213 self.algorithm = algorithm;
214 self
215 }
216
217 pub fn algorithm(&self) -> EncryptionAlgorithm {
221 self.algorithm
222 }
223
224 pub fn encrypt(&self, plaintext: &[u8]) -> DbxResult<Vec<u8>> {
233 let mut nonce_bytes = [0u8; NONCE_SIZE];
234 rand::thread_rng().fill_bytes(&mut nonce_bytes);
235 let nonce = GenericArray::from_slice(&nonce_bytes);
236
237 let ciphertext = match self.algorithm {
238 EncryptionAlgorithm::Aes256GcmSiv => {
239 let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
240 cipher.encrypt(nonce, plaintext).map_err(|e| {
241 DbxError::Encryption(format!("AES-GCM-SIV encrypt failed: {}", e))
242 })?
243 }
244 EncryptionAlgorithm::ChaCha20Poly1305 => {
245 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
246 cipher
247 .encrypt(nonce, plaintext)
248 .map_err(|e| DbxError::Encryption(format!("ChaCha20 encrypt failed: {}", e)))?
249 }
250 };
251
252 let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
254 output.extend_from_slice(&nonce_bytes);
255 output.extend_from_slice(&ciphertext);
256 Ok(output)
257 }
258
259 pub fn decrypt(&self, encrypted: &[u8]) -> DbxResult<Vec<u8>> {
269 if encrypted.len() < NONCE_SIZE {
270 return Err(DbxError::Encryption(
271 "encrypted data too short (missing nonce)".to_string(),
272 ));
273 }
274
275 let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
276 let nonce = GenericArray::from_slice(nonce_bytes);
277
278 match self.algorithm {
279 EncryptionAlgorithm::Aes256GcmSiv => {
280 let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
281 cipher
282 .decrypt(nonce, ciphertext)
283 .map_err(|e| DbxError::Encryption(format!("AES-GCM-SIV decrypt failed: {}", e)))
284 }
285 EncryptionAlgorithm::ChaCha20Poly1305 => {
286 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
287 cipher
288 .decrypt(nonce, ciphertext)
289 .map_err(|e| DbxError::Encryption(format!("ChaCha20 decrypt failed: {}", e)))
290 }
291 }
292 }
293
294 pub fn encrypt_with_aad(&self, plaintext: &[u8], aad: &[u8]) -> DbxResult<Vec<u8>> {
299 use aes_gcm_siv::aead::Payload;
300
301 let mut nonce_bytes = [0u8; NONCE_SIZE];
302 rand::thread_rng().fill_bytes(&mut nonce_bytes);
303 let nonce = GenericArray::from_slice(&nonce_bytes);
304
305 let payload = Payload {
306 msg: plaintext,
307 aad,
308 };
309
310 let ciphertext = match self.algorithm {
311 EncryptionAlgorithm::Aes256GcmSiv => {
312 let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
313 cipher.encrypt(nonce, payload).map_err(|e| {
314 DbxError::Encryption(format!("AES-GCM-SIV encrypt+AAD failed: {}", e))
315 })?
316 }
317 EncryptionAlgorithm::ChaCha20Poly1305 => {
318 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
319 cipher.encrypt(nonce, payload).map_err(|e| {
320 DbxError::Encryption(format!("ChaCha20 encrypt+AAD failed: {}", e))
321 })?
322 }
323 };
324
325 let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
326 output.extend_from_slice(&nonce_bytes);
327 output.extend_from_slice(&ciphertext);
328 Ok(output)
329 }
330
331 pub fn decrypt_with_aad(&self, encrypted: &[u8], aad: &[u8]) -> DbxResult<Vec<u8>> {
336 use aes_gcm_siv::aead::Payload;
337
338 if encrypted.len() < NONCE_SIZE {
339 return Err(DbxError::Encryption(
340 "encrypted data too short (missing nonce)".to_string(),
341 ));
342 }
343
344 let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
345 let nonce = GenericArray::from_slice(nonce_bytes);
346
347 let payload = Payload {
348 msg: ciphertext,
349 aad,
350 };
351
352 match self.algorithm {
353 EncryptionAlgorithm::Aes256GcmSiv => {
354 let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
355 cipher.decrypt(nonce, payload).map_err(|e| {
356 DbxError::Encryption(format!("AES-GCM-SIV decrypt+AAD failed: {}", e))
357 })
358 }
359 EncryptionAlgorithm::ChaCha20Poly1305 => {
360 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
361 cipher.decrypt(nonce, payload).map_err(|e| {
362 DbxError::Encryption(format!("ChaCha20 decrypt+AAD failed: {}", e))
363 })
364 }
365 }
366 }
367
368 fn derive_key(input: &[u8]) -> [u8; KEY_SIZE] {
372 let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), input);
373 let mut key = [0u8; KEY_SIZE];
374 hk.expand(HKDF_INFO, &mut key)
375 .expect("HKDF expand should never fail for 32-byte output");
376 key
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn default_algorithm_is_aes_gcm_siv() {
390 let config = EncryptionConfig::from_password("test");
391 assert_eq!(config.algorithm(), EncryptionAlgorithm::Aes256GcmSiv);
392 }
393
394 #[test]
395 fn round_trip_aes_gcm_siv() {
396 let config = EncryptionConfig::from_password("test-password");
397 let plaintext = b"Hello, DBX encryption!";
398
399 let encrypted = config.encrypt(plaintext).unwrap();
400 assert_ne!(encrypted, plaintext);
401 assert!(encrypted.len() > plaintext.len()); let decrypted = config.decrypt(&encrypted).unwrap();
404 assert_eq!(decrypted, plaintext);
405 }
406
407 #[test]
408 fn round_trip_chacha20() {
409 let config = EncryptionConfig::from_password("test-password")
410 .with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
411 let plaintext = b"Hello, ChaCha20!";
412
413 let encrypted = config.encrypt(plaintext).unwrap();
414 let decrypted = config.decrypt(&encrypted).unwrap();
415 assert_eq!(decrypted, plaintext);
416 }
417
418 #[test]
419 fn round_trip_all_algorithms() {
420 let plaintext = b"Testing all algorithms";
421 for algo in EncryptionAlgorithm::ALL {
422 let config = EncryptionConfig::from_password("pw").with_algorithm(*algo);
423 let encrypted = config.encrypt(plaintext).unwrap();
424 let decrypted = config.decrypt(&encrypted).unwrap();
425 assert_eq!(decrypted, plaintext, "Round-trip failed for {:?}", algo);
426 }
427 }
428
429 #[test]
430 fn from_raw_key() {
431 let key = [0xABu8; KEY_SIZE];
432 let config = EncryptionConfig::from_key(key);
433 let plaintext = b"raw key test";
434
435 let encrypted = config.encrypt(plaintext).unwrap();
436 let decrypted = config.decrypt(&encrypted).unwrap();
437 assert_eq!(decrypted, plaintext);
438 }
439
440 #[test]
441 fn wrong_password_fails() {
442 let config1 = EncryptionConfig::from_password("correct-password");
443 let config2 = EncryptionConfig::from_password("wrong-password");
444
445 let plaintext = b"secret data";
446 let encrypted = config1.encrypt(plaintext).unwrap();
447
448 let result = config2.decrypt(&encrypted);
449 assert!(result.is_err(), "Decryption with wrong key should fail");
450 }
451
452 #[test]
453 fn wrong_algorithm_fails() {
454 let config_aes = EncryptionConfig::from_password("same-password");
455 let config_chacha = EncryptionConfig::from_password("same-password")
456 .with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
457
458 let plaintext = b"algorithm mismatch test";
459 let encrypted = config_aes.encrypt(plaintext).unwrap();
460
461 let result = config_chacha.decrypt(&encrypted);
463 assert!(
464 result.is_err(),
465 "Decryption with wrong algorithm should fail"
466 );
467 }
468
469 #[test]
470 fn tampered_data_fails() {
471 let config = EncryptionConfig::from_password("test");
472 let plaintext = b"tamper test";
473 let mut encrypted = config.encrypt(plaintext).unwrap();
474
475 let last = encrypted.len() - 1;
477 encrypted[last] ^= 0xFF;
478
479 let result = config.decrypt(&encrypted);
480 assert!(result.is_err(), "Tampered data should fail authentication");
481 }
482
483 #[test]
484 fn too_short_data_fails() {
485 let config = EncryptionConfig::from_password("test");
486
487 let result = config.decrypt(&[0u8; 5]);
489 assert!(result.is_err());
490 }
491
492 #[test]
493 fn empty_plaintext() {
494 let config = EncryptionConfig::from_password("test");
495 let plaintext = b"";
496
497 let encrypted = config.encrypt(plaintext).unwrap();
498 let decrypted = config.decrypt(&encrypted).unwrap();
499 assert_eq!(decrypted, plaintext);
500 }
501
502 #[test]
503 fn large_data_round_trip() {
504 let config = EncryptionConfig::from_password("test");
505 let plaintext: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
506
507 let encrypted = config.encrypt(&plaintext).unwrap();
508 let decrypted = config.decrypt(&encrypted).unwrap();
509 assert_eq!(decrypted, plaintext);
510 }
511
512 #[test]
513 fn each_encrypt_produces_different_output() {
514 let config = EncryptionConfig::from_password("test");
515 let plaintext = b"same input";
516
517 let enc1 = config.encrypt(plaintext).unwrap();
518 let enc2 = config.encrypt(plaintext).unwrap();
519
520 assert_ne!(enc1, enc2, "Each encryption should use a fresh nonce");
522
523 assert_eq!(config.decrypt(&enc1).unwrap(), plaintext);
525 assert_eq!(config.decrypt(&enc2).unwrap(), plaintext);
526 }
527
528 #[test]
529 fn aad_round_trip() {
530 let config = EncryptionConfig::from_password("test");
531 let plaintext = b"sensitive data";
532 let aad = b"table:users,column:email";
533
534 let encrypted = config.encrypt_with_aad(plaintext, aad).unwrap();
535 let decrypted = config.decrypt_with_aad(&encrypted, aad).unwrap();
536 assert_eq!(decrypted, plaintext);
537 }
538
539 #[test]
540 fn aad_mismatch_fails() {
541 let config = EncryptionConfig::from_password("test");
542 let plaintext = b"sensitive data";
543 let aad = b"table:users";
544
545 let encrypted = config.encrypt_with_aad(plaintext, aad).unwrap();
546
547 let result = config.decrypt_with_aad(&encrypted, b"table:orders");
549 assert!(result.is_err(), "Wrong AAD should fail authentication");
550 }
551
552 #[test]
553 fn display_names() {
554 assert_eq!(
555 format!("{}", EncryptionAlgorithm::Aes256GcmSiv),
556 "AES-256-GCM-SIV"
557 );
558 assert_eq!(
559 format!("{}", EncryptionAlgorithm::ChaCha20Poly1305),
560 "ChaCha20-Poly1305"
561 );
562 }
563
564 #[test]
565 fn all_algorithms_count() {
566 assert_eq!(EncryptionAlgorithm::ALL.len(), 2);
567 }
568
569 #[test]
570 fn debug_redacts_key() {
571 let config = EncryptionConfig::from_password("secret");
572 let debug_str = format!("{:?}", config);
573 assert!(debug_str.contains("REDACTED"));
574 assert!(!debug_str.contains("secret"));
575 }
576
577 #[test]
578 fn wire_format_structure() {
579 let config = EncryptionConfig::from_password("test");
580 let plaintext = b"hello";
581
582 let encrypted = config.encrypt(plaintext).unwrap();
583
584 assert_eq!(
587 encrypted.len(),
588 NONCE_SIZE + plaintext.len() + 16, "Wire format should be nonce + plaintext + tag"
590 );
591 }
592}