use alloc::vec::Vec;
use const_oid::AssociatedOid;
use crypto_bigint::{BoxedUint, Choice, CtAssign, CtEq, CtLt, CtSelect};
use digest::{Digest, KeyInit, Mac};
use hmac::Hmac;
use rand_core::TryCryptoRng;
use sha2::Sha256;
use zeroize::Zeroizing;
use crate::errors::{Error, Result};
type HmacSha256 = Hmac<Sha256>;
#[inline]
fn non_zero_random_bytes<R: TryCryptoRng + ?Sized>(
rng: &mut R,
data: &mut [u8],
) -> core::result::Result<(), R::Error> {
rng.try_fill_bytes(data)?;
for el in data {
if *el == 0u8 {
while *el == 0u8 {
rng.try_fill_bytes(core::slice::from_mut(el))?;
}
}
}
Ok(())
}
pub(crate) fn pkcs1v15_encrypt_pad<R>(
rng: &mut R,
msg: &[u8],
k: usize,
) -> Result<Zeroizing<Vec<u8>>>
where
R: TryCryptoRng + ?Sized,
{
if msg.len() + 11 > k {
return Err(Error::MessageTooLong);
}
let mut em = Zeroizing::new(vec![0u8; k]);
em[1] = 2;
non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]).map_err(|_: R::Error| Error::Rng)?;
em[k - msg.len() - 1] = 0;
em[k - msg.len()..].copy_from_slice(msg);
Ok(em)
}
#[inline]
fn irprf(key: &[u8], label: &[u8], output_length: usize) -> Vec<u8> {
let mut output = Vec::with_capacity(output_length);
let mut counter: u16 = 0;
while output.len() < output_length {
let mut message = Vec::with_capacity(2 + label.len() + 2);
message.extend_from_slice(&counter.to_be_bytes());
message.extend_from_slice(label);
message.extend_from_slice(&((output_length * 8) as u16).to_be_bytes());
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
Mac::update(&mut mac, &message);
output.extend_from_slice(&Mac::finalize(mac).into_bytes());
counter += 1;
}
output.truncate(output_length);
output
}
#[inline]
pub(crate) fn derive_kdk(private_exponent: &BoxedUint, ciphertext: &[u8]) -> [u8; 32] {
let k = ciphertext.len();
let d_raw = Zeroizing::new(private_exponent.to_be_bytes());
let mut d_bytes = Zeroizing::new(vec![0u8; k]);
let start = k.saturating_sub(d_raw.len());
let copy_len = d_raw.len().min(k);
d_bytes[start..start + copy_len].copy_from_slice(&d_raw[d_raw.len() - copy_len..]);
let dh = Sha256::digest(&*d_bytes);
let mut mac = HmacSha256::new_from_slice(&dh).expect("HMAC can take key of any size");
Mac::update(&mut mac, ciphertext);
Mac::finalize(mac).into_bytes().into()
}
#[inline]
fn select_alternative_length(length_candidates: &[u8], k: usize) -> u32 {
let max_message_length = k.saturating_sub(11) as u32;
let mut candidate = u32::from_be_bytes([
length_candidates[0],
length_candidates[1],
length_candidates[2],
length_candidates[3],
]);
if max_message_length > 0 {
candidate %= max_message_length + 1;
} else {
candidate = 0;
}
candidate
}
#[inline]
pub(crate) fn pkcs1v15_encrypt_unpad(em: Vec<u8>, k: usize, kdk: &[u8; 32]) -> Vec<u8> {
let (valid, out, index) = match decrypt_inner(em, k) {
Ok(result) => result,
Err(_) => (0, vec![0; k], 0),
};
let max_message_length = k.saturating_sub(11);
let synthetic_full = irprf(kdk, b"message", k);
let length_candidates = irprf(kdk, b"length", 256);
let synthetic_length = select_alternative_length(&length_candidates, k) as usize;
let actual_length = (k as u32).saturating_sub(index) as usize;
let valid_choice = Choice::from_u8_lsb(valid);
let output_length = usize::ct_select(&synthetic_length, &actual_length, valid_choice);
let mut output = vec![0u8; max_message_length];
for (i, out_byte) in output.iter_mut().enumerate() {
let src_index = k - max_message_length + i;
let act_byte = out[src_index];
let syn_byte = synthetic_full[src_index];
*out_byte = u8::ct_select(&syn_byte, &act_byte, valid_choice);
}
let start = max_message_length.saturating_sub(output_length);
output[start..].to_vec()
}
#[inline]
fn decrypt_inner(em: Vec<u8>, k: usize) -> Result<(u8, Vec<u8>, u32)> {
if k < 11 {
return Err(Error::Decryption);
}
let first_byte_is_zero = em[0].ct_eq(&0u8);
let second_byte_is_two = em[1].ct_eq(&2u8);
let mut looking_for_index = Choice::TRUE;
let mut index = 0u32;
for (i, el) in em.iter().enumerate().skip(2) {
let equals0 = el.ct_eq(&0u8);
index.ct_assign(&(i as u32), looking_for_index & equals0);
looking_for_index &= !equals0;
}
let valid_ps = !index.ct_lt(&10u32);
let valid = first_byte_is_zero & second_byte_is_two & !looking_for_index & valid_ps;
index = u32::ct_select(&0, &(index + 1), valid);
Ok((valid.to_u8(), em, index))
}
#[inline]
pub(crate) fn pkcs1v15_sign_pad(prefix: &[u8], hashed: &[u8], k: usize) -> Result<Vec<u8>> {
let hash_len = hashed.len();
let t_len = prefix.len() + hashed.len();
if k < t_len + 11 {
return Err(Error::MessageTooLong);
}
let mut em = vec![0xff; k];
em[0] = 0;
em[1] = 1;
em[k - t_len - 1] = 0;
em[k - t_len..k - hash_len].copy_from_slice(prefix);
em[k - hash_len..k].copy_from_slice(hashed);
Ok(em)
}
#[inline]
pub(crate) fn pkcs1v15_sign_unpad(prefix: &[u8], hashed: &[u8], em: &[u8], k: usize) -> Result<()> {
let hash_len = hashed.len();
let t_len = prefix.len() + hashed.len();
if k < t_len + 11 {
return Err(Error::Verification);
}
let mut ok = em[0].ct_eq(&0u8);
ok &= em[1].ct_eq(&1u8);
ok &= em[k - hash_len..k].ct_eq(hashed);
ok &= em[k - t_len..k - hash_len].ct_eq(prefix);
ok &= em[k - t_len - 1].ct_eq(&0u8);
for el in em.iter().skip(2).take(k - t_len - 3) {
ok &= el.ct_eq(&0xff)
}
if !ok.to_bool() {
return Err(Error::Verification);
}
Ok(())
}
#[inline]
pub(crate) fn pkcs1v15_generate_prefix<D>() -> Vec<u8>
where
D: Digest + AssociatedOid,
{
let oid = D::OID.as_bytes();
let oid_len = oid.len() as u8;
let digest_len = <D as Digest>::output_size() as u8;
let mut v = vec![
0x30,
oid_len + 8 + digest_len,
0x30,
oid_len + 4,
0x6,
oid_len,
];
v.extend_from_slice(oid);
v.extend_from_slice(&[0x05, 0x00, 0x04, digest_len]);
v
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ChaCha8Rng;
use rand_core::SeedableRng;
#[test]
fn test_non_zero_bytes() {
for _ in 0..10 {
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let mut b = vec![0u8; 512];
non_zero_random_bytes(&mut rng, &mut b).unwrap();
for el in &b {
assert_ne!(*el, 0u8);
}
}
}
#[test]
fn test_encrypt_tiny_no_crash() {
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let k = 8;
let message = vec![1u8; 4];
let res = pkcs1v15_encrypt_pad(&mut rng, &message, k);
assert_eq!(res, Err(Error::MessageTooLong));
}
}