use std::fmt;
use aes_gcm_siv::aead::generic_array::GenericArray;
use aes_gcm_siv::aead::{Aead, KeyInit};
use aes_gcm_siv::Aes256GcmSiv;
use argon2::Argon2;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use rand::RngCore;
use zeroize::Zeroizing;
pub const SALT_LEN: usize = 16;
pub const KEY_LEN: usize = 32;
pub const RS_DEFAULT_PARITY_LEN: usize = 32;
pub const RS_DEFAULT_DATA_LEN: usize = 223;
#[allow(dead_code)]
const RS_MAX_BLOCK_SIZE: usize = 255;
pub const MAX_PLAINTEXT_LEN: usize = 50 * 1024 * 1024;
const BLOB_VERSION: u8 = 1;
pub const ARGON2_M_COST_KIB: u32 = 65536;
pub const ARGON2_T_COST: u32 = 3;
pub const ARGON2_P_COST: u32 = 4;
#[derive(Debug)]
pub enum CryptoError {
KeyDerivation(String),
Cipher(String),
ErrorCorrection(String),
Encoding(String),
InvalidInput(String),
}
impl fmt::Display for CryptoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::KeyDerivation(msg) => write!(f, "Key derivation error: {}", msg),
Self::Cipher(msg) => write!(f, "Cipher error: {}", msg),
Self::ErrorCorrection(msg) => write!(f, "Error correction error: {}", msg),
Self::Encoding(msg) => write!(f, "Encoding error: {}", msg),
Self::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
}
}
}
impl std::error::Error for CryptoError {}
pub trait KeyDerivation: Send + Sync {
fn derive_key(
&self,
password: &[u8],
salt: &[u8],
output_len: usize,
) -> Result<Zeroizing<Vec<u8>>, CryptoError>;
}
pub trait AuthenticatedCipher: Send + Sync {
fn encrypt(&self, key: &[u8], nonce: &[u8], data: &[u8]) -> Result<Vec<u8>, CryptoError>;
fn decrypt(&self, key: &[u8], nonce: &[u8], data: &[u8]) -> Result<Vec<u8>, CryptoError>;
fn nonce_len(&self) -> usize;
}
pub trait ErrorCorrection: Send + Sync {
fn encode(&self, data: &[u8]) -> Vec<u8>;
fn decode(&self, encoded: &[u8], original_len: usize) -> Result<Vec<u8>, CryptoError>;
}
pub struct Argon2Kdf;
impl Argon2Kdf {
pub fn owasp_params() -> argon2::Params {
argon2::Params::new(ARGON2_M_COST_KIB, ARGON2_T_COST, ARGON2_P_COST, None)
.expect("OWASP Argon2 parameters are statically valid")
}
}
impl KeyDerivation for Argon2Kdf {
fn derive_key(
&self,
password: &[u8],
salt: &[u8],
output_len: usize,
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
let mut key = Zeroizing::new(vec![0u8; output_len]);
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
argon2::Version::V0x13,
Self::owasp_params(),
);
argon2
.hash_password_into(password, salt, &mut key)
.map_err(|e| CryptoError::KeyDerivation(format!("Argon2 failed: {}", e)))?;
Ok(key)
}
}
pub struct Aes256GcmSivCipher;
const AES_GCM_SIV_NONCE_LEN: usize = 12;
impl AuthenticatedCipher for Aes256GcmSivCipher {
fn encrypt(&self, key: &[u8], nonce: &[u8], data: &[u8]) -> Result<Vec<u8>, CryptoError> {
let cipher = Aes256GcmSiv::new_from_slice(key)
.map_err(|e| CryptoError::Cipher(format!("Cipher init failed: {}", e)))?;
let nonce = GenericArray::from_slice(nonce);
cipher
.encrypt(nonce, data)
.map_err(|e| CryptoError::Cipher(format!("Encryption failed: {}", e)))
}
fn decrypt(&self, key: &[u8], nonce: &[u8], data: &[u8]) -> Result<Vec<u8>, CryptoError> {
let cipher = Aes256GcmSiv::new_from_slice(key)
.map_err(|e| CryptoError::Cipher(format!("Cipher init failed: {}", e)))?;
let nonce = GenericArray::from_slice(nonce);
cipher
.decrypt(nonce, data)
.map_err(|e| CryptoError::Cipher(format!("Decryption failed: {}", e)))
}
fn nonce_len(&self) -> usize {
AES_GCM_SIV_NONCE_LEN
}
}
#[derive(Debug)]
pub struct ReedSolomonCodec {
parity_len: usize,
data_len: usize,
}
impl Default for ReedSolomonCodec {
fn default() -> Self {
Self {
parity_len: RS_DEFAULT_PARITY_LEN,
data_len: RS_DEFAULT_DATA_LEN,
}
}
}
impl ReedSolomonCodec {
#[allow(dead_code)]
pub fn new(parity_len: usize, data_len: usize) -> Result<Self, CryptoError> {
if parity_len == 0 || data_len == 0 {
return Err(CryptoError::InvalidInput(
"Parity and data length must be greater than zero".to_string(),
));
}
if parity_len + data_len > 255 {
return Err(CryptoError::InvalidInput(format!(
"parity_len ({}) + data_len ({}) exceeds GF(2^8) limit of 255",
parity_len, data_len
)));
}
Ok(Self {
parity_len,
data_len,
})
}
}
impl ErrorCorrection for ReedSolomonCodec {
fn encode(&self, data: &[u8]) -> Vec<u8> {
let enc = reed_solomon::Encoder::new(self.parity_len);
let mut result = Vec::new();
for chunk in data.chunks(self.data_len) {
let encoded = enc.encode(chunk);
result.extend_from_slice(&encoded);
}
result
}
fn decode(&self, encoded: &[u8], original_len: usize) -> Result<Vec<u8>, CryptoError> {
let dec = reed_solomon::Decoder::new(self.parity_len);
let block_size = self.data_len + self.parity_len;
let mut result = Vec::new();
for chunk in encoded.chunks(block_size) {
if chunk.len() <= self.parity_len {
return Err(CryptoError::ErrorCorrection(
"Encoded block too short for Reed-Solomon parity".to_string(),
));
}
let recovered = dec.correct(chunk, None).map_err(|_| {
CryptoError::ErrorCorrection("Reed-Solomon error correction failed".to_string())
})?;
result.extend_from_slice(recovered.data());
}
result.truncate(original_len);
Ok(result)
}
}
pub struct CryptoVault {
kdf: Box<dyn KeyDerivation>,
cipher: Box<dyn AuthenticatedCipher>,
fec: Box<dyn ErrorCorrection>,
}
impl Default for CryptoVault {
fn default() -> Self {
Self {
kdf: Box::new(Argon2Kdf),
cipher: Box::new(Aes256GcmSivCipher),
fec: Box::new(ReedSolomonCodec::default()),
}
}
}
impl CryptoVault {
#[allow(dead_code)]
pub fn new(
kdf: Box<dyn KeyDerivation>,
cipher: Box<dyn AuthenticatedCipher>,
fec: Box<dyn ErrorCorrection>,
) -> Self {
Self { kdf, cipher, fec }
}
pub fn derive_key(
&self,
password: &str,
salt: &[u8],
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
if password.is_empty() {
return Err(CryptoError::InvalidInput(
"Password must not be empty".to_string(),
));
}
self.kdf.derive_key(password.as_bytes(), salt, KEY_LEN)
}
pub fn encrypt_with_key(&self, key: &[u8], plaintext: &str) -> Result<String, CryptoError> {
let nonce_len = self.cipher.nonce_len();
let projected_original_len = nonce_len + plaintext.len() + 16; if projected_original_len > MAX_PLAINTEXT_LEN {
return Err(CryptoError::InvalidInput(format!(
"Record length {} (nonce+ciphertext) exceeds MAX_PLAINTEXT_LEN ({})",
projected_original_len, MAX_PLAINTEXT_LEN
)));
}
let mut nonce = vec![0u8; nonce_len];
rand::rngs::OsRng.fill_bytes(&mut nonce);
let ciphertext = self.cipher.encrypt(key, &nonce, plaintext.as_bytes())?;
let mut plaindata = Vec::with_capacity(nonce_len + ciphertext.len());
plaindata.extend_from_slice(&nonce);
plaindata.extend_from_slice(&ciphertext);
let rs_encoded = self.fec.encode(&plaindata);
let original_len_u32 = u32::try_from(plaindata.len())
.map_err(|_| CryptoError::Encoding("Data too large for length header".to_string()))?;
let mut blob = Vec::with_capacity(1 + 4 + rs_encoded.len());
blob.push(BLOB_VERSION);
blob.extend_from_slice(&original_len_u32.to_le_bytes());
blob.extend_from_slice(&rs_encoded);
Ok(STANDARD.encode(&blob))
}
pub fn decrypt_with_key(
&self,
key: &[u8],
encrypted_base64: &str,
) -> Result<String, CryptoError> {
let nonce_len = self.cipher.nonce_len();
let blob = STANDARD
.decode(encrypted_base64)
.map_err(|e| CryptoError::Encoding(format!("Invalid base64: {}", e)))?;
if blob.len() < 5 {
return Err(CryptoError::Encoding(
"Encrypted blob too short".to_string(),
));
}
if blob[0] != BLOB_VERSION {
return Err(CryptoError::InvalidInput(format!(
"Unsupported blob version {} (expected {})",
blob[0], BLOB_VERSION
)));
}
let len_bytes: [u8; 4] = blob[1..5].try_into().unwrap();
let original_len = u32::from_le_bytes(len_bytes) as usize;
if original_len > MAX_PLAINTEXT_LEN {
return Err(CryptoError::InvalidInput(format!(
"Length header {} exceeds MAX_PLAINTEXT_LEN ({}); refusing to allocate",
original_len, MAX_PLAINTEXT_LEN
)));
}
if original_len > (blob.len() - 5) {
return Err(CryptoError::InvalidInput(
"Length header exceeds encoded data size".to_string(),
));
}
if blob.len().saturating_sub(5) > original_len.saturating_mul(2).saturating_add(4096) {
return Err(CryptoError::InvalidInput(format!(
"Encoded blob length {} is inconsistent with declared plaintext length {}; refusing to allocate",
blob.len() - 5,
original_len
)));
}
let plaindata = self.fec.decode(&blob[5..], original_len)?;
if plaindata.len() < nonce_len {
return Err(CryptoError::InvalidInput(
"Decoded blob too short for nonce".to_string(),
));
}
let nonce = &plaindata[..nonce_len];
let ciphertext = &plaindata[nonce_len..];
let plaintext = self.cipher.decrypt(key, nonce, ciphertext)?;
String::from_utf8(plaintext)
.map_err(|e| CryptoError::Encoding(format!("Invalid UTF-8: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn k(vault: &CryptoVault) -> Zeroizing<Vec<u8>> {
vault.derive_key("pw", &[0u8; SALT_LEN]).unwrap()
}
#[test]
fn test_encrypt_with_key_roundtrips() {
let vault = CryptoVault::default();
let key = k(&vault);
let pt = "sk-ant-secret-payload";
let blob = vault.encrypt_with_key(&key, pt).unwrap();
assert_eq!(vault.decrypt_with_key(&key, &blob).unwrap(), pt);
}
#[test]
fn test_blob_carries_version_byte() {
let vault = CryptoVault::default();
let key = k(&vault);
let raw = STANDARD
.decode(vault.encrypt_with_key(&key, "payload").unwrap())
.unwrap();
assert_eq!(raw[0], 1, "blob must start with the format version byte");
}
#[test]
fn test_decrypt_rejects_unsupported_blob_version() {
let vault = CryptoVault::default();
let key = k(&vault);
let mut raw = STANDARD
.decode(vault.encrypt_with_key(&key, "x").unwrap())
.unwrap();
raw[0] = 2; let err = vault
.decrypt_with_key(&key, &STANDARD.encode(&raw))
.unwrap_err();
assert!(
err.to_string().to_lowercase().contains("version"),
"an unsupported blob version must be rejected with a version error: {err}"
);
}
#[test]
fn test_blob_layout_carries_no_salt() {
let vault = CryptoVault::default();
let key = k(&vault);
let pt = "0123456789"; let blob = vault.encrypt_with_key(&key, pt).unwrap();
let raw = STANDARD.decode(&blob).unwrap();
let original_len = u32::from_le_bytes(raw[1..5].try_into().unwrap()) as usize;
let codec = ReedSolomonCodec::default();
let plaindata = codec.decode(&raw[5..], original_len).unwrap();
assert_eq!(plaindata.len(), 12 + (pt.len() + 16));
}
#[test]
fn test_encrypt_with_key_uses_independent_nonce() {
let vault = CryptoVault::default();
let key = k(&vault);
let pt = "identical plaintext";
let a = vault.encrypt_with_key(&key, pt).unwrap();
let b = vault.encrypt_with_key(&key, pt).unwrap();
assert_ne!(a, b, "independent nonces must yield different blobs");
}
#[test]
fn test_decrypt_with_wrong_key_errors_without_panic() {
let vault = CryptoVault::default();
let key_a = vault.derive_key("pw-a", &[1u8; SALT_LEN]).unwrap();
let key_b = vault.derive_key("pw-b", &[1u8; SALT_LEN]).unwrap();
let blob = vault.encrypt_with_key(&key_a, "secret").unwrap();
assert!(matches!(
vault.decrypt_with_key(&key_b, &blob),
Err(CryptoError::Cipher(_))
));
}
#[test]
fn test_decrypt_with_key_preserves_c7_caps() {
let vault = CryptoVault::default();
let key = k(&vault);
let mut over = vec![1u8]; over.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes());
over.extend_from_slice(&[0u8; 8]);
assert!(matches!(
vault.decrypt_with_key(&key, &STANDARD.encode(&over)),
Err(CryptoError::InvalidInput(_))
));
let mut big = vec![1u8]; big.extend_from_slice(&100u32.to_le_bytes());
big.extend_from_slice(&vec![0u8; 20_000]);
assert!(matches!(
vault.decrypt_with_key(&key, &STANDARD.encode(&big)),
Err(CryptoError::InvalidInput(_))
));
}
#[test]
fn test_invalid_key_length_errors_without_panic() {
let vault = CryptoVault::default();
assert!(matches!(
vault.encrypt_with_key(&[0u8; 5], "x"),
Err(CryptoError::Cipher(_))
));
let valid = vault.encrypt_with_key(&k(&vault), "x").unwrap();
assert!(matches!(
vault.decrypt_with_key(&[0u8; 5], &valid),
Err(CryptoError::Cipher(_))
));
}
#[test]
fn test_decrypt_rejects_prefix_at_exactly_cap_plus_one() {
let oversized = (MAX_PLAINTEXT_LEN + 1) as u32;
let mut blob = vec![1u8]; blob.extend_from_slice(&oversized.to_le_bytes());
blob.extend_from_slice(&[0u8; 8]);
let encoded = STANDARD.encode(&blob);
let vault = CryptoVault::default();
let key = k(&vault);
assert!(
matches!(
vault.decrypt_with_key(&key, &encoded),
Err(CryptoError::InvalidInput(_))
),
"a length prefix one byte over the cap must be rejected"
);
}
#[test]
fn test_encrypt_rejects_plaintext_over_cap() {
let vault = CryptoVault::default();
let key = k(&vault);
let huge = "a".repeat(MAX_PLAINTEXT_LEN + 1);
assert!(
matches!(
vault.encrypt_with_key(&key, &huge),
Err(CryptoError::InvalidInput(_))
),
"encrypting beyond MAX_PLAINTEXT_LEN must be rejected"
);
}
#[test]
fn rs_corrects_corrupted_data() {
let rs = ReedSolomonCodec::default();
let data = b"FEC correction test payload for Reed-Solomon codec.";
let mut encoded = rs.encode(data);
for i in 0..10 {
encoded[i * 7] ^= 0xAA;
}
let decoded = rs.decode(&encoded, data.len()).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_argon2_uses_owasp_2025_parameters() {
let params = Argon2Kdf::owasp_params();
assert_eq!(
params.m_cost(),
65536,
"memory cost must be 64 MiB (OWASP 2025)"
);
assert_eq!(params.t_cost(), 3, "time cost (iterations) must be 3");
assert_eq!(params.p_cost(), 4, "parallelism must be 4");
}
}