1use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM};
42use ring::hkdf::{Salt, HKDF_SHA256};
43use ring::rand::{SecureRandom, SystemRandom};
44use secrecy::{ExposeSecret, SecretBox};
45use serde::{Deserialize, Serialize};
46use std::fmt;
47use std::fs;
48use std::io;
49use std::path::PathBuf;
50use std::sync::atomic::{AtomicU32, Ordering};
51use std::sync::Arc;
52use thiserror::Error;
53
54const NONCE_SIZE: usize = 12;
56
57const TAG_SIZE: usize = 16;
59
60const KEY_SIZE: usize = 32;
62
63const ENCRYPTION_MAGIC: [u8; 4] = [0x52, 0x56, 0x45, 0x4E]; const FORMAT_VERSION: u8 = 1;
68
69#[derive(Debug, Error)]
71pub enum EncryptionError {
72 #[error("key provider error: {0}")]
73 KeyProvider(String),
74
75 #[error("encryption failed: {0}")]
76 Encryption(String),
77
78 #[error("decryption failed: {0}")]
79 Decryption(String),
80
81 #[error("invalid key: {0}")]
82 InvalidKey(String),
83
84 #[error("key rotation error: {0}")]
85 KeyRotation(String),
86
87 #[error("io error: {0}")]
88 Io(#[from] io::Error),
89
90 #[error("invalid format: {0}")]
91 InvalidFormat(String),
92
93 #[error("unsupported version: {0}")]
94 UnsupportedVersion(u8),
95
96 #[error("key not found: version {0}")]
97 KeyNotFound(u32),
98}
99
100pub type Result<T> = std::result::Result<T, EncryptionError>;
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
105#[serde(rename_all = "kebab-case")]
106pub enum Algorithm {
107 #[default]
109 Aes256Gcm,
110 ChaCha20Poly1305,
112}
113
114impl Algorithm {
115 pub fn key_size(&self) -> usize {
117 match self {
118 Algorithm::Aes256Gcm => 32,
119 Algorithm::ChaCha20Poly1305 => 32,
120 }
121 }
122
123 pub fn nonce_size(&self) -> usize {
125 match self {
126 Algorithm::Aes256Gcm => 12,
127 Algorithm::ChaCha20Poly1305 => 12,
128 }
129 }
130
131 pub fn tag_size(&self) -> usize {
133 match self {
134 Algorithm::Aes256Gcm => 16,
135 Algorithm::ChaCha20Poly1305 => 16,
136 }
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(tag = "type", rename_all = "kebab-case")]
143pub enum KeyProvider {
144 File { path: PathBuf },
146
147 Environment { variable: String },
149
150 #[serde(skip)]
154 InMemory(#[serde(skip)] Vec<u8>),
155}
156
157impl Default for KeyProvider {
158 fn default() -> Self {
159 KeyProvider::Environment {
160 variable: "RIVVEN_ENCRYPTION_KEY".to_string(),
161 }
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct EncryptionConfig {
168 #[serde(default)]
170 pub enabled: bool,
171
172 #[serde(default)]
174 pub algorithm: Algorithm,
175
176 #[serde(default)]
178 pub key_provider: KeyProvider,
179
180 #[serde(default)]
182 pub key_rotation_days: u32,
183
184 #[serde(default = "default_aad_scope")]
186 pub aad_scope: String,
187}
188
189fn default_aad_scope() -> String {
190 "rivven".to_string()
191}
192
193impl Default for EncryptionConfig {
194 fn default() -> Self {
195 Self {
196 enabled: false,
197 algorithm: Algorithm::default(),
198 key_provider: KeyProvider::default(),
199 key_rotation_days: 0,
200 aad_scope: default_aad_scope(),
201 }
202 }
203}
204
205impl EncryptionConfig {
206 pub fn new() -> Self {
208 Self::default()
209 }
210
211 pub fn enabled(mut self) -> Self {
213 self.enabled = true;
214 self
215 }
216
217 pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
219 self.algorithm = algorithm;
220 self
221 }
222
223 pub fn with_key_provider(mut self, provider: KeyProvider) -> Self {
225 self.key_provider = provider;
226 self
227 }
228
229 pub fn with_key_rotation_days(mut self, days: u32) -> Self {
231 self.key_rotation_days = days;
232 self
233 }
234}
235
236#[derive(Debug, Clone)]
247pub struct EncryptedHeader {
248 pub version: u8,
249 pub algorithm: Algorithm,
250 pub key_version: u32,
251 pub nonce: [u8; NONCE_SIZE],
252}
253
254impl EncryptedHeader {
255 pub const SIZE: usize = 24;
257
258 pub fn new(algorithm: Algorithm, key_version: u32, nonce: [u8; NONCE_SIZE]) -> Self {
260 Self {
261 version: FORMAT_VERSION,
262 algorithm,
263 key_version,
264 nonce,
265 }
266 }
267
268 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
270 let mut buf = [0u8; Self::SIZE];
271 buf[0..4].copy_from_slice(&ENCRYPTION_MAGIC);
272 buf[4] = self.version;
273 buf[5] = self.algorithm as u8;
274 buf[6..10].copy_from_slice(&self.key_version.to_be_bytes());
275 buf[10..22].copy_from_slice(&self.nonce);
276 buf
278 }
279
280 pub fn from_bytes(data: &[u8]) -> Result<Self> {
282 if data.len() < Self::SIZE {
283 return Err(EncryptionError::InvalidFormat(format!(
284 "header too short: {} < {}",
285 data.len(),
286 Self::SIZE
287 )));
288 }
289
290 if data[0..4] != ENCRYPTION_MAGIC {
292 return Err(EncryptionError::InvalidFormat("invalid magic bytes".into()));
293 }
294
295 let version = data[4];
296 if version != FORMAT_VERSION {
297 return Err(EncryptionError::UnsupportedVersion(version));
298 }
299
300 let algorithm = match data[5] {
301 0 => Algorithm::Aes256Gcm,
302 1 => Algorithm::ChaCha20Poly1305,
303 v => {
304 return Err(EncryptionError::InvalidFormat(format!(
305 "unknown algorithm: {}",
306 v
307 )))
308 }
309 };
310
311 let key_version = u32::from_be_bytes([data[6], data[7], data[8], data[9]]);
312 let mut nonce = [0u8; NONCE_SIZE];
313 nonce.copy_from_slice(&data[10..22]);
314
315 Ok(Self {
316 version,
317 algorithm,
318 key_version,
319 nonce,
320 })
321 }
322}
323
324pub struct MasterKey {
326 key: SecretBox<[u8; KEY_SIZE]>,
327 version: u32,
328}
329
330impl MasterKey {
331 pub fn new(key: Vec<u8>, version: u32) -> Result<Self> {
333 if key.len() != KEY_SIZE {
334 return Err(EncryptionError::InvalidKey(format!(
335 "key must be {} bytes, got {}",
336 KEY_SIZE,
337 key.len()
338 )));
339 }
340 let mut key_array = [0u8; KEY_SIZE];
341 key_array.copy_from_slice(&key);
342 Ok(Self {
343 key: SecretBox::new(Box::new(key_array)),
344 version,
345 })
346 }
347
348 pub fn generate(version: u32) -> Result<Self> {
350 let rng = SystemRandom::new();
351 let mut key = vec![0u8; KEY_SIZE];
352 rng.fill(&mut key)
353 .map_err(|_| EncryptionError::KeyProvider("failed to generate random key".into()))?;
354 Self::new(key, version)
355 }
356
357 pub fn from_provider(provider: &KeyProvider) -> Result<Self> {
359 match provider {
360 KeyProvider::File { path } => {
361 let data = fs::read(path)?;
362 let key = if data.len() == KEY_SIZE {
363 data
365 } else {
366 let hex_str = String::from_utf8(data)
368 .map_err(|_| EncryptionError::InvalidKey("invalid key file format".into()))?
369 .trim()
370 .to_string();
371 hex::decode(&hex_str).map_err(|e| {
372 EncryptionError::InvalidKey(format!("invalid hex key: {}", e))
373 })?
374 };
375 Self::new(key, 1)
376 }
377 KeyProvider::Environment { variable } => {
378 let hex_key = std::env::var(variable).map_err(|_| {
379 EncryptionError::KeyProvider(format!(
380 "environment variable '{}' not set",
381 variable
382 ))
383 })?;
384 let key = hex::decode(hex_key.trim()).map_err(|e| {
385 EncryptionError::InvalidKey(format!("invalid hex key in env var: {}", e))
386 })?;
387 Self::new(key, 1)
388 }
389 KeyProvider::InMemory(key) => Self::new(key.clone(), 1),
390 #[allow(unreachable_patterns)]
391 _ => Err(EncryptionError::KeyProvider(
392 "unsupported key provider".into(),
393 )),
394 }
395 }
396
397 pub fn version(&self) -> u32 {
399 self.version
400 }
401
402 fn derive_data_key(&self, info: &[u8]) -> Result<[u8; KEY_SIZE]> {
404 let salt = Salt::new(HKDF_SHA256, b"rivven-encryption-v1");
405 let prk = salt.extract(self.key.expose_secret());
406 let info_refs = [info];
407 let okm = prk
408 .expand(&info_refs, DataKeyLen)
409 .map_err(|_| EncryptionError::Encryption("key derivation failed".into()))?;
410
411 let mut data_key = [0u8; KEY_SIZE];
412 okm.fill(&mut data_key)
413 .map_err(|_| EncryptionError::Encryption("key expansion failed".into()))?;
414 Ok(data_key)
415 }
416}
417
418impl fmt::Debug for MasterKey {
419 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420 f.debug_struct("MasterKey")
421 .field("version", &self.version)
422 .field("key", &"[REDACTED]")
423 .finish()
424 }
425}
426
427struct DataKeyLen;
429
430impl ring::hkdf::KeyType for DataKeyLen {
431 fn len(&self) -> usize {
432 KEY_SIZE
433 }
434}
435
436pub struct EncryptionManager {
438 config: EncryptionConfig,
439 master_key: MasterKey,
440 data_key: LessSafeKey,
441 rng: SystemRandom,
442 current_key_version: AtomicU32,
443}
444
445impl EncryptionManager {
446 pub fn new(config: EncryptionConfig) -> Result<Arc<Self>> {
448 let master_key = MasterKey::from_provider(&config.key_provider)?;
449 let data_key_bytes = master_key.derive_data_key(config.aad_scope.as_bytes())?;
450
451 let unbound_key = UnboundKey::new(&AES_256_GCM, &data_key_bytes)
452 .map_err(|_| EncryptionError::InvalidKey("failed to create encryption key".into()))?;
453
454 let data_key = LessSafeKey::new(unbound_key);
455
456 Ok(Arc::new(Self {
457 config,
458 current_key_version: AtomicU32::new(master_key.version()),
459 master_key,
460 data_key,
461 rng: SystemRandom::new(),
462 }))
463 }
464
465 pub fn disabled() -> Arc<DisabledEncryption> {
467 Arc::new(DisabledEncryption)
468 }
469
470 pub fn is_enabled(&self) -> bool {
472 self.config.enabled
473 }
474
475 pub fn key_version(&self) -> u32 {
477 self.current_key_version.load(Ordering::Relaxed)
478 }
479
480 fn generate_nonce(&self, lsn: u64) -> [u8; NONCE_SIZE] {
482 let mut nonce = [0u8; NONCE_SIZE];
483 nonce[0..8].copy_from_slice(&lsn.to_be_bytes());
485 self.rng.fill(&mut nonce[8..12]).ok();
486 nonce
487 }
488
489 pub fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
491 let nonce_bytes = self.generate_nonce(lsn);
492 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
493
494 let header = EncryptedHeader::new(
495 self.config.algorithm,
496 self.current_key_version.load(Ordering::Relaxed),
497 nonce_bytes,
498 );
499
500 let mut output = Vec::with_capacity(EncryptedHeader::SIZE + plaintext.len() + TAG_SIZE);
502 output.extend_from_slice(&header.to_bytes());
503 output.extend_from_slice(plaintext);
504
505 let ciphertext_start = EncryptedHeader::SIZE;
507 let tag = self
508 .data_key
509 .seal_in_place_separate_tag(
510 nonce,
511 Aad::from(self.config.aad_scope.as_bytes()),
512 &mut output[ciphertext_start..],
513 )
514 .map_err(|_| EncryptionError::Encryption("seal failed".into()))?;
515
516 output.extend_from_slice(tag.as_ref());
518
519 Ok(output)
520 }
521
522 pub fn decrypt(&self, ciphertext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
524 if ciphertext.len() < EncryptedHeader::SIZE + TAG_SIZE {
525 return Err(EncryptionError::InvalidFormat(
526 "ciphertext too short".into(),
527 ));
528 }
529
530 let header = EncryptedHeader::from_bytes(ciphertext)?;
531
532 if header.key_version != self.master_key.version() {
534 return Err(EncryptionError::KeyNotFound(header.key_version));
535 }
536
537 let nonce = Nonce::assume_unique_for_key(header.nonce);
538
539 let mut buffer = ciphertext[EncryptedHeader::SIZE..].to_vec();
541
542 let plaintext = self
543 .data_key
544 .open_in_place(
545 nonce,
546 Aad::from(self.config.aad_scope.as_bytes()),
547 &mut buffer,
548 )
549 .map_err(|_| EncryptionError::Decryption("authentication failed".into()))?;
550
551 Ok(plaintext.to_vec())
552 }
553
554 pub fn encrypted_size(&self, plaintext_len: usize) -> usize {
556 EncryptedHeader::SIZE + plaintext_len + TAG_SIZE
557 }
558}
559
560impl fmt::Debug for EncryptionManager {
561 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
562 f.debug_struct("EncryptionManager")
563 .field("enabled", &self.config.enabled)
564 .field("algorithm", &self.config.algorithm)
565 .field("key_version", &self.key_version())
566 .finish()
567 }
568}
569
570#[derive(Debug)]
572pub struct DisabledEncryption;
573
574impl DisabledEncryption {
575 pub fn encrypt(&self, plaintext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
577 Ok(plaintext.to_vec())
578 }
579
580 pub fn decrypt(&self, ciphertext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
582 Ok(ciphertext.to_vec())
583 }
584
585 pub fn encrypted_size(&self, plaintext_len: usize) -> usize {
587 plaintext_len
588 }
589
590 pub fn is_enabled(&self) -> bool {
591 false
592 }
593}
594
595pub trait Encryptor: Send + Sync + std::fmt::Debug {
597 fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>>;
598 fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>>;
599 fn encrypted_size(&self, plaintext_len: usize) -> usize;
600 fn is_enabled(&self) -> bool;
601}
602
603impl Encryptor for EncryptionManager {
604 fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
605 self.encrypt(plaintext, lsn)
606 }
607
608 fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>> {
609 self.decrypt(ciphertext, lsn)
610 }
611
612 fn encrypted_size(&self, plaintext_len: usize) -> usize {
613 self.encrypted_size(plaintext_len)
614 }
615
616 fn is_enabled(&self) -> bool {
617 self.is_enabled()
618 }
619}
620
621impl Encryptor for DisabledEncryption {
622 fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
623 self.encrypt(plaintext, lsn)
624 }
625
626 fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>> {
627 self.decrypt(ciphertext, lsn)
628 }
629
630 fn encrypted_size(&self, plaintext_len: usize) -> usize {
631 self.encrypted_size(plaintext_len)
632 }
633
634 fn is_enabled(&self) -> bool {
635 false
636 }
637}
638
639pub fn generate_key_file(path: &std::path::Path) -> Result<()> {
641 let key = MasterKey::generate(1)?;
642 let hex_key = hex::encode(key.key.expose_secret());
643 fs::write(path, hex_key)?;
644
645 #[cfg(unix)]
647 {
648 use std::os::unix::fs::PermissionsExt;
649 let mut perms = fs::metadata(path)?.permissions();
650 perms.set_mode(0o600);
651 fs::set_permissions(path, perms)?;
652 }
653
654 Ok(())
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 fn test_config() -> EncryptionConfig {
662 let key = vec![0u8; 32]; EncryptionConfig {
664 enabled: true,
665 algorithm: Algorithm::Aes256Gcm,
666 key_provider: KeyProvider::InMemory(key),
667 key_rotation_days: 0,
668 aad_scope: "test".to_string(),
669 }
670 }
671
672 #[test]
673 fn test_encrypt_decrypt() {
674 let manager = EncryptionManager::new(test_config()).unwrap();
675 let plaintext = b"Hello, World! This is sensitive data.";
676 let lsn = 12345u64;
677
678 let ciphertext = manager.encrypt(plaintext, lsn).unwrap();
679 assert_ne!(ciphertext.as_slice(), plaintext);
680 assert!(ciphertext.len() > plaintext.len());
681
682 let decrypted = manager.decrypt(&ciphertext, lsn).unwrap();
683 assert_eq!(decrypted, plaintext);
684 }
685
686 #[test]
687 fn test_encrypted_size() {
688 let manager = EncryptionManager::new(test_config()).unwrap();
689 let plaintext_len = 1000;
690
691 let expected = EncryptedHeader::SIZE + plaintext_len + TAG_SIZE;
692 assert_eq!(manager.encrypted_size(plaintext_len), expected);
693 }
694
695 #[test]
696 fn test_header_roundtrip() {
697 let nonce = [1u8; NONCE_SIZE];
698 let header = EncryptedHeader::new(Algorithm::Aes256Gcm, 42, nonce);
699
700 let bytes = header.to_bytes();
701 let parsed = EncryptedHeader::from_bytes(&bytes).unwrap();
702
703 assert_eq!(parsed.version, header.version);
704 assert_eq!(parsed.algorithm, header.algorithm);
705 assert_eq!(parsed.key_version, header.key_version);
706 assert_eq!(parsed.nonce, header.nonce);
707 }
708
709 #[test]
710 fn test_invalid_ciphertext() {
711 let manager = EncryptionManager::new(test_config()).unwrap();
712
713 let result = manager.decrypt(&[0u8; 10], 1);
715 assert!(result.is_err());
716
717 let mut bad_magic = vec![0u8; 100];
719 let result = manager.decrypt(&bad_magic, 1);
720 assert!(result.is_err());
721
722 bad_magic[0..4].copy_from_slice(&ENCRYPTION_MAGIC);
724 bad_magic[4] = FORMAT_VERSION;
725 let result = manager.decrypt(&bad_magic, 1);
726 assert!(result.is_err());
727 }
728
729 #[test]
730 fn test_tamper_detection() {
731 let manager = EncryptionManager::new(test_config()).unwrap();
732 let plaintext = b"Sensitive data that must not be tampered with";
733
734 let mut ciphertext = manager.encrypt(plaintext, 1).unwrap();
735
736 let tamper_pos = EncryptedHeader::SIZE + 10;
738 ciphertext[tamper_pos] ^= 0x01;
739
740 let result = manager.decrypt(&ciphertext, 1);
742 assert!(result.is_err());
743 }
744
745 #[test]
746 fn test_different_lsns_produce_different_ciphertexts() {
747 let manager = EncryptionManager::new(test_config()).unwrap();
748 let plaintext = b"Same plaintext";
749
750 let ct1 = manager.encrypt(plaintext, 1).unwrap();
751 let ct2 = manager.encrypt(plaintext, 2).unwrap();
752
753 assert_ne!(ct1, ct2);
755
756 assert_eq!(manager.decrypt(&ct1, 1).unwrap(), plaintext);
758 assert_eq!(manager.decrypt(&ct2, 2).unwrap(), plaintext);
759 }
760
761 #[test]
762 fn test_disabled_encryption_passthrough() {
763 let disabled = DisabledEncryption;
764 let plaintext = b"Not encrypted";
765
766 let encrypted = disabled.encrypt(plaintext, 1).unwrap();
767 assert_eq!(&encrypted[..], plaintext);
768
769 let decrypted = disabled.decrypt(plaintext, 1).unwrap();
770 assert_eq!(&decrypted[..], plaintext);
771
772 assert_eq!(disabled.encrypted_size(100), 100);
773 assert!(!disabled.is_enabled());
774 }
775
776 #[test]
777 fn test_master_key_validation() {
778 let result = MasterKey::new(vec![0u8; 16], 1);
780 assert!(result.is_err());
781
782 let result = MasterKey::new(vec![0u8; 32], 1);
784 assert!(result.is_ok());
785 }
786
787 #[test]
788 fn test_key_derivation_consistency() {
789 let key = MasterKey::new(vec![42u8; 32], 1).unwrap();
790
791 let dk1 = key.derive_data_key(b"scope1").unwrap();
792 let dk2 = key.derive_data_key(b"scope1").unwrap();
793 let dk3 = key.derive_data_key(b"scope2").unwrap();
794
795 assert_eq!(dk1, dk2);
797
798 assert_ne!(dk1, dk3);
800 }
801
802 #[test]
803 fn test_large_data_encryption() {
804 let manager = EncryptionManager::new(test_config()).unwrap();
805
806 let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
808
809 let ciphertext = manager.encrypt(&plaintext, 999999).unwrap();
810 let decrypted = manager.decrypt(&ciphertext, 999999).unwrap();
811
812 assert_eq!(decrypted, plaintext);
813 }
814
815 #[test]
816 fn test_generate_key() {
817 let key = MasterKey::generate(1).unwrap();
818 assert_eq!(key.version(), 1);
819 }
820}