use digest::{Digest, DynDigest, FixedOutputReset};
use rand_core::CryptoRngCore;
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use zeroize::Zeroizing;
use super::mgf::{mgf1_xor, mgf1_xor_digest};
use crate::errors::{Error, Result};
mod label;
pub use label::Label;
const MAX_LABEL_LEN: u64 = 1 << 61;
const MAX_DIGEST_LEN: usize = 64;
#[inline]
fn encrypt_internal<'a, R: CryptoRngCore + ?Sized, MGF: FnMut(&mut [u8], &mut [u8])>(
rng: &mut R,
msg: &[u8],
p_hash: &[u8],
h_size: usize,
k: usize,
mut mgf: MGF,
storage: &'a mut [u8],
) -> Result<&'a [u8]> {
if msg.len() + 2 * h_size + 2 > k {
return Err(Error::MessageTooLong);
}
let mut em = storage.get_mut(..k).ok_or(Error::OutputBufferTooSmall)?;
let (_, payload) = em.split_at_mut(1);
let (seed, db) = payload.split_at_mut(h_size);
rng.fill_bytes(seed);
let db_len = k - h_size - 1;
db[0..h_size].copy_from_slice(p_hash);
db[db_len - msg.len() - 1] = 1;
db[db_len - msg.len()..].copy_from_slice(msg);
mgf(seed, db);
Ok(em)
}
#[inline]
pub(crate) fn oaep_encrypt<'a, R: CryptoRngCore + ?Sized, D>(
rng: &mut R,
msg: &[u8],
digest: &mut D,
mgf_digest: &mut D,
label: Option<Label>,
k: usize,
storage: &'a mut [u8],
) -> Result<&'a [u8]>
where
D: Digest + FixedOutputReset,
{
let h_size = <D as Digest>::output_size();
todo!()
}
#[inline]
pub(crate) fn oaep_encrypt_digest<
'a,
R: CryptoRngCore + ?Sized,
D: Digest,
MGD: Digest + FixedOutputReset,
>(
rng: &mut R,
msg: &[u8],
label: Option<Label>,
k: usize,
storage: &'a mut [u8],
) -> Result<&'a [u8]> {
let h_size = <D as Digest>::output_size();
let label = label.unwrap_or_default();
if label.len() as u64 >= MAX_LABEL_LEN {
return Err(Error::LabelTooLong);
}
let p_hash = D::digest(label.as_bytes());
encrypt_internal(
rng,
msg,
&p_hash,
h_size,
k,
|seed, db| {
let mut mgf_digest = MGD::new();
mgf1_xor_digest(db, &mut mgf_digest, seed);
mgf1_xor_digest(seed, &mut mgf_digest, db);
},
storage,
)
}
#[inline]
pub(crate) fn oaep_decrypt<'a, D>(
em: &mut [u8],
digest: &mut D,
mgf_digest: &mut D,
label: Option<Label>,
k: usize,
storage: &'a mut [u8],
) -> Result<&'a [u8]>
where
D: Digest + FixedOutputReset,
{
let h_size = <D as Digest>::output_size();
todo!()
}
#[inline]
pub(crate) fn oaep_decrypt_digest<'a, D: Digest, MGD: Digest + FixedOutputReset>(
em: &mut [u8],
label: Option<Label>,
k: usize,
storage: &'a mut [u8],
) -> Result<&'a [u8]> {
let h_size = <D as Digest>::output_size();
let label = label.unwrap_or_default();
if label.len() as u64 >= MAX_LABEL_LEN {
return Err(Error::LabelTooLong);
}
let expected_p_hash = D::digest(label.as_bytes());
let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
let mut mgf_digest = MGD::new();
mgf1_xor_digest(seed, &mut mgf_digest, db);
mgf1_xor_digest(db, &mut mgf_digest, seed);
}, storage)?;
if res.is_none().into() {
return Err(Error::Decryption);
}
let (out, index) = res.unwrap();
todo!()
}
#[inline]
fn decrypt_inner<'a,MGF: FnMut(&mut [u8], &mut [u8])>(
em: &mut [u8],
h_size: usize,
expected_p_hash: &[u8],
k: usize,
mut mgf: MGF,
storage: &'a mut [u8],
) -> Result<CtOption<(&'a [u8], u32)>> {
if k < 11 {
return Err(Error::Decryption);
}
if k < h_size * 2 + 2 {
return Err(Error::Decryption);
}
let first_byte_is_zero = em[0].ct_eq(&0u8);
let (_, payload) = em.split_at_mut(1);
let (seed, db) = payload.split_at_mut(h_size);
mgf(seed, db);
let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
let mut looking_for_index = Choice::from(1u8);
let mut index = 0u32;
let mut nonzero_before_one = Choice::from(0u8);
for (i, el) in db.iter().skip(h_size).enumerate() {
let equals0 = el.ct_eq(&0u8);
let equals1 = el.ct_eq(&1u8);
index.conditional_assign(&(i as u32), looking_for_index & equals1);
looking_for_index &= !equals1;
nonzero_before_one |= looking_for_index & !equals0;
}
let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
todo!()
}