use crate::{
errors::{CrystalsError, EncryptionDecryptionError, KeyGenerationError, PackingError},
indcpa::{
generate_indcpa_key_pair, PrivateKey as IndcpaPrivateKey, PublicKey as IndcpaPublicKey,
},
params::{SecurityLevel, K, MAX_CIPHERTEXT, SHAREDSECRETBYTES, SYMBYTES},
};
use rand_chacha::ChaCha20Rng;
use rand_core::{CryptoRng, RngCore, SeedableRng};
use sha3::{
digest::{ExtendableOutput, Update, XofReader},
Digest, Sha3_256, Sha3_512, Shake256,
};
use subtle::{ConditionallySelectable, ConstantTimeEq};
use tinyvec::ArrayVec;
#[derive(Debug, Eq, PartialEq)]
pub struct PrivateKey {
#[cfg(not(feature = "decap_key"))]
key: PrivateSeed,
#[cfg(feature = "decap_key")]
key: PrivateKeyInner,
sec_level: SecurityLevel,
}
#[cfg(not(feature = "decap_key"))]
#[derive(Debug, Eq, PartialEq)]
struct PrivateSeed {
seed: [u8; 2 * SYMBYTES],
}
#[derive(Debug, Eq, PartialEq)]
struct PrivateKeyInner {
sk: IndcpaPrivateKey,
pk: IndcpaPublicKey,
h_pk: [u8; SYMBYTES],
z: [u8; SYMBYTES],
}
#[derive(Debug, Eq, PartialEq)]
pub struct PublicKey {
pk: IndcpaPublicKey,
h_pk: [u8; SYMBYTES],
}
pub struct Ciphertext {
bytes: [u8; MAX_CIPHERTEXT], len: usize,
}
impl Ciphertext {
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len]
}
}
fn sha3_256_from(input: &[u8]) -> [u8; SYMBYTES] {
let mut hash = Sha3_256::new();
Digest::update(&mut hash, input);
let output: [u8; SYMBYTES] = hash.finalize().into();
output
}
fn sha3_512_from(input: &[u8]) -> ([u8; SHAREDSECRETBYTES], [u8; SYMBYTES]) {
let mut hash = Sha3_512::new();
Digest::update(&mut hash, input);
let output = hash.finalize();
let mut o1 = [0u8; SHAREDSECRETBYTES];
let mut o2 = [0u8; SYMBYTES];
o1.copy_from_slice(&output[..SHAREDSECRETBYTES]);
o2.copy_from_slice(&output[SHAREDSECRETBYTES..]);
(o1, o2)
}
fn shake256_from(input: &[u8]) -> [u8; SHAREDSECRETBYTES] {
let mut hash = Shake256::default();
hash.update(input);
let mut output = [0u8; SHAREDSECRETBYTES];
hash.finalize_xof().read(&mut output);
output
}
fn new_key_from_seed(
seed: [u8; 2 * SYMBYTES],
sec_level: SecurityLevel,
) -> Result<(PublicKey, PrivateKeyInner), KeyGenerationError> {
let (sk, pk) = generate_indcpa_key_pair(&seed[..SYMBYTES], sec_level)?;
let z: [u8; SYMBYTES] = seed[SYMBYTES..].try_into()?;
let mut packed_pk = [0u8; MAX_CIPHERTEXT]; pk.pack(&mut packed_pk[..sec_level.indcpa_public_key_bytes()])?;
let h_pk: [u8; SYMBYTES] = sha3_256_from(&packed_pk[..sec_level.indcpa_public_key_bytes()]);
Ok((PublicKey { pk, h_pk }, PrivateKeyInner { sk, pk, h_pk, z }))
}
pub trait AcceptableRng: RngCore + CryptoRng {}
pub(crate) fn generate_key_pair(
rng: Option<&mut dyn AcceptableRng>,
k: K,
) -> Result<(PublicKey, PrivateKey), KeyGenerationError> {
let mut seed = [0u8; 2 * SYMBYTES];
if let Some(rng) = rng {
rng.try_fill_bytes(&mut seed)?;
} else {
let mut chacha = ChaCha20Rng::from_entropy();
chacha.try_fill_bytes(&mut seed)?;
};
let sec_level = SecurityLevel::new(k);
let (pk, _sk_inner) = new_key_from_seed(seed, sec_level)?;
Ok((
pk,
PrivateKey {
#[cfg(not(feature = "decap_key"))]
key: PrivateSeed { seed },
#[cfg(feature = "decap_key")]
#[allow(clippy::used_underscore_binding)]
key: _sk_inner,
sec_level,
},
))
}
pub fn generate_keypair_512(
rng: Option<&mut dyn AcceptableRng>,
) -> Result<(PublicKey, PrivateKey), KeyGenerationError> {
generate_key_pair(rng, K::Two)
}
pub fn generate_keypair_768(
rng: Option<&mut dyn AcceptableRng>,
) -> Result<(PublicKey, PrivateKey), KeyGenerationError> {
generate_key_pair(rng, K::Three)
}
pub fn generate_keypair_1024(
rng: Option<&mut dyn AcceptableRng>,
) -> Result<(PublicKey, PrivateKey), KeyGenerationError> {
generate_key_pair(rng, K::Four)
}
impl PrivateKey {
#[cfg(feature = "decap_key")]
pub(crate) const fn sec_level(&self) -> SecurityLevel {
self.key.sk.sec_level()
}
#[allow(clippy::missing_panics_doc, clippy::unwrap_used)]
#[must_use]
pub fn get_public_key(&self) -> PublicKey {
#[cfg(not(feature = "decap_key"))]
{
let (pk, _) = new_key_from_seed(self.key.seed, self.sec_level).unwrap();
pk
}
#[cfg(feature = "decap_key")]
{
PublicKey {
pk: self.key.pk,
h_pk: self.key.h_pk,
}
}
}
#[must_use]
#[cfg(not(feature = "decap_key"))]
pub const fn pack(&self) -> [u8; 2 * SYMBYTES] {
self.key.seed
}
#[cfg(feature = "decap_key")]
pub fn pack(&self, bytes: &mut [u8]) -> Result<(), PackingError> {
let sec_level = self.sec_level();
if bytes.len() != sec_level.private_key_bytes() {
return Err(CrystalsError::IncorrectBufferLength(
bytes.len(),
sec_level.private_key_bytes(),
)
.into());
}
let (sk_bytes, rest) = bytes.split_at_mut(sec_level.indcpa_private_key_bytes());
let (pk_bytes, rest) = rest.split_at_mut(sec_level.indcpa_public_key_bytes());
let (h_pk_bytes, z_bytes) = rest.split_at_mut(SYMBYTES);
self.key.sk.pack(sk_bytes)?;
self.key.pk.pack(pk_bytes)?;
h_pk_bytes.copy_from_slice(&self.key.h_pk);
z_bytes.copy_from_slice(&self.key.z);
Ok(())
}
#[must_use]
#[cfg(not(feature = "decap_key"))]
pub const fn unpack_512(bytes: [u8; 2 * SYMBYTES]) -> Self {
Self {
key: PrivateSeed { seed: bytes },
sec_level: SecurityLevel::new(K::Two),
}
}
#[must_use]
#[cfg(not(feature = "decap_key"))]
pub const fn unpack_768(bytes: [u8; 2 * SYMBYTES]) -> Self {
Self {
key: PrivateSeed { seed: bytes },
sec_level: SecurityLevel::new(K::Three),
}
}
#[must_use]
#[cfg(not(feature = "decap_key"))]
pub const fn unpack_1024(bytes: [u8; 2 * SYMBYTES]) -> Self {
Self {
key: PrivateSeed { seed: bytes },
sec_level: SecurityLevel::new(K::Four),
}
}
#[cfg(feature = "decap_key")]
pub fn unpack(bytes: &[u8]) -> Result<Self, PackingError> {
let sec_level = match bytes.len() {
1632 => SecurityLevel::new(K::Two),
2400 => SecurityLevel::new(K::Three),
3168 => SecurityLevel::new(K::Four),
_ => return Err(CrystalsError::IncorrectBufferLength(bytes.len(), 3168).into()),
};
let (sk_bytes, rest) = bytes.split_at(sec_level.indcpa_private_key_bytes());
let (pk_bytes, rest) = rest.split_at(sec_level.indcpa_public_key_bytes());
let (h_pk_bytes, z_bytes) = rest.split_at(SYMBYTES);
let sk = IndcpaPrivateKey::unpack(sk_bytes)?;
let pk = IndcpaPublicKey::unpack(pk_bytes)?;
let mut h_pk = [0u8; SYMBYTES];
h_pk.copy_from_slice(h_pk_bytes);
let mut z = [0u8; SYMBYTES];
z.copy_from_slice(z_bytes);
Ok(Self {
key: PrivateKeyInner { sk, pk, h_pk, z },
sec_level,
})
}
pub fn decapsulate(
&self,
ciphertext: &[u8],
) -> Result<[u8; SHAREDSECRETBYTES], EncryptionDecryptionError> {
let valid_bytes = [
SecurityLevel::new(K::Two).ciphertext_bytes(),
SecurityLevel::new(K::Three).ciphertext_bytes(),
SecurityLevel::new(K::Four).ciphertext_bytes(),
];
let sec_level = match ciphertext.len() {
len if len == valid_bytes[0] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Two))
}
len if len == valid_bytes[1] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Three))
}
len if len == valid_bytes[2] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Four))
}
_ => Err(CrystalsError::InvalidCiphertextLength(ciphertext.len())),
}?;
#[cfg(not(feature = "decap_key"))]
let (_, inner) = new_key_from_seed(self.key.seed, sec_level)?;
#[cfg(feature = "decap_key")]
let inner = &self.key;
let m = inner.sk.decrypt(ciphertext)?;
let (k, r) = sha3_512_from(&[m, inner.h_pk].concat());
let k_bar = shake256_from(&[&inner.z, ciphertext].concat());
let mut ct = [0u8; MAX_CIPHERTEXT]; inner
.pk
.encrypt(&m, &r, &mut ct[..sec_level.indcpa_bytes()])?;
let equal = ct.ct_eq(ciphertext);
Ok(k.iter()
.zip(k_bar.iter())
.map(|(x, y)| u8::conditional_select(x, y, equal))
.collect::<ArrayVec<[u8; SHAREDSECRETBYTES]>>()
.into_inner())
}
}
impl PublicKey {
pub(crate) const fn sec_level(&self) -> SecurityLevel {
self.pk.sec_level()
}
pub fn pack(&self, bytes: &mut [u8]) -> Result<(), PackingError> {
if bytes.len() != self.sec_level().public_key_bytes() {
return Err(CrystalsError::IncorrectBufferLength(
bytes.len(),
self.sec_level().public_key_bytes(),
)
.into());
}
self.pk.pack(bytes)?;
Ok(())
}
pub fn unpack(bytes: &[u8]) -> Result<Self, PackingError> {
let pk = IndcpaPublicKey::unpack(bytes)?;
let h_pk = sha3_256_from(bytes);
Ok(Self { pk, h_pk })
}
pub fn encapsulate(
&self,
seed: Option<&[u8]>,
rng: Option<&mut dyn AcceptableRng>,
) -> Result<(Ciphertext, [u8; SHAREDSECRETBYTES]), EncryptionDecryptionError> {
let sec_level = self.pk.sec_level();
let mut m = [0u8; SYMBYTES];
if let Some(seed) = seed {
if seed.len() != SYMBYTES {
return Err(CrystalsError::InvalidSeedLength(seed.len(), SYMBYTES).into());
}
m.copy_from_slice(seed);
} else if let Some(rng) = rng {
rng.try_fill_bytes(&mut m)?;
} else {
let mut chacha = ChaCha20Rng::from_entropy();
chacha.try_fill_bytes(&mut m)?;
}
let (k, r) = sha3_512_from(&[m, self.h_pk].concat());
let mut bytes = [0u8; MAX_CIPHERTEXT]; self.pk
.encrypt(&m, &r, &mut bytes[..sec_level.ciphertext_bytes()])?;
Ok((
Ciphertext {
bytes,
len: sec_level.ciphertext_bytes(),
},
k,
))
}
}