use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::cell::UnsafeCell;
use core::ffi::c_void;
use core::ptr;
use crate::error::{check, len_as_u32, WolfCryptError};
use crate::rand::WolfRng;
use wolfcrypt_rs::{
wc_HpkeDeserializePublicKey, wc_HpkeFreeKey, wc_HpkeGenerateKeyPair, wc_HpkeInit,
wc_HpkeOpenBase, wc_HpkeSealBase, wc_HpkeSerializePublicKey, HPKE_Nt_MAX, WcHpke,
DHKEM_P256_ENC_LEN, DHKEM_P256_HKDF_SHA256, DHKEM_P384_ENC_LEN, DHKEM_P384_HKDF_SHA384,
DHKEM_P521_ENC_LEN, DHKEM_P521_HKDF_SHA512, DHKEM_X25519_ENC_LEN, DHKEM_X25519_HKDF_SHA256,
DHKEM_X448_ENC_LEN, DHKEM_X448_HKDF_SHA512, HPKE_AES_128_GCM, HPKE_AES_256_GCM,
HPKE_HKDF_SHA256, HPKE_HKDF_SHA384, HPKE_HKDF_SHA512,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HpkeSuite {
kem: i32,
kdf: i32,
aead: i32,
}
impl HpkeSuite {
pub const P256_SHA256_AES128: Self = Self {
kem: DHKEM_P256_HKDF_SHA256,
kdf: HPKE_HKDF_SHA256,
aead: HPKE_AES_128_GCM,
};
pub const P256_SHA256_AES256: Self = Self {
kem: DHKEM_P256_HKDF_SHA256,
kdf: HPKE_HKDF_SHA256,
aead: HPKE_AES_256_GCM,
};
pub const X25519_SHA256_AES128: Self = Self {
kem: DHKEM_X25519_HKDF_SHA256,
kdf: HPKE_HKDF_SHA256,
aead: HPKE_AES_128_GCM,
};
pub const X25519_SHA256_AES256: Self = Self {
kem: DHKEM_X25519_HKDF_SHA256,
kdf: HPKE_HKDF_SHA256,
aead: HPKE_AES_256_GCM,
};
pub const P384_SHA384_AES256: Self = Self {
kem: DHKEM_P384_HKDF_SHA384,
kdf: HPKE_HKDF_SHA384,
aead: HPKE_AES_256_GCM,
};
pub const P521_SHA512_AES256: Self = Self {
kem: DHKEM_P521_HKDF_SHA512,
kdf: HPKE_HKDF_SHA512,
aead: HPKE_AES_256_GCM,
};
pub const X448_SHA512_AES256: Self = Self {
kem: DHKEM_X448_HKDF_SHA512,
kdf: HPKE_HKDF_SHA512,
aead: HPKE_AES_256_GCM,
};
pub fn enc_len(&self) -> usize {
match self.kem {
DHKEM_P256_HKDF_SHA256 => DHKEM_P256_ENC_LEN,
DHKEM_P384_HKDF_SHA384 => DHKEM_P384_ENC_LEN,
DHKEM_P521_HKDF_SHA512 => DHKEM_P521_ENC_LEN,
DHKEM_X25519_HKDF_SHA256 => DHKEM_X25519_ENC_LEN,
DHKEM_X448_HKDF_SHA512 => DHKEM_X448_ENC_LEN,
_ => 0,
}
}
pub fn tag_len(&self) -> usize {
HPKE_Nt_MAX
}
}
pub struct HpkeKeyPair {
key: *mut c_void,
kem: i32,
suite: HpkeSuite,
hpke: Arc<UnsafeCell<WcHpke>>,
}
impl HpkeKeyPair {
pub fn serialize_public_key(&mut self) -> Result<Vec<u8>, WolfCryptError> {
let enc_len = self.suite.enc_len();
if enc_len == 0 {
return Err(WolfCryptError::InvalidInput);
}
let mut buf = vec![0u8; enc_len];
let mut out_sz: u16 = enc_len as u16;
let rc = unsafe {
wc_HpkeSerializePublicKey(
&mut *self.hpke.get(),
self.key,
buf.as_mut_ptr(),
&mut out_sz,
)
};
check(rc, "wc_HpkeSerializePublicKey")?;
buf.truncate(out_sz as usize);
Ok(buf)
}
}
impl Drop for HpkeKeyPair {
fn drop(&mut self) {
if !self.key.is_null() {
unsafe {
wc_HpkeFreeKey(self.hpke.get(), self.kem as u16, self.key, ptr::null_mut());
}
}
}
}
unsafe impl Send for HpkeKeyPair {}
pub struct Hpke {
hpke: Arc<UnsafeCell<WcHpke>>,
suite: HpkeSuite,
}
impl Hpke {
pub fn new(suite: HpkeSuite) -> Result<Self, WolfCryptError> {
let hpke = Arc::new(UnsafeCell::new(WcHpke::zeroed()));
let rc = unsafe {
wc_HpkeInit(
&mut *hpke.get(),
suite.kem,
suite.kdf,
suite.aead,
ptr::null_mut(),
)
};
check(rc, "wc_HpkeInit")?;
Ok(Self { hpke, suite })
}
pub fn suite(&self) -> HpkeSuite {
self.suite
}
pub fn generate_keypair(&mut self, rng: &mut WolfRng) -> Result<HpkeKeyPair, WolfCryptError> {
let mut key: *mut c_void = ptr::null_mut();
let rc =
unsafe { wc_HpkeGenerateKeyPair(&mut *self.hpke.get(), &mut key, &mut rng.rng) };
check(rc, "wc_HpkeGenerateKeyPair")?;
if key.is_null() {
return Err(WolfCryptError::AllocFailed);
}
Ok(HpkeKeyPair {
key,
kem: self.suite.kem,
suite: self.suite,
hpke: Arc::clone(&self.hpke),
})
}
pub fn deserialize_public_key(&mut self, enc: &[u8]) -> Result<HpkeKeyPair, WolfCryptError> {
let mut key: *mut c_void = ptr::null_mut();
let rc = unsafe {
wc_HpkeDeserializePublicKey(
&mut *self.hpke.get(),
&mut key,
enc.as_ptr(),
enc.len() as u16,
)
};
check(rc, "wc_HpkeDeserializePublicKey")?;
if key.is_null() {
return Err(WolfCryptError::AllocFailed);
}
Ok(HpkeKeyPair {
key,
kem: self.suite.kem,
suite: self.suite,
hpke: Arc::clone(&self.hpke),
})
}
pub fn seal_base(
&mut self,
ephemeral: &mut HpkeKeyPair,
receiver_pub: &mut HpkeKeyPair,
info: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), WolfCryptError> {
let enc = ephemeral.serialize_public_key()?;
let ct_len = plaintext.len() + self.suite.tag_len();
let mut ciphertext = vec![0u8; ct_len];
let mut info_buf = Vec::from(info);
let mut aad_buf = Vec::from(aad);
let mut pt_buf = Vec::from(plaintext);
let rc = unsafe {
wc_HpkeSealBase(
&mut *self.hpke.get(),
ephemeral.key,
receiver_pub.key,
info_buf.as_mut_ptr(),
len_as_u32(info_buf.len()),
aad_buf.as_mut_ptr(),
len_as_u32(aad_buf.len()),
pt_buf.as_mut_ptr(),
len_as_u32(pt_buf.len()),
ciphertext.as_mut_ptr(),
)
};
check(rc, "wc_HpkeSealBase")?;
Ok((enc, ciphertext))
}
pub fn open_base(
&mut self,
receiver: &mut HpkeKeyPair,
enc: &[u8],
info: &[u8],
aad: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>, WolfCryptError> {
let tag_len = self.suite.tag_len();
if ciphertext.len() < tag_len {
return Err(WolfCryptError::InvalidInput);
}
let pt_len = ciphertext.len() - tag_len;
let mut plaintext = vec![0u8; pt_len];
let mut info_buf = Vec::from(info);
let mut aad_buf = Vec::from(aad);
let mut ct_buf = Vec::from(ciphertext);
let rc = unsafe {
wc_HpkeOpenBase(
&mut *self.hpke.get(),
receiver.key,
enc.as_ptr(),
enc.len() as u16,
info_buf.as_mut_ptr(),
len_as_u32(info_buf.len()),
aad_buf.as_mut_ptr(),
len_as_u32(aad_buf.len()),
ct_buf.as_mut_ptr(),
len_as_u32(ct_buf.len()),
plaintext.as_mut_ptr(),
)
};
check(rc, "wc_HpkeOpenBase")?;
Ok(plaintext)
}
}
unsafe impl Send for Hpke {}