use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM, CHACHA20_POLY1305};
use ring::hkdf::{Salt, HKDF_SHA256};
use ring::rand::{SecureRandom, SystemRandom};
use secrecy::{ExposeSecret, SecretBox};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::fs;
use std::io;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use thiserror::Error;
pub use crate::crypto::NONCE_SIZE;
pub use crate::crypto::TAG_SIZE;
pub use crate::crypto::KEY_SIZE;
pub use crate::crypto::{KeyInfo, KeyMaterial};
const ENCRYPTION_MAGIC: [u8; 4] = [0x52, 0x56, 0x45, 0x4E];
const FORMAT_VERSION: u8 = 1;
#[derive(Debug, Error)]
pub enum EncryptionError {
#[error("key provider error: {0}")]
KeyProvider(String),
#[error("encryption failed: {0}")]
Encryption(String),
#[error("decryption failed: {0}")]
Decryption(String),
#[error("invalid key: {0}")]
InvalidKey(String),
#[error("key rotation error: {0}")]
KeyRotation(String),
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("invalid format: {0}")]
InvalidFormat(String),
#[error("unsupported version: {0}")]
UnsupportedVersion(u8),
#[error("key not found: version {0}")]
KeyNotFound(u32),
}
pub type Result<T> = std::result::Result<T, EncryptionError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Algorithm {
#[default]
Aes256Gcm,
ChaCha20Poly1305,
}
impl Algorithm {
pub fn key_size(&self) -> usize {
match self {
Algorithm::Aes256Gcm => 32,
Algorithm::ChaCha20Poly1305 => 32,
}
}
pub fn nonce_size(&self) -> usize {
match self {
Algorithm::Aes256Gcm => 12,
Algorithm::ChaCha20Poly1305 => 12,
}
}
pub fn tag_size(&self) -> usize {
match self {
Algorithm::Aes256Gcm => 16,
Algorithm::ChaCha20Poly1305 => 16,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum KeyProvider {
File { path: PathBuf },
Environment { variable: String },
#[serde(skip)]
InMemory(#[serde(skip)] Vec<u8>),
}
impl Default for KeyProvider {
fn default() -> Self {
KeyProvider::Environment {
variable: "RIVVEN_ENCRYPTION_KEY".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub algorithm: Algorithm,
#[serde(default)]
pub key_provider: KeyProvider,
#[serde(default)]
pub key_rotation_days: u32,
#[serde(default = "default_aad_scope")]
pub aad_scope: String,
}
fn default_aad_scope() -> String {
"rivven".to_string()
}
impl Default for EncryptionConfig {
fn default() -> Self {
Self {
enabled: false,
algorithm: Algorithm::default(),
key_provider: KeyProvider::default(),
key_rotation_days: 0,
aad_scope: default_aad_scope(),
}
}
}
impl EncryptionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn enabled(mut self) -> Self {
self.enabled = true;
self
}
pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_key_provider(mut self, provider: KeyProvider) -> Self {
self.key_provider = provider;
self
}
pub fn with_key_rotation_days(mut self, days: u32) -> Self {
self.key_rotation_days = days;
self
}
}
#[derive(Debug, Clone)]
pub struct EncryptedHeader {
pub version: u8,
pub algorithm: Algorithm,
pub key_version: u32,
pub nonce: [u8; NONCE_SIZE],
}
impl EncryptedHeader {
pub const SIZE: usize = 24;
pub fn new(algorithm: Algorithm, key_version: u32, nonce: [u8; NONCE_SIZE]) -> Self {
Self {
version: FORMAT_VERSION,
algorithm,
key_version,
nonce,
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0..4].copy_from_slice(&ENCRYPTION_MAGIC);
buf[4] = self.version;
buf[5] = self.algorithm as u8;
buf[6..10].copy_from_slice(&self.key_version.to_be_bytes());
buf[10..22].copy_from_slice(&self.nonce);
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < Self::SIZE {
return Err(EncryptionError::InvalidFormat(format!(
"header too short: {} < {}",
data.len(),
Self::SIZE
)));
}
if data[0..4] != ENCRYPTION_MAGIC {
return Err(EncryptionError::InvalidFormat("invalid magic bytes".into()));
}
let version = data[4];
if version != FORMAT_VERSION {
return Err(EncryptionError::UnsupportedVersion(version));
}
let algorithm = match data[5] {
0 => Algorithm::Aes256Gcm,
1 => Algorithm::ChaCha20Poly1305,
v => {
return Err(EncryptionError::InvalidFormat(format!(
"unknown algorithm: {}",
v
)))
}
};
let key_version = u32::from_be_bytes([data[6], data[7], data[8], data[9]]);
let mut nonce = [0u8; NONCE_SIZE];
nonce.copy_from_slice(&data[10..22]);
Ok(Self {
version,
algorithm,
key_version,
nonce,
})
}
}
pub struct MasterKey {
key: SecretBox<[u8; KEY_SIZE]>,
version: u32,
}
impl MasterKey {
pub fn new(key: Vec<u8>, version: u32) -> Result<Self> {
if key.len() != KEY_SIZE {
return Err(EncryptionError::InvalidKey(format!(
"key must be {} bytes, got {}",
KEY_SIZE,
key.len()
)));
}
let mut key_array = [0u8; KEY_SIZE];
key_array.copy_from_slice(&key);
Ok(Self {
key: SecretBox::new(Box::new(key_array)),
version,
})
}
pub fn generate(version: u32) -> Result<Self> {
let rng = SystemRandom::new();
let mut key = vec![0u8; KEY_SIZE];
rng.fill(&mut key)
.map_err(|_| EncryptionError::KeyProvider("failed to generate random key".into()))?;
Self::new(key, version)
}
pub fn from_provider(provider: &KeyProvider) -> Result<Self> {
match provider {
KeyProvider::File { path } => {
let data = fs::read(path)?;
let key = if data.len() == KEY_SIZE {
data
} else {
let hex_str = String::from_utf8(data)
.map_err(|_| EncryptionError::InvalidKey("invalid key file format".into()))?
.trim()
.to_string();
hex::decode(&hex_str).map_err(|e| {
EncryptionError::InvalidKey(format!("invalid hex key: {}", e))
})?
};
Self::new(key, 1)
}
KeyProvider::Environment { variable } => {
let hex_key = std::env::var(variable).map_err(|_| {
EncryptionError::KeyProvider(format!(
"environment variable '{}' not set",
variable
))
})?;
let key = hex::decode(hex_key.trim()).map_err(|e| {
EncryptionError::InvalidKey(format!("invalid hex key in env var: {}", e))
})?;
Self::new(key, 1)
}
KeyProvider::InMemory(key) => Self::new(key.clone(), 1),
#[allow(unreachable_patterns)]
_ => Err(EncryptionError::KeyProvider(
"unsupported key provider".into(),
)),
}
}
pub fn version(&self) -> u32 {
self.version
}
fn derive_data_key(&self, info: &[u8]) -> Result<[u8; KEY_SIZE]> {
let salt = Salt::new(HKDF_SHA256, b"rivven-encryption-v1");
let prk = salt.extract(self.key.expose_secret());
let info_refs = [info];
let okm = prk
.expand(&info_refs, DataKeyLen)
.map_err(|_| EncryptionError::Encryption("key derivation failed".into()))?;
let mut data_key = [0u8; KEY_SIZE];
okm.fill(&mut data_key)
.map_err(|_| EncryptionError::Encryption("key expansion failed".into()))?;
Ok(data_key)
}
}
impl fmt::Debug for MasterKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MasterKey")
.field("version", &self.version)
.field("key", &"[REDACTED]")
.finish()
}
}
struct DataKeyLen;
impl ring::hkdf::KeyType for DataKeyLen {
fn len(&self) -> usize {
KEY_SIZE
}
}
pub struct EncryptionManager {
config: EncryptionConfig,
master_key: MasterKey,
key_store: parking_lot::RwLock<HashMap<u32, LessSafeKey>>,
rng: SystemRandom,
current_key_version: AtomicU32,
}
fn ring_algorithm(algo: Algorithm) -> &'static ring::aead::Algorithm {
match algo {
Algorithm::Aes256Gcm => &AES_256_GCM,
Algorithm::ChaCha20Poly1305 => &CHACHA20_POLY1305,
}
}
impl EncryptionManager {
pub fn new(config: EncryptionConfig) -> Result<Arc<Self>> {
let master_key = MasterKey::from_provider(&config.key_provider)?;
let data_key_bytes = master_key.derive_data_key(config.aad_scope.as_bytes())?;
let algo = ring_algorithm(config.algorithm);
let version = master_key.version();
let store_unbound = UnboundKey::new(algo, &data_key_bytes)
.map_err(|_| EncryptionError::InvalidKey("failed to create store key".into()))?;
let mut key_store = HashMap::new();
key_store.insert(version, LessSafeKey::new(store_unbound));
Ok(Arc::new(Self {
config,
current_key_version: AtomicU32::new(version),
master_key,
key_store: parking_lot::RwLock::new(key_store),
rng: SystemRandom::new(),
}))
}
pub fn rotate_key(&self, new_master: MasterKey) -> Result<()> {
let new_version = new_master.version();
if new_version <= self.current_key_version.load(Ordering::Acquire) {
return Err(EncryptionError::KeyRotation(
"new key version must be greater than current".into(),
));
}
let data_key_bytes = new_master.derive_data_key(self.config.aad_scope.as_bytes())?;
let algo = ring_algorithm(self.config.algorithm);
let new_key = UnboundKey::new(algo, &data_key_bytes)
.map_err(|_| EncryptionError::KeyRotation("failed to create new key".into()))?;
{
let mut store = self.key_store.write();
store.insert(new_version, LessSafeKey::new(new_key));
}
self.current_key_version
.store(new_version, Ordering::Release);
Ok(())
}
fn get_key_for_version(&self, version: u32) -> Result<()> {
let store = self.key_store.read();
if store.contains_key(&version) {
Ok(())
} else {
Err(EncryptionError::KeyNotFound(version))
}
}
pub fn disabled() -> Arc<DisabledEncryption> {
Arc::new(DisabledEncryption)
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn key_version(&self) -> u32 {
self.current_key_version.load(Ordering::Relaxed)
}
fn generate_nonce(&self, lsn: u64) -> Result<[u8; NONCE_SIZE]> {
let mut nonce = [0u8; NONCE_SIZE];
nonce[0..8].copy_from_slice(&lsn.to_be_bytes());
self.rng.fill(&mut nonce[8..12]).map_err(|_| {
EncryptionError::Encryption(
"RNG failure during nonce generation — refusing to encrypt with zero nonce bytes"
.into(),
)
})?;
Ok(nonce)
}
pub fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
let nonce_bytes = self.generate_nonce(lsn)?;
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let version = self.current_key_version.load(Ordering::Acquire);
let header = EncryptedHeader::new(self.config.algorithm, version, nonce_bytes);
let mut output = Vec::with_capacity(EncryptedHeader::SIZE + plaintext.len() + TAG_SIZE);
output.extend_from_slice(&header.to_bytes());
output.extend_from_slice(plaintext);
let store = self.key_store.read();
let key = store
.get(&version)
.ok_or(EncryptionError::KeyNotFound(version))?;
let ciphertext_start = EncryptedHeader::SIZE;
let tag = key
.seal_in_place_separate_tag(
nonce,
Aad::from(self.config.aad_scope.as_bytes()),
&mut output[ciphertext_start..],
)
.map_err(|_| EncryptionError::Encryption("seal failed".into()))?;
output.extend_from_slice(tag.as_ref());
Ok(output)
}
pub fn decrypt(&self, ciphertext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
if ciphertext.len() < EncryptedHeader::SIZE + TAG_SIZE {
return Err(EncryptionError::InvalidFormat(
"ciphertext too short".into(),
));
}
let header = EncryptedHeader::from_bytes(ciphertext)?;
self.get_key_for_version(header.key_version)?;
let nonce = Nonce::assume_unique_for_key(header.nonce);
let algo = ring_algorithm(header.algorithm);
let mut buffer = ciphertext[EncryptedHeader::SIZE..].to_vec();
let store = self.key_store.read();
let key = store
.get(&header.key_version)
.ok_or(EncryptionError::KeyNotFound(header.key_version))?;
if *key.algorithm() != *algo {
drop(store);
let data_key_bytes = self
.master_key
.derive_data_key(self.config.aad_scope.as_bytes())?;
let unbound = UnboundKey::new(algo, &data_key_bytes)
.map_err(|_| EncryptionError::Decryption("key re-derive failed".into()))?;
let temp_key = LessSafeKey::new(unbound);
let plaintext = temp_key
.open_in_place(
nonce,
Aad::from(self.config.aad_scope.as_bytes()),
&mut buffer,
)
.map_err(|_| EncryptionError::Decryption("authentication failed".into()))?;
return Ok(plaintext.to_vec());
}
let plaintext = key
.open_in_place(
nonce,
Aad::from(self.config.aad_scope.as_bytes()),
&mut buffer,
)
.map_err(|_| EncryptionError::Decryption("authentication failed".into()))?;
Ok(plaintext.to_vec())
}
pub fn encrypted_size(&self, plaintext_len: usize) -> usize {
EncryptedHeader::SIZE + plaintext_len + TAG_SIZE
}
}
impl fmt::Debug for EncryptionManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EncryptionManager")
.field("enabled", &self.config.enabled)
.field("algorithm", &self.config.algorithm)
.field("key_version", &self.key_version())
.finish()
}
}
#[derive(Debug)]
pub struct DisabledEncryption;
impl DisabledEncryption {
pub fn encrypt(&self, plaintext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
Ok(plaintext.to_vec())
}
pub fn decrypt(&self, ciphertext: &[u8], _lsn: u64) -> Result<Vec<u8>> {
Ok(ciphertext.to_vec())
}
pub fn encrypted_size(&self, plaintext_len: usize) -> usize {
plaintext_len
}
pub fn is_enabled(&self) -> bool {
false
}
}
pub trait Encryptor: Send + Sync + std::fmt::Debug {
fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>>;
fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>>;
fn encrypted_size(&self, plaintext_len: usize) -> usize;
fn is_enabled(&self) -> bool;
}
impl Encryptor for EncryptionManager {
fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
self.encrypt(plaintext, lsn)
}
fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>> {
self.decrypt(ciphertext, lsn)
}
fn encrypted_size(&self, plaintext_len: usize) -> usize {
self.encrypted_size(plaintext_len)
}
fn is_enabled(&self) -> bool {
self.is_enabled()
}
}
impl Encryptor for DisabledEncryption {
fn encrypt(&self, plaintext: &[u8], lsn: u64) -> Result<Vec<u8>> {
self.encrypt(plaintext, lsn)
}
fn decrypt(&self, ciphertext: &[u8], lsn: u64) -> Result<Vec<u8>> {
self.decrypt(ciphertext, lsn)
}
fn encrypted_size(&self, plaintext_len: usize) -> usize {
self.encrypted_size(plaintext_len)
}
fn is_enabled(&self) -> bool {
false
}
}
pub fn generate_key_file(path: &std::path::Path) -> Result<()> {
let key = MasterKey::generate(1)?;
let hex_key = hex::encode(key.key.expose_secret());
fs::write(path, hex_key)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(path)?.permissions();
perms.set_mode(0o600);
fs::set_permissions(path, perms)?;
}
#[cfg(windows)]
{
let mut perms = fs::metadata(path)?.permissions();
perms.set_readonly(true);
fs::set_permissions(path, perms)?;
tracing::warn!(
path = %path.display(),
"Key file created on Windows \u{2014} manually verify file ACLs restrict access to current user only"
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> EncryptionConfig {
let key = vec![0u8; 32]; EncryptionConfig {
enabled: true,
algorithm: Algorithm::Aes256Gcm,
key_provider: KeyProvider::InMemory(key),
key_rotation_days: 0,
aad_scope: "test".to_string(),
}
}
#[test]
fn test_encrypt_decrypt() {
let manager = EncryptionManager::new(test_config()).unwrap();
let plaintext = b"Hello, World! This is sensitive data.";
let lsn = 12345u64;
let ciphertext = manager.encrypt(plaintext, lsn).unwrap();
assert_ne!(ciphertext.as_slice(), plaintext);
assert!(ciphertext.len() > plaintext.len());
let decrypted = manager.decrypt(&ciphertext, lsn).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypted_size() {
let manager = EncryptionManager::new(test_config()).unwrap();
let plaintext_len = 1000;
let expected = EncryptedHeader::SIZE + plaintext_len + TAG_SIZE;
assert_eq!(manager.encrypted_size(plaintext_len), expected);
}
#[test]
fn test_header_roundtrip() {
let nonce = [1u8; NONCE_SIZE];
let header = EncryptedHeader::new(Algorithm::Aes256Gcm, 42, nonce);
let bytes = header.to_bytes();
let parsed = EncryptedHeader::from_bytes(&bytes).unwrap();
assert_eq!(parsed.version, header.version);
assert_eq!(parsed.algorithm, header.algorithm);
assert_eq!(parsed.key_version, header.key_version);
assert_eq!(parsed.nonce, header.nonce);
}
#[test]
fn test_invalid_ciphertext() {
let manager = EncryptionManager::new(test_config()).unwrap();
let result = manager.decrypt(&[0u8; 10], 1);
assert!(result.is_err());
let mut bad_magic = vec![0u8; 100];
let result = manager.decrypt(&bad_magic, 1);
assert!(result.is_err());
bad_magic[0..4].copy_from_slice(&ENCRYPTION_MAGIC);
bad_magic[4] = FORMAT_VERSION;
let result = manager.decrypt(&bad_magic, 1);
assert!(result.is_err());
}
#[test]
fn test_tamper_detection() {
let manager = EncryptionManager::new(test_config()).unwrap();
let plaintext = b"Sensitive data that must not be tampered with";
let mut ciphertext = manager.encrypt(plaintext, 1).unwrap();
let tamper_pos = EncryptedHeader::SIZE + 10;
ciphertext[tamper_pos] ^= 0x01;
let result = manager.decrypt(&ciphertext, 1);
assert!(result.is_err());
}
#[test]
fn test_different_lsns_produce_different_ciphertexts() {
let manager = EncryptionManager::new(test_config()).unwrap();
let plaintext = b"Same plaintext";
let ct1 = manager.encrypt(plaintext, 1).unwrap();
let ct2 = manager.encrypt(plaintext, 2).unwrap();
assert_ne!(ct1, ct2);
assert_eq!(manager.decrypt(&ct1, 1).unwrap(), plaintext);
assert_eq!(manager.decrypt(&ct2, 2).unwrap(), plaintext);
}
#[test]
fn test_disabled_encryption_passthrough() {
let disabled = DisabledEncryption;
let plaintext = b"Not encrypted";
let encrypted = disabled.encrypt(plaintext, 1).unwrap();
assert_eq!(&encrypted[..], plaintext);
let decrypted = disabled.decrypt(plaintext, 1).unwrap();
assert_eq!(&decrypted[..], plaintext);
assert_eq!(disabled.encrypted_size(100), 100);
assert!(!disabled.is_enabled());
}
#[test]
fn test_master_key_validation() {
let result = MasterKey::new(vec![0u8; 16], 1);
assert!(result.is_err());
let result = MasterKey::new(vec![0u8; 32], 1);
assert!(result.is_ok());
}
#[test]
fn test_key_derivation_consistency() {
let key = MasterKey::new(vec![42u8; 32], 1).unwrap();
let dk1 = key.derive_data_key(b"scope1").unwrap();
let dk2 = key.derive_data_key(b"scope1").unwrap();
let dk3 = key.derive_data_key(b"scope2").unwrap();
assert_eq!(dk1, dk2);
assert_ne!(dk1, dk3);
}
#[test]
fn test_large_data_encryption() {
let manager = EncryptionManager::new(test_config()).unwrap();
let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let ciphertext = manager.encrypt(&plaintext, 999999).unwrap();
let decrypted = manager.decrypt(&ciphertext, 999999).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_generate_key() {
let key = MasterKey::generate(1).unwrap();
assert_eq!(key.version(), 1);
}
#[test]
fn test_chacha20_poly1305_encrypt_decrypt() {
let config = EncryptionConfig {
enabled: true,
algorithm: Algorithm::ChaCha20Poly1305,
key_provider: KeyProvider::InMemory(vec![0u8; 32]),
key_rotation_days: 0,
aad_scope: "test".to_string(),
};
let manager = EncryptionManager::new(config).unwrap();
let plaintext = b"ChaCha20-Poly1305 test payload";
let lsn = 42u64;
let ciphertext = manager.encrypt(plaintext, lsn).unwrap();
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
let header = EncryptedHeader::from_bytes(&ciphertext).unwrap();
assert_eq!(header.algorithm, Algorithm::ChaCha20Poly1305);
let decrypted = manager.decrypt(&ciphertext, lsn).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_key_rotation() {
let config = EncryptionConfig {
enabled: true,
algorithm: Algorithm::Aes256Gcm,
key_provider: KeyProvider::InMemory(vec![1u8; 32]),
key_rotation_days: 30,
aad_scope: "test".to_string(),
};
let manager = EncryptionManager::new(config).unwrap();
let plaintext = b"data encrypted with key v1";
let ct_v1 = manager.encrypt(plaintext, 100).unwrap();
assert_eq!(manager.key_version(), 1);
let new_master = MasterKey::new(vec![2u8; 32], 2).unwrap();
manager.rotate_key(new_master).unwrap();
assert_eq!(manager.key_version(), 2);
let decrypted = manager.decrypt(&ct_v1, 100).unwrap();
assert_eq!(decrypted, plaintext);
let bad_master = MasterKey::new(vec![3u8; 32], 1).unwrap();
assert!(manager.rotate_key(bad_master).is_err());
}
}