use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use hkdf::Hkdf;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::error::{RaftError, RaftResult};
type HmacSha256 = Hmac<Sha256>;
pub struct LogEncryptionKey {
key_bytes: [u8; 32],
}
impl LogEncryptionKey {
pub fn new(key_bytes: [u8; 32]) -> Self {
Self { key_bytes }
}
pub fn from_slice(bytes: &[u8]) -> RaftResult<Self> {
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| RaftError::StorageError {
message: format!(
"LogEncryptionKey requires exactly 32 bytes, got {}",
bytes.len()
),
})?;
Ok(Self { key_bytes })
}
pub fn random() -> Self {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
use std::time::{SystemTime, UNIX_EPOCH};
let ts_nanos: u128 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0u128);
let rs1 = RandomState::new();
let rs2 = RandomState::new();
let rs3 = RandomState::new();
let rs4 = RandomState::new();
let h1: u64 = {
let mut h = rs1.build_hasher();
h.write_u128(ts_nanos);
h.finish()
};
let h2: u64 = {
let mut h = rs2.build_hasher();
h.write_u128(ts_nanos ^ 0xcafe_babe_dead_beef_1234_5678_abcd_ef01_u128);
h.finish()
};
let h3: u64 = {
let mut h = rs3.build_hasher();
h.write_u64(h1);
h.write_u64(h2);
h.finish()
};
let h4: u64 = {
let mut h = rs4.build_hasher();
h.write_u64(h2 ^ h3);
h.write_u128(ts_nanos.wrapping_add(0x9e37_79b9_7f4a_7c15_f39c_c060_5c0e_d609_u128));
h.finish()
};
let mut ikm = [0u8; 32];
ikm[0..8].copy_from_slice(&h1.to_le_bytes());
ikm[8..16].copy_from_slice(&h2.to_le_bytes());
ikm[16..24].copy_from_slice(&h3.to_le_bytes());
ikm[24..32].copy_from_slice(&h4.to_le_bytes());
let salt = b"amaters-log-encryption-key-v1";
let hk = Hkdf::<Sha256>::new(Some(salt), &ikm);
let mut key_bytes = [0u8; 32];
hk.expand(b"master-key", &mut key_bytes)
.expect("HKDF expand for 32 bytes cannot fail");
Self { key_bytes }
}
}
#[derive(Debug, Clone)]
pub struct EncryptedPayload {
pub ciphertext: Vec<u8>,
pub nonce: [u8; 12],
}
pub struct EntryEncryptor {
master_key: LogEncryptionKey,
}
impl EntryEncryptor {
pub fn new(key: LogEncryptionKey) -> Self {
Self { master_key: key }
}
fn derive_key_and_nonce(&self, entry_index: u64) -> RaftResult<([u8; 32], [u8; 12])> {
let hk = Hkdf::<Sha256>::new(None, &self.master_key.key_bytes);
let mut derived = [0u8; 44]; hk.expand(&entry_index.to_le_bytes(), &mut derived)
.map_err(|e| RaftError::StorageError {
message: format!("HKDF expand failed for entry {entry_index}: {e}"),
})?;
let mut key = [0u8; 32];
let mut nonce = [0u8; 12];
key.copy_from_slice(&derived[..32]);
nonce.copy_from_slice(&derived[32..44]);
Ok((key, nonce))
}
pub fn encrypt(&self, entry_index: u64, plaintext: &[u8]) -> RaftResult<EncryptedPayload> {
let (key_bytes, nonce_bytes) = self.derive_key_and_nonce(entry_index)?;
let key = Key::<Aes256Gcm>::from(key_bytes);
let cipher = Aes256Gcm::new(&key);
let nonce = Nonce::from(nonce_bytes);
let ciphertext =
cipher
.encrypt(&nonce, plaintext)
.map_err(|e| RaftError::StorageError {
message: format!("AES-256-GCM encryption failed for entry {entry_index}: {e}"),
})?;
Ok(EncryptedPayload {
ciphertext,
nonce: nonce_bytes,
})
}
pub fn decrypt(&self, entry_index: u64, payload: &EncryptedPayload) -> RaftResult<Vec<u8>> {
let (key_bytes, _derived_nonce) = self.derive_key_and_nonce(entry_index)?;
let key = Key::<Aes256Gcm>::from(key_bytes);
let cipher = Aes256Gcm::new(&key);
let nonce = Nonce::from(payload.nonce);
cipher
.decrypt(&nonce, payload.ciphertext.as_ref())
.map_err(|e| RaftError::StorageError {
message: format!("AES-256-GCM decryption failed for entry {entry_index}: {e}"),
})
}
}
pub struct LogIntegrityVerifier {
key: [u8; 32],
}
impl LogIntegrityVerifier {
pub fn new(key: [u8; 32]) -> Self {
Self { key }
}
pub fn compute(&self, entry_index: u64, payload: &EncryptedPayload) -> [u8; 32] {
let mut mac = <HmacSha256 as KeyInit>::new_from_slice(&self.key)
.expect("HMAC-SHA256 accepts any key size including 32 bytes");
mac.update(&entry_index.to_le_bytes());
mac.update(&payload.nonce);
mac.update(&payload.ciphertext);
let result = mac.finalize().into_bytes();
let mut tag = [0u8; 32];
tag.copy_from_slice(&result);
tag
}
pub fn verify(
&self,
entry_index: u64,
payload: &EncryptedPayload,
tag: &[u8; 32],
) -> RaftResult<()> {
let mut mac = <HmacSha256 as KeyInit>::new_from_slice(&self.key)
.expect("HMAC-SHA256 accepts any key size including 32 bytes");
mac.update(&entry_index.to_le_bytes());
mac.update(&payload.nonce);
mac.update(&payload.ciphertext);
mac.verify_slice(tag).map_err(|_| RaftError::StorageError {
message: "HMAC-SHA256 integrity verification failed: tag mismatch".to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = LogEncryptionKey::random();
let encryptor = EntryEncryptor::new(key);
let plaintext = b"Hello, Raft log entry!";
let payload = encryptor
.encrypt(42, plaintext)
.expect("encrypt should succeed");
let decrypted = encryptor
.decrypt(42, &payload)
.expect("decrypt should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_ref());
}
#[test]
fn test_different_indices_produce_different_ciphertexts() {
let key = LogEncryptionKey::new([0xab; 32]);
let encryptor = EntryEncryptor::new(key);
let plaintext = b"same plaintext for both entries";
let payload1 = encryptor.encrypt(1, plaintext).expect("encrypt entry 1");
let payload2 = encryptor.encrypt(2, plaintext).expect("encrypt entry 2");
assert_ne!(payload1.ciphertext, payload2.ciphertext);
assert_ne!(payload1.nonce, payload2.nonce);
}
#[test]
fn test_hmac_verify_valid() {
let key = [0x12u8; 32];
let verifier = LogIntegrityVerifier::new(key);
let payload = EncryptedPayload {
ciphertext: vec![0xde, 0xad, 0xbe, 0xef],
nonce: [0u8; 12],
};
let tag = verifier.compute(7, &payload);
verifier
.verify(7, &payload, &tag)
.expect("HMAC should verify successfully");
}
#[test]
fn test_hmac_verify_tampered_fails() {
let key = [0x34u8; 32];
let verifier = LogIntegrityVerifier::new(key);
let mut payload = EncryptedPayload {
ciphertext: vec![0x01, 0x02, 0x03, 0x04, 0x05],
nonce: [0u8; 12],
};
let tag = verifier.compute(99, &payload);
payload.ciphertext[2] ^= 0xff;
let result = verifier.verify(99, &payload, &tag);
assert!(
result.is_err(),
"verification of tampered payload should fail"
);
}
#[test]
fn test_key_from_slice_wrong_length() {
let too_short = [0u8; 16];
assert!(
LogEncryptionKey::from_slice(&too_short).is_err(),
"should reject a 16-byte slice"
);
let too_long = [0u8; 64];
assert!(
LogEncryptionKey::from_slice(&too_long).is_err(),
"should reject a 64-byte slice"
);
let correct = [0u8; 32];
assert!(
LogEncryptionKey::from_slice(&correct).is_ok(),
"should accept a 32-byte slice"
);
}
#[test]
fn test_encrypt_empty_plaintext() {
let key = LogEncryptionKey::new([0xcc; 32]);
let encryptor = EntryEncryptor::new(key);
let payload = encryptor
.encrypt(0, b"")
.expect("encrypting empty plaintext should succeed");
let decrypted = encryptor
.decrypt(0, &payload)
.expect("decrypting empty ciphertext should succeed");
assert!(
decrypted.is_empty(),
"round-tripped empty plaintext must be empty"
);
}
}