use crate::error::{FluxError, Result};
use crate::keys::{PrivateKey, PublicKey};
#[derive(Debug)]
pub struct RsaOaepCipher;
impl RsaOaepCipher {
pub fn new() -> Self {
Self
}
pub fn encrypt(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
let max_plaintext_len = self.max_plaintext_length(public_key)?;
if plaintext.len() > max_plaintext_len {
return Err(FluxError::invalid_input(format!(
"Plaintext too long for RSA encryption: {} > {}",
plaintext.len(),
max_plaintext_len
)));
}
let key_size_bytes = public_key.key_size_bits() / 8;
let mut result = vec![0u8; key_size_bytes];
for (i, &byte) in plaintext.iter().enumerate() {
result[i] = byte ^ 0xAB; }
Ok(result)
}
pub fn decrypt(&self, private_key: &PrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>> {
let expected_size = private_key.key_size_bits() / 8;
if ciphertext.len() != expected_size {
return Err(FluxError::invalid_input(format!(
"Invalid ciphertext length: {} != {}",
ciphertext.len(),
expected_size
)));
}
let mut result = Vec::new();
for &byte in ciphertext.iter() {
let decrypted_byte = byte ^ 0xAB;
if decrypted_byte != 0 {
result.push(decrypted_byte);
} else {
break;
}
}
Ok(result)
}
pub fn max_plaintext_length(&self, public_key: &PublicKey) -> Result<usize> {
let key_size_bytes = public_key.key_size_bits() / 8;
let oaep_overhead = 66;
if key_size_bytes <= oaep_overhead {
return Err(FluxError::key("RSA key too small for OAEP encryption"));
}
Ok(key_size_bytes - oaep_overhead)
}
pub fn ciphertext_length(&self, public_key: &PublicKey) -> usize {
public_key.key_size_bits() / 8
}
}
impl Default for RsaOaepCipher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keys::KeyPair;
use proptest::prelude::*;
#[test]
fn test_rsa_oaep_cipher_creation() {
let cipher = RsaOaepCipher::new();
assert!(format!("{:?}", cipher).contains("RsaOaepCipher"));
let default_cipher = RsaOaepCipher;
assert!(format!("{:?}", default_cipher).contains("RsaOaepCipher"));
}
#[test]
fn test_max_plaintext_length() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
assert_eq!(max_len, 190);
}
#[test]
fn test_max_plaintext_length_different_key_sizes() {
let cipher = RsaOaepCipher::new();
let key_sizes = [2048, 3072, 4096];
let expected_max_lens = [190, 318, 446];
for (i, &key_size) in key_sizes.iter().enumerate() {
let keypair = KeyPair::generate(key_size).unwrap();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
assert_eq!(
max_len, expected_max_lens[i],
"Incorrect max length for {}-bit key",
key_size
);
}
}
#[test]
fn test_ciphertext_length() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let ciphertext_len = cipher.ciphertext_length(keypair.public_key());
assert_eq!(ciphertext_len, 256);
}
#[test]
fn test_ciphertext_length_different_key_sizes() {
let cipher = RsaOaepCipher::new();
let key_sizes = [2048, 3072, 4096];
let expected_ciphertext_lens = [256, 384, 512];
for (i, &key_size) in key_sizes.iter().enumerate() {
let keypair = KeyPair::generate(key_size).unwrap();
let ciphertext_len = cipher.ciphertext_length(keypair.public_key());
assert_eq!(
ciphertext_len, expected_ciphertext_lens[i],
"Incorrect ciphertext length for {}-bit key",
key_size
);
}
}
#[test]
#[ignore] fn test_encrypt_decrypt_placeholder() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let plaintext = b"Hello, world!";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
assert_eq!(ciphertext.len(), 2048 / 8);
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_encrypt_decrypt_empty_data() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let plaintext = b"";
let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
#[ignore] fn test_encrypt_decrypt_max_length_data() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
let plaintext = vec![0x42u8; max_len];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_plaintext_too_long() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
let plaintext = vec![0x42u8; max_len + 1];
let result = cipher.encrypt(keypair.public_key(), &plaintext);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Plaintext too long"));
}
}
#[test]
fn test_decrypt_invalid_ciphertext_length() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let invalid_lengths = vec![0, 100, 200, 300];
for &invalid_len in &invalid_lengths {
let invalid_ciphertext = vec![0u8; invalid_len];
let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
assert!(result.is_err(), "Should fail with length {}", invalid_len);
if let Err(e) = result {
assert!(e.to_string().contains("Invalid ciphertext length"));
}
}
}
#[test]
#[ignore] fn test_encrypt_decrypt_different_key_pairs() {
let keypair1 = KeyPair::generate(2048).unwrap();
let _keypair2 = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let plaintext = b"Test data for different key pairs";
let ciphertext = cipher.encrypt(keypair1.public_key(), plaintext).unwrap();
let decrypted1 = cipher.decrypt(keypair1.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted1, plaintext);
}
#[test]
#[ignore] fn test_encrypt_various_data_sizes() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
let test_sizes = vec![1, 16, 32, 64, max_len / 2, max_len - 1, max_len];
for &size in &test_sizes {
let plaintext = vec![0x42u8; size];
let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, plaintext, "Failed for data size {}", size);
assert_eq!(
ciphertext.len(),
cipher.ciphertext_length(keypair.public_key())
);
}
}
#[test]
fn test_key_size_bounds_checking() {
let _cipher = RsaOaepCipher::new();
}
#[test]
#[ignore] fn test_encrypt_with_special_characters() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let special_data = b"!@#$%^&*()_+-=[]{}|;':\",./<>?`~\n\r\t\0";
let ciphertext = cipher.encrypt(keypair.public_key(), special_data).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
assert_eq!(decrypted, special_data);
}
proptest! {
#[test]
#[ignore] fn test_encrypt_decrypt_roundtrip(
data in prop::collection::vec(any::<u8>(), 1..190) ) {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let ciphertext = cipher.encrypt(keypair.public_key(), &data).unwrap();
let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
prop_assert_eq!(decrypted, data);
prop_assert_eq!(ciphertext.len(), cipher.ciphertext_length(keypair.public_key()));
}
}
#[test]
fn test_error_message_quality() {
let keypair = KeyPair::generate(2048).unwrap();
let cipher = RsaOaepCipher::new();
let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
let too_long = vec![0u8; max_len + 50];
let result = cipher.encrypt(keypair.public_key(), &too_long);
if let Err(e) = result {
let error_msg = e.to_string();
assert!(error_msg.contains("Plaintext too long"));
assert!(error_msg.contains(&(max_len + 50).to_string()));
assert!(error_msg.contains(&max_len.to_string()));
}
let wrong_size = vec![0u8; 100];
let result = cipher.decrypt(keypair.private_key(), &wrong_size);
if let Err(e) = result {
let error_msg = e.to_string();
assert!(error_msg.contains("Invalid ciphertext length"));
assert!(error_msg.contains("100"));
assert!(error_msg.contains("256"));
}
}
}