use aes_gcm::{AeadInPlace, KeyInit};
use aes_gcm::aead::generic_array::typenum::U32;
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::AesGcm;
use aes::Aes256;
use rand::RngCore;
use crate::PrimitivesError;
const IV_LEN: usize = 32;
const TAG_LEN: usize = 16;
pub struct SymmetricKey {
key: zeroize::Zeroizing<[u8; 32]>,
}
impl SymmetricKey {
pub fn new(key: &[u8]) -> Self {
let mut padded = [0u8; 32];
if key.len() < 32 {
padded[32 - key.len()..].copy_from_slice(key);
} else {
padded.copy_from_slice(&key[..32]);
}
SymmetricKey { key: zeroize::Zeroizing::new(padded) }
}
pub fn new_random() -> Self {
let mut key = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut key);
SymmetricKey { key: zeroize::Zeroizing::new(key) }
}
pub fn from_base64(b64: &str) -> Result<Self, PrimitivesError> {
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| PrimitivesError::Other(e.to_string()))?;
Ok(Self::new(&bytes))
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, PrimitivesError> {
let mut iv = [0u8; IV_LEN];
rand::rngs::OsRng.fill_bytes(&mut iv);
self.encrypt_with_iv(plaintext, &iv)
}
fn encrypt_with_iv(
&self,
plaintext: &[u8],
iv: &[u8; IV_LEN],
) -> Result<Vec<u8>, PrimitivesError> {
let cipher = AesGcm::<Aes256, U32>::new(GenericArray::from_slice(self.key.as_ref()));
let nonce = GenericArray::from_slice(iv);
let mut buffer = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(nonce, &[], &mut buffer)
.map_err(|_| PrimitivesError::Aead)?;
let mut result = Vec::with_capacity(IV_LEN + buffer.len() + TAG_LEN);
result.extend_from_slice(iv);
result.extend_from_slice(&buffer);
result.extend_from_slice(&tag);
Ok(result)
}
pub fn decrypt(&self, message: &[u8]) -> Result<Vec<u8>, PrimitivesError> {
if message.len() < IV_LEN + TAG_LEN {
return Err(PrimitivesError::DecryptionError(
"message is too short to be a valid encrypted message".to_string(),
));
}
let iv = &message[..IV_LEN];
let ciphertext = &message[IV_LEN..message.len() - TAG_LEN];
let tag = &message[message.len() - TAG_LEN..];
let cipher = AesGcm::<Aes256, U32>::new(GenericArray::from_slice(self.key.as_ref()));
let nonce = GenericArray::from_slice(iv);
let tag = GenericArray::from_slice(tag);
let mut buffer = ciphertext.to_vec();
cipher
.decrypt_in_place_detached(nonce, &[], &mut buffer, tag)
.map_err(|_| PrimitivesError::Aead)?;
Ok(buffer)
}
pub fn encrypt_string(&self, message: &str) -> Result<Vec<u8>, PrimitivesError> {
self.encrypt(message.as_bytes())
}
pub fn decrypt_string(&self, message: &[u8]) -> Result<String, PrimitivesError> {
let plaintext = self.decrypt(message)?;
String::from_utf8(plaintext)
.map_err(|e| PrimitivesError::DecryptionError(e.to_string()))
}
pub fn to_bytes(&self) -> &[u8; 32] {
&self.key
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symmetric_key_encryption_and_decryption() {
let key = SymmetricKey::new_random();
let plaintext = b"a thing to encrypt";
let ciphertext = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_symmetric_key_decryption_vectors() {
let vectors_json = include_str!("testdata/SymmetricKey.vectors.json");
let vectors: Vec<serde_json::Value> = serde_json::from_str(vectors_json).unwrap();
for (i, v) in vectors.iter().enumerate() {
let key_b64 = v["key"].as_str().unwrap();
let ciphertext_b64 = v["ciphertext"].as_str().unwrap();
let expected_plaintext = v["plaintext"].as_str().unwrap();
use base64::Engine;
let ciphertext = base64::engine::general_purpose::STANDARD
.decode(ciphertext_b64)
.unwrap();
let sym_key = SymmetricKey::from_base64(key_b64).unwrap();
let decrypted = sym_key.decrypt(&ciphertext).unwrap_or_else(|e| {
panic!("vector #{}: decryption failed: {}", i + 1, e);
});
assert_eq!(
String::from_utf8_lossy(&decrypted),
expected_plaintext,
"vector #{}: plaintext mismatch",
i + 1
);
}
}
#[test]
fn test_symmetric_key_with_short_key() {
let short_key = vec![0xABu8; 31];
let sym_key = SymmetricKey::new(&short_key);
let plaintext = b"test message";
let ciphertext = sym_key.encrypt(plaintext).unwrap();
let decrypted = sym_key.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_symmetric_key_decrypt_too_short() {
let key = SymmetricKey::new_random();
let short_msg = vec![0u8; 10];
assert!(key.decrypt(&short_msg).is_err());
}
}