cryptoy 0.4.0

Toy implementations of cryptographic protocols for educational purposes
Documentation
use num_bigint::{BigUint, ModInverse, RandPrime};
use num_integer::Integer;
use rand::rngs::StdRng;
use rand::SeedableRng;

use crate::error::Error;

pub mod codec;
pub mod traits;

use traits::{Codec, RsaKey};

#[derive(Debug, Clone, PartialEq)]
pub struct Rsa<C: Codec> {
    pub codec: C,
    pub key_pair: KeyPair,
}

impl<C: Codec> Rsa<C> {
    pub fn encrypt(&mut self, msg: &[u8]) -> Result<BigUint, Error> {
        let encoded_msg = self.codec.encode(msg)?;
        let encrypted_msg = self.key_pair.public_key.crypt(&encoded_msg)?;
        Ok(encrypted_msg)
    }

    pub fn decrypt(&self, msg: &BigUint) -> Result<Vec<u8>, Error> {
        let decrypted_msg = self.key_pair.private_key.crypt(msg)?;
        let decoded_msg = self.codec.decode(&decrypted_msg)?;
        Ok(decoded_msg)
    }

    pub fn decrypt_bytes(&self, msg: &[u8]) -> Result<Vec<u8>, Error> {
        let msg = BigUint::from_bytes_be(msg);
        self.decrypt(&msg)
    }
}

#[derive(Debug, Clone, PartialEq)]
pub struct KeyPair {
    pub p: BigUint,
    pub q: BigUint,
    pub private_key: Key,
    pub public_key: Key,
}

#[derive(Debug, Clone, PartialEq)]
pub struct Key {
    pub exponent: BigUint,
    pub modulus: BigUint,
}

impl Key {
    pub fn new(exponent: BigUint, modulus: BigUint) -> Self {
        Self { exponent, modulus }
    }
}

impl RsaKey for Key {
    fn crypt(&self, message: &BigUint) -> Result<BigUint, Error> {
        if message >= &self.modulus {
            println!("message: {}", message);
            return Err(Error::InvalidSize);
        }
        Ok(message.modpow(&self.exponent, &self.modulus))
    }
}

impl Codec for Key {
    fn encode(&mut self, chunk: &[u8]) -> Result<BigUint, Error> {
        let msg = BigUint::from_bytes_be(chunk);
        if msg > self.modulus {
            return Err(Error::MessageTooLarge);
        }

        Ok(msg)
    }

    fn decode(&self, chunk: &BigUint) -> Result<Vec<u8>, Error> {
        Ok(chunk.to_bytes_be())
    }
}

impl KeyPair {
    pub fn generate(prime_bits: usize, seed: u64) -> Result<Self, Error> {
        if prime_bits < 8 {
            return Err(Error::InvalidKeyParams { prime_bits });
        }

        let mut rng = StdRng::seed_from_u64(seed);

        let p: BigUint = rng.gen_prime(prime_bits);
        let mut q: BigUint = rng.gen_prime(prime_bits);
        while q == p {
            q = rng.gen_prime(prime_bits);
        }

        Self::new(&p, &q)
    }

    pub fn new(p: &BigUint, q: &BigUint) -> Result<Self, Error> {
        let n = p * q;

        let one: BigUint = 1u8.into();

        let euler_p: BigUint = p - 1u8;
        let euler_q: BigUint = q - 1u8;
        let totient = euler_p.lcm(&euler_q);

        let mut public_exponent: BigUint = 3u8.into();
        while public_exponent.gcd(&totient) != one && public_exponent < totient {
            public_exponent += 2u8;
        }

        if public_exponent >= totient {
            return Err(Error::UnknownError);
        }

        let private_exponent = public_exponent
            .clone()
            .mod_inverse(&totient)
            .ok_or(Error::UnknownError)?
            .to_biguint()
            .ok_or(Error::UnknownError)?;
        let modulus = n;

        let private_key = Key::new(private_exponent.clone(), modulus.clone());
        let public_key = Key::new(public_exponent.clone(), modulus.clone());

        Ok(Self {
            p: p.clone(),
            q: q.clone(),
            private_key,
            public_key,
        })
    }
}

#[cfg(test)]
mod tests {
    use std::num::NonZeroU8;

