cryptoy 0.4.0

Toy implementations of cryptographic protocols for educational purposes
Documentation
use num_bigint::BigUint;
use rand::prelude::Distribution;
use rand::rngs::StdRng;
use rand::SeedableRng;

use super::traits::Codec;
use crate::bytes::left_pad_0s;

#[derive(Clone, Debug, PartialEq)]
pub enum BlockType {
    // Type00 is omitted since its encoding/decoding is ambiguous if there is
    // a leading 0.
    Type01,
    Type02,
}

impl TryFrom<u8> for BlockType {
    type Error = crate::error::Error;

    fn try_from(value: u8) -> Result<Self, Self::Error> {
        match value {
            0x01 => Ok(BlockType::Type01),
            0x02 => Ok(BlockType::Type02),
            _ => Err(crate::error::Error::InvalidPadding),
        }
    }
}

#[derive(Clone, Debug, PartialEq)]
pub struct Pkcs1V1_5 {
    pub block_type: BlockType,

    /// max bytes which can be encoded per chunk
    pub max_data_length: usize,

    /// total length of the encoded chunk before conversion to an int
    pub total_length: usize,

    /// source of randomness
    rng: StdRng,
}

impl Pkcs1V1_5 {
    pub fn new(
        block_type: BlockType,
        modulus_length: usize,
        max_data_length: usize,
        seed: u64,
    ) -> Self {
        let total_length = modulus_length;
        let rng = StdRng::seed_from_u64(seed);

        Self {
            block_type,
            max_data_length,
            total_length,
            rng,
        }
    }

    pub fn strip_padding(bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
        if bytes[0] != 0x00 {
            return Err(crate::error::Error::InvalidPadding);
        }

        let mut i = 2;
        while i < bytes.len() && bytes[i] != 0x00 {
            i += 1;
        }

        if i == bytes.len() {
            return Err(crate::error::Error::InvalidPadding);
        }

        Ok(bytes[i + 1..].to_vec())
    }

    pub fn decode_type01(&self, bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
        if bytes[0] != 0x00 || bytes[1] != 0x01 {
            return Err(crate::error::Error::InvalidPadding);
        }

        let mut i = 2;
        while i < bytes.len() && bytes[i] == 0xff {
            i += 1;
        }

        if i == bytes.len() || bytes[i] != 0x00 {
            return Err(crate::error::Error::InvalidPadding);
        }

        Ok(bytes[i + 1..self.total_length].to_vec())
    }

    pub fn decode_type02(&self, bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
        if bytes[0] != 0x00 || bytes[1] != 0x02 {
            return Err(crate::error::Error::InvalidPadding);
        }

        let mut i = 2;
        while i < bytes.len() && bytes[i] != 0x00 {
            i += 1;
        }

        if i == bytes.len() {
            return Err(crate::error::Error::InvalidPadding);
        }

        Ok(bytes[i + 1..self.total_length].to_vec())
    }
}

impl Codec for Pkcs1V1_5 {
    fn encode(&mut self, chunk: &[u8]) -> Result<BigUint, crate::error::Error> {
        if chunk.len() + 3 >= self.total_length {
            return Err(crate::error::Error::MessageTooLarge);
        }

        let mut bytes = vec![0; self.total_length];

        bytes[1] = match self.block_type {
            BlockType::Type01 => 0x01,
            BlockType::Type02 => 0x02,
        };

        let padding_length = self.total_length - 3 - chunk.len();

        match self.block_type {
            BlockType::Type01 => {
                for x in bytes[2..2 + padding_length].iter_mut() {
                    *x = 0xff;
                }
            }
            BlockType::Type02 => {
                let distribution = rand::distributions::Uniform::from(1..=255);

                for x in bytes[2..2 + padding_length].iter_mut() {
                    *x = distribution.sample(&mut self.rng);
                }
            }
        }

        bytes[padding_length + 2] = 0x00;

        let data_start = padding_length + 3;
        let data_end = data_start + chunk.len();

        bytes[data_start..data_end].copy_from_slice(chunk);

        Ok(BigUint::from_bytes_be(&bytes))
    }

    fn decode(&self, chunk: &BigUint) -> Result<Vec<u8>, crate::error::Error> {
        let bytes: Vec<u8> = left_pad_0s(&chunk.to_bytes_be(), self.total_length);

        if bytes[0] != 0x00 {
            return Err(crate::error::Error::InvalidPadding);
        }

        let block_type = bytes[1].try_into()?;

        match block_type {
            BlockType::Type01 => self.decode_type01(&bytes),
            BlockType::Type02 => self.decode_type02(&bytes),
        }
    }
}

#[cfg(test)]
mod tests {
    use proptest::prelude::*;

    use super::*;

    const SEED: u64 = 1234;

    #[test]
    fn test_type01_encode_decode() {
        let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type01, 32, 21, SEED);

        let plaintext = b"hello, world!";
        let encoded_plaintext = pkcs1v1_5.encode(plaintext).unwrap();
        let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();

        assert_eq!(plaintext, decoded_plaintext.as_slice());
    }

    #[test]
    fn test_type02_encode_decode() {
        let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type02, 32, 21, SEED);

        let plaintext = b"hello, world!";
        let encoded_plaintext = pkcs1v1_5.encode(plaintext).unwrap();
        let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();

        assert_eq!(plaintext, decoded_plaintext.as_slice());
    }

    proptest! {
        #[test]
        fn round_trip_codec_type01(
            plaintext in prop::collection::vec(any::<u8>(), 1..16),
        ) {
            let modulus_length = 32;
            let max_data_length = modulus_length - 8;
            let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type01, modulus_length, max_data_length, SEED);

            let encoded_plaintext = pkcs1v1_5.encode(&plaintext).unwrap();
            let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();

            assert_eq!(plaintext, decoded_plaintext.as_slice());
        }

        #[test]
        fn round_trip_codec_type02(
            plaintext in prop::collection::vec(any::<u8>(), 1..16),
        ) {
            let modulus_length = 32;
            let max_data_length = modulus_length - 8;
            let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type02, modulus_length, max_data_length, SEED);

            let encoded_plaintext = pkcs1v1_5.encode(&plaintext).unwrap();
            let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();

            assert_eq!(plaintext, decoded_plaintext.as_slice());
        }
    }
}