use alloc::vec::Vec;
use digest::{Digest, DynDigest, FixedOutputReset};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use super::mgf::{mgf1_xor, mgf1_xor_digest};
use crate::errors::{Error, Result};
pub(crate) fn emsa_pss_encode(
m_hash: &[u8],
em_bits: usize,
salt: &[u8],
hash: &mut dyn DynDigest,
) -> Result<Vec<u8>> {
let h_len = hash.output_size();
let s_len = salt.len();
let em_len = em_bits.div_ceil(8);
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
let mut em = vec![0; em_len];
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..(em_len - 1) - db.len()];
let prefix = [0u8; 8];
hash.update(&prefix);
hash.update(m_hash);
hash.update(salt);
let hashed = hash.finalize_reset();
h.copy_from_slice(&hashed);
db[em_len - s_len - h_len - 2] = 0x01;
db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
mgf1_xor(db, hash, h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
em[em_len - 1] = 0xBC;
Ok(em)
}
pub(crate) fn emsa_pss_encode_digest<D>(
m_hash: &[u8],
em_bits: usize,
salt: &[u8],
) -> Result<Vec<u8>>
where
D: Digest + FixedOutputReset,
{
let h_len = <D as Digest>::output_size();
let s_len = salt.len();
let em_len = em_bits.div_ceil(8);
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
let mut em = vec![0; em_len];
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..(em_len - 1) - db.len()];
let prefix = [0u8; 8];
let mut hash = D::new();
Digest::update(&mut hash, prefix);
Digest::update(&mut hash, m_hash);
Digest::update(&mut hash, salt);
let hashed = hash.finalize_reset();
h.copy_from_slice(&hashed);
db[em_len - s_len - h_len - 2] = 0x01;
db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
mgf1_xor_digest(db, &mut hash, h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
em[em_len - 1] = 0xBC;
Ok(em)
}
fn emsa_pss_verify_pre<'a>(
m_hash: &[u8],
em: &'a mut [u8],
em_bits: usize,
s_len: Option<usize>,
h_len: usize,
) -> Result<(&'a mut [u8], &'a mut [u8])> {
if m_hash.len() != h_len {
return Err(Error::Verification);
}
let em_len = em.len(); if let Some(s_len) = s_len {
if em_len < h_len + s_len + 2 {
return Err(Error::Verification);
}
}
if em[em.len() - 1] != 0xBC {
return Err(Error::Verification);
}
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..h_len];
if db[0]
& (0xFF_u8
.checked_shl(8 - (8 * em_len - em_bits) as u32)
.unwrap_or(0))
!= 0
{
return Err(Error::Verification);
}
Ok((db, h))
}
fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
let valid: Choice = zeroes
.iter()
.fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00));
valid & rest[0].ct_eq(&0x01)
}
fn emsa_pss_get_salt_len(db: &[u8], em_len: usize, h_len: usize) -> (usize, Choice) {
let em_len = em_len as u32;
let h_len = h_len as u32;
let max_scan_len = em_len - h_len - 2;
let mut separator_pos = 0u32;
let mut found_separator = Choice::from(0u8);
let mut padding_valid = Choice::from(1u8);
for i in 0..=max_scan_len {
let byte_val = db[i as usize];
let is_zero = byte_val.ct_eq(&0x00);
let is_separator = byte_val.ct_eq(&0x01);
let is_invalid = !(is_zero | is_separator);
let should_update_pos = is_separator & !found_separator;
separator_pos = u32::conditional_select(&separator_pos, &i, should_update_pos);
found_separator =
Choice::conditional_select(&found_separator, &Choice::from(1u8), should_update_pos);
let corrupts_padding = is_invalid & !found_separator;
padding_valid &= !corrupts_padding;
}
let salt_len = max_scan_len.wrapping_sub(separator_pos);
let final_valid = found_separator & padding_valid;
let result_len = u32::conditional_select(&0u32, &salt_len, final_valid);
(result_len as usize, final_valid)
}
pub(crate) fn emsa_pss_verify(
m_hash: &[u8],
em: &mut [u8],
s_len: Option<usize>,
hash: &mut dyn DynDigest,
key_bits: usize,
) -> Result<()> {
let em_bits = key_bits - 1;
let em_len = em_bits.div_ceil(8);
let key_len = key_bits.div_ceil(8);
let h_len = hash.output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
mgf1_xor(db, hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let (s_len, salt_valid) = match s_len {
Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
None => emsa_pss_get_salt_len(db, em_len, h_len),
};
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
hash.update(&prefix[..]);
hash.update(m_hash);
hash.update(salt);
let h0 = hash.finalize_reset();
if (salt_valid & h0.ct_eq(h)).into() {
Ok(())
} else {
Err(Error::Verification)
}
}
pub(crate) fn emsa_pss_verify_digest<D>(
m_hash: &[u8],
em: &mut [u8],
s_len: Option<usize>,
key_bits: usize,
) -> Result<()>
where
D: Digest + FixedOutputReset,
{
let em_bits = key_bits - 1;
let em_len = em_bits.div_ceil(8);
let key_len = key_bits.div_ceil(8);
let h_len = <D as Digest>::output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
let mut hash = D::new();
mgf1_xor_digest::<D>(db, &mut hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let (s_len, salt_valid) = match s_len {
Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
None => emsa_pss_get_salt_len(db, em_len, h_len),
};
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
Digest::update(&mut hash, &prefix[..]);
Digest::update(&mut hash, m_hash);
Digest::update(&mut hash, salt);
let h0 = hash.finalize_reset();
if (salt_valid & h0.ct_eq(h)).into() {
Ok(())
} else {
Err(Error::Verification)
}
}