use crate::config::CipherSuite;
use crate::encryption::aes_gcm::{AesGcmCipher, AesKey};
use crate::error::{FluxError, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
#[derive(Debug)]
pub struct SymmetricCipher {
cipher: AesGcmCipher,
key: AesKey,
}
impl SymmetricCipher {
pub fn new(hex_key: &str) -> Result<Self> {
let key_bytes = hex::decode(hex_key)
.map_err(|_| FluxError::key("Invalid hex encoding for encryption key"))?;
if key_bytes.len() != 32 {
return Err(FluxError::key(format!(
"Encryption key must be 32 bytes (64 hex characters), got {} bytes",
key_bytes.len()
)));
}
let key = AesKey::new(key_bytes);
let cipher = AesGcmCipher::new(CipherSuite::Aes256Gcm);
Ok(Self { cipher, key })
}
pub fn from_bytes(key_bytes: &[u8]) -> Result<Self> {
if key_bytes.len() != 32 {
return Err(FluxError::key(format!(
"Encryption key must be 32 bytes, got {} bytes",
key_bytes.len()
)));
}
let key = AesKey::new(key_bytes.to_vec());
let cipher = AesGcmCipher::new(CipherSuite::Aes256Gcm);
Ok(Self { cipher, key })
}
pub fn encrypt(&self, plaintext: &str) -> Result<String> {
self.encrypt_bytes(plaintext.as_bytes())
}
pub fn encrypt_bytes(&self, plaintext: &[u8]) -> Result<String> {
let (nonce, ciphertext) = self.cipher.encrypt(&self.key, plaintext, None)?;
let mut combined = nonce;
combined.extend(ciphertext);
Ok(BASE64.encode(combined))
}
pub fn decrypt(&self, encrypted: &str) -> Result<String> {
let decrypted_bytes = self.decrypt_bytes(encrypted)?;
String::from_utf8(decrypted_bytes)
.map_err(|_| FluxError::crypto("Decrypted data is not valid UTF-8"))
}
pub fn decrypt_bytes(&self, encrypted: &str) -> Result<Vec<u8>> {
let combined = BASE64
.decode(encrypted)
.map_err(|_| FluxError::invalid_input("Invalid base64 encoding in ciphertext"))?;
if combined.len() < 28 {
return Err(FluxError::invalid_input(
"Ciphertext too short (must be at least 28 bytes after base64 decoding)",
));
}
let (nonce, ciphertext) = combined.split_at(12);
self.cipher.decrypt(&self.key, nonce, ciphertext, None)
}
pub fn generate_key() -> Result<String> {
let key = AesKey::generate(CipherSuite::Aes256Gcm)?;
Ok(hex::encode(key.as_bytes()))
}
}
impl Clone for SymmetricCipher {
fn clone(&self) -> Self {
Self {
cipher: AesGcmCipher::new(CipherSuite::Aes256Gcm),
key: AesKey::new(self.key.as_bytes().to_vec()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_test_key() -> String {
SymmetricCipher::generate_key().unwrap()
}
#[test]
fn test_encrypt_decrypt() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let plaintext = "access-sandbox-abc123-secret-token";
let encrypted = cipher.encrypt(plaintext).unwrap();
assert_ne!(encrypted, plaintext);
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_bytes() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let plaintext = b"binary\x00data\xff\xfe";
let encrypted = cipher.encrypt_bytes(plaintext).unwrap();
let decrypted = cipher.decrypt_bytes(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_invalid_key_not_hex() {
let result =
SymmetricCipher::new("not-hex-gggggggggggggggggggggggggggggggggggggggggggggggggggggg");
assert!(result.is_err());
}
#[test]
fn test_invalid_key_too_short() {
let result = SymmetricCipher::new("tooshort");
assert!(result.is_err());
}
#[test]
fn test_invalid_key_too_long() {
let result = SymmetricCipher::new(&"a".repeat(66));
assert!(result.is_err());
}
#[test]
fn test_valid_key_exact_length() {
let result = SymmetricCipher::new(&"a".repeat(64));
assert!(result.is_ok());
}
#[test]
fn test_empty_string() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let encrypted = cipher.encrypt("").unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, "");
}
#[test]
fn test_unicode_content() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let unicode_text = "Hello 世界! 🔐 Привет мир! café résumé";
let encrypted = cipher.encrypt(unicode_text).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, unicode_text);
}
#[test]
fn test_very_long_data() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let long_text: String = (0..10000)
.map(|i| ((i % 26) as u8 + b'a') as char)
.collect();
let encrypted = cipher.encrypt(&long_text).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, long_text);
}
#[test]
fn test_decrypt_invalid_base64() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
assert!(cipher.decrypt("not-valid-base64!!!").is_err());
}
#[test]
fn test_decrypt_too_short() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
assert!(cipher.decrypt("YWJj").is_err());
}
#[test]
fn test_decrypt_with_wrong_key() {
let key1 = generate_test_key();
let key2 = generate_test_key();
let cipher1 = SymmetricCipher::new(&key1).unwrap();
let cipher2 = SymmetricCipher::new(&key2).unwrap();
let encrypted = cipher1.encrypt("secret data").unwrap();
assert!(cipher2.decrypt(&encrypted).is_err());
}
#[test]
fn test_different_encryptions_produce_different_ciphertext() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let plaintext = "same input";
let encrypted1 = cipher.encrypt(plaintext).unwrap();
let encrypted2 = cipher.encrypt(plaintext).unwrap();
assert_ne!(encrypted1, encrypted2);
assert_eq!(cipher.decrypt(&encrypted1).unwrap(), plaintext);
assert_eq!(cipher.decrypt(&encrypted2).unwrap(), plaintext);
}
#[test]
fn test_special_characters() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let special_chars = "!@#$%^&*()_+-=[]{}|;':\",./<>?\n\t\r\\";
let encrypted = cipher.encrypt(special_chars).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, special_chars);
}
#[test]
fn test_json_content() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let json = r#"{"access_token":"secret","refresh_token":"also_secret","expires_in":3600}"#;
let encrypted = cipher.encrypt(json).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, json);
}
#[test]
fn test_generate_key() {
let key = SymmetricCipher::generate_key().unwrap();
assert_eq!(key.len(), 64);
assert!(hex::decode(&key).is_ok());
let cipher = SymmetricCipher::new(&key).unwrap();
let encrypted = cipher.encrypt("test").unwrap();
assert_eq!(cipher.decrypt(&encrypted).unwrap(), "test");
}
#[test]
fn test_from_bytes() {
let key_bytes = [0x42u8; 32];
let cipher = SymmetricCipher::from_bytes(&key_bytes).unwrap();
let encrypted = cipher.encrypt("test").unwrap();
assert_eq!(cipher.decrypt(&encrypted).unwrap(), "test");
}
#[test]
fn test_from_bytes_wrong_length() {
let short = [0u8; 16];
assert!(SymmetricCipher::from_bytes(&short).is_err());
let long = [0u8; 64];
assert!(SymmetricCipher::from_bytes(&long).is_err());
}
#[test]
fn test_clone() {
let key = generate_test_key();
let cipher1 = SymmetricCipher::new(&key).unwrap();
let cipher2 = cipher1.clone();
let encrypted = cipher1.encrypt("test").unwrap();
assert_eq!(cipher2.decrypt(&encrypted).unwrap(), "test");
}
#[test]
fn test_tampered_ciphertext() {
let key = generate_test_key();
let cipher = SymmetricCipher::new(&key).unwrap();
let encrypted = cipher.encrypt("secret").unwrap();
let mut bytes = BASE64.decode(&encrypted).unwrap();
if bytes.len() > 12 {
bytes[12] ^= 1; }
let tampered = BASE64.encode(&bytes);
assert!(cipher.decrypt(&tampered).is_err());
}
}