use std::sync::Arc;
use crate::encryption::Cipher;
use crate::encryption::error::EncryptionError;
const WAL_HEADER_SIZE: usize = 25;
pub fn encrypt_wal_payload(
serialized: &[u8],
cipher: &Arc<dyn Cipher>,
) -> Result<Vec<u8>, EncryptionError> {
if serialized.len() < WAL_HEADER_SIZE {
return Err(EncryptionError::InvalidWalEntry {
expected: WAL_HEADER_SIZE,
actual: serialized.len(),
});
}
let header_and_op = &serialized[..WAL_HEADER_SIZE];
let payload = &serialized[WAL_HEADER_SIZE..];
let encrypted_payload = cipher.encrypt(payload, header_and_op)?;
let mut output = Vec::with_capacity(WAL_HEADER_SIZE + encrypted_payload.len());
output.extend_from_slice(header_and_op);
output.extend_from_slice(&encrypted_payload);
Ok(output)
}
pub fn decrypt_wal_payload(
encrypted_entry: &[u8],
cipher: &Arc<dyn Cipher>,
) -> Result<Vec<u8>, EncryptionError> {
let min_len = WAL_HEADER_SIZE + cipher.overhead();
if encrypted_entry.len() < min_len {
return Err(EncryptionError::InvalidWalEntry {
expected: min_len,
actual: encrypted_entry.len(),
});
}
let header_and_op = &encrypted_entry[..WAL_HEADER_SIZE];
let encrypted_payload = &encrypted_entry[WAL_HEADER_SIZE..];
let plaintext = cipher.decrypt(encrypted_payload, header_and_op)?;
let mut output = Vec::with_capacity(WAL_HEADER_SIZE + plaintext.len());
output.extend_from_slice(header_and_op);
output.extend_from_slice(&plaintext);
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encryption::{Aes256GcmCipher, ChaCha20Poly1305Cipher};
use rand::RngCore;
use zeroize::Zeroizing;
fn fake_entry(payload_len: usize) -> Vec<u8> {
let mut buf = vec![0u8; WAL_HEADER_SIZE + payload_len];
buf[..8].copy_from_slice(&42u64.to_le_bytes());
buf[24] = 1;
for (i, b) in buf[WAL_HEADER_SIZE..].iter_mut().enumerate() {
*b = (i % 256) as u8;
}
buf
}
fn random_key() -> Zeroizing<[u8; 32]> {
let mut key = Zeroizing::new([0u8; 32]);
rand::thread_rng().fill_bytes(key.as_mut());
key
}
fn aes_cipher() -> Arc<dyn Cipher> {
Arc::new(Aes256GcmCipher::new(&random_key()))
}
fn chacha_cipher() -> Arc<dyn Cipher> {
Arc::new(ChaCha20Poly1305Cipher::new(&random_key()))
}
#[test]
fn roundtrip_aes() {
let cipher = aes_cipher();
let entry = fake_entry(128);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
let decrypted = decrypt_wal_payload(&encrypted, &cipher).unwrap();
assert_eq!(decrypted, entry);
}
#[test]
fn roundtrip_chacha() {
let cipher = chacha_cipher();
let entry = fake_entry(128);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
let decrypted = decrypt_wal_payload(&encrypted, &cipher).unwrap();
assert_eq!(decrypted, entry);
}
#[test]
fn header_preserved_in_encrypted_output() {
let cipher = aes_cipher();
let entry = fake_entry(64);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
assert_eq!(&encrypted[..WAL_HEADER_SIZE], &entry[..WAL_HEADER_SIZE]);
}
#[test]
fn payload_is_encrypted() {
let cipher = aes_cipher();
let entry = fake_entry(64);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
assert_ne!(
&encrypted[WAL_HEADER_SIZE..],
&entry[WAL_HEADER_SIZE..],
"payload should be encrypted (different from plaintext)"
);
}
#[test]
fn encrypted_is_larger() {
let cipher = aes_cipher();
let entry = fake_entry(64);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
assert_eq!(encrypted.len(), entry.len() + cipher.overhead());
}
#[test]
fn tampered_header_fails_decryption() {
let cipher = aes_cipher();
let entry = fake_entry(64);
let mut encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
encrypted[0] ^= 0xFF;
let result = decrypt_wal_payload(&encrypted, &cipher);
assert!(result.is_err(), "tampered header should fail AAD check");
}
#[test]
fn empty_payload_roundtrips() {
let cipher = aes_cipher();
let entry = fake_entry(0);
assert_eq!(entry.len(), WAL_HEADER_SIZE);
let encrypted = encrypt_wal_payload(&entry, &cipher).unwrap();
let decrypted = decrypt_wal_payload(&encrypted, &cipher).unwrap();
assert_eq!(decrypted, entry);
}
#[test]
fn too_short_entry_fails() {
let cipher = aes_cipher();
let short = vec![0u8; 10];
let result = encrypt_wal_payload(&short, &cipher);
assert!(matches!(
result,
Err(EncryptionError::InvalidWalEntry {
expected: WAL_HEADER_SIZE,
actual: 10
})
));
}
}