use crate::primitives::aes_gcm::{aes_gcm_decrypt_ts_compat, aes_gcm_encrypt_ts_compat};
use crate::primitives::big_number::{BigNumber, Endian};
use crate::primitives::error::PrimitivesError;
use crate::primitives::random::random_bytes;
pub struct SymmetricKey {
key: BigNumber,
}
impl SymmetricKey {
pub fn from_random() -> Self {
let bytes = random_bytes(32);
SymmetricKey {
key: BigNumber::from_bytes(&bytes, Endian::Big),
}
}
pub fn from_hex(hex: &str) -> Result<Self, PrimitivesError> {
if hex.len() != 64 {
return Err(PrimitivesError::InvalidLength(format!(
"symmetric key hex must be 64 characters, got {}",
hex.len()
)));
}
let key = BigNumber::from_hex(hex)?;
Ok(SymmetricKey { key })
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, PrimitivesError> {
if bytes.len() != 32 {
return Err(PrimitivesError::InvalidLength(format!(
"symmetric key must be 32 bytes, got {}",
bytes.len()
)));
}
Ok(SymmetricKey {
key: BigNumber::from_bytes(bytes, Endian::Big),
})
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, PrimitivesError> {
let iv = random_bytes(32);
let key_bytes = self.key.to_array(Endian::Big, Some(32));
let ct_and_tag = aes_gcm_encrypt_ts_compat(&key_bytes, &iv, plaintext)?;
let mut result = Vec::with_capacity(iv.len() + ct_and_tag.len());
result.extend_from_slice(&iv);
result.extend_from_slice(&ct_and_tag);
Ok(result)
}
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, PrimitivesError> {
if data.len() < 48 {
return Err(PrimitivesError::DecryptionFailed);
}
let iv = &data[0..32];
let ciphertext_with_tag = &data[32..];
let key_bytes = self.key.to_array(Endian::Big, Some(32));
aes_gcm_decrypt_ts_compat(&key_bytes, iv, ciphertext_with_tag)
}
pub fn to_hex(&self) -> String {
let bytes = self.key.to_array(Endian::Big, Some(32));
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
pub fn to_bytes(&self) -> Vec<u8> {
self.key.to_array(Endian::Big, Some(32))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_random_produces_32_byte_key() {
let key = SymmetricKey::from_random();
assert_eq!(key.to_bytes().len(), 32);
}
#[test]
fn test_from_hex_round_trip() {
let hex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
let key = SymmetricKey::from_hex(hex).unwrap();
assert_eq!(key.to_hex(), hex);
}
#[test]
fn test_from_bytes_round_trip() {
let bytes = vec![0xabu8; 32];
let key = SymmetricKey::from_bytes(&bytes).unwrap();
assert_eq!(key.to_bytes(), bytes);
}
#[test]
fn test_from_hex_invalid_length() {
let result = SymmetricKey::from_hex("0123456789abcdef");
assert!(result.is_err());
}
#[test]
fn test_encrypt_decrypt_round_trip_empty() {
let key = SymmetricKey::from_random();
let plaintext = b"";
let ciphertext = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_round_trip_1_byte() {
let key = SymmetricKey::from_random();
let plaintext = b"\x42";
let ciphertext = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_round_trip_16_bytes() {
let key = SymmetricKey::from_random();
let plaintext = b"0123456789abcdef";
let ciphertext = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_round_trip_100_bytes() {
let key = SymmetricKey::from_random();
let plaintext = vec![0x55u8; 100];
let ciphertext = key.encrypt(&plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_round_trip_1000_bytes() {
let key = SymmetricKey::from_random();
let plaintext = vec![0x77u8; 1000];
let ciphertext = key.encrypt(&plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_ciphertext_format_iv32_ct_tag16() {
let key = SymmetricKey::from_random();
let plaintext = b"hello world";
let ciphertext = key.encrypt(plaintext).unwrap();
assert_eq!(ciphertext.len(), 32 + plaintext.len() + 16);
}
#[test]
fn test_decrypt_rejects_tampered_ciphertext() {
let key = SymmetricKey::from_random();
let plaintext = b"test message";
let mut ciphertext = key.encrypt(plaintext).unwrap();
if ciphertext.len() > 33 {
ciphertext[33] ^= 0xff;
}
let result = key.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_decrypt_rejects_truncated_data() {
let key = SymmetricKey::from_random();
let short_data = vec![0u8; 47];
let result = key.decrypt(&short_data);
assert!(result.is_err());
}
#[test]
fn test_two_encryptions_produce_different_ciphertext() {
let key = SymmetricKey::from_random();
let plaintext = b"same plaintext";
let ct1 = key.encrypt(plaintext).unwrap();
let ct2 = key.encrypt(plaintext).unwrap();
assert_ne!(ct1, ct2);
assert_eq!(key.decrypt(&ct1).unwrap(), plaintext);
assert_eq!(key.decrypt(&ct2).unwrap(), plaintext);
}
#[test]
fn test_to_hex_returns_64_chars() {
let key = SymmetricKey::from_random();
let hex = key.to_hex();
assert_eq!(hex.len(), 64);
}
}