use super::{
aes_gcm::{AesGcmCipher, AesKey},
rsa_oaep::RsaOaepCipher,
};
use crate::config::Config;
use crate::error::{FluxError, Result};
use crate::keys::{PrivateKey, PublicKey};
#[derive(Debug, Clone)]
pub struct HybridCipher {
config: Config,
}
impl HybridCipher {
pub fn new(config: Config) -> Self {
Self { config }
}
pub fn encrypt(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
let aes_key = AesKey::generate(self.config.cipher_suite)?;
let aes_cipher = AesGcmCipher::new(self.config.cipher_suite);
let (nonce, aes_ciphertext) = aes_cipher.encrypt(&aes_key, plaintext, None)?;
let rsa_cipher = RsaOaepCipher::new();
let encrypted_aes_key = rsa_cipher.encrypt(public_key, aes_key.as_bytes())?;
let mut result = Vec::new();
result.extend_from_slice(&(encrypted_aes_key.len() as u32).to_be_bytes());
result.extend_from_slice(&encrypted_aes_key);
result.extend_from_slice(&(nonce.len() as u32).to_be_bytes());
result.extend_from_slice(&nonce);
result.extend_from_slice(&aes_ciphertext);
Ok(result)
}
pub fn decrypt(&self, private_key: &PrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>> {
if ciphertext.len() < 8 {
return Err(FluxError::invalid_input("Ciphertext too short"));
}
let mut offset = 0;
let encrypted_key_len = u32::from_be_bytes([
ciphertext[offset],
ciphertext[offset + 1],
ciphertext[offset + 2],
ciphertext[offset + 3],
]) as usize;
offset += 4;
if offset + encrypted_key_len > ciphertext.len() {
return Err(FluxError::invalid_input("Invalid encrypted key length"));
}
let encrypted_aes_key = &ciphertext[offset..offset + encrypted_key_len];
offset += encrypted_key_len;
if offset + 4 > ciphertext.len() {
return Err(FluxError::invalid_input("Invalid nonce length field"));
}
let nonce_len = u32::from_be_bytes([
ciphertext[offset],
ciphertext[offset + 1],
ciphertext[offset + 2],
ciphertext[offset + 3],
]) as usize;
offset += 4;
if offset + nonce_len > ciphertext.len() {
return Err(FluxError::invalid_input("Invalid nonce length"));
}
let nonce = &ciphertext[offset..offset + nonce_len];
offset += nonce_len;
let aes_ciphertext = &ciphertext[offset..];
let rsa_cipher = RsaOaepCipher::new();
let aes_key_bytes = rsa_cipher.decrypt(private_key, encrypted_aes_key)?;
let aes_key = AesKey::new(aes_key_bytes);
let aes_cipher = AesGcmCipher::new(self.config.cipher_suite);
let plaintext = aes_cipher.decrypt(&aes_key, nonce, aes_ciphertext, None)?;
Ok(plaintext)
}
pub fn config(&self) -> &Config {
&self.config
}
}
impl Default for HybridCipher {
fn default() -> Self {
Self::new(Config::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{CipherSuite, Config, RsaKeySize};
use crate::keys::KeyPair;
use proptest::prelude::*;
#[test]
fn test_hybrid_cipher_creation() {
let cipher = HybridCipher::default();
assert!(cipher.config().validate().is_ok());
}
#[test]
fn test_hybrid_cipher_with_custom_config() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes128Gcm)
.rsa_key_size(RsaKeySize::Rsa3072)
.build()
.unwrap();
let cipher = HybridCipher::new(config.clone());
assert_eq!(cipher.config().cipher_suite, config.cipher_suite);
assert_eq!(cipher.config().rsa_key_size, config.rsa_key_size);
}
#[test]
fn test_hybrid_cipher_debug() {
let cipher = HybridCipher::default();
let debug_str = format!("{:?}", cipher);
assert!(debug_str.contains("HybridCipher"));
assert!(debug_str.contains("config"));
}
#[test]
fn test_hybrid_cipher_clone() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes256Gcm)
.build()
.unwrap();
let cipher1 = HybridCipher::new(config);
let cipher2 = cipher1.clone();
assert_eq!(cipher1.config().cipher_suite, cipher2.config().cipher_suite);
}
#[test]
#[ignore] fn test_encrypt_decrypt() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Hello, FluxEncrypt hybrid encryption!";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert!(!ciphertext.is_empty());
assert!(ciphertext.len() > plaintext.len());
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_encrypt_decrypt_empty_data() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_large_data_encryption() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = vec![42u8; 10000];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_very_large_data_encryption() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = vec![0x42u8; 1_000_000];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_different_cipher_suites() {
let keypair = KeyPair::generate(2048).unwrap();
let plaintext = b"Test data for different cipher suites";
for cipher_suite in &[CipherSuite::Aes128Gcm, CipherSuite::Aes256Gcm] {
let config = Config::builder()
.cipher_suite(*cipher_suite)
.build()
.unwrap();
let cipher = HybridCipher::new(config);
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
}
#[test]
#[ignore] fn test_different_key_sizes() {
let plaintext = b"Test data for different key sizes";
for key_size in &[2048, 3072, 4096] {
let keypair = KeyPair::generate(*key_size).unwrap();
let cipher = HybridCipher::default();
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
}
#[test]
fn test_decrypt_invalid_ciphertext_too_short() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
for len in 0..8 {
let short_ciphertext = vec![0u8; len];
let result = cipher.decrypt(keypair.private_key(), &short_ciphertext);
assert!(result.is_err(), "Should fail with length {}", len);
if let Err(e) = result {
assert!(e.to_string().contains("Ciphertext too short"));
}
}
}
#[test]
fn test_decrypt_invalid_encrypted_key_length() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let mut invalid_ciphertext = Vec::new();
invalid_ciphertext.extend_from_slice(&(1000u32).to_be_bytes()); invalid_ciphertext.resize(12, 0);
let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Invalid encrypted key length"));
}
}
#[test]
fn test_decrypt_invalid_nonce_length_field() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let mut invalid_ciphertext = Vec::new();
invalid_ciphertext.extend_from_slice(&(256u32).to_be_bytes()); invalid_ciphertext.resize(256 + 4 + 2, 0);
let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Invalid nonce length field"));
}
}
#[test]
fn test_decrypt_invalid_nonce_length() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let mut invalid_ciphertext = Vec::new();
invalid_ciphertext.extend_from_slice(&(256u32).to_be_bytes()); invalid_ciphertext.resize(256 + 4, 0); invalid_ciphertext.extend_from_slice(&(1000u32).to_be_bytes()); invalid_ciphertext.resize(256 + 4 + 4 + 10, 0);
let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Invalid nonce length"));
}
}
#[test]
#[ignore] fn test_ciphertext_format_integrity() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Test ciphertext format integrity";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert!(
ciphertext.len() >= 8,
"Ciphertext should have at least length fields"
);
let mut offset = 0;
let encrypted_key_len = u32::from_be_bytes([
ciphertext[offset],
ciphertext[offset + 1],
ciphertext[offset + 2],
ciphertext[offset + 3],
]) as usize;
offset += 4;
assert_eq!(
encrypted_key_len, 256,
"Encrypted key should be 256 bytes for 2048-bit RSA"
);
offset += encrypted_key_len;
let nonce_len = u32::from_be_bytes([
ciphertext[offset],
ciphertext[offset + 1],
ciphertext[offset + 2],
ciphertext[offset + 3],
]) as usize;
offset += 4;
assert_eq!(nonce_len, 12, "Nonce should be 12 bytes for GCM");
offset += nonce_len;
let aes_ciphertext_len = ciphertext.len() - offset;
assert_eq!(
aes_ciphertext_len,
plaintext.len() + 16,
"AES ciphertext should be plaintext + 16-byte tag"
);
}
#[test]
#[ignore] fn test_different_plaintexts_produce_different_ciphertexts() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext1 = b"First test message";
let plaintext2 = b"Second test message";
let ciphertext1 = cipher.encrypt(keypair.public_key(), plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(keypair.public_key(), plaintext2).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Different plaintexts should produce different ciphertexts"
);
}
#[test]
#[ignore] fn test_same_plaintext_produces_different_ciphertexts() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Same test message";
let ciphertext1 = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let ciphertext2 = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Same plaintext should produce different ciphertexts due to randomness"
);
let decrypted1 = cipher.decrypt(keypair.private_key(), &ciphertext1).unwrap();
let decrypted2 = cipher.decrypt(keypair.private_key(), &ciphertext2).unwrap();
assert_eq!(decrypted1, plaintext);
assert_eq!(decrypted2, plaintext);
}
#[test]
#[ignore] fn test_tampered_ciphertext_detection() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Test tamper detection";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let original_len = ciphertext.len();
if ciphertext.len() > 4 {
let mut tampered = ciphertext.clone();
tampered[0] ^= 1;
let result = cipher.decrypt(keypair.private_key(), &tampered);
assert!(
result.is_err(),
"Tampering with encrypted key length should be detected"
);
}
if ciphertext.len() > 10 {
let mut tampered = ciphertext.clone();
tampered[8] ^= 1; let _result = cipher.decrypt(keypair.private_key(), &tampered);
}
if !ciphertext.is_empty() {
let mut tampered = ciphertext.clone();
tampered[original_len - 1] ^= 1; let _result = cipher.decrypt(keypair.private_key(), &tampered);
}
}
proptest! {
#[test]
#[ignore] fn test_encrypt_decrypt_roundtrip(
data in prop::collection::vec(any::<u8>(), 0..10000)
) {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let ciphertext = cipher.encrypt(keypair.public_key(), &data).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
let data_len = data.len();
let is_empty = data.is_empty();
prop_assert_eq!(decrypted, data);
if !is_empty {
prop_assert!(ciphertext.len() > data_len, "Ciphertext should be larger than plaintext");
}
}
}
#[test]
#[ignore] fn test_encrypt_with_different_data_patterns() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let test_patterns = vec![
vec![0x00; 100], vec![0xFF; 100], (0..100u8).collect(), [0xAA, 0x55].repeat(50), b"The quick brown fox jumps over the lazy dog".repeat(10), ];
for pattern in test_patterns {
let ciphertext = cipher.encrypt(keypair.public_key(), &pattern).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, pattern);
}
}
#[test]
fn test_config_access() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes128Gcm)
.build()
.unwrap();
let cipher = HybridCipher::new(config.clone());
assert_eq!(cipher.config().cipher_suite, config.cipher_suite);
}
}