use crate::error::Error;
use crate::params::SntrupParams;
use core::marker::PhantomData;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
#[derive(Clone)]
pub struct EncapsulationKey<P: SntrupParams> {
bytes: Vec<u8>,
_marker: PhantomData<P>,
}
#[derive(Clone)]
pub struct DecapsulationKey<P: SntrupParams> {
bytes: Vec<u8>,
_marker: PhantomData<P>,
}
#[derive(Clone)]
pub struct Ciphertext<P: SntrupParams> {
bytes: Vec<u8>,
_marker: PhantomData<P>,
}
#[derive(Clone)]
pub struct SharedSecret<P: SntrupParams> {
bytes: Vec<u8>,
_marker: PhantomData<P>,
}
#[derive(Debug, Clone, Copy)]
pub struct SntrupKem<P: SntrupParams>(PhantomData<P>);
impl<P: SntrupParams> EncapsulationKey<P> {
pub(crate) fn from_vec(bytes: Vec<u8>) -> Self {
Self {
bytes,
_marker: PhantomData,
}
}
}
impl<P: SntrupParams> DecapsulationKey<P> {
pub(crate) fn from_vec(bytes: Vec<u8>) -> Self {
Self {
bytes,
_marker: PhantomData,
}
}
}
impl<P: SntrupParams> Ciphertext<P> {
pub(crate) fn from_vec(bytes: Vec<u8>) -> Self {
Self {
bytes,
_marker: PhantomData,
}
}
}
impl<P: SntrupParams> SharedSecret<P> {
pub(crate) fn from_vec(bytes: Vec<u8>) -> Self {
Self {
bytes,
_marker: PhantomData,
}
}
}
impl<P: SntrupParams> DecapsulationKey<P> {
pub fn encapsulation_key(&self) -> EncapsulationKey<P> {
let params = P::params();
let pk_start = 2 * params.small_encode_size;
let pk_end = pk_start + params.pk_size;
EncapsulationKey::from_vec(self.bytes[pk_start..pk_end].to_vec())
}
}
impl<P: SntrupParams> core::fmt::Debug for EncapsulationKey<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let name: String = format!("{}::EncapsulationKey", P::NAME);
f.debug_struct(&name)
.field("len", &P::PK_BYTES)
.field("bytes", &hex::encode(&self.bytes))
.finish()
}
}
impl<P: SntrupParams> core::fmt::Debug for DecapsulationKey<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let name: String = format!("{}::DecapsulationKey", P::NAME);
f.debug_struct(&name).finish()
}
}
impl<P: SntrupParams> core::fmt::Debug for Ciphertext<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let name: String = format!("{}::Ciphertext", P::NAME);
f.debug_struct(&name)
.field("len", &P::CT_BYTES)
.field("bytes", &hex::encode(&self.bytes))
.finish()
}
}
impl<P: SntrupParams> core::fmt::Debug for SharedSecret<P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let name: String = format!("{}::SharedSecret", P::NAME);
f.debug_struct(&name).finish()
}
}
impl<P: SntrupParams> AsRef<[u8]> for EncapsulationKey<P> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<P: SntrupParams> AsRef<[u8]> for DecapsulationKey<P> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<P: SntrupParams> AsRef<[u8]> for Ciphertext<P> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<P: SntrupParams> AsRef<[u8]> for SharedSecret<P> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<P: SntrupParams> TryFrom<&[u8]> for EncapsulationKey<P> {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != P::PK_BYTES {
return Err(Error::InvalidSize {
expected: P::PK_BYTES,
actual: bytes.len(),
});
}
Ok(Self {
bytes: bytes.to_vec(),
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> TryFrom<&[u8]> for DecapsulationKey<P> {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != P::SK_BYTES {
return Err(Error::InvalidSize {
expected: P::SK_BYTES,
actual: bytes.len(),
});
}
Ok(Self {
bytes: bytes.to_vec(),
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> TryFrom<&[u8]> for Ciphertext<P> {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != P::CT_BYTES {
return Err(Error::InvalidSize {
expected: P::CT_BYTES,
actual: bytes.len(),
});
}
Ok(Self {
bytes: bytes.to_vec(),
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> TryFrom<Vec<u8>> for EncapsulationKey<P> {
type Error = Error;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<&Vec<u8>> for EncapsulationKey<P> {
type Error = Error;
fn try_from(bytes: &Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<Box<[u8]>> for EncapsulationKey<P> {
type Error = Error;
fn try_from(bytes: Box<[u8]>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_ref())
}
}
impl<P: SntrupParams> TryFrom<Vec<u8>> for DecapsulationKey<P> {
type Error = Error;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<&Vec<u8>> for DecapsulationKey<P> {
type Error = Error;
fn try_from(bytes: &Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<Box<[u8]>> for DecapsulationKey<P> {
type Error = Error;
fn try_from(bytes: Box<[u8]>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_ref())
}
}
impl<P: SntrupParams> TryFrom<Vec<u8>> for Ciphertext<P> {
type Error = Error;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<&Vec<u8>> for Ciphertext<P> {
type Error = Error;
fn try_from(bytes: &Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl<P: SntrupParams> TryFrom<Box<[u8]>> for Ciphertext<P> {
type Error = Error;
fn try_from(bytes: Box<[u8]>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_ref())
}
}
impl<P: SntrupParams> PartialEq for EncapsulationKey<P> {
fn eq(&self, other: &Self) -> bool {
self.bytes == other.bytes
}
}
impl<P: SntrupParams> Eq for EncapsulationKey<P> {}
impl<P: SntrupParams> PartialEq for Ciphertext<P> {
fn eq(&self, other: &Self) -> bool {
self.bytes == other.bytes
}
}
impl<P: SntrupParams> Eq for Ciphertext<P> {}
impl<P: SntrupParams> ConstantTimeEq for DecapsulationKey<P> {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.as_slice().ct_eq(other.bytes.as_slice())
}
}
impl<P: SntrupParams> PartialEq for DecapsulationKey<P> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<P: SntrupParams> Eq for DecapsulationKey<P> {}
impl<P: SntrupParams> ConstantTimeEq for SharedSecret<P> {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.as_slice().ct_eq(other.bytes.as_slice())
}
}
impl<P: SntrupParams> PartialEq for SharedSecret<P> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<P: SntrupParams> Eq for SharedSecret<P> {}
impl<P: SntrupParams> Zeroize for DecapsulationKey<P> {
fn zeroize(&mut self) {
self.bytes.zeroize();
}
}
impl<P: SntrupParams> Drop for DecapsulationKey<P> {
fn drop(&mut self) {
self.zeroize();
}
}
impl<P: SntrupParams> Zeroize for SharedSecret<P> {
fn zeroize(&mut self) {
self.bytes.zeroize();
}
}
impl<P: SntrupParams> Drop for SharedSecret<P> {
fn drop(&mut self) {
self.zeroize();
}
}
#[cfg(feature = "kgen")]
impl<P: SntrupParams> SntrupKem<P> {
pub fn generate_key(
rng: &mut impl rand::CryptoRng,
) -> (EncapsulationKey<P>, DecapsulationKey<P>) {
let (pk, sk) = crate::kem::keygen(P::params(), rng);
(
EncapsulationKey::from_vec(pk),
DecapsulationKey::from_vec(sk),
)
}
pub fn generate_key_deterministic(
seed: &[u8; 32],
) -> (EncapsulationKey<P>, DecapsulationKey<P>) {
use rand::SeedableRng;
let mut rng = rand_chacha::ChaCha20Rng::from_seed(*seed);
Self::generate_key(&mut rng)
}
}
#[cfg(feature = "ecap")]
impl<P: SntrupParams> EncapsulationKey<P> {
pub fn encapsulate(&self, rng: &mut impl rand::CryptoRng) -> (Ciphertext<P>, SharedSecret<P>) {
let (ct, ss) = crate::kem::encaps(&self.bytes, P::params(), rng);
(Ciphertext::from_vec(ct), SharedSecret::from_vec(ss))
}
}
#[cfg(feature = "dcap")]
impl<P: SntrupParams> DecapsulationKey<P> {
pub fn decapsulate(&self, ct: &Ciphertext<P>) -> SharedSecret<P> {
let ss = crate::kem::decaps(&self.bytes, &ct.bytes, P::params());
SharedSecret::from_vec(ss)
}
}
#[cfg(feature = "serde")]
mod serde_impl {
use super::*;
impl<P: SntrupParams> serde::Serialize for EncapsulationKey<P> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serdect::slice::serialize_hex_lower_or_bin(&self.bytes, s)
}
}
impl<'de, P: SntrupParams> serde::Deserialize<'de> for EncapsulationKey<P> {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut buf = vec![0u8; P::PK_BYTES];
let _ = serdect::slice::deserialize_hex_or_bin(&mut buf, d)?;
Ok(Self {
bytes: buf,
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> serde::Serialize for DecapsulationKey<P> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serdect::slice::serialize_hex_lower_or_bin(&self.bytes, s)
}
}
impl<'de, P: SntrupParams> serde::Deserialize<'de> for DecapsulationKey<P> {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut buf = vec![0u8; P::SK_BYTES];
let _ = serdect::slice::deserialize_hex_or_bin(&mut buf, d)?;
Ok(Self {
bytes: buf,
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> serde::Serialize for Ciphertext<P> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serdect::slice::serialize_hex_lower_or_bin(&self.bytes, s)
}
}
impl<'de, P: SntrupParams> serde::Deserialize<'de> for Ciphertext<P> {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut buf = vec![0u8; P::CT_BYTES];
let _ = serdect::slice::deserialize_hex_or_bin(&mut buf, d)?;
Ok(Self {
bytes: buf,
_marker: PhantomData,
})
}
}
impl<P: SntrupParams> serde::Serialize for SharedSecret<P> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serdect::slice::serialize_hex_lower_or_bin(&self.bytes, s)
}
}
impl<'de, P: SntrupParams> serde::Deserialize<'de> for SharedSecret<P> {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let mut buf = vec![0u8; P::SS_BYTES];
let _ = serdect::slice::deserialize_hex_or_bin(&mut buf, d)?;
Ok(Self {
bytes: buf,
_marker: PhantomData,
})
}
}
}