1use std::sync::Arc;
28
29use aes_gcm::aead::{Aead, KeyInit};
30use aes_gcm::{Aes256Gcm, Key, Nonce};
31use hkdf::Hkdf;
32use hmac::{Hmac, Mac};
33use parking_lot::RwLock;
34use serde::{Deserialize, Serialize};
35use sha2::Sha256;
36
37use crate::error::{RaftError, RaftResult};
41use crate::key_rotation::{KeyManager, KeyVersion, LEGACY_KEY_VERSION};
42
43type HmacSha256 = Hmac<Sha256>;
44
45pub struct LogEncryptionKey {
51 key_bytes: [u8; 32],
52}
53
54impl LogEncryptionKey {
55 pub fn new(key_bytes: [u8; 32]) -> Self {
57 Self { key_bytes }
58 }
59
60 pub fn from_slice(bytes: &[u8]) -> RaftResult<Self> {
65 let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| RaftError::StorageError {
66 message: format!(
67 "LogEncryptionKey requires exactly 32 bytes, got {}",
68 bytes.len()
69 ),
70 })?;
71 Ok(Self { key_bytes })
72 }
73
74 pub(crate) fn as_bytes(&self) -> &[u8; 32] {
80 &self.key_bytes
81 }
82
83 pub fn random() -> Self {
89 use std::collections::hash_map::RandomState;
90 use std::hash::{BuildHasher, Hasher};
91 use std::time::{SystemTime, UNIX_EPOCH};
92
93 let ts_nanos: u128 = SystemTime::now()
94 .duration_since(UNIX_EPOCH)
95 .map(|d| d.as_nanos())
96 .unwrap_or(0u128);
97
98 let rs1 = RandomState::new();
100 let rs2 = RandomState::new();
101 let rs3 = RandomState::new();
102 let rs4 = RandomState::new();
103
104 let h1: u64 = {
105 let mut h = rs1.build_hasher();
106 h.write_u128(ts_nanos);
107 h.finish()
108 };
109 let h2: u64 = {
110 let mut h = rs2.build_hasher();
111 h.write_u128(ts_nanos ^ 0xcafe_babe_dead_beef_1234_5678_abcd_ef01_u128);
113 h.finish()
114 };
115 let h3: u64 = {
116 let mut h = rs3.build_hasher();
117 h.write_u64(h1);
118 h.write_u64(h2);
119 h.finish()
120 };
121 let h4: u64 = {
122 let mut h = rs4.build_hasher();
123 h.write_u64(h2 ^ h3);
124 h.write_u128(ts_nanos.wrapping_add(0x9e37_79b9_7f4a_7c15_f39c_c060_5c0e_d609_u128));
125 h.finish()
126 };
127
128 let mut ikm = [0u8; 32];
130 ikm[0..8].copy_from_slice(&h1.to_le_bytes());
131 ikm[8..16].copy_from_slice(&h2.to_le_bytes());
132 ikm[16..24].copy_from_slice(&h3.to_le_bytes());
133 ikm[24..32].copy_from_slice(&h4.to_le_bytes());
134
135 let salt = b"amaters-log-encryption-key-v1";
136 let hk = Hkdf::<Sha256>::new(Some(salt), &ikm);
137 let mut key_bytes = [0u8; 32];
138 hk.expand(b"master-key", &mut key_bytes)
140 .expect("HKDF expand for 32 bytes cannot fail");
141
142 Self { key_bytes }
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct EncryptedPayload {
162 pub ciphertext: Vec<u8>,
164 pub nonce: [u8; 12],
166 #[serde(default = "default_key_version")]
172 pub key_version: KeyVersion,
173}
174
175fn default_key_version() -> KeyVersion {
179 LEGACY_KEY_VERSION
180}
181
182pub struct EntryEncryptor {
197 keys: Arc<RwLock<KeyManager>>,
198}
199
200impl EntryEncryptor {
201 pub fn new(key: LogEncryptionKey) -> Self {
207 let mgr = KeyManager::new(key, 1);
208 Self {
209 keys: Arc::new(RwLock::new(mgr)),
210 }
211 }
212
213 pub fn with_key_manager(keys: Arc<RwLock<KeyManager>>) -> Self {
221 Self { keys }
222 }
223
224 pub fn key_manager(&self) -> &Arc<RwLock<KeyManager>> {
227 &self.keys
228 }
229
230 fn derive_key_and_nonce_from(
233 master_key: &LogEncryptionKey,
234 entry_index: u64,
235 ) -> RaftResult<([u8; 32], [u8; 12])> {
236 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
237 let mut derived = [0u8; 44]; hk.expand(&entry_index.to_le_bytes(), &mut derived)
239 .map_err(|e| RaftError::StorageError {
240 message: format!("HKDF expand failed for entry {entry_index}: {e}"),
241 })?;
242
243 let mut key = [0u8; 32];
244 let mut nonce = [0u8; 12];
245 key.copy_from_slice(&derived[..32]);
246 nonce.copy_from_slice(&derived[32..44]);
247 Ok((key, nonce))
248 }
249
250 pub fn encrypt(&self, entry_index: u64, plaintext: &[u8]) -> RaftResult<EncryptedPayload> {
259 let guard = self.keys.read();
260 let (key_version, master_key) = guard.current();
261 let (key_bytes, nonce_bytes) = Self::derive_key_and_nonce_from(master_key, entry_index)?;
262
263 let key = Key::<Aes256Gcm>::from(key_bytes);
264 let cipher = Aes256Gcm::new(&key);
265 let nonce = Nonce::from(nonce_bytes);
266
267 let ciphertext =
268 cipher
269 .encrypt(&nonce, plaintext)
270 .map_err(|e| RaftError::StorageError {
271 message: format!("AES-256-GCM encryption failed for entry {entry_index}: {e}"),
272 })?;
273
274 Ok(EncryptedPayload {
275 ciphertext,
276 nonce: nonce_bytes,
277 key_version,
278 })
279 }
280
281 pub fn decrypt(&self, entry_index: u64, payload: &EncryptedPayload) -> RaftResult<Vec<u8>> {
294 let guard = self.keys.read();
295 let master_key =
296 guard
297 .lookup(payload.key_version)
298 .ok_or_else(|| RaftError::StorageError {
299 message: format!(
300 "EntryEncryptor::decrypt: key version {} not in KeyManager (history exhausted or unknown)",
301 payload.key_version
302 ),
303 })?;
304 let (key_bytes, _derived_nonce) = Self::derive_key_and_nonce_from(master_key, entry_index)?;
305
306 let key = Key::<Aes256Gcm>::from(key_bytes);
307 let cipher = Aes256Gcm::new(&key);
308 let nonce = Nonce::from(payload.nonce);
309
310 cipher
311 .decrypt(&nonce, payload.ciphertext.as_ref())
312 .map_err(|e| RaftError::StorageError {
313 message: format!("AES-256-GCM decryption failed for entry {entry_index}: {e}"),
314 })
315 }
316}
317
318pub struct LogIntegrityVerifier {
327 key: [u8; 32],
328}
329
330impl LogIntegrityVerifier {
331 pub fn new(key: [u8; 32]) -> Self {
333 Self { key }
334 }
335
336 pub fn compute(&self, entry_index: u64, payload: &EncryptedPayload) -> [u8; 32] {
338 let mut mac = <HmacSha256 as KeyInit>::new_from_slice(&self.key)
339 .expect("HMAC-SHA256 accepts any key size including 32 bytes");
340 mac.update(&entry_index.to_le_bytes());
341 mac.update(&payload.nonce);
342 mac.update(&payload.ciphertext);
343
344 let result = mac.finalize().into_bytes();
345 let mut tag = [0u8; 32];
346 tag.copy_from_slice(&result);
347 tag
348 }
349
350 pub fn verify(
355 &self,
356 entry_index: u64,
357 payload: &EncryptedPayload,
358 tag: &[u8; 32],
359 ) -> RaftResult<()> {
360 let mut mac = <HmacSha256 as KeyInit>::new_from_slice(&self.key)
361 .expect("HMAC-SHA256 accepts any key size including 32 bytes");
362 mac.update(&entry_index.to_le_bytes());
363 mac.update(&payload.nonce);
364 mac.update(&payload.ciphertext);
365
366 mac.verify_slice(tag).map_err(|_| RaftError::StorageError {
368 message: "HMAC-SHA256 integrity verification failed: tag mismatch".to_string(),
369 })
370 }
371}
372
373#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_encrypt_decrypt_roundtrip() {
383 let key = LogEncryptionKey::random();
384 let encryptor = EntryEncryptor::new(key);
385 let plaintext = b"Hello, Raft log entry!";
386
387 let payload = encryptor
388 .encrypt(42, plaintext)
389 .expect("encrypt should succeed");
390 let decrypted = encryptor
391 .decrypt(42, &payload)
392 .expect("decrypt should succeed");
393
394 assert_eq!(decrypted.as_slice(), plaintext.as_ref());
395 }
396
397 #[test]
398 fn test_different_indices_produce_different_ciphertexts() {
399 let key = LogEncryptionKey::new([0xab; 32]);
400 let encryptor = EntryEncryptor::new(key);
401 let plaintext = b"same plaintext for both entries";
402
403 let payload1 = encryptor.encrypt(1, plaintext).expect("encrypt entry 1");
404 let payload2 = encryptor.encrypt(2, plaintext).expect("encrypt entry 2");
405
406 assert_ne!(payload1.ciphertext, payload2.ciphertext);
407 assert_ne!(payload1.nonce, payload2.nonce);
408 }
409
410 #[test]
411 fn test_hmac_verify_valid() {
412 let key = [0x12u8; 32];
413 let verifier = LogIntegrityVerifier::new(key);
414 let payload = EncryptedPayload {
415 ciphertext: vec![0xde, 0xad, 0xbe, 0xef],
416 nonce: [0u8; 12],
417 key_version: 1,
418 };
419
420 let tag = verifier.compute(7, &payload);
421 verifier
422 .verify(7, &payload, &tag)
423 .expect("HMAC should verify successfully");
424 }
425
426 #[test]
427 fn test_hmac_verify_tampered_fails() {
428 let key = [0x34u8; 32];
429 let verifier = LogIntegrityVerifier::new(key);
430 let mut payload = EncryptedPayload {
431 ciphertext: vec![0x01, 0x02, 0x03, 0x04, 0x05],
432 nonce: [0u8; 12],
433 key_version: 1,
434 };
435
436 let tag = verifier.compute(99, &payload);
437
438 payload.ciphertext[2] ^= 0xff;
440
441 let result = verifier.verify(99, &payload, &tag);
442 assert!(
443 result.is_err(),
444 "verification of tampered payload should fail"
445 );
446 }
447
448 #[test]
449 fn test_key_from_slice_wrong_length() {
450 let too_short = [0u8; 16];
451 assert!(
452 LogEncryptionKey::from_slice(&too_short).is_err(),
453 "should reject a 16-byte slice"
454 );
455
456 let too_long = [0u8; 64];
457 assert!(
458 LogEncryptionKey::from_slice(&too_long).is_err(),
459 "should reject a 64-byte slice"
460 );
461
462 let correct = [0u8; 32];
463 assert!(
464 LogEncryptionKey::from_slice(&correct).is_ok(),
465 "should accept a 32-byte slice"
466 );
467 }
468
469 #[test]
470 fn test_encrypt_empty_plaintext() {
471 let key = LogEncryptionKey::new([0xcc; 32]);
472 let encryptor = EntryEncryptor::new(key);
473
474 let payload = encryptor
475 .encrypt(0, b"")
476 .expect("encrypting empty plaintext should succeed");
477 let decrypted = encryptor
478 .decrypt(0, &payload)
479 .expect("decrypting empty ciphertext should succeed");
480
481 assert!(
482 decrypted.is_empty(),
483 "round-tripped empty plaintext must be empty"
484 );
485 }
486
487 #[test]
492 fn test_entry_encryptor_uses_current_key_for_encrypt() {
493 let mgr = KeyManager::new(LogEncryptionKey::new([0x01; 32]), 3);
494 let mgr = Arc::new(RwLock::new(mgr));
495 let encryptor = EntryEncryptor::with_key_manager(Arc::clone(&mgr));
496
497 let payload_v1 = encryptor.encrypt(7, b"hello").expect("encrypt v1");
499 assert_eq!(payload_v1.key_version, 1);
500
501 mgr.write().rotate(LogEncryptionKey::new([0x02; 32]));
503 let payload_v2 = encryptor.encrypt(8, b"hello").expect("encrypt v2");
504 assert_eq!(payload_v2.key_version, 2);
505 }
506
507 #[test]
508 fn test_entry_encryptor_uses_payload_version_for_decrypt() {
509 let k1 = LogEncryptionKey::new([0x11; 32]);
510 let k2 = LogEncryptionKey::new([0x22; 32]);
511 let mgr = Arc::new(RwLock::new(KeyManager::new(k1, 3)));
512 let encryptor = EntryEncryptor::with_key_manager(Arc::clone(&mgr));
513
514 let payload_v1 = encryptor.encrypt(100, b"under-v1").expect("encrypt v1");
516 assert_eq!(payload_v1.key_version, 1);
517
518 mgr.write().rotate(k2);
520 let payload_v2 = encryptor.encrypt(101, b"under-v2").expect("encrypt v2");
521 assert_eq!(payload_v2.key_version, 2);
522
523 let pt_v1 = encryptor.decrypt(100, &payload_v1).expect("decrypt v1");
526 assert_eq!(pt_v1.as_slice(), b"under-v1");
527
528 let pt_v2 = encryptor.decrypt(101, &payload_v2).expect("decrypt v2");
529 assert_eq!(pt_v2.as_slice(), b"under-v2");
530 }
531
532 #[test]
533 fn test_key_manager_decrypts_old_version_payload() {
534 let k1 = LogEncryptionKey::new([0xaa; 32]);
540 let k2 = LogEncryptionKey::new([0xbb; 32]);
541 let k3 = LogEncryptionKey::new([0xcc; 32]);
542 let mgr = Arc::new(RwLock::new(KeyManager::new(k1, 3)));
543 let encryptor = EntryEncryptor::with_key_manager(Arc::clone(&mgr));
544
545 let payload_v1 = encryptor
546 .encrypt(42, b"persisted-under-v1")
547 .expect("encrypt");
548 assert_eq!(payload_v1.key_version, 1);
549
550 mgr.write().rotate(k2);
552 mgr.write().rotate(k3);
553
554 let recovered = encryptor
555 .decrypt(42, &payload_v1)
556 .expect("decrypt under historical key v1 must succeed");
557 assert_eq!(recovered.as_slice(), b"persisted-under-v1");
558
559 let mut tampered = payload_v1.clone();
561 tampered.ciphertext[0] ^= 0xff;
562 assert!(
563 encryptor.decrypt(42, &tampered).is_err(),
564 "tampered ciphertext must still fail authentication post-rotation"
565 );
566 }
567
568 #[test]
569 fn test_decrypt_fails_when_key_version_pruned() {
570 let k1 = LogEncryptionKey::new([0x01; 32]);
573 let k2 = LogEncryptionKey::new([0x02; 32]);
574 let mgr = Arc::new(RwLock::new(KeyManager::new(k1, 1)));
575 let encryptor = EntryEncryptor::with_key_manager(Arc::clone(&mgr));
576
577 let payload_v1 = encryptor.encrypt(0, b"will-be-lost").expect("encrypt");
578 mgr.write().rotate(k2);
579
580 let result = encryptor.decrypt(0, &payload_v1);
581 assert!(
582 result.is_err(),
583 "decryption of pruned key version must surface a clear error"
584 );
585 }
586
587 #[test]
588 fn test_encrypted_payload_serde_default_key_version() {
589 let json = r#"{"ciphertext":[1,2,3,4],"nonce":[0,0,0,0,0,0,0,0,0,0,0,0]}"#;
593 let payload: EncryptedPayload =
594 serde_json::from_str(json).expect("legacy payload must deserialize");
595 assert_eq!(payload.key_version, LEGACY_KEY_VERSION);
596 assert_eq!(payload.ciphertext, vec![1, 2, 3, 4]);
597
598 let with_version =
600 r#"{"ciphertext":[5,6],"nonce":[1,2,3,4,5,6,7,8,9,10,11,12],"key_version":7}"#;
601 let payload: EncryptedPayload =
602 serde_json::from_str(with_version).expect("v-tagged payload must deserialize");
603 assert_eq!(payload.key_version, 7);
604 }
605}