pub trait EncryptionProvider:
Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe
{
fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>>;
fn max_overhead(&self) -> u32;
fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>>;
fn encrypt_vec(&self, plaintext: Vec<u8>) -> crate::Result<Vec<u8>> {
self.encrypt(&plaintext)
}
fn decrypt_vec(&self, ciphertext: Vec<u8>) -> crate::Result<Vec<u8>> {
self.decrypt(&ciphertext)
}
}
#[cfg(feature = "encryption")]
pub struct Aes256GcmProvider {
cipher: aes_gcm::Aes256Gcm,
}
#[cfg(feature = "encryption")]
impl Aes256GcmProvider {
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
pub const OVERHEAD: usize = Self::NONCE_LEN + Self::TAG_LEN;
#[must_use]
pub fn new(key: &[u8; 32]) -> Self {
use aes_gcm::KeyInit;
Self {
cipher: aes_gcm::Aes256Gcm::new(key.into()),
}
}
pub fn from_slice(key: &[u8]) -> crate::Result<Self> {
let key: &[u8; 32] = key
.try_into()
.map_err(|_| crate::Error::Encrypt("AES-256-GCM key must be exactly 32 bytes"))?;
Ok(Self::new(key))
}
}
#[cfg(feature = "encryption")]
fn new_chacha_rng() -> rand_chacha::ChaCha20Rng {
use aes_gcm::aead::rand_core::{OsRng, SeedableRng};
#[expect(
clippy::expect_used,
reason = "intentionally panics if OsRng is unavailable"
)]
rand_chacha::ChaCha20Rng::from_rng(OsRng)
.expect("OS RNG should be available for initial CSPRNG seed")
}
#[cfg(feature = "encryption")]
struct ForkAwareRng {
pid: std::cell::Cell<u32>,
rng: std::cell::RefCell<rand_chacha::ChaCha20Rng>,
}
#[cfg(feature = "encryption")]
impl ForkAwareRng {
fn new() -> Self {
Self {
pid: std::cell::Cell::new(std::process::id()),
rng: std::cell::RefCell::new(new_chacha_rng()),
}
}
fn with_rng<R>(&self, f: impl FnOnce(&mut rand_chacha::ChaCha20Rng) -> R) -> R {
let mut rng_ref = self.rng.borrow_mut();
let current_pid = std::process::id();
if self.pid.get() != current_pid {
self.pid.set(current_pid);
*rng_ref = new_chacha_rng();
}
f(&mut rng_ref)
}
}
#[cfg(feature = "encryption")]
thread_local! {
static THREAD_RNG: ForkAwareRng = ForkAwareRng::new();
}
#[cfg(feature = "encryption")]
fn thread_local_rng<R>(f: impl FnOnce(&mut rand_chacha::ChaCha20Rng) -> R) -> R {
THREAD_RNG.with(|state| state.with_rng(f))
}
#[cfg(feature = "encryption")]
impl EncryptionProvider for Aes256GcmProvider {
fn max_overhead(&self) -> u32 {
#[expect(clippy::cast_possible_truncation, reason = "OVERHEAD is 28")]
{
Self::OVERHEAD as u32
}
}
fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>> {
use aes_gcm::AeadCore;
use aes_gcm::AeadInPlace;
let nonce = thread_local_rng(|rng| aes_gcm::Aes256Gcm::generate_nonce(rng));
let mut buf = Vec::with_capacity(Self::NONCE_LEN + plaintext.len() + Self::TAG_LEN);
buf.extend_from_slice(&nonce);
buf.extend_from_slice(plaintext);
#[expect(
clippy::indexing_slicing,
reason = "buf length = NONCE_LEN + plaintext.len()"
)]
let tag = self
.cipher
.encrypt_in_place_detached(&nonce, b"", &mut buf[Self::NONCE_LEN..])
.map_err(|_| crate::Error::Encrypt("AES-256-GCM encryption failed"))?;
buf.extend_from_slice(&tag);
Ok(buf)
}
fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>> {
use aes_gcm::AeadInPlace;
use aes_gcm::aead::generic_array::GenericArray;
let min_len = Self::NONCE_LEN + Self::TAG_LEN;
if ciphertext.len() < min_len {
return Err(crate::Error::Decrypt(
"ciphertext too short for AES-256-GCM (need nonce + tag)",
));
}
#[expect(clippy::indexing_slicing, reason = "length checked above")]
let nonce = GenericArray::from_slice(&ciphertext[..Self::NONCE_LEN]);
let tag_start = ciphertext.len() - Self::TAG_LEN;
#[expect(clippy::indexing_slicing, reason = "length checked above")]
let tag = GenericArray::from_slice(&ciphertext[tag_start..]);
#[expect(clippy::indexing_slicing, reason = "length checked above")]
let mut buf = ciphertext[Self::NONCE_LEN..tag_start].to_vec();
self.cipher
.decrypt_in_place_detached(nonce, b"", &mut buf, tag)
.map_err(|_| {
crate::Error::Decrypt("AES-256-GCM decryption failed (bad key or tampered data)")
})?;
Ok(buf)
}
fn encrypt_vec(&self, mut buf: Vec<u8>) -> crate::Result<Vec<u8>> {
use aes_gcm::AeadCore;
use aes_gcm::AeadInPlace;
let nonce = thread_local_rng(|rng| aes_gcm::Aes256Gcm::generate_nonce(rng));
let plaintext_len = buf.len();
buf.reserve(Self::NONCE_LEN + Self::TAG_LEN);
buf.resize(plaintext_len + Self::NONCE_LEN, 0);
buf.copy_within(..plaintext_len, Self::NONCE_LEN);
#[expect(
clippy::indexing_slicing,
reason = "buf was just resized to include NONCE_LEN"
)]
buf[..Self::NONCE_LEN].copy_from_slice(&nonce);
#[expect(
clippy::indexing_slicing,
reason = "buf length ≥ NONCE_LEN after resize + copy_within"
)]
let tag = self
.cipher
.encrypt_in_place_detached(&nonce, b"", &mut buf[Self::NONCE_LEN..])
.map_err(|_| crate::Error::Encrypt("AES-256-GCM encryption failed"))?;
buf.extend_from_slice(&tag);
Ok(buf)
}
fn decrypt_vec(&self, mut buf: Vec<u8>) -> crate::Result<Vec<u8>> {
use aes_gcm::AeadInPlace;
use aes_gcm::aead::generic_array::GenericArray;
let min_len = Self::NONCE_LEN + Self::TAG_LEN;
if buf.len() < min_len {
return Err(crate::Error::Decrypt(
"ciphertext too short for AES-256-GCM (need nonce + tag)",
));
}
#[expect(clippy::indexing_slicing, reason = "length checked above")]
let nonce = *GenericArray::from_slice(&buf[..Self::NONCE_LEN]);
let tag_start = buf.len() - Self::TAG_LEN;
#[expect(clippy::indexing_slicing, reason = "length checked above")]
let tag = *GenericArray::from_slice(&buf[tag_start..]);
buf.copy_within(Self::NONCE_LEN..tag_start, 0);
buf.truncate(tag_start - Self::NONCE_LEN);
self.cipher
.decrypt_in_place_detached(&nonce, b"", &mut buf, &tag)
.map_err(|_| {
crate::Error::Decrypt("AES-256-GCM decryption failed (bad key or tampered data)")
})?;
Ok(buf)
}
}
#[cfg(test)]
#[allow(
clippy::doc_markdown,
clippy::redundant_clone,
clippy::unnecessary_wraps,
clippy::redundant_closure_for_method_calls
)]
mod tests {
use super::*;
#[test]
fn encryption_provider_trait_is_object_safe() {
fn _assert_object_safe(_: &dyn EncryptionProvider) {}
}
struct XorProvider;
impl std::panic::UnwindSafe for XorProvider {}
impl std::panic::RefUnwindSafe for XorProvider {}
impl EncryptionProvider for XorProvider {
fn encrypt(&self, plaintext: &[u8]) -> crate::Result<Vec<u8>> {
Ok(plaintext.iter().map(|b| b ^ 0xAA).collect())
}
fn max_overhead(&self) -> u32 {
0
}
fn decrypt(&self, ciphertext: &[u8]) -> crate::Result<Vec<u8>> {
Ok(ciphertext.iter().map(|b| b ^ 0xAA).collect())
}
}
#[test]
fn default_encrypt_vec_delegates_to_encrypt() -> crate::Result<()> {
let provider = XorProvider;
let plaintext = b"test default encrypt_vec";
let via_encrypt = provider.encrypt(plaintext)?;
let via_encrypt_vec = provider.encrypt_vec(plaintext.to_vec())?;
assert_eq!(via_encrypt, via_encrypt_vec);
let decrypted = provider.decrypt(&via_encrypt_vec)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn default_decrypt_vec_delegates_to_decrypt() -> crate::Result<()> {
let provider = XorProvider;
let plaintext = b"test default decrypt_vec";
let ciphertext = provider.encrypt(plaintext)?;
let via_decrypt = provider.decrypt(&ciphertext)?;
let via_decrypt_vec = provider.decrypt_vec(ciphertext)?;
assert_eq!(via_decrypt, via_decrypt_vec);
assert_eq!(via_decrypt_vec, plaintext);
Ok(())
}
#[cfg(feature = "encryption")]
mod aes256gcm {
use super::*;
fn test_key() -> [u8; 32] {
[0x42; 32]
}
#[test]
fn roundtrip_basic() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"hello world, this is a block of data!";
let ciphertext = provider.encrypt(plaintext)?;
assert_ne!(&ciphertext[..], plaintext.as_slice());
assert_eq!(
ciphertext.len(),
Aes256GcmProvider::NONCE_LEN + plaintext.len() + Aes256GcmProvider::TAG_LEN,
);
let decrypted = provider.decrypt(&ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn roundtrip_empty() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"";
let ciphertext = provider.encrypt(plaintext)?;
let decrypted = provider.decrypt(&ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn different_nonces_produce_different_ciphertexts() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"deterministic input";
let ct1 = provider.encrypt(plaintext)?;
let ct2 = provider.encrypt(plaintext)?;
assert_ne!(
ct1, ct2,
"random nonces should produce different ciphertexts"
);
assert_eq!(provider.decrypt(&ct1)?, provider.decrypt(&ct2)?,);
Ok(())
}
#[test]
fn wrong_key_fails_decrypt() -> crate::Result<()> {
let provider1 = Aes256GcmProvider::new(&[0x01; 32]);
let provider2 = Aes256GcmProvider::new(&[0x02; 32]);
let ciphertext = provider1.encrypt(b"secret")?;
let result = provider2.decrypt(&ciphertext);
assert!(result.is_err());
Ok(())
}
#[test]
fn tampered_ciphertext_fails_decrypt() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let mut ciphertext = provider.encrypt(b"data")?;
let mid = Aes256GcmProvider::NONCE_LEN + 1;
if mid < ciphertext.len() {
#[expect(clippy::indexing_slicing)]
{
ciphertext[mid] ^= 0xFF;
}
}
let result = provider.decrypt(&ciphertext);
assert!(result.is_err());
Ok(())
}
#[test]
fn truncated_ciphertext_fails_decrypt() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let result = provider.decrypt(&[0u8; 10]); assert!(result.is_err());
Ok(())
}
#[test]
fn from_slice_rejects_wrong_length() {
assert!(Aes256GcmProvider::from_slice(&[0u8; 16]).is_err());
assert!(Aes256GcmProvider::from_slice(&[0u8; 31]).is_err());
assert!(Aes256GcmProvider::from_slice(&[0u8; 33]).is_err());
assert!(Aes256GcmProvider::from_slice(&[0u8; 32]).is_ok());
}
#[test]
fn roundtrip_large_payload() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = vec![0xAB_u8; 64 * 1024];
let ciphertext = provider.encrypt(&plaintext)?;
let decrypted = provider.decrypt(&ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn thread_local_rng_produces_unique_nonces() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"nonce uniqueness test";
let mut nonces = std::collections::HashSet::new();
for _ in 0..1000 {
let ct = provider.encrypt(plaintext)?;
#[expect(clippy::indexing_slicing, reason = "ct always >= NONCE_LEN")]
#[expect(clippy::expect_used, reason = "test assertion")]
let nonce: [u8; Aes256GcmProvider::NONCE_LEN] = ct[..Aes256GcmProvider::NONCE_LEN]
.try_into()
.expect("nonce has expected length");
assert!(
nonces.insert(nonce),
"nonce collision detected — CSPRNG produced duplicate nonce"
);
}
Ok(())
}
#[test]
fn fork_aware_rng_reseeds_on_pid_change() {
let rng = ForkAwareRng::new();
let _ = rng.with_rng(aes_gcm::aead::rand_core::RngCore::next_u64);
let current_pid = std::process::id();
let fake_pid = current_pid ^ 1;
rng.pid.set(fake_pid);
assert_eq!(rng.pid.get(), fake_pid, "PID should be set to fake value");
let _ = rng.with_rng(aes_gcm::aead::rand_core::RngCore::next_u64);
assert_eq!(
rng.pid.get(),
std::process::id(),
"PID should be restored to real process ID after reseed"
);
}
#[test]
fn encrypt_vec_roundtrip() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"block data for encrypt_vec test";
let ciphertext = provider.encrypt_vec(plaintext.to_vec())?;
assert_eq!(
ciphertext.len(),
Aes256GcmProvider::NONCE_LEN + plaintext.len() + Aes256GcmProvider::TAG_LEN,
);
let decrypted = provider.decrypt(&ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn decrypt_vec_roundtrip() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = b"block data for decrypt_vec test";
let ciphertext = provider.encrypt(plaintext)?;
let decrypted = provider.decrypt_vec(ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn encrypt_vec_decrypt_vec_roundtrip() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let plaintext = vec![0xCD_u8; 16 * 1024];
let ciphertext = provider.encrypt_vec(plaintext.clone())?;
let decrypted = provider.decrypt_vec(ciphertext)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn encrypt_vec_empty() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let ciphertext = provider.encrypt_vec(vec![])?;
let decrypted = provider.decrypt_vec(ciphertext)?;
assert!(decrypted.is_empty());
Ok(())
}
#[test]
fn decrypt_vec_truncated_fails() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let result = provider.decrypt_vec(vec![0u8; 10]);
assert!(result.is_err());
Ok(())
}
#[test]
fn decrypt_vec_tampered_fails() -> crate::Result<()> {
let provider = Aes256GcmProvider::new(&test_key());
let mut ciphertext = provider.encrypt_vec(b"data".to_vec())?;
let mid = Aes256GcmProvider::NONCE_LEN + 1;
if mid < ciphertext.len() {
#[expect(clippy::indexing_slicing)]
{
ciphertext[mid] ^= 0xFF;
}
}
let result = provider.decrypt_vec(ciphertext);
assert!(result.is_err());
Ok(())
}
}
}