use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use chacha20poly1305::{ChaCha20Poly1305, Key as ChaChaKey, Nonce as ChaChaNonce};
#[derive(Error, Debug)]
pub enum EncryptionError {
#[error("Encryption failed")]
EncryptionFailed,
#[error("Decryption failed")]
DecryptionFailed,
#[error("Invalid key length: expected {expected}, got {actual}")]
InvalidKeyLength { expected: usize, actual: usize },
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("Invalid algorithm: {0}")]
InvalidAlgorithm(String),
#[error("Nonce generation failed")]
NonceGenerationFailed,
#[error("Key provider error: {0}")]
KeyProviderError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum EncryptionAlgorithm {
#[default]
Aes256Gcm,
ChaCha20Poly1305,
}
impl EncryptionAlgorithm {
pub fn key_len(&self) -> usize {
match self {
EncryptionAlgorithm::Aes256Gcm => 32, EncryptionAlgorithm::ChaCha20Poly1305 => 32, }
}
pub fn nonce_len(&self) -> usize {
match self {
EncryptionAlgorithm::Aes256Gcm => 12, EncryptionAlgorithm::ChaCha20Poly1305 => 12, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkInfo {
pub nonce: Vec<u8>,
pub plaintext_len: u64,
}
const GCM_TAG_LEN: usize = 16;
pub const SSE_CHUNK_SIZE: usize = 5 * 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub algorithm: EncryptionAlgorithm,
pub encrypted_dek: Vec<u8>,
pub kek_id: String,
pub dek_nonce: Vec<u8>,
pub ciphertext: Vec<u8>,
pub payload_nonce: Vec<u8>,
pub aad: Option<Vec<u8>>,
#[serde(default)]
pub chunks: Vec<ChunkInfo>,
#[serde(default)]
pub chunk_size: u64,
}
#[async_trait]
pub trait KeyProvider: Send + Sync {
async fn get_kek(&self, key_id: &str) -> Result<Vec<u8>, EncryptionError>;
async fn list_kek_ids(&self) -> Result<Vec<String>, EncryptionError>;
async fn default_kek_id(&self) -> Result<String, EncryptionError>;
async fn create_kek(&self, key_id: String) -> Result<(), EncryptionError>;
}
pub struct LocalKeyProvider {
keys: Arc<tokio::sync::RwLock<HashMap<String, Vec<u8>>>>,
default_key_id: String,
}
impl LocalKeyProvider {
pub fn new() -> Result<Self, EncryptionError> {
let mut keys = HashMap::new();
let mut master_key = vec![0u8; 32];
getrandom::fill(&mut master_key).map_err(|e| {
EncryptionError::KeyProviderError(format!("Failed to generate random key: {}", e))
})?;
keys.insert("master-key-v1".to_string(), master_key);
Ok(Self {
keys: Arc::new(tokio::sync::RwLock::new(keys)),
default_key_id: "master-key-v1".to_string(),
})
}
pub fn with_master_key(key_id: String, master_key: Vec<u8>) -> Self {
let mut keys = HashMap::new();
keys.insert(key_id.clone(), master_key);
Self {
keys: Arc::new(tokio::sync::RwLock::new(keys)),
default_key_id: key_id,
}
}
pub fn new_with_persistence(kek_path: std::path::PathBuf) -> Result<Self, EncryptionError> {
#[derive(serde::Serialize, serde::Deserialize)]
struct KekFile {
kek_id: String,
key_base64: String,
}
const KEY_ID: &str = "master-key-v1";
let master_key: Vec<u8> = if kek_path.exists() {
let raw = std::fs::read(&kek_path)?;
let record: KekFile = serde_json::from_slice(&raw)
.map_err(|e| EncryptionError::SerializationError(e.to_string()))?;
use base64::Engine as _;
base64::engine::general_purpose::STANDARD
.decode(&record.key_base64)
.map_err(|e| {
EncryptionError::SerializationError(format!("base64 decode error: {}", e))
})?
} else {
let mut key = vec![0u8; 32];
getrandom::fill(&mut key).map_err(|e| {
EncryptionError::KeyProviderError(format!("Failed to generate random key: {}", e))
})?;
if let Some(parent) = kek_path.parent() {
std::fs::create_dir_all(parent)?;
}
use base64::Engine as _;
let record = KekFile {
kek_id: KEY_ID.to_string(),
key_base64: base64::engine::general_purpose::STANDARD.encode(&key),
};
let json = serde_json::to_vec(&record)
.map_err(|e| EncryptionError::SerializationError(e.to_string()))?;
let tmp_path = kek_path.with_extension("tmp");
std::fs::write(&tmp_path, &json)?;
std::fs::rename(&tmp_path, &kek_path)?;
key
};
let mut keys = HashMap::new();
keys.insert(KEY_ID.to_string(), master_key);
Ok(Self {
keys: Arc::new(tokio::sync::RwLock::new(keys)),
default_key_id: KEY_ID.to_string(),
})
}
}
impl Default for LocalKeyProvider {
fn default() -> Self {
Self::new().unwrap_or_else(|_| {
Self::with_master_key("master-key-v1".to_string(), vec![0u8; 32])
})
}
}
#[async_trait]
impl KeyProvider for LocalKeyProvider {
async fn get_kek(&self, key_id: &str) -> Result<Vec<u8>, EncryptionError> {
let keys = self.keys.read().await;
keys.get(key_id)
.cloned()
.ok_or_else(|| EncryptionError::KeyNotFound(key_id.to_string()))
}
async fn list_kek_ids(&self) -> Result<Vec<String>, EncryptionError> {
let keys = self.keys.read().await;
Ok(keys.keys().cloned().collect())
}
async fn default_kek_id(&self) -> Result<String, EncryptionError> {
Ok(self.default_key_id.clone())
}
async fn create_kek(&self, key_id: String) -> Result<(), EncryptionError> {
let mut keys = self.keys.write().await;
if keys.contains_key(&key_id) {
return Err(EncryptionError::KeyProviderError(format!(
"Key {} already exists",
key_id
)));
}
let mut new_key = vec![0u8; 32];
getrandom::fill(&mut new_key).map_err(|e| {
EncryptionError::KeyProviderError(format!("Failed to generate random key: {}", e))
})?;
keys.insert(key_id, new_key);
Ok(())
}
}
pub struct EncryptionService {
key_provider: Arc<dyn KeyProvider>,
default_algorithm: EncryptionAlgorithm,
}
impl EncryptionService {
pub fn new(key_provider: Arc<dyn KeyProvider>) -> Self {
Self {
key_provider,
default_algorithm: EncryptionAlgorithm::Aes256Gcm,
}
}
pub fn with_algorithm(
key_provider: Arc<dyn KeyProvider>,
algorithm: EncryptionAlgorithm,
) -> Self {
Self {
key_provider,
default_algorithm: algorithm,
}
}
fn generate_dek(&self, algorithm: EncryptionAlgorithm) -> Result<Vec<u8>, EncryptionError> {
let mut dek = vec![0u8; algorithm.key_len()];
getrandom::fill(&mut dek).map_err(|_| {
EncryptionError::KeyProviderError("Failed to generate random DEK".to_string())
})?;
Ok(dek)
}
fn generate_nonce(&self, algorithm: EncryptionAlgorithm) -> Result<Vec<u8>, EncryptionError> {
let mut nonce = vec![0u8; algorithm.nonce_len()];
getrandom::fill(&mut nonce).map_err(|_| EncryptionError::NonceGenerationFailed)?;
Ok(nonce)
}
fn encrypt_aes256gcm(
&self,
plaintext: &[u8],
key: &[u8],
nonce: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, EncryptionError> {
let cipher =
Aes256Gcm::new_from_slice(key).map_err(|_| EncryptionError::InvalidKeyLength {
expected: 32,
actual: key.len(),
})?;
let nonce_array = Nonce::from_slice(nonce);
let ciphertext = if let Some(aad_data) = aad {
cipher
.encrypt(
nonce_array,
aes_gcm::aead::Payload {
msg: plaintext,
aad: aad_data,
},
)
.map_err(|_| EncryptionError::EncryptionFailed)?
} else {
cipher
.encrypt(nonce_array, plaintext)
.map_err(|_| EncryptionError::EncryptionFailed)?
};
Ok(ciphertext)
}
fn decrypt_aes256gcm(
&self,
ciphertext: &[u8],
key: &[u8],
nonce: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, EncryptionError> {
let cipher =
Aes256Gcm::new_from_slice(key).map_err(|_| EncryptionError::InvalidKeyLength {
expected: 32,
actual: key.len(),
})?;
let nonce_array = Nonce::from_slice(nonce);
let plaintext = if let Some(aad_data) = aad {
cipher
.decrypt(
nonce_array,
aes_gcm::aead::Payload {
msg: ciphertext,
aad: aad_data,
},
)
.map_err(|_| EncryptionError::DecryptionFailed)?
} else {
cipher
.decrypt(nonce_array, ciphertext)
.map_err(|_| EncryptionError::DecryptionFailed)?
};
Ok(plaintext)
}
fn encrypt_chacha20poly1305(
&self,
plaintext: &[u8],
key: &[u8],
nonce: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, EncryptionError> {
let key_array = ChaChaKey::from_slice(key);
let cipher = ChaCha20Poly1305::new(key_array);
let nonce_array = ChaChaNonce::from_slice(nonce);
let ciphertext = if let Some(aad_data) = aad {
cipher
.encrypt(
nonce_array,
chacha20poly1305::aead::Payload {
msg: plaintext,
aad: aad_data,
},
)
.map_err(|_| EncryptionError::EncryptionFailed)?
} else {
cipher
.encrypt(nonce_array, plaintext)
.map_err(|_| EncryptionError::EncryptionFailed)?
};
Ok(ciphertext)
}
fn decrypt_chacha20poly1305(
&self,
ciphertext: &[u8],
key: &[u8],
nonce: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, EncryptionError> {
let key_array = ChaChaKey::from_slice(key);
let cipher = ChaCha20Poly1305::new(key_array);
let nonce_array = ChaChaNonce::from_slice(nonce);
let plaintext = if let Some(aad_data) = aad {
cipher
.decrypt(
nonce_array,
chacha20poly1305::aead::Payload {
msg: ciphertext,
aad: aad_data,
},
)
.map_err(|_| EncryptionError::DecryptionFailed)?
} else {
cipher
.decrypt(nonce_array, ciphertext)
.map_err(|_| EncryptionError::DecryptionFailed)?
};
Ok(plaintext)
}
pub async fn encrypt(
&self,
plaintext: &[u8],
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
self.encrypt_with_algorithm(plaintext, self.default_algorithm, aad)
.await
}
pub async fn encrypt_with_algorithm(
&self,
plaintext: &[u8],
algorithm: EncryptionAlgorithm,
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
let dek = self.generate_dek(algorithm)?;
let payload_nonce = self.generate_nonce(algorithm)?;
let ciphertext = match algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.encrypt_aes256gcm(plaintext, &dek, &payload_nonce, aad)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
self.encrypt_chacha20poly1305(plaintext, &dek, &payload_nonce, aad)?
}
};
let kek_id = self.key_provider.default_kek_id().await?;
let kek = self.key_provider.get_kek(&kek_id).await?;
let dek_nonce = self.generate_nonce(algorithm)?;
let encrypted_dek = match algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.encrypt_aes256gcm(&dek, &kek, &dek_nonce, None)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
self.encrypt_chacha20poly1305(&dek, &kek, &dek_nonce, None)?
}
};
Ok(EncryptedData {
algorithm,
encrypted_dek,
kek_id,
dek_nonce,
ciphertext,
payload_nonce,
aad: aad.map(|a| a.to_vec()),
chunks: vec![],
chunk_size: 0,
})
}
pub async fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
if encrypted.chunks.is_empty() {
self.decrypt_single_shot(encrypted).await
} else {
self.decrypt_chunked_all(encrypted).await
}
}
async fn decrypt_single_shot(
&self,
encrypted: &EncryptedData,
) -> Result<Vec<u8>, EncryptionError> {
let kek = self.key_provider.get_kek(&encrypted.kek_id).await?;
let dek = match encrypted.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.decrypt_aes256gcm(&encrypted.encrypted_dek, &kek, &encrypted.dek_nonce, None)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => self.decrypt_chacha20poly1305(
&encrypted.encrypted_dek,
&kek,
&encrypted.dek_nonce,
None,
)?,
};
let plaintext = match encrypted.algorithm {
EncryptionAlgorithm::Aes256Gcm => self.decrypt_aes256gcm(
&encrypted.ciphertext,
&dek,
&encrypted.payload_nonce,
encrypted.aad.as_deref(),
)?,
EncryptionAlgorithm::ChaCha20Poly1305 => self.decrypt_chacha20poly1305(
&encrypted.ciphertext,
&dek,
&encrypted.payload_nonce,
encrypted.aad.as_deref(),
)?,
};
Ok(plaintext)
}
async fn decrypt_chunked_all(
&self,
encrypted: &EncryptedData,
) -> Result<Vec<u8>, EncryptionError> {
let kek = self.key_provider.get_kek(&encrypted.kek_id).await?;
let dek =
self.decrypt_aes256gcm(&encrypted.encrypted_dek, &kek, &encrypted.dek_nonce, None)?;
let aad_prefix = encrypted.aad.as_deref().unwrap_or(&[]);
let mut result: Vec<u8> = Vec::new();
let mut offset: usize = 0;
for (i, chunk_info) in encrypted.chunks.iter().enumerate() {
let ct_len = chunk_info.plaintext_len as usize + GCM_TAG_LEN;
let chunk_ct = encrypted
.ciphertext
.get(offset..offset + ct_len)
.ok_or(EncryptionError::DecryptionFailed)?;
let chunk_aad = build_chunk_aad(aad_prefix, i);
let chunk_pt =
self.decrypt_aes256gcm(chunk_ct, &dek, &chunk_info.nonce, Some(&chunk_aad))?;
result.extend_from_slice(&chunk_pt);
offset += ct_len;
}
Ok(result)
}
pub async fn rotate_key(
&self,
encrypted: &EncryptedData,
new_kek_id: &str,
) -> Result<EncryptedData, EncryptionError> {
let old_kek = self.key_provider.get_kek(&encrypted.kek_id).await?;
let dek = match encrypted.algorithm {
EncryptionAlgorithm::Aes256Gcm => self.decrypt_aes256gcm(
&encrypted.encrypted_dek,
&old_kek,
&encrypted.dek_nonce,
None,
)?,
EncryptionAlgorithm::ChaCha20Poly1305 => self.decrypt_chacha20poly1305(
&encrypted.encrypted_dek,
&old_kek,
&encrypted.dek_nonce,
None,
)?,
};
let new_kek = self.key_provider.get_kek(new_kek_id).await?;
let new_dek_nonce = self.generate_nonce(encrypted.algorithm)?;
let new_encrypted_dek = match encrypted.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.encrypt_aes256gcm(&dek, &new_kek, &new_dek_nonce, None)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
self.encrypt_chacha20poly1305(&dek, &new_kek, &new_dek_nonce, None)?
}
};
Ok(EncryptedData {
algorithm: encrypted.algorithm,
encrypted_dek: new_encrypted_dek,
kek_id: new_kek_id.to_string(),
dek_nonce: new_dek_nonce,
ciphertext: encrypted.ciphertext.clone(),
payload_nonce: encrypted.payload_nonce.clone(),
aad: encrypted.aad.clone(),
chunks: encrypted.chunks.clone(),
chunk_size: encrypted.chunk_size,
})
}
pub async fn encrypt_with_customer_key(
&self,
plaintext: &[u8],
customer_key: &[u8; 32],
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
let algorithm = EncryptionAlgorithm::Aes256Gcm;
let dek = self.generate_dek(algorithm)?;
let payload_nonce = self.generate_nonce(algorithm)?;
let ciphertext = self.encrypt_aes256gcm(plaintext, &dek, &payload_nonce, aad)?;
let dek_nonce = self.generate_nonce(algorithm)?;
let encrypted_dek = self.encrypt_aes256gcm(&dek, customer_key, &dek_nonce, None)?;
Ok(EncryptedData {
algorithm,
encrypted_dek,
kek_id: String::new(),
dek_nonce,
ciphertext,
payload_nonce,
aad: aad.map(|a| a.to_vec()),
chunks: vec![],
chunk_size: 0,
})
}
pub async fn decrypt_with_customer_key(
&self,
encrypted: &EncryptedData,
customer_key: &[u8; 32],
) -> Result<Vec<u8>, EncryptionError> {
let dek = self.decrypt_aes256gcm(
&encrypted.encrypted_dek,
customer_key,
&encrypted.dek_nonce,
None,
)?;
let plaintext = self.decrypt_aes256gcm(
&encrypted.ciphertext,
&dek,
&encrypted.payload_nonce,
encrypted.aad.as_deref(),
)?;
Ok(plaintext)
}
pub async fn encrypt_with_kek_id(
&self,
plaintext: &[u8],
kek_id: &str,
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
let algorithm = self.default_algorithm;
let dek = self.generate_dek(algorithm)?;
let payload_nonce = self.generate_nonce(algorithm)?;
let ciphertext = match algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.encrypt_aes256gcm(plaintext, &dek, &payload_nonce, aad)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
self.encrypt_chacha20poly1305(plaintext, &dek, &payload_nonce, aad)?
}
};
let kek = self.key_provider.get_kek(kek_id).await?;
let dek_nonce = self.generate_nonce(algorithm)?;
let encrypted_dek = match algorithm {
EncryptionAlgorithm::Aes256Gcm => {
self.encrypt_aes256gcm(&dek, &kek, &dek_nonce, None)?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
self.encrypt_chacha20poly1305(&dek, &kek, &dek_nonce, None)?
}
};
Ok(EncryptedData {
algorithm,
encrypted_dek,
kek_id: kek_id.to_string(),
dek_nonce,
ciphertext,
payload_nonce,
aad: aad.map(|a| a.to_vec()),
chunks: vec![],
chunk_size: 0,
})
}
pub async fn resolve_kms_key_id(
&self,
requested: Option<&str>,
) -> Result<String, EncryptionError> {
match requested {
None => self.key_provider.default_kek_id().await,
Some(id) => {
self.key_provider.get_kek(id).await?;
Ok(id.to_string())
}
}
}
pub async fn encrypt_chunked(
&self,
plaintext: &[u8],
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
let kek_id = self.key_provider.default_kek_id().await?;
self.encrypt_chunked_internal(plaintext, &kek_id, aad).await
}
pub async fn encrypt_chunked_with_kek_id(
&self,
plaintext: &[u8],
kek_id: &str,
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
self.key_provider.get_kek(kek_id).await?;
self.encrypt_chunked_internal(plaintext, kek_id, aad).await
}
async fn encrypt_chunked_internal(
&self,
plaintext: &[u8],
kek_id: &str,
aad: Option<&[u8]>,
) -> Result<EncryptedData, EncryptionError> {
let algorithm = EncryptionAlgorithm::Aes256Gcm;
let dek = self.generate_dek(algorithm)?;
let kek = self.key_provider.get_kek(kek_id).await?;
let dek_nonce = self.generate_nonce(algorithm)?;
let encrypted_dek = self.encrypt_aes256gcm(&dek, &kek, &dek_nonce, None)?;
let aad_prefix = aad.unwrap_or(&[]);
let mut ciphertext: Vec<u8> = Vec::new();
let mut chunks: Vec<ChunkInfo> = Vec::new();
for (i, chunk_plain) in plaintext.chunks(SSE_CHUNK_SIZE).enumerate() {
let nonce = self.generate_nonce(algorithm)?;
let chunk_aad = build_chunk_aad(aad_prefix, i);
let chunk_ct = self.encrypt_aes256gcm(chunk_plain, &dek, &nonce, Some(&chunk_aad))?;
chunks.push(ChunkInfo {
nonce,
plaintext_len: chunk_plain.len() as u64,
});
ciphertext.extend_from_slice(&chunk_ct);
}
Ok(EncryptedData {
algorithm,
encrypted_dek,
kek_id: kek_id.to_string(),
dek_nonce,
ciphertext,
payload_nonce: vec![],
aad: aad.map(|a| a.to_vec()),
chunks,
chunk_size: SSE_CHUNK_SIZE as u64,
})
}
pub async fn decrypt_chunked_range(
&self,
sidecar_kek_id: &str,
sidecar_encrypted_dek: &[u8],
sidecar_dek_nonce: &[u8],
sidecar_chunk_size: u64,
sidecar_chunks: &[crate::storage::SidecarChunk],
ciphertext_bytes: &[u8],
range_start: u64,
range_end: u64, aad_prefix: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
if sidecar_chunk_size == 0 || sidecar_chunks.is_empty() {
return Err(EncryptionError::DecryptionFailed);
}
if range_start >= range_end {
return Err(EncryptionError::DecryptionFailed);
}
let kek = self.key_provider.get_kek(sidecar_kek_id).await?;
let dek = self.decrypt_aes256gcm(sidecar_encrypted_dek, &kek, sidecar_dek_nonce, None)?;
let first_chunk = (range_start / sidecar_chunk_size) as usize;
let last_chunk = ((range_end - 1) / sidecar_chunk_size) as usize;
let mut file_offsets: Vec<usize> = Vec::with_capacity(sidecar_chunks.len() + 1);
let mut acc: usize = 0;
for chunk in sidecar_chunks.iter() {
file_offsets.push(acc);
acc += chunk.plaintext_len as usize + GCM_TAG_LEN;
}
file_offsets.push(acc);
let mut plaintext_parts: Vec<u8> = Vec::new();
for idx in first_chunk..=last_chunk {
let chunk = sidecar_chunks
.get(idx)
.ok_or(EncryptionError::DecryptionFailed)?;
let file_start = *file_offsets
.get(idx)
.ok_or(EncryptionError::DecryptionFailed)?;
let file_end = file_start + chunk.plaintext_len as usize + GCM_TAG_LEN;
let chunk_ct = ciphertext_bytes
.get(file_start..file_end)
.ok_or(EncryptionError::DecryptionFailed)?;
let chunk_aad = build_chunk_aad(aad_prefix, idx);
let chunk_pt =
self.decrypt_aes256gcm(chunk_ct, &dek, &chunk.nonce, Some(&chunk_aad))?;
plaintext_parts.extend_from_slice(&chunk_pt);
}
let first_chunk_start_byte = first_chunk as u64 * sidecar_chunk_size;
let local_start = (range_start - first_chunk_start_byte) as usize;
let local_end = (range_end - first_chunk_start_byte) as usize;
let result = plaintext_parts
.get(local_start..local_end)
.ok_or(EncryptionError::DecryptionFailed)?
.to_vec();
Ok(result)
}
pub async fn encrypt_bytes(&self, data: &Bytes) -> Result<Bytes, EncryptionError> {
let encrypted = self.encrypt(data, None).await?;
let serialized = serde_json::to_vec(&encrypted)
.map_err(|e| EncryptionError::SerializationError(e.to_string()))?;
Ok(Bytes::from(serialized))
}
pub async fn decrypt_bytes(&self, data: &Bytes) -> Result<Bytes, EncryptionError> {
let encrypted: EncryptedData = serde_json::from_slice(data)
.map_err(|e| EncryptionError::SerializationError(e.to_string()))?;
let plaintext = self.decrypt(&encrypted).await?;
Ok(Bytes::from(plaintext))
}
}
fn build_chunk_aad(prefix: &[u8], chunk_index: usize) -> Vec<u8> {
if prefix.is_empty() {
format!("chunk/{}", chunk_index).into_bytes()
} else {
let mut aad = prefix.to_vec();
aad.extend_from_slice(b"/chunk/");
aad.extend_from_slice(chunk_index.to_string().as_bytes());
aad
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_local_key_provider() {
let provider = LocalKeyProvider::new().expect("Failed to create provider");
let default_id = provider
.default_kek_id()
.await
.expect("Failed to get default key ID");
assert_eq!(default_id, "master-key-v1");
let key = provider
.get_kek(&default_id)
.await
.expect("Failed to get key");
assert_eq!(key.len(), 32);
let keys = provider.list_kek_ids().await.expect("Failed to list keys");
assert_eq!(keys.len(), 1);
assert!(keys.contains(&"master-key-v1".to_string()));
provider
.create_kek("master-key-v2".to_string())
.await
.expect("Failed to create key");
let keys = provider.list_kek_ids().await.expect("Failed to list keys");
assert_eq!(keys.len(), 2);
}
#[tokio::test]
async fn test_encryption_aes256gcm() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
let service = EncryptionService::new(provider);
let plaintext = b"Hello, World! This is a secret message.";
let encrypted = service
.encrypt(plaintext, None)
.await
.expect("Failed to encrypt");
assert_eq!(encrypted.algorithm, EncryptionAlgorithm::Aes256Gcm);
assert!(!encrypted.ciphertext.is_empty());
assert!(!encrypted.encrypted_dek.is_empty());
let decrypted = service
.decrypt(&encrypted)
.await
.expect("Failed to decrypt");
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_encryption_chacha20poly1305() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
let service =
EncryptionService::with_algorithm(provider, EncryptionAlgorithm::ChaCha20Poly1305);
let plaintext = b"Hello, ChaCha20-Poly1305!";
let encrypted = service
.encrypt(plaintext, None)
.await
.expect("Failed to encrypt");
assert_eq!(encrypted.algorithm, EncryptionAlgorithm::ChaCha20Poly1305);
let decrypted = service
.decrypt(&encrypted)
.await
.expect("Failed to decrypt");
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_encryption_with_aad() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
let service = EncryptionService::new(provider);
let plaintext = b"Secret data";
let aad = b"bucket=test-bucket,key=test-key";
let encrypted = service
.encrypt(plaintext, Some(aad))
.await
.expect("Failed to encrypt");
let decrypted = service
.decrypt(&encrypted)
.await
.expect("Failed to decrypt");
assert_eq!(decrypted, plaintext);
let mut tampered = encrypted.clone();
tampered.aad = Some(b"tampered-aad".to_vec());
let result = service.decrypt(&tampered).await;
assert!(result.is_err(), "Decryption should fail with tampered AAD");
}
#[tokio::test]
async fn test_key_rotation() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
provider
.create_kek("master-key-v2".to_string())
.await
.expect("Failed to create new key");
let service = EncryptionService::new(provider);
let plaintext = b"Data to be rotated";
let encrypted = service
.encrypt(plaintext, None)
.await
.expect("Failed to encrypt");
assert_eq!(encrypted.kek_id, "master-key-v1");
let rotated = service
.rotate_key(&encrypted, "master-key-v2")
.await
.expect("Failed to rotate key");
assert_eq!(rotated.kek_id, "master-key-v2");
assert_eq!(rotated.ciphertext, encrypted.ciphertext);
assert_ne!(rotated.encrypted_dek, encrypted.encrypted_dek);
let decrypted = service.decrypt(&rotated).await.expect("Failed to decrypt");
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_encrypt_decrypt_bytes() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
let service = EncryptionService::new(provider);
let plaintext = Bytes::from("Test data for bytes encryption");
let encrypted = service
.encrypt_bytes(&plaintext)
.await
.expect("Failed to encrypt bytes");
assert!(!encrypted.is_empty());
let decrypted = service
.decrypt_bytes(&encrypted)
.await
.expect("Failed to decrypt bytes");
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_different_plaintexts_different_ciphertexts() {
let provider = Arc::new(LocalKeyProvider::new().expect("Failed to create provider"));
let service = EncryptionService::new(provider);
let plaintext1 = b"Message 1";
let plaintext2 = b"Message 2";
let encrypted1 = service
.encrypt(plaintext1, None)
.await
.expect("Failed to encrypt");
let encrypted2 = service
.encrypt(plaintext2, None)
.await
.expect("Failed to encrypt");
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
let encrypted3 = service
.encrypt(plaintext1, None)
.await
.expect("Failed to encrypt");
assert_ne!(encrypted1.ciphertext, encrypted3.ciphertext);
}
}