fluxencrypt 0.7.1

A high-performance, secure encryption SDK for Rust applications
Documentation
//! RSA-OAEP asymmetric encryption implementation.
//!
//! This module provides RSA-OAEP encryption and decryption functionality
//! for encrypting small amounts of data, typically AES keys in hybrid encryption.

use crate::error::{FluxError, Result};
use crate::keys::{PrivateKey, PublicKey};
use rsa::Oaep;
use sha2::Sha512;

/// RSA-OAEP cipher for asymmetric encryption operations
#[derive(Debug)]
pub struct RsaOaepCipher;

impl RsaOaepCipher {
    /// Create a new RSA-OAEP cipher
    pub fn new() -> Self {
        Self
    }

    /// Encrypt data with RSA-OAEP
    ///
    /// # Arguments
    /// * `public_key` - The RSA public key to encrypt with
    /// * `plaintext` - The data to encrypt (must be small enough for RSA)
    ///
    /// # Returns
    /// The encrypted data
    pub fn encrypt(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
        // Check plaintext size constraints
        let max_plaintext_len = self.max_plaintext_length(public_key)?;
        if plaintext.len() > max_plaintext_len {
            return Err(FluxError::invalid_input(format!(
                "Plaintext too long for RSA encryption: {} > {}",
                plaintext.len(),
                max_plaintext_len
            )));
        }

        // Create OAEP padding with SHA-512
        let padding = Oaep::new::<Sha512>();

        // Get a random number generator
        let mut rng = rand::thread_rng();

        // Encrypt using RSA-OAEP
        public_key
            .inner()
            .encrypt(&mut rng, padding, plaintext)
            .map_err(|e| FluxError::crypto(format!("RSA encryption failed: {}", e)))
    }

    /// Decrypt data with RSA-OAEP
    ///
    /// # Arguments
    /// * `private_key` - The RSA private key to decrypt with
    /// * `ciphertext` - The encrypted data
    ///
    /// # Returns
    /// The decrypted plaintext
    pub fn decrypt(&self, private_key: &PrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>> {
        let expected_size = private_key.key_size_bits() / 8;
        if ciphertext.len() != expected_size {
            return Err(FluxError::invalid_input(format!(
                "Invalid ciphertext length: {} != {}",
                ciphertext.len(),
                expected_size
            )));
        }

        // Create OAEP padding with SHA-512
        let padding = Oaep::new::<Sha512>();

        // Decrypt using RSA-OAEP
        private_key
            .inner()
            .decrypt(padding, ciphertext)
            .map_err(|e| FluxError::crypto(format!("RSA decryption failed: {}", e)))
    }

    /// Calculate the maximum plaintext length for RSA-OAEP encryption
    ///
    /// For RSA-OAEP with SHA-512, the maximum plaintext length is:
    /// key_length_bytes - 2 * hash_length_bytes - 2
    /// where hash_length_bytes = 64 for SHA-512
    pub fn max_plaintext_length(&self, public_key: &PublicKey) -> Result<usize> {
        let key_size_bytes = public_key.key_size_bits() / 8;

        // For RSA-OAEP with SHA-512: overhead = 2 * 64 + 2 = 130 bytes
        let oaep_overhead = 130;

        if key_size_bytes <= oaep_overhead {
            return Err(FluxError::key("RSA key too small for OAEP encryption"));
        }

        Ok(key_size_bytes - oaep_overhead)
    }

    /// Get the ciphertext length for a given RSA key
    pub fn ciphertext_length(&self, public_key: &PublicKey) -> usize {
        public_key.key_size_bits() / 8
    }
}

impl Default for RsaOaepCipher {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::keys::KeyPair;
    use proptest::prelude::*;

    #[test]
    fn test_rsa_oaep_cipher_creation() {
        let cipher = RsaOaepCipher::new();
        assert!(format!("{:?}", cipher).contains("RsaOaepCipher"));

        let default_cipher = RsaOaepCipher;
        assert!(format!("{:?}", default_cipher).contains("RsaOaepCipher"));
    }

    #[test]
    fn test_max_plaintext_length() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();

        // For 2048-bit RSA with OAEP-SHA512: 256 - 130 = 126 bytes
        assert_eq!(max_len, 126);
    }

    #[test]
    fn test_max_plaintext_length_different_key_sizes() {
        let cipher = RsaOaepCipher::new();

        // Test different key sizes
        let key_sizes = [2048, 3072, 4096];
        let expected_max_lens = [126, 254, 382]; // key_size/8 - 130

        for (i, &key_size) in key_sizes.iter().enumerate() {
            let keypair = KeyPair::generate(key_size).unwrap();
            let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
            assert_eq!(
                max_len, expected_max_lens[i],
                "Incorrect max length for {}-bit key",
                key_size
            );
        }
    }

    #[test]
    fn test_ciphertext_length() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        let ciphertext_len = cipher.ciphertext_length(keypair.public_key());

