use aes_gcm::aead::generic_array::typenum::U32;
use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::aes::Aes256;
use aes_gcm::{AesGcm, Nonce};
use crate::error::{Error, Result};
type Aes256Gcm32 = AesGcm<Aes256, U32>;
const KEY_SIZE: usize = 32;
const IV_SIZE: usize = 32;
const TAG_SIZE: usize = 16;
const MIN_CIPHERTEXT_SIZE: usize = IV_SIZE + TAG_SIZE;
#[derive(Clone)]
pub struct SymmetricKey {
key: [u8; KEY_SIZE],
}
impl SymmetricKey {
pub fn random() -> Self {
let mut key = [0u8; KEY_SIZE];
getrandom::getrandom(&mut key).expect("Failed to generate random bytes");
Self { key }
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.is_empty() {
return Err(Error::InvalidKeyLength {
expected: KEY_SIZE,
actual: 0,
});
}
if bytes.len() > KEY_SIZE {
return Err(Error::InvalidKeyLength {
expected: KEY_SIZE,
actual: bytes.len(),
});
}
let mut key = [0u8; KEY_SIZE];
if bytes.len() < KEY_SIZE {
let offset = KEY_SIZE - bytes.len();
key[offset..].copy_from_slice(bytes);
} else {
key.copy_from_slice(bytes);
}
Ok(Self { key })
}
pub fn as_bytes(&self) -> &[u8; KEY_SIZE] {
&self.key
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm32::new_from_slice(&self.key)
.map_err(|e| Error::CryptoError(format!("Failed to create cipher: {}", e)))?;
let mut iv = [0u8; IV_SIZE];
getrandom::getrandom(&mut iv)
.map_err(|e| Error::CryptoError(format!("Failed to generate IV: {}", e)))?;
#[allow(deprecated)]
let nonce = Nonce::<U32>::from_slice(&iv);
let ciphertext_with_tag = cipher
.encrypt(nonce, plaintext)
.map_err(|_| Error::CryptoError("Encryption failed".to_string()))?;
let mut result = Vec::with_capacity(IV_SIZE + ciphertext_with_tag.len());
result.extend_from_slice(&iv);
result.extend_from_slice(&ciphertext_with_tag);
Ok(result)
}
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < MIN_CIPHERTEXT_SIZE {
return Err(Error::InvalidDataLength {
expected: MIN_CIPHERTEXT_SIZE,
actual: data.len(),
});
}
let cipher = Aes256Gcm32::new_from_slice(&self.key)
.map_err(|e| Error::CryptoError(format!("Failed to create cipher: {}", e)))?;
let (iv, ciphertext_with_tag) = data.split_at(IV_SIZE);
#[allow(deprecated)]
let nonce = Nonce::<U32>::from_slice(iv);
cipher
.decrypt(nonce, ciphertext_with_tag)
.map_err(|_| Error::DecryptionFailed)
}
}
impl std::fmt::Debug for SymmetricKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymmetricKey")
.field("key", &"[REDACTED]")
.finish()
}
}
impl PartialEq for SymmetricKey {
fn eq(&self, other: &Self) -> bool {
use subtle::ConstantTimeEq;
self.key.ct_eq(&other.key).into()
}
}
impl Eq for SymmetricKey {}
impl Drop for SymmetricKey {
fn drop(&mut self) {
self.key.fill(0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_random_key() {
let key1 = SymmetricKey::random();
let key2 = SymmetricKey::random();
assert_ne!(key1.as_bytes(), key2.as_bytes());
assert_eq!(key1.as_bytes().len(), 32);
}
#[test]
fn test_from_bytes_32() {
let bytes = [0xABu8; 32];
let key = SymmetricKey::from_bytes(&bytes).unwrap();
assert_eq!(key.as_bytes(), &bytes);
}
#[test]
fn test_from_bytes_31_padding() {
let key_31 = vec![0xFFu8; 31];
let sym_key = SymmetricKey::from_bytes(&key_31).unwrap();
assert_eq!(sym_key.as_bytes()[0], 0x00);
assert_eq!(sym_key.as_bytes()[1], 0xFF);
assert_eq!(sym_key.as_bytes()[31], 0xFF);
}
#[test]
fn test_from_bytes_short_key() {
let key_16 = vec![0xABu8; 16];
let sym_key = SymmetricKey::from_bytes(&key_16).unwrap();
for i in 0..16 {
assert_eq!(sym_key.as_bytes()[i], 0x00);
}
for i in 16..32 {
assert_eq!(sym_key.as_bytes()[i], 0xAB);
}
}
#[test]
fn test_from_bytes_empty() {
let result = SymmetricKey::from_bytes(&[]);
assert!(result.is_err());
}
#[test]
fn test_from_bytes_too_long() {
let bytes = [0u8; 33];
let result = SymmetricKey::from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = SymmetricKey::random();
let plaintext = b"Hello, BSV!";
let ciphertext = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_encrypt_decrypt_empty_plaintext() {
let key = SymmetricKey::random();
let plaintext = b"";
let ciphertext = key.encrypt(plaintext).unwrap();
assert_eq!(ciphertext.len(), 48);
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_encrypt_decrypt_large_data() {
let key = SymmetricKey::random();
let plaintext = vec![0xABu8; 10000];
let ciphertext = key.encrypt(&plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_ciphertext_format() {
let key = SymmetricKey::random();
let plaintext = b"test";
let ciphertext = key.encrypt(plaintext).unwrap();
assert_eq!(ciphertext.len(), 32 + 4 + 16);
}
#[test]
fn test_decrypt_too_short() {
let key = SymmetricKey::random();
let short_data = vec![0u8; 47];
let result = key.decrypt(&short_data);
assert!(result.is_err());
}
#[test]
fn test_decrypt_wrong_key() {
let key1 = SymmetricKey::random();
let key2 = SymmetricKey::random();
let plaintext = b"Hello, BSV!";
let ciphertext = key1.encrypt(plaintext).unwrap();
let result = key2.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_decrypt_tampered_ciphertext() {
let key = SymmetricKey::random();
let plaintext = b"Hello, BSV!";
let mut ciphertext = key.encrypt(plaintext).unwrap();
if ciphertext.len() > 40 {
ciphertext[40] ^= 0xFF;
}
let result = key.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_decrypt_tampered_tag() {
let key = SymmetricKey::random();
let plaintext = b"Hello, BSV!";
let mut ciphertext = key.encrypt(plaintext).unwrap();
let last_idx = ciphertext.len() - 1;
ciphertext[last_idx] ^= 0xFF;
let result = key.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_different_encryptions_produce_different_ciphertexts() {
let key = SymmetricKey::random();
let plaintext = b"Hello, BSV!";
let ciphertext1 = key.encrypt(plaintext).unwrap();
let ciphertext2 = key.encrypt(plaintext).unwrap();
assert_ne!(ciphertext1, ciphertext2);
let decrypted1 = key.decrypt(&ciphertext1).unwrap();
let decrypted2 = key.decrypt(&ciphertext2).unwrap();
assert_eq!(decrypted1, decrypted2);
}
#[test]
fn test_key_equality_constant_time() {
let bytes1 = [0xABu8; 32];
let bytes2 = [0xABu8; 32];
let bytes3 = [0xCDu8; 32];
let key1 = SymmetricKey::from_bytes(&bytes1).unwrap();
let key2 = SymmetricKey::from_bytes(&bytes2).unwrap();
let key3 = SymmetricKey::from_bytes(&bytes3).unwrap();
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_debug_redacts_key() {
let key = SymmetricKey::random();
let debug_output = format!("{:?}", key);
assert!(debug_output.contains("REDACTED"));
}
mod cross_sdk_tests {
use super::SymmetricKey;
fn hex_decode(s: &str) -> Vec<u8> {
hex::decode(s).expect("Invalid hex string")
}
fn get_31_byte_key() -> SymmetricKey {
let key_bytes =
hex_decode("6f54f86a07f22ac6934a61e5a2bf0da03ce1cd6e6f978bfd064a37d0e1a111");
assert_eq!(key_bytes.len(), 31, "Expected 31-byte key");
SymmetricKey::from_bytes(&key_bytes).unwrap()
}
fn get_32_byte_key() -> SymmetricKey {
let key_bytes =
hex_decode("cb3b4168ccd86a783945e4cf69243d1b546078610cb9cee3e9beeed2428aa54e");
assert_eq!(key_bytes.len(), 32, "Expected 32-byte key");
SymmetricKey::from_bytes(&key_bytes).unwrap()
}
const EXPECTED_PLAINTEXT: &[u8] = b"cross-sdk test message";
#[test]
fn test_decrypt_typescript_ciphertext_31_byte_key() {
let key = get_31_byte_key();
assert_eq!(
key.as_bytes()[0],
0x00,
"31-byte key should be padded with leading zero"
);
assert_eq!(
key.as_bytes()[1],
0x6f,
"First byte of original key should be at index 1"
);
let ts_ciphertexts = [
"c374d70a4623036f1dd7b971dbeeea375630dc1da40e7068f4c4aa03487d3b19de3afb26a29173deccfbb1ece4fee6c92406b25948e6fe9cb53383057cb826d0a20269e290bd",
"1025d330504549601a611b75af4450722353f431ca2fc3f6aed41ca7b53e7859fa9cfea4654c871668449308c420282b372c1008dcd7a21fb5b1410c4f3a913c74c86a1aa070",
"efb87383667dda0bca519acb50a264cb958447f6d0f5cb20adace5fae8e812d4c39b569ad8a64ba70ca5a941d8096ded43a45cde8eec16b6a396112c248effce132797a73698",
];
for (i, hex_ciphertext) in ts_ciphertexts.iter().enumerate() {
let ciphertext = hex_decode(hex_ciphertext);
let decrypted = key.decrypt(&ciphertext).unwrap_or_else(|_| {
panic!("Failed to decrypt TS ciphertext {} with 31-byte key", i)
});
assert_eq!(
decrypted, EXPECTED_PLAINTEXT,
"TS ciphertext {} decryption mismatch",
i
);
}
}
#[test]
fn test_decrypt_go_ciphertext_31_byte_key() {
let key = get_31_byte_key();
let go_ciphertext = "7604d5bdb0eb843051d21873c871c9b1507c3de7ba222e1b407c163c2c166277df95de73be9534a2caf9d4b72157f78e5e2e69d97bc25b18ff4cfbd61a1306c02c0b8b2d165e";
let ciphertext = hex_decode(go_ciphertext);
let decrypted = key
.decrypt(&ciphertext)
.expect("Failed to decrypt Go ciphertext with 31-byte key");
assert_eq!(decrypted, EXPECTED_PLAINTEXT);
}
#[test]
fn test_decrypt_typescript_ciphertext_32_byte_key() {
let key = get_32_byte_key();
let ts_ciphertexts = [
"2059fc32910bef280d89c4c7edbbc587b31be22339e609fdcc23319bf458840a91ad1b2da87aea13a5dc5cb3469b41c52001070b8003863843978acbdf57755b24491581a059",
"b6b751277049399fdf5d35fda899c8433509268b0528c25ac8cf60c23dbeef23441c9efcdb996312c6aa32352637789bcf19d02b990903003a9a894efe874a65e84b6e57d30b",
];
for (i, hex_ciphertext) in ts_ciphertexts.iter().enumerate() {
let ciphertext = hex_decode(hex_ciphertext);
let decrypted = key.decrypt(&ciphertext).unwrap_or_else(|_| {
panic!("Failed to decrypt TS ciphertext {} with 32-byte key", i)
});
assert_eq!(
decrypted, EXPECTED_PLAINTEXT,
"TS ciphertext {} decryption mismatch",
i
);
}
}
#[test]
fn test_decrypt_go_ciphertext_32_byte_key() {
let key = get_32_byte_key();
let go_ciphertext = "d7744c85ad3dafcb9fc5752ab0d04c40f87084e8a466f6b6013ebe0fc5170daab8184aaef66ab2c2733f01c0dc3de322ba3ddeea976499548bc6ec166581181f919c69aa2de5";
let ciphertext = hex_decode(go_ciphertext);
let decrypted = key
.decrypt(&ciphertext)
.expect("Failed to decrypt Go ciphertext with 32-byte key");
assert_eq!(decrypted, EXPECTED_PLAINTEXT);
}
#[test]
fn test_rust_encrypt_can_be_decrypted() {
let key = get_32_byte_key();
let ciphertext = key.encrypt(EXPECTED_PLAINTEXT).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, EXPECTED_PLAINTEXT);
}
#[test]
fn test_31_byte_key_padding_matches_go() {
let key_31_bytes =
hex_decode("6f54f86a07f22ac6934a61e5a2bf0da03ce1cd6e6f978bfd064a37d0e1a111");
let sym_key = SymmetricKey::from_bytes(&key_31_bytes).unwrap();
let expected_padded =
hex_decode("006f54f86a07f22ac6934a61e5a2bf0da03ce1cd6e6f978bfd064a37d0e1a111");
assert_eq!(sym_key.as_bytes(), expected_padded.as_slice());
}
}
mod vector_tests {
use super::*;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
#[derive(Debug, serde::Deserialize)]
struct TestVector {
ciphertext: String,
key: String,
plaintext: String,
}
fn load_test_vectors() -> Vec<TestVector> {
let json_str = include_str!("../../tests/vectors/symmetric_key.json");
serde_json::from_str(json_str).expect("Failed to parse test vectors")
}
#[test]
fn test_decrypt_vectors() {
let vectors = load_test_vectors();
for (i, vector) in vectors.iter().enumerate() {
let key_bytes = BASE64
.decode(&vector.key)
.unwrap_or_else(|_| panic!("Failed to decode key for vector {}", i));
let ciphertext = BASE64
.decode(&vector.ciphertext)
.unwrap_or_else(|_| panic!("Failed to decode ciphertext for vector {}", i));
let expected_plaintext = vector.plaintext.as_bytes();
let sym_key = SymmetricKey::from_bytes(&key_bytes)
.unwrap_or_else(|_| panic!("Failed to create key for vector {}", i));
let decrypted = sym_key
.decrypt(&ciphertext)
.unwrap_or_else(|_| panic!("Failed to decrypt vector {}", i));
assert_eq!(
decrypted, expected_plaintext,
"Vector {} decryption mismatch",
i
);
}
}
#[test]
fn test_encrypt_decrypt_vectors() {
let vectors = load_test_vectors();
for (i, vector) in vectors.iter().enumerate() {
let key_bytes = BASE64
.decode(&vector.key)
.unwrap_or_else(|_| panic!("Failed to decode key for vector {}", i));
let plaintext = vector.plaintext.as_bytes();
let sym_key = SymmetricKey::from_bytes(&key_bytes)
.unwrap_or_else(|_| panic!("Failed to create key for vector {}", i));
let ciphertext = sym_key
.encrypt(plaintext)
.unwrap_or_else(|_| panic!("Failed to encrypt vector {}", i));
let decrypted = sym_key.decrypt(&ciphertext).unwrap_or_else(|_| {
panic!("Failed to decrypt our ciphertext for vector {}", i)
});
assert_eq!(decrypted, plaintext, "Vector {} round-trip mismatch", i);
}
}
}
mod aesgcm_security_tests {
use super::*;
#[test]
fn test_empty_plaintext_encryption_decryption() {
let key = SymmetricKey::random();
let ciphertext = key.encrypt(b"").expect("Empty plaintext encryption failed");
assert_eq!(
ciphertext.len(),
48,
"Empty plaintext ciphertext wrong size"
);
let decrypted = key
.decrypt(&ciphertext)
.expect("Empty plaintext decryption failed");
assert!(
decrypted.is_empty(),
"Decrypted empty plaintext should be empty"
);
}
#[test]
fn test_block_boundary_sizes() {
let key = SymmetricKey::random();
let test_sizes = [15, 16, 17, 31, 32, 33, 63, 64, 65];
for size in test_sizes {
let plaintext = vec![0xABu8; size];
let ciphertext = key
.encrypt(&plaintext)
.unwrap_or_else(|_| panic!("Failed to encrypt {} bytes", size));
assert_eq!(
ciphertext.len(),
32 + size + 16,
"Wrong ciphertext size for {} byte plaintext",
size
);
let decrypted = key
.decrypt(&ciphertext)
.unwrap_or_else(|_| panic!("Failed to decrypt {} bytes", size));
assert_eq!(
decrypted, plaintext,
"Decryption mismatch for {} bytes",
size
);
}
}
#[test]
fn test_tag_tampering_detection() {
let key = SymmetricKey::random();
let plaintext = b"test message for tag tampering";
let mut ciphertext = key.encrypt(plaintext).expect("Encryption failed");
let tag_start = ciphertext.len() - 16;
for i in 0..16 {
let original_byte = ciphertext[tag_start + i];
ciphertext[tag_start + i] ^= 0x01;
let result = key.decrypt(&ciphertext);
assert!(
result.is_err(),
"Tampered tag byte {} should fail decryption",
i
);
ciphertext[tag_start + i] = original_byte;
}
}
#[test]
fn test_nonce_tampering_detection() {
let key = SymmetricKey::random();
let plaintext = b"test message for nonce tampering";
let mut ciphertext = key.encrypt(plaintext).expect("Encryption failed");
for i in 0..32 {
let original_byte = ciphertext[i];
ciphertext[i] ^= 0x01;
let result = key.decrypt(&ciphertext);
assert!(
result.is_err(),
"Tampered nonce byte {} should fail decryption",
i
);
ciphertext[i] = original_byte;
}
}
#[test]
fn test_ciphertext_tampering_detection() {
let key = SymmetricKey::random();
let plaintext = b"test message for ciphertext tampering detection";
let mut ciphertext = key.encrypt(plaintext).expect("Encryption failed");
let ct_start = 32; let ct_end = ciphertext.len() - 16;
if ct_end > ct_start {
for i in ct_start..ct_end {
let original_byte = ciphertext[i];
ciphertext[i] ^= 0xFF;
let result = key.decrypt(&ciphertext);
assert!(
result.is_err(),
"Tampered ciphertext byte {} should fail decryption",
i
);
ciphertext[i] = original_byte;
}
}
}
#[test]
fn test_minimum_length_validation() {
let key = SymmetricKey::random();
for len in 0..48 {
let data = vec![0u8; len];
let result = key.decrypt(&data);
assert!(
result.is_err(),
"Ciphertext of {} bytes should fail (min is 48)",
len
);
}
let valid_empty = key.encrypt(b"").expect("Encryption failed");
assert_eq!(valid_empty.len(), 48);
let result = key.decrypt(&valid_empty);
assert!(
result.is_ok(),
"48-byte ciphertext (empty plaintext) should work"
);
}
#[test]
fn test_nonce_uniqueness() {
let key = SymmetricKey::random();
let plaintext = b"same message encrypted multiple times";
let mut nonces = Vec::new();
for _ in 0..10 {
let ciphertext = key.encrypt(plaintext).expect("Encryption failed");
let nonce: Vec<u8> = ciphertext[..32].to_vec();
nonces.push(nonce);
}
for i in 0..nonces.len() {
for j in (i + 1)..nonces.len() {
assert_ne!(
nonces[i], nonces[j],
"Nonces {} and {} should be different",
i, j
);
}
}
}
#[test]
fn test_wrong_key_decryption() {
let key1 = SymmetricKey::random();
let key2 = SymmetricKey::random();
let plaintext = b"secret message";
let ciphertext = key1.encrypt(plaintext).expect("Encryption failed");
let result = key2.decrypt(&ciphertext);
assert!(result.is_err(), "Decryption with wrong key should fail");
}
#[test]
fn test_truncated_ciphertext() {
let key = SymmetricKey::random();
let plaintext = b"test message for truncation";
let ciphertext = key.encrypt(plaintext).expect("Encryption failed");
for truncate_at in [47, 40, 32, 16, 8, 1] {
if truncate_at < ciphertext.len() {
let truncated = &ciphertext[..truncate_at];
let result = key.decrypt(truncated);
assert!(
result.is_err(),
"Truncated ciphertext at {} bytes should fail",
truncate_at
);
}
}
}
#[test]
fn test_extended_ciphertext() {
let key = SymmetricKey::random();
let plaintext = b"test message";
let mut ciphertext = key.encrypt(plaintext).expect("Encryption failed");
ciphertext.push(0x00);
let result = key.decrypt(&ciphertext);
assert!(
result.is_err(),
"Extended ciphertext should fail decryption"
);
}
}
}