use super::{
aes_gcm::{AesGcmCipher, AesKey},
rsa_oaep::RsaOaepCipher,
};
use crate::config::Config;
use crate::error::{FluxError, Result};
use crate::keys::{PrivateKey, PublicKey};
#[derive(Debug, Clone)]
pub struct HybridCipher {
config: Config,
}
impl HybridCipher {
pub fn new(config: Config) -> Self {
Self { config }
}
pub fn encrypt(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
const MAX_BLOB_SIZE: usize = 512 * 1024; if plaintext.len() > MAX_BLOB_SIZE {
return Err(FluxError::invalid_input(format!(
"Data too large for blob encryption: {} bytes exceeds {} KB limit",
plaintext.len(),
MAX_BLOB_SIZE / 1024
)));
}
let aes_key = AesKey::generate(self.config.cipher_suite)?;
let aes_cipher = AesGcmCipher::new(self.config.cipher_suite);
let (nonce, aes_ciphertext) = aes_cipher.encrypt(&aes_key, plaintext, None)?;
let rsa_cipher = RsaOaepCipher::new();
let encrypted_aes_key = rsa_cipher.encrypt(public_key, aes_key.as_bytes())?;
let expected_key_size = public_key.key_size_bits() / 8;
if encrypted_aes_key.len() != expected_key_size {
return Err(FluxError::crypto(format!(
"Unexpected encrypted key size: {} bytes, expected {} bytes for {}-bit RSA",
encrypted_aes_key.len(),
expected_key_size,
public_key.key_size_bits()
)));
}
if nonce.len() != 12 {
return Err(FluxError::crypto(format!(
"Unexpected nonce size: {} bytes, expected 12 bytes for GCM",
nonce.len()
)));
}
let mut result = Vec::with_capacity(encrypted_aes_key.len() + 12 + aes_ciphertext.len());
result.extend_from_slice(&encrypted_aes_key); result.extend_from_slice(&nonce); result.extend_from_slice(&aes_ciphertext);
Ok(result)
}
pub fn decrypt(&self, private_key: &PrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>> {
let encrypted_key_size = private_key.key_size_bits() / 8;
let min_size = encrypted_key_size + 12 + 16;
if ciphertext.len() < min_size {
return Err(FluxError::invalid_input(format!(
"Ciphertext too short: {} bytes, minimum {} bytes required",
ciphertext.len(),
min_size
)));
}
let encrypted_aes_key = &ciphertext[0..encrypted_key_size];
let nonce = &ciphertext[encrypted_key_size..encrypted_key_size + 12];
let aes_ciphertext = &ciphertext[encrypted_key_size + 12..];
let rsa_cipher = RsaOaepCipher::new();
let aes_key_bytes = rsa_cipher.decrypt(private_key, encrypted_aes_key)?;
let aes_key = AesKey::new(aes_key_bytes);
let aes_cipher = AesGcmCipher::new(self.config.cipher_suite);
let plaintext = aes_cipher.decrypt(&aes_key, nonce, aes_ciphertext, None)?;
Ok(plaintext)
}
pub fn config(&self) -> &Config {
&self.config
}
}
impl Default for HybridCipher {
fn default() -> Self {
Self::new(Config::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{CipherSuite, Config, RsaKeySize};
use crate::keys::KeyPair;
use proptest::prelude::*;
#[test]
fn test_hybrid_cipher_creation() {
let cipher = HybridCipher::default();
assert!(cipher.config().validate().is_ok());
}
#[test]
fn test_hybrid_cipher_with_custom_config() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes128Gcm)
.rsa_key_size(RsaKeySize::Rsa3072)
.build()
.unwrap();
let cipher = HybridCipher::new(config.clone());
assert_eq!(cipher.config().cipher_suite, config.cipher_suite);
assert_eq!(cipher.config().rsa_key_size, config.rsa_key_size);
}
#[test]
fn test_hybrid_cipher_debug() {
let cipher = HybridCipher::default();
let debug_str = format!("{:?}", cipher);
assert!(debug_str.contains("HybridCipher"));
assert!(debug_str.contains("config"));
}
#[test]
fn test_hybrid_cipher_clone() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes256Gcm)
.build()
.unwrap();
let cipher1 = HybridCipher::new(config);
let cipher2 = cipher1.clone();
assert_eq!(cipher1.config().cipher_suite, cipher2.config().cipher_suite);
}
#[test]
fn test_encrypt_decrypt() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Hello, FluxEncrypt hybrid encryption!";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert!(!ciphertext.is_empty());
assert!(ciphertext.len() > plaintext.len());
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_empty_data() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_large_data_encryption() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = vec![42u8; 10000];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_very_large_data_encryption() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = vec![0x42u8; 512 * 1024];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_data_size_limit_exceeded() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = vec![0x42u8; 512 * 1024 + 1];
let result = cipher.encrypt(keypair.public_key(), &plaintext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Data too large for blob encryption"));
}
}
#[test]
fn test_different_cipher_suites() {
let keypair = KeyPair::generate(2048).unwrap();
let plaintext = b"Test data for different cipher suites";
for cipher_suite in &[CipherSuite::Aes128Gcm, CipherSuite::Aes256Gcm] {
let config = Config::builder()
.cipher_suite(*cipher_suite)
.build()
.unwrap();
let cipher = HybridCipher::new(config);
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
}
#[test]
fn test_different_key_sizes() {
let plaintext = b"Test data for different key sizes";
for key_size in &[2048, 3072, 4096] {
let keypair = KeyPair::generate(*key_size).unwrap();
let cipher = HybridCipher::default();
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
}
#[test]
fn test_decrypt_invalid_ciphertext_too_short() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
for len in 0..284 {
let short_ciphertext = vec![0u8; len];
let result = cipher.decrypt(keypair.private_key(), &short_ciphertext);
assert!(result.is_err(), "Should fail with length {}", len);
if let Err(e) = result {
assert!(e.to_string().contains("Ciphertext too short"));
}
}
}
#[test]
fn test_decrypt_invalid_encrypted_key_data() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let encrypted_key_size = keypair.private_key().key_size_bits() / 8;
let invalid_ciphertext = vec![0u8; encrypted_key_size + 12 + 16];
let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("RSA decryption failed"));
}
}
#[test]
fn test_ciphertext_format_integrity() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Test ciphertext format integrity";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let encrypted_key_size = keypair.public_key().key_size_bits() / 8; let nonce_size = 12; let tag_size = 16;
let expected_min_size = encrypted_key_size + nonce_size + tag_size;
assert!(
ciphertext.len() >= expected_min_size,
"Ciphertext should have at least {} bytes, got {}",
expected_min_size,
ciphertext.len()
);
assert_eq!(
encrypted_key_size, 256,
"Encrypted key should be 256 bytes for 2048-bit RSA"
);
let aes_data_size = ciphertext.len() - encrypted_key_size;
assert_eq!(
aes_data_size,
nonce_size + plaintext.len() + tag_size,
"AES data should be nonce + plaintext + tag"
);
}
#[test]
fn test_different_plaintexts_produce_different_ciphertexts() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext1 = b"First test message";
let plaintext2 = b"Second test message";
let ciphertext1 = cipher.encrypt(keypair.public_key(), plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(keypair.public_key(), plaintext2).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Different plaintexts should produce different ciphertexts"
);
}
#[test]
fn test_same_plaintext_produces_different_ciphertexts() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Same test message";
let ciphertext1 = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let ciphertext2 = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert_ne!(
ciphertext1, ciphertext2,
"Same plaintext should produce different ciphertexts due to randomness"
);
let decrypted1 = cipher.decrypt(keypair.private_key(), &ciphertext1).unwrap();
let decrypted2 = cipher.decrypt(keypair.private_key(), &ciphertext2).unwrap();
assert_eq!(decrypted1, plaintext);
assert_eq!(decrypted2, plaintext);
}
#[test]
fn test_tampered_ciphertext_detection() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let plaintext = b"Test tamper detection";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let encrypted_key_size = keypair.public_key().key_size_bits() / 8; let original_len = ciphertext.len();
if ciphertext.len() > 10 {
let mut tampered = ciphertext.clone();
tampered[10] ^= 1; let _result = cipher.decrypt(keypair.private_key(), &tampered);
}
if ciphertext.len() > encrypted_key_size + 5 {
let mut tampered = ciphertext.clone();
tampered[encrypted_key_size + 5] ^= 1; let _result = cipher.decrypt(keypair.private_key(), &tampered);
}
if !ciphertext.is_empty() {
let mut tampered = ciphertext.clone();
tampered[original_len - 1] ^= 1; let _result = cipher.decrypt(keypair.private_key(), &tampered);
}
}
proptest! {
#[test]
fn test_encrypt_decrypt_roundtrip(
data in prop::collection::vec(any::<u8>(), 0..10000)
) {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let ciphertext = cipher.encrypt(keypair.public_key(), &data).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
let data_len = data.len();
let is_empty = data.is_empty();
prop_assert_eq!(decrypted, data);
if !is_empty {
prop_assert!(ciphertext.len() > data_len, "Ciphertext should be larger than plaintext");
}
}
}
#[test]
fn test_encrypt_with_different_data_patterns() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = HybridCipher::default();
let test_patterns = vec![
vec![0x00; 100], vec![0xFF; 100], (0..100u8).collect(), [0xAA, 0x55].repeat(50), b"The quick brown fox jumps over the lazy dog".repeat(10), ];
for pattern in test_patterns {
let ciphertext = cipher.encrypt(keypair.public_key(), &pattern).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, pattern);
}
}
#[test]
fn test_config_access() {
let config = Config::builder()
.cipher_suite(CipherSuite::Aes128Gcm)
.build()
.unwrap();
let cipher = HybridCipher::new(config.clone());
assert_eq!(cipher.config().cipher_suite, config.cipher_suite);
}
}