        // For 2048-bit RSA: 2048 / 8 = 256 bytes
        assert_eq!(ciphertext_len, 256);
    }

    #[test]
    fn test_ciphertext_length_different_key_sizes() {
        let cipher = RsaOaepCipher::new();

        let key_sizes = [2048, 3072, 4096];
        let expected_ciphertext_lens = [256, 384, 512]; // key_size / 8

        for (i, &key_size) in key_sizes.iter().enumerate() {
            let keypair = KeyPair::generate(key_size).unwrap();
            let ciphertext_len = cipher.ciphertext_length(keypair.public_key());
            assert_eq!(
                ciphertext_len, expected_ciphertext_lens[i],
                "Incorrect ciphertext length for {}-bit key",
                key_size
            );
        }
    }

    #[test]
    fn test_encrypt_decrypt() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();
        let plaintext = b"Hello, world!";

        // Test that encryption produces output of expected size
        let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
        assert_eq!(ciphertext.len(), 2048 / 8); // Should be key size in bytes

        // Test that decryption recovers the plaintext
        let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    #[test]
    fn test_encrypt_decrypt_empty_data() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();
        let plaintext = b"";

        let ciphertext = cipher.encrypt(keypair.public_key(), plaintext).unwrap();
        let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    #[test]
    fn test_encrypt_decrypt_max_length_data() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
        let plaintext = vec![0x42u8; max_len];

        let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
        let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
        assert_eq!(decrypted, plaintext);
    }

    #[test]
    fn test_encrypt_plaintext_too_long() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
        let plaintext = vec![0x42u8; max_len + 1]; // One byte too long

        let result = cipher.encrypt(keypair.public_key(), &plaintext);
        assert!(result.is_err());

        if let Err(e) = result {
            assert!(e.to_string().contains("Plaintext too long"));
        }
    }

    #[test]
    fn test_decrypt_invalid_ciphertext_length() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        // Test with various invalid lengths
        let invalid_lengths = vec![0, 100, 200, 300]; // All wrong for 2048-bit key

        for &invalid_len in &invalid_lengths {
            let invalid_ciphertext = vec![0u8; invalid_len];
            let result = cipher.decrypt(keypair.private_key(), &invalid_ciphertext);
            assert!(result.is_err(), "Should fail with length {}", invalid_len);

            if let Err(e) = result {
                assert!(e.to_string().contains("Invalid ciphertext length"));
            }
        }
    }

    #[test]
    fn test_encrypt_decrypt_different_key_pairs() {
        let keypair1 = KeyPair::generate(2048).unwrap();
        let _keypair2 = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();
        let plaintext = b"Test data for different key pairs";

        // Encrypt with first key pair
        let ciphertext = cipher.encrypt(keypair1.public_key(), plaintext).unwrap();

        // Should work with correct private key
        let decrypted1 = cipher.decrypt(keypair1.private_key(), &ciphertext).unwrap();
        assert_eq!(decrypted1, plaintext);

        // Should fail with wrong private key
        let result = cipher.decrypt(_keypair2.private_key(), &ciphertext);
        assert!(
            result.is_err(),
            "Decryption should fail with wrong private key"
        );
    }

    #[test]
    fn test_encrypt_various_data_sizes() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();
        let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();

        let test_sizes = vec![1, 16, 32, 64, max_len / 2, max_len - 1, max_len];

        for &size in &test_sizes {
            let plaintext = vec![0x42u8; size];

            let ciphertext = cipher.encrypt(keypair.public_key(), &plaintext).unwrap();
            let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();

            assert_eq!(decrypted, plaintext, "Failed for data size {}", size);
            assert_eq!(
                ciphertext.len(),
                cipher.ciphertext_length(keypair.public_key())
            );
        }
    }

    #[test]
    fn test_key_size_bounds_checking() {
        // Test that very small keys would be rejected
        // Note: KeyPair::generate might not allow very small keys, but we test the logic
        let _cipher = RsaOaepCipher::new();

        // This is testing the theoretical case where a key is too small
        // In practice, KeyPair::generate should reject keys smaller than minimum secure sizes
        // But the max_plaintext_length should handle edge cases gracefully

        // Test with minimum viable key size (512 bits = 64 bytes)
        // OAEP overhead is 66 bytes, so this should fail
        // We can't actually create such a small key with KeyPair::generate,
        // so this test documents the expected behavior
    }

    #[test]
    fn test_encrypt_with_special_characters() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        let special_data = b"!@#$%^&*()_+-=[]{}|;':\",./<>?`~\n\r\t\0";

        let ciphertext = cipher.encrypt(keypair.public_key(), special_data).unwrap();
        let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();
        assert_eq!(decrypted, special_data);
    }

    // Property-based tests
    proptest! {
        #[test]
        fn test_encrypt_decrypt_roundtrip(
            data in prop::collection::vec(any::<u8>(), 1..126) // Max 126 bytes for 2048-bit RSA with SHA-512
        ) {
            let keypair = KeyPair::generate(2048).unwrap();
            let cipher = RsaOaepCipher::new();

            let ciphertext = cipher.encrypt(keypair.public_key(), &data).unwrap();
            let decrypted = cipher.decrypt(keypair.private_key(), &ciphertext).unwrap();

            prop_assert_eq!(decrypted, data);
            prop_assert_eq!(ciphertext.len(), cipher.ciphertext_length(keypair.public_key()));
        }
    }

    #[test]
    fn test_error_message_quality() {
        let keypair = KeyPair::generate(2048).unwrap();
        let cipher = RsaOaepCipher::new();

        // Test plaintext too long error
        let max_len = cipher.max_plaintext_length(keypair.public_key()).unwrap();
        let too_long = vec![0u8; max_len + 50];
        let result = cipher.encrypt(keypair.public_key(), &too_long);

        if let Err(e) = result {
            let error_msg = e.to_string();
            assert!(error_msg.contains("Plaintext too long"));
            assert!(error_msg.contains(&(max_len + 50).to_string()));
            assert!(error_msg.contains(&max_len.to_string()));
        }

        // Test invalid ciphertext length error
        let wrong_size = vec![0u8; 100];
        let result = cipher.decrypt(keypair.private_key(), &wrong_size);

        if let Err(e) = result {
            let error_msg = e.to_string();
            assert!(error_msg.contains("Invalid ciphertext length"));
            assert!(error_msg.contains("100"));
            assert!(error_msg.contains("256"));
        }
    }
}