use chacha20poly1305::{
ChaCha20Poly1305, Nonce,
aead::{Aead, KeyInit, OsRng},
};
use rand::RngCore;
use std::time::{Duration, SystemTime};
use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, Error)]
pub enum EncryptionError {
#[error("Failed to generate encryption key: {0}")]
KeyGenerationError(String),
#[error("Failed to encrypt data: {0}")]
EncryptionFailed(String),
#[error("Failed to decrypt data: {0}")]
DecryptionFailed(String),
#[error("Invalid encrypted data format: {0}")]
InvalidFormat(String),
#[error("Data expired: retention period exceeded")]
DataExpired,
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct SecureData {
#[zeroize(skip)]
created_at: SystemTime,
data: Vec<u8>,
}
impl SecureData {
pub fn new(data: Vec<u8>) -> Self {
Self {
created_at: SystemTime::now(),
data,
}
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn created_at(&self) -> SystemTime {
self.created_at
}
pub fn is_expired(&self, retention_period: Duration) -> bool {
if let Ok(elapsed) = self.created_at.elapsed() {
elapsed > retention_period
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct DataRetentionPolicy {
pub ttl: Duration,
pub auto_cleanup: bool,
}
impl Default for DataRetentionPolicy {
fn default() -> Self {
Self {
ttl: Duration::from_secs(30 * 24 * 60 * 60), auto_cleanup: false,
}
}
}
impl DataRetentionPolicy {
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
auto_cleanup: true,
}
}
pub fn should_retain(&self, created_at: SystemTime) -> bool {
if let Ok(elapsed) = created_at.elapsed() {
elapsed <= self.ttl
} else {
true }
}
}
pub struct EncryptionService {
cipher: ChaCha20Poly1305,
}
impl EncryptionService {
pub fn new() -> Result<Self, EncryptionError> {
let mut key_bytes = [0u8; 32];
OsRng.fill_bytes(&mut key_bytes);
let cipher = ChaCha20Poly1305::new(&key_bytes.into());
key_bytes.zeroize();
Ok(Self { cipher })
}
pub fn from_key(key: &[u8]) -> Result<Self, EncryptionError> {
if key.len() != 32 {
return Err(EncryptionError::KeyGenerationError(
"Key must be exactly 32 bytes".to_string(),
));
}
let cipher = ChaCha20Poly1305::new(key.into());
Ok(Self { cipher })
}
pub fn encrypt_image_data(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
self.encrypt_data(data)
}
pub fn decrypt_image_data(&self, encrypted: &[u8]) -> Result<SecureData, EncryptionError> {
self.decrypt_data(encrypted)
}
pub fn encrypt_document_data(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
self.encrypt_data(data)
}
pub fn decrypt_document_data(&self, encrypted: &[u8]) -> Result<SecureData, EncryptionError> {
self.decrypt_data(encrypted)
}
fn encrypt_data(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from(nonce_bytes);
let ciphertext = self
.cipher
.encrypt(&nonce, data)
.map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
let mut result = Vec::with_capacity(nonce_bytes.len() + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
fn decrypt_data(&self, encrypted: &[u8]) -> Result<SecureData, EncryptionError> {
if encrypted.len() < 12 {
return Err(EncryptionError::InvalidFormat(
"Data too short to contain nonce".to_string(),
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = self
.cipher
.decrypt(nonce, ciphertext)
.map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
Ok(SecureData::new(plaintext))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_service_creation() {
let service = EncryptionService::new();
assert!(service.is_ok());
}
#[test]
fn test_encryption_service_from_key() {
let key = [0u8; 32];
let service = EncryptionService::from_key(&key);
assert!(service.is_ok());
}
#[test]
fn test_encryption_service_from_invalid_key() {
let key = [0u8; 16]; let service = EncryptionService::from_key(&key);
assert!(matches!(
service,
Err(EncryptionError::KeyGenerationError(_))
));
}
#[test]
fn test_encrypt_decrypt_image_data() {
let service = EncryptionService::new().unwrap();
let original = b"fake image data for testing";
let encrypted = service.encrypt_image_data(original).unwrap();
assert_ne!(encrypted.as_slice(), original);
assert!(encrypted.len() > original.len());
let decrypted = service.decrypt_image_data(&encrypted).unwrap();
assert_eq!(decrypted.as_slice(), original);
}
#[test]
fn test_encrypt_decrypt_document_data() {
let service = EncryptionService::new().unwrap();
let original = b"fake PDF document content";
let encrypted = service.encrypt_document_data(original).unwrap();
assert_ne!(encrypted.as_slice(), original);
let decrypted = service.decrypt_document_data(&encrypted).unwrap();
assert_eq!(decrypted.as_slice(), original);
}
#[test]
fn test_decrypt_invalid_data() {
let service = EncryptionService::new().unwrap();
let invalid_data = b"not encrypted data";
let result = service.decrypt_image_data(invalid_data);
assert!(matches!(result, Err(EncryptionError::DecryptionFailed(_))));
}
#[test]
fn test_decrypt_data_too_short() {
let service = EncryptionService::new().unwrap();
let too_short = b"short";
let result = service.decrypt_image_data(too_short);
assert!(matches!(result, Err(EncryptionError::InvalidFormat(_))));
}
#[test]
fn test_secure_data_zeroization() {
let data = vec![1u8, 2, 3, 4, 5];
let secure = SecureData::new(data.clone());
assert_eq!(secure.as_slice(), data.as_slice());
drop(secure);
}
#[test]
fn test_secure_data_expiration() {
let data = vec![1u8, 2, 3];
let secure = SecureData::new(data);
let long_period = Duration::from_secs(3600);
assert!(!secure.is_expired(long_period));
let zero_period = Duration::from_secs(0);
std::thread::sleep(Duration::from_millis(10));
assert!(secure.is_expired(zero_period));
}
#[test]
fn test_data_retention_policy_default() {
let policy = DataRetentionPolicy::default();
assert_eq!(policy.ttl, Duration::from_secs(30 * 24 * 60 * 60));
assert!(!policy.auto_cleanup);
}
#[test]
fn test_data_retention_policy_custom() {
let ttl = Duration::from_secs(3600);
let policy = DataRetentionPolicy::new(ttl);
assert_eq!(policy.ttl, ttl);
assert!(policy.auto_cleanup);
}
#[test]
fn test_retention_policy_should_retain() {
let policy = DataRetentionPolicy::new(Duration::from_secs(60));
let now = SystemTime::now();
assert!(policy.should_retain(now));
let old = now - Duration::from_secs(120);
assert!(!policy.should_retain(old));
}
#[test]
fn test_encryption_produces_different_ciphertexts() {
let service = EncryptionService::new().unwrap();
let data = b"same data";
let encrypted1 = service.encrypt_image_data(data).unwrap();
let encrypted2 = service.encrypt_image_data(data).unwrap();
assert_ne!(encrypted1, encrypted2);
let decrypted1 = service.decrypt_image_data(&encrypted1).unwrap();
let decrypted2 = service.decrypt_image_data(&encrypted2).unwrap();
assert_eq!(decrypted1.as_slice(), data);
assert_eq!(decrypted2.as_slice(), data);
}
#[test]
fn test_large_data_encryption() {
let service = EncryptionService::new().unwrap();
let large_data = vec![42u8; 1024 * 100];
let encrypted = service.encrypt_image_data(&large_data).unwrap();
let decrypted = service.decrypt_image_data(&encrypted).unwrap();
assert_eq!(decrypted.as_slice(), large_data.as_slice());
}
}