    use proptest::prelude::*;
    use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};

    use super::*;

    const SEED: u64 = 1234;

    #[test]
    fn fails_to_generate_primes_too_small() {
        let prime_bits = 7usize;
        let result = KeyPair::generate(prime_bits, SEED);
        assert!(result.is_err());
    }

    proptest! {
        #[test]
        fn generates_keys_with_valid_params(primebits in 8usize..128) {
            let result = KeyPair::generate(primebits, SEED);
            assert!(result.is_ok());

            let keypair = result.unwrap();
            assert!(keypair.p != keypair.q);
        }
    }

    #[test]
    fn keys_match_known_good_implementation() {
        let prime_bits = 128usize;
        let keypair = KeyPair::generate(prime_bits, SEED).unwrap();

        let p = keypair.p.clone();
        let q = keypair.q.clone();
        let n = p.clone() * q.clone();
        let public_exponent = keypair.public_key.exponent.clone();

        let other_keypair = RsaPrivateKey::from_p_q(p, q, public_exponent.clone()).unwrap();
        let other_pubkey = RsaPublicKey::new(n, public_exponent).unwrap();

        assert_eq!(other_pubkey, other_keypair.to_public_key());
    }

    #[test]
    fn can_be_decrypted_by_known_good_implementation() {
        let prime_bits = 128usize;
        let keypair = KeyPair::generate(prime_bits, SEED).unwrap();

        let modulus_bytes = keypair.private_key.modulus.to_bytes_be().len();

        let codec = codec::Pkcs1V1_5::new(
            codec::BlockType::Type02,
            modulus_bytes,
            modulus_bytes - 11,
            SEED,
        );

        let mut rsa = Rsa::<codec::Pkcs1V1_5> {
            codec,
            key_pair: keypair.clone(),
        };

        let kg_keypair = RsaPrivateKey::from_p_q(
            keypair.p.clone(),
            keypair.q.clone(),
            keypair.public_key.exponent.clone(),
        )
        .expect("failed to generate keypair");

        let plaintext = b"Hello, world!";
        let ciphertext = rsa.encrypt(plaintext).unwrap();

        let decrypted = kg_keypair
            .decrypt(Pkcs1v15Encrypt, &ciphertext.to_bytes_be())
            .expect("should decrypt");

        assert_eq!(plaintext.as_slice(), &decrypted);
    }

    #[test]
    fn can_encrypt_for_known_good_implementation() {
        let mut rng = StdRng::seed_from_u64(1002);

        let prime_bits = 128usize;
        let keypair = KeyPair::generate(prime_bits, SEED).unwrap();

        let modulus_bytes = keypair.private_key.modulus.to_bytes_be().len();

        let codec = codec::Pkcs1V1_5::new(
            codec::BlockType::Type02,
            modulus_bytes,
            modulus_bytes - 11,
            SEED,
        );

        let kg_keypair = RsaPrivateKey::from_p_q(
            keypair.p.clone(),
            keypair.q.clone(),
            keypair.public_key.exponent.clone(),
        )
        .expect("failed to generate keypair");

        let plaintext = b"Hello, world!";
        let ciphertext = kg_keypair
            .to_public_key()
            .encrypt(&mut rng, Pkcs1v15Encrypt, &plaintext[..])
            .expect("failed to encrypt");

        let rsa = Rsa::<codec::Pkcs1V1_5> {
            codec,
            key_pair: keypair.clone(),
        };

        let cipher_text_int = BigUint::from_bytes_be(ciphertext.as_slice());
        let decrypted = rsa.decrypt(&cipher_text_int).unwrap();

        assert_eq!(plaintext.as_slice(), &decrypted);
    }

    /// This one failed during a migration between bigint libraries. It is
    /// included as an explicit test case because it's a bit easier to debug a
    /// single failing case than the failing case of a whole proptest, plus
    /// faster.
    #[test]
    fn round_trips_encryption_plain_2() {
        let primebits = 33;
        let plaintext = vec![1; 7];
        let mut keypair = KeyPair::generate(primebits, SEED).unwrap();

        let encoded_plaintext = keypair.public_key.encode(&plaintext).unwrap();

        let ciphertext = keypair.public_key.crypt(&encoded_plaintext).unwrap();
        let decrypted = keypair.private_key.crypt(&ciphertext).unwrap();

        let decrypted_text = keypair.private_key.decode(&decrypted).unwrap();

        assert_eq!(plaintext, decrypted_text);
    }

    proptest! {
        #[test]
        fn round_trips_encryption_plain(
            plaintext in prop::collection::vec(any::<NonZeroU8>(), 1..10),
        ) {
            let primebits = 64;
            let mut keypair = KeyPair::generate(primebits, SEED).unwrap();

            let plaintext = plaintext.into_iter().map(|nz| nz.get()).collect::<Vec<u8>>();
            let encoded_plaintext = keypair.public_key.encode(&plaintext).unwrap();

            let ciphertext = keypair.public_key.crypt(&encoded_plaintext).unwrap();
            let decrypted = keypair.private_key.crypt(&ciphertext).unwrap();

            let decrypted_text = keypair.private_key.decode(&decrypted).unwrap();

            assert_eq!(plaintext, decrypted_text);
        }
    }

    proptest! {
        #[test]
        fn round_trips_encryption_pkcs_type01(
            plaintext in prop::collection::vec(any::<u8>(), 1..10),
        ) {
            let primebits = 64;

            let keypair = KeyPair::generate(primebits, SEED).unwrap();
            let modulus_bytes = keypair.private_key.modulus.bits() / 8;
            let codec = codec::Pkcs1V1_5::new(codec::BlockType::Type01, modulus_bytes, modulus_bytes - 11, SEED);

            let mut rsa = Rsa::<codec::Pkcs1V1_5> {
                codec,
                key_pair: keypair,
            };

            let ciphertext = rsa.encrypt(&plaintext).unwrap();
            let decrypted_text = rsa.decrypt(&ciphertext).unwrap();

            assert_eq!(plaintext, decrypted_text);
        }
    }

    proptest! {
        #[test]
        fn round_trips_encryption_pkcs_type02(
            plaintext in prop::collection::vec(any::<u8>(), 1..10),
        ) {
            let primebits = 64;
            let keypair = KeyPair::generate(primebits, SEED).unwrap();
            let modulus_bytes = keypair.private_key.modulus.bits() / 8;
            let codec = codec::Pkcs1V1_5::new(codec::BlockType::Type02, modulus_bytes, modulus_bytes - 11, SEED);

            let mut rsa = Rsa::<codec::Pkcs1V1_5> {
                codec,
                key_pair: keypair.clone(),
            };

            let ciphertext = rsa.encrypt(&plaintext).unwrap();
            let decrypted_text = rsa.decrypt(&ciphertext).unwrap();
            assert_eq!(plaintext, decrypted_text);
        }
    }
}