use num_bigint::BigUint;
use rand::prelude::Distribution;
use rand::rngs::StdRng;
use rand::SeedableRng;
use super::traits::Codec;
use crate::bytes::left_pad_0s;
#[derive(Clone, Debug, PartialEq)]
pub enum BlockType {
Type01,
Type02,
}
impl TryFrom<u8> for BlockType {
type Error = crate::error::Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(BlockType::Type01),
0x02 => Ok(BlockType::Type02),
_ => Err(crate::error::Error::InvalidPadding),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Pkcs1V1_5 {
pub block_type: BlockType,
pub max_data_length: usize,
pub total_length: usize,
rng: StdRng,
}
impl Pkcs1V1_5 {
pub fn new(
block_type: BlockType,
modulus_length: usize,
max_data_length: usize,
seed: u64,
) -> Self {
let total_length = modulus_length;
let rng = StdRng::seed_from_u64(seed);
Self {
block_type,
max_data_length,
total_length,
rng,
}
}
pub fn strip_padding(bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
if bytes[0] != 0x00 {
return Err(crate::error::Error::InvalidPadding);
}
let mut i = 2;
while i < bytes.len() && bytes[i] != 0x00 {
i += 1;
}
if i == bytes.len() {
return Err(crate::error::Error::InvalidPadding);
}
Ok(bytes[i + 1..].to_vec())
}
pub fn decode_type01(&self, bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
if bytes[0] != 0x00 || bytes[1] != 0x01 {
return Err(crate::error::Error::InvalidPadding);
}
let mut i = 2;
while i < bytes.len() && bytes[i] == 0xff {
i += 1;
}
if i == bytes.len() || bytes[i] != 0x00 {
return Err(crate::error::Error::InvalidPadding);
}
Ok(bytes[i + 1..self.total_length].to_vec())
}
pub fn decode_type02(&self, bytes: &[u8]) -> Result<Vec<u8>, crate::error::Error> {
if bytes[0] != 0x00 || bytes[1] != 0x02 {
return Err(crate::error::Error::InvalidPadding);
}
let mut i = 2;
while i < bytes.len() && bytes[i] != 0x00 {
i += 1;
}
if i == bytes.len() {
return Err(crate::error::Error::InvalidPadding);
}
Ok(bytes[i + 1..self.total_length].to_vec())
}
}
impl Codec for Pkcs1V1_5 {
fn encode(&mut self, chunk: &[u8]) -> Result<BigUint, crate::error::Error> {
if chunk.len() + 3 >= self.total_length {
return Err(crate::error::Error::MessageTooLarge);
}
let mut bytes = vec![0; self.total_length];
bytes[1] = match self.block_type {
BlockType::Type01 => 0x01,
BlockType::Type02 => 0x02,
};
let padding_length = self.total_length - 3 - chunk.len();
match self.block_type {
BlockType::Type01 => {
for x in bytes[2..2 + padding_length].iter_mut() {
*x = 0xff;
}
}
BlockType::Type02 => {
let distribution = rand::distributions::Uniform::from(1..=255);
for x in bytes[2..2 + padding_length].iter_mut() {
*x = distribution.sample(&mut self.rng);
}
}
}
bytes[padding_length + 2] = 0x00;
let data_start = padding_length + 3;
let data_end = data_start + chunk.len();
bytes[data_start..data_end].copy_from_slice(chunk);
Ok(BigUint::from_bytes_be(&bytes))
}
fn decode(&self, chunk: &BigUint) -> Result<Vec<u8>, crate::error::Error> {
let bytes: Vec<u8> = left_pad_0s(&chunk.to_bytes_be(), self.total_length);
if bytes[0] != 0x00 {
return Err(crate::error::Error::InvalidPadding);
}
let block_type = bytes[1].try_into()?;
match block_type {
BlockType::Type01 => self.decode_type01(&bytes),
BlockType::Type02 => self.decode_type02(&bytes),
}
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
const SEED: u64 = 1234;
#[test]
fn test_type01_encode_decode() {
let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type01, 32, 21, SEED);
let plaintext = b"hello, world!";
let encoded_plaintext = pkcs1v1_5.encode(plaintext).unwrap();
let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();
assert_eq!(plaintext, decoded_plaintext.as_slice());
}
#[test]
fn test_type02_encode_decode() {
let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type02, 32, 21, SEED);
let plaintext = b"hello, world!";
let encoded_plaintext = pkcs1v1_5.encode(plaintext).unwrap();
let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();
assert_eq!(plaintext, decoded_plaintext.as_slice());
}
proptest! {
#[test]
fn round_trip_codec_type01(
plaintext in prop::collection::vec(any::<u8>(), 1..16),
) {
let modulus_length = 32;
let max_data_length = modulus_length - 8;
let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type01, modulus_length, max_data_length, SEED);
let encoded_plaintext = pkcs1v1_5.encode(&plaintext).unwrap();
let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();
assert_eq!(plaintext, decoded_plaintext.as_slice());
}
#[test]
fn round_trip_codec_type02(
plaintext in prop::collection::vec(any::<u8>(), 1..16),
) {
let modulus_length = 32;
let max_data_length = modulus_length - 8;
let mut pkcs1v1_5 = Pkcs1V1_5::new(BlockType::Type02, modulus_length, max_data_length, SEED);
let encoded_plaintext = pkcs1v1_5.encode(&plaintext).unwrap();
let decoded_plaintext = pkcs1v1_5.decode(&encoded_plaintext).unwrap();
assert_eq!(plaintext, decoded_plaintext.as_slice());
}
}
}