#![cfg(all(rsa, rsa_oaep))]
use core::marker::PhantomData;
use crate::rsa::RSA;
use crate::sys;
#[cfg(random)]
use crate::random::RNG;
mod private {
pub trait Sealed {}
}
pub trait Hash: private::Sealed {
const HASH_TYPE: u32;
const MGF: i32;
}
#[cfg(sha)]
pub enum Sha1 {}
#[cfg(sha)]
impl private::Sealed for Sha1 {}
#[cfg(sha)]
impl Hash for Sha1 {
const HASH_TYPE: u32 = sys::wc_HashType_WC_HASH_TYPE_SHA;
const MGF: i32 = sys::WC_MGF1SHA1 as i32;
}
#[cfg(sha224)]
pub enum Sha224 {}
#[cfg(sha224)]
impl private::Sealed for Sha224 {}
#[cfg(sha224)]
impl Hash for Sha224 {
const HASH_TYPE: u32 = sys::wc_HashType_WC_HASH_TYPE_SHA224;
const MGF: i32 = sys::WC_MGF1SHA224 as i32;
}
#[cfg(sha256)]
pub enum Sha256 {}
#[cfg(sha256)]
impl private::Sealed for Sha256 {}
#[cfg(sha256)]
impl Hash for Sha256 {
const HASH_TYPE: u32 = sys::wc_HashType_WC_HASH_TYPE_SHA256;
const MGF: i32 = sys::WC_MGF1SHA256 as i32;
}
#[cfg(sha384)]
pub enum Sha384 {}
#[cfg(sha384)]
impl private::Sealed for Sha384 {}
#[cfg(sha384)]
impl Hash for Sha384 {
const HASH_TYPE: u32 = sys::wc_HashType_WC_HASH_TYPE_SHA384;
const MGF: i32 = sys::WC_MGF1SHA384 as i32;
}
#[cfg(sha512)]
pub enum Sha512 {}
#[cfg(sha512)]
impl private::Sealed for Sha512 {}
#[cfg(sha512)]
impl Hash for Sha512 {
const HASH_TYPE: u32 = sys::wc_HashType_WC_HASH_TYPE_SHA512;
const MGF: i32 = sys::WC_MGF1SHA512 as i32;
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Ciphertext<const N: usize>([u8; N]);
impl<const N: usize> Ciphertext<N> {
pub const fn from_bytes(bytes: [u8; N]) -> Self {
Self(bytes)
}
pub const fn to_bytes(&self) -> [u8; N] {
self.0
}
}
impl<const N: usize> AsRef<[u8]> for Ciphertext<N> {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl<const N: usize> TryFrom<&[u8]> for Ciphertext<N> {
type Error = i32;
fn try_from(bytes: &[u8]) -> Result<Self, i32> {
let arr: [u8; N] = bytes.try_into()
.map_err(|_| sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG)?;
Ok(Self(arr))
}
}
impl<const N: usize> From<Ciphertext<N>> for [u8; N] {
fn from(ct: Ciphertext<N>) -> Self {
ct.0
}
}
fn check_modulus_size(rsa: &RSA, expected: usize) -> Result<(), i32> {
let actual = rsa.get_encrypt_size()?;
if actual != expected {
return Err(sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG);
}
Ok(())
}
const MAX_E_LEN: usize = 8;
pub struct EncryptingKey<H: Hash, const N: usize> {
n: [u8; N],
e: [u8; MAX_E_LEN],
e_len: u8,
_hash: PhantomData<H>,
}
impl<H: Hash, const N: usize> Clone for EncryptingKey<H, N> {
fn clone(&self) -> Self { *self }
}
impl<H: Hash, const N: usize> Copy for EncryptingKey<H, N> {}
impl<H: Hash, const N: usize> core::fmt::Debug for EncryptingKey<H, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("EncryptingKey")
.field("n", &&self.n[..])
.field("e", &self.exponent())
.finish()
}
}
impl<H: Hash, const N: usize> PartialEq for EncryptingKey<H, N> {
fn eq(&self, other: &Self) -> bool {
self.n == other.n && self.exponent() == other.exponent()
}
}
impl<H: Hash, const N: usize> Eq for EncryptingKey<H, N> {}
impl<H: Hash, const N: usize> EncryptingKey<H, N> {
pub fn from_components(n: &[u8], e: &[u8]) -> Result<Self, i32> {
if n.len() != N || e.is_empty() || e.len() > MAX_E_LEN {
return Err(sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG);
}
let mut n_arr = [0u8; N];
n_arr.copy_from_slice(n);
let mut e_arr = [0u8; MAX_E_LEN];
e_arr[..e.len()].copy_from_slice(e);
Ok(Self {
n: n_arr,
e: e_arr,
e_len: e.len() as u8,
_hash: PhantomData,
})
}
pub fn from_rsa(rsa: &RSA) -> Result<Self, i32> {
check_modulus_size(rsa, N)?;
let mut n = [0u8; N];
let mut e = [0u8; MAX_E_LEN];
let mut n_len: u32 = n.len() as u32;
let mut e_len: u32 = e.len() as u32;
#[cfg(rsa_const_api)]
let key = &rsa.wc_rsakey;
#[cfg(not(rsa_const_api))]
let key = core::ptr::addr_of!(rsa.wc_rsakey) as *mut sys::RsaKey;
let rc = unsafe {
sys::wc_RsaFlattenPublicKey(
key,
e.as_mut_ptr(), &mut e_len,
n.as_mut_ptr(), &mut n_len,
)
};
if rc != 0 {
return Err(rc);
}
if (n_len as usize) != N || e_len == 0 || (e_len as usize) > MAX_E_LEN {
return Err(sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG);
}
Ok(Self {
n,
e,
e_len: e_len as u8,
_hash: PhantomData,
})
}
pub fn from_public_der(der: &[u8]) -> Result<Self, i32> {
let rsa = RSA::new_public_from_der(der)?;
Self::from_rsa(&rsa)
}
pub const fn modulus(&self) -> &[u8; N] {
&self.n
}
pub fn exponent(&self) -> &[u8] {
&self.e[..self.e_len as usize]
}
#[cfg(random)]
pub fn encrypt(&self, rng: &RNG, msg: &[u8]) -> Result<Ciphertext<N>, i32> {
self.encrypt_inner(rng, msg, None)
}
#[cfg(random)]
pub fn encrypt_with_label(&self, rng: &RNG, msg: &[u8], label: &[u8]) -> Result<Ciphertext<N>, i32> {
self.encrypt_inner(rng, msg, Some(label))
}
#[cfg(random)]
fn encrypt_inner(&self, rng: &RNG, msg: &[u8], label: Option<&[u8]>) -> Result<Ciphertext<N>, i32> {
let mut rsa = RSA::new_public_from_raw(&self.n, self.exponent())?;
let mut out = [0u8; N];
let len = rsa.public_encrypt_oaep_ex(msg, &mut out, H::HASH_TYPE, H::MGF, label, rng)?;
if len != N {
return Err(sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG);
}
Ok(Ciphertext(out))
}
}
pub struct DecryptingKey<H: Hash, const N: usize> {
inner: RSA,
_hash: PhantomData<H>,
}
impl<H: Hash, const N: usize> DecryptingKey<H, N> {
#[cfg(all(random, rsa_keygen))]
pub fn generate(rng: RNG) -> Result<Self, i32> {
let bits: i32 = (N * 8).try_into().map_err(|_| sys::wolfCrypt_ErrorCodes_BAD_FUNC_ARG)?;
let mut rsa = RSA::generate(bits, 65537, &rng)?;
rsa.set_rng(rng)?;
Ok(Self { inner: rsa, _hash: PhantomData })
}
#[cfg(random)]
pub fn from_rsa(rsa: RSA, rng: RNG) -> Result<Self, i32> {
check_modulus_size(&rsa, N)?;
let mut rsa = rsa;
rsa.set_rng(rng)?;
Ok(Self { inner: rsa, _hash: PhantomData })
}
#[cfg(random)]
pub fn from_private_der(der: &[u8], rng: RNG) -> Result<Self, i32> {
let rsa = RSA::new_from_der(der)?;
Self::from_rsa(rsa, rng)
}
pub fn as_rsa(&self) -> &RSA {
&self.inner
}
pub fn into_rsa(self) -> RSA {
self.inner
}
pub fn encrypting_key(&self) -> Result<EncryptingKey<H, N>, i32> {
EncryptingKey::from_rsa(&self.inner)
}
pub fn decrypt(&mut self, ciphertext: &Ciphertext<N>, out: &mut [u8]) -> Result<usize, i32> {
self.decrypt_inner(ciphertext, out, None)
}
pub fn decrypt_with_label(&mut self, ciphertext: &Ciphertext<N>, out: &mut [u8], label: &[u8]) -> Result<usize, i32> {
self.decrypt_inner(ciphertext, out, Some(label))
}
fn decrypt_inner(&mut self, ciphertext: &Ciphertext<N>, out: &mut [u8], label: Option<&[u8]>) -> Result<usize, i32> {
self.inner.private_decrypt_oaep_ex(&ciphertext.0, out, H::HASH_TYPE, H::MGF, label)
}
}