use alloc::vec::Vec;
use chacha20poly1305_nostd::ChaCha20Poly1305;
use sha3::Shake256;
use sha3::digest::{ExtendableOutput, Update, XofReader};
use zeroize::Zeroize;
use crate::FsError;
use crate::arch;
pub const KEY_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 12;
pub const TAG_SIZE: usize = 16;
pub fn encrypt_block(data: &[u8], key: &[u8]) -> Result<Vec<u8>, FsError> {
if key.len() != KEY_SIZE {
return Err(FsError::InvalidArgument {
reason: "encryption key must be 32 bytes",
});
}
let cipher = ChaCha20Poly1305::new(key).map_err(|_| FsError::InvalidArgument {
reason: "invalid key for ChaCha20Poly1305",
})?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
fill_hardware_entropy(&mut nonce_bytes)?;
let ciphertext = cipher
.encrypt(&nonce_bytes, data, None)
.map_err(|_| FsError::EncryptionFailed)?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt_block(data: &[u8], key: &[u8]) -> Result<Vec<u8>, FsError> {
if key.len() != KEY_SIZE {
return Err(FsError::InvalidArgument {
reason: "decryption key must be 32 bytes",
});
}
if data.len() < NONCE_SIZE + TAG_SIZE {
return Err(FsError::InvalidArgument {
reason: "ciphertext too short (minimum 28 bytes)",
});
}
let nonce = &data[..NONCE_SIZE];
let ciphertext = &data[NONCE_SIZE..];
let cipher = ChaCha20Poly1305::new(key).map_err(|_| FsError::InvalidArgument {
reason: "invalid key for ChaCha20Poly1305",
})?;
let plaintext = cipher
.decrypt(nonce, ciphertext, None)
.map_err(|_| FsError::DecryptionFailed)?;
Ok(plaintext)
}
pub fn shake256(data: &[u8], output_len: usize) -> Result<Vec<u8>, FsError> {
if output_len == 0 {
return Err(FsError::InvalidArgument {
reason: "output length must be greater than 0",
});
}
let mut hasher = Shake256::default();
hasher.update(data);
let mut output = alloc::vec![0u8; output_len];
let mut reader = hasher.finalize_xof();
reader.read(&mut output);
Ok(output)
}
fn fill_hardware_entropy(buf: &mut [u8]) -> Result<(), FsError> {
if arch::has_rdrand() && arch::fill_hardware_entropy(buf).is_ok() {
return Ok(());
}
crate::crypto::random::fill_random(buf).map_err(|_| FsError::EncryptionFailed)
}
#[derive(Zeroize)]
#[zeroize(drop)]
pub struct ZeroizingKey {
bytes: [u8; KEY_SIZE],
}
impl ZeroizingKey {
pub fn new(bytes: [u8; KEY_SIZE]) -> Self {
Self { bytes }
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_actually_encrypts() {
let plaintext = b"secret data that must be encrypted";
let key = [0u8; 32];
let ciphertext = encrypt_block(plaintext, &key).unwrap();
assert_ne!(
&ciphertext[NONCE_SIZE..],
plaintext.as_slice(),
"Encryption must change data - not an identity function!"
);
assert_eq!(
ciphertext.len(),
NONCE_SIZE + plaintext.len() + TAG_SIZE,
"Ciphertext should be plaintext + 28 bytes overhead"
);
}
#[test]
fn test_encrypt_produces_different_ciphertext_each_time() {
let plaintext = b"same plaintext";
let key = [0x42u8; 32];
let ciphertext1 = encrypt_block(plaintext, &key).unwrap();
let ciphertext2 = encrypt_block(plaintext, &key).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Same plaintext should produce different ciphertext due to random nonce"
);
}
#[test]
fn test_encrypt_invalid_key_length() {
let plaintext = b"test";
assert!(matches!(
encrypt_block(plaintext, &[0u8; 16]),
Err(FsError::InvalidArgument { .. })
));
assert!(matches!(
encrypt_block(plaintext, &[0u8; 64]),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_decrypt_reverses_encrypt() {
let plaintext = b"secret data to round-trip";
let key = [0x42u8; 32];
let ciphertext = encrypt_block(plaintext, &key).unwrap();
let decrypted = decrypt_block(&ciphertext, &key).unwrap();
assert_eq!(
decrypted.as_slice(),
plaintext.as_slice(),
"Decryption must recover original plaintext"
);
}
#[test]
fn test_decrypt_with_wrong_key_fails() {
let plaintext = b"secret";
let key1 = [0x42u8; 32];
let key2 = [0x43u8; 32];
let ciphertext = encrypt_block(plaintext, &key1).unwrap();
assert!(
matches!(
decrypt_block(&ciphertext, &key2),
Err(FsError::DecryptionFailed)
),
"Decryption with wrong key must fail"
);
}
#[test]
fn test_decrypt_tampered_ciphertext_fails() {
let plaintext = b"secret";
let key = [0u8; 32];
let mut ciphertext = encrypt_block(plaintext, &key).unwrap();
ciphertext[NONCE_SIZE + 1] ^= 0xFF;
assert!(
matches!(
decrypt_block(&ciphertext, &key),
Err(FsError::DecryptionFailed)
),
"Tampered ciphertext must fail authentication"
);
}
#[test]
fn test_decrypt_invalid_key_length() {
let fake_ciphertext = [0u8; 28];
assert!(matches!(
decrypt_block(&fake_ciphertext, &[0u8; 16]),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_decrypt_ciphertext_too_short() {
let key = [0u8; 32];
assert!(matches!(
decrypt_block(&[0u8; 27], &key),
Err(FsError::InvalidArgument { .. })
));
assert!(matches!(
decrypt_block(&[], &key),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_encrypt_decrypt_empty_plaintext() {
let plaintext = b"";
let key = [0x42u8; 32];
let ciphertext = encrypt_block(plaintext, &key).unwrap();
assert_eq!(ciphertext.len(), NONCE_SIZE + TAG_SIZE);
let decrypted = decrypt_block(&ciphertext, &key).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_encrypt_decrypt_large_data() {
let plaintext = alloc::vec![0xAB_u8; 1024 * 1024]; let key = [0x42u8; 32];
let ciphertext = encrypt_block(&plaintext, &key).unwrap();
let decrypted = decrypt_block(&ciphertext, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_shake256_not_zeros() {
let data = b"test data for hashing";
let hash = shake256(data, 64).unwrap();
assert_ne!(
hash,
alloc::vec![0u8; 64],
"SHAKE256 must not return zeros - it must actually hash!"
);
}
#[test]
fn test_shake256_deterministic() {
let data = b"deterministic test";
let hash1 = shake256(data, 64).unwrap();
let hash2 = shake256(data, 64).unwrap();
assert_eq!(hash1, hash2, "SHAKE256 must be deterministic");
}
#[test]
fn test_shake256_different_inputs_different_outputs() {
let hash1 = shake256(b"input 1", 32).unwrap();
let hash2 = shake256(b"input 2", 32).unwrap();
assert_ne!(
hash1, hash2,
"Different inputs must produce different hashes"
);
}
#[test]
fn test_shake256_variable_output_length() {
let data = b"test";
let hash16 = shake256(data, 16).unwrap();
let hash32 = shake256(data, 32).unwrap();
let hash64 = shake256(data, 64).unwrap();
let hash128 = shake256(data, 128).unwrap();
assert_eq!(hash16.len(), 16);
assert_eq!(hash32.len(), 32);
assert_eq!(hash64.len(), 64);
assert_eq!(hash128.len(), 128);
assert_eq!(&hash32[..16], &hash16[..]);
assert_eq!(&hash64[..32], &hash32[..]);
assert_eq!(&hash128[..64], &hash64[..]);
}
#[test]
fn test_shake256_zero_length_fails() {
assert!(matches!(
shake256(b"test", 0),
Err(FsError::InvalidArgument { .. })
));
}
#[test]
fn test_shake256_empty_input() {
let hash = shake256(b"", 32).unwrap();
assert_eq!(hash.len(), 32);
assert_ne!(hash, alloc::vec![0u8; 32]);
}
#[test]
fn test_shake256_known_test_vector() {
let hash = shake256(b"", 32).unwrap();
let expected: [u8; 32] = [
0x46, 0xb9, 0xdd, 0x2b, 0x0b, 0xa8, 0x8d, 0x13, 0x23, 0x3b, 0x3f, 0xeb, 0x74, 0x3e,
0xeb, 0x24, 0x3f, 0xcd, 0x52, 0xea, 0x62, 0xb8, 0x1b, 0x82, 0xb5, 0x0c, 0x27, 0x64,
0x6e, 0xd5, 0x76, 0x2f,
];
assert_eq!(
hash.as_slice(),
&expected[..],
"SHAKE256 must match NIST test vector"
);
}
#[test]
fn test_zeroizing_key_creation() {
let key_bytes = [0x42u8; 32];
let key = ZeroizingKey::new(key_bytes);
assert_eq!(key.as_bytes(), &key_bytes);
}
#[test]
fn test_encrypt_then_hash_workflow() {
let plaintext = b"data to encrypt and hash";
let key = [0x42u8; 32];
let ciphertext = encrypt_block(plaintext, &key).unwrap();
let hash = shake256(&ciphertext, 32).unwrap();
assert_ne!(hash, alloc::vec![0u8; 32]);
assert_eq!(hash.len(), 32);
let decrypted = decrypt_block(&ciphertext, &key).unwrap();
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
}