use digest::{Digest, DynDigest, FixedOutputReset};
use subtle::{Choice, ConstantTimeEq};
use super::mgf::{mgf1_xor, mgf1_xor_digest};
use crate::errors::{Error, Result};
use core::marker::PhantomData;
const MAX_DIGEST_LEN: usize = 64;
pub(crate) fn emsa_pss_encode(
m_hash: &[u8],
em_bits: usize,
salt: &[u8],
hash: &mut dyn DynDigest,
) -> Result<()> {
let h_len = hash.output_size();
let s_len = salt.len();
let em_len = (em_bits + 7) / 8;
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
todo!()
}
pub(crate) fn emsa_pss_encode_digest<D>(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result<()>
where
D: Digest + FixedOutputReset,
{
let h_len = <D as Digest>::output_size();
let s_len = salt.len();
let em_len = (em_bits + 7) / 8;
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
todo!()
}
fn emsa_pss_verify_pre<'a>(
m_hash: &[u8],
em: &'a mut [u8],
em_bits: usize,
s_len: usize,
h_len: usize,
) -> Result<(&'a mut [u8], &'a mut [u8])> {
if m_hash.len() != h_len {
return Err(Error::Verification);
}
let em_len = em.len(); if em_len < h_len + s_len + 2 {
return Err(Error::Verification);
}
if em[em.len() - 1] != 0xBC {
return Err(Error::Verification);
}
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..h_len];
if db[0]
& (0xFF_u8
.checked_shl(8 - (8 * em_len - em_bits) as u32)
.unwrap_or(0))
!= 0
{
return Err(Error::Verification);
}
Ok((db, h))
}
fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
let valid: Choice = zeroes
.iter()
.fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00));
valid & rest[0].ct_eq(&0x01)
}
pub(crate) fn emsa_pss_verify<D>(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
hash: &mut D,
key_bits: usize,
) -> Result<()>
where
D: Digest + FixedOutputReset,
{
let em_bits = key_bits - 1;
let em_len = (em_bits + 7) / 8;
let key_len = (key_bits + 7) / 8;
let h_len = <D as Digest>::output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
mgf1_xor(db, hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
Digest::update(hash, &prefix[..]);
Digest::update(hash, m_hash);
Digest::update(hash, salt);
let mut digest_storage = [0u8; MAX_DIGEST_LEN];
todo!()
}
pub(crate) fn emsa_pss_verify_digest<D>(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
key_bits: usize,
) -> Result<()>
where
D: Digest + FixedOutputReset,
{
let em_bits = key_bits - 1;
let em_len = (em_bits + 7) / 8;
let key_len = (key_bits + 7) / 8;
let h_len = <D as Digest>::output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
let mut hash = D::new();
mgf1_xor_digest::<D>(db, &mut hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
Digest::update(&mut hash, &prefix[..]);
Digest::update(&mut hash, m_hash);
Digest::update(&mut hash, salt);
let h0 = hash.finalize_reset();
if (salt_valid & h0.ct_eq(h)).into() {
Ok(())
} else {
Err(Error::Verification)
}
}