use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::codes;
use crate::hash;
use crate::params::{HqcParams, SEED_BYTES};
use crate::parsing;
use crate::poly::mul::{mul_dense_ct, mul_sparse_dense};
use crate::poly::sampling::{sample_fixed_weight, sample_fixed_weight_mod, sample_uniform};
use crate::poly::Poly;
pub struct EncryptionKey<P: HqcParams> {
pub seed_ek: [u8; SEED_BYTES],
pub s: Poly<P>,
}
impl<P: HqcParams> Clone for EncryptionKey<P> {
fn clone(&self) -> Self {
Self {
seed_ek: self.seed_ek,
s: self.s.clone(),
}
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct DecryptionKey {
pub(crate) seed_dk: [u8; SEED_BYTES],
}
impl<P: HqcParams> EncryptionKey<P> {
pub fn to_bytes(&self) -> Vec<u8> {
parsing::pack_public_key::<P>(&self.seed_ek, &self.s)
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
let (seed_ek, s) = parsing::unpack_public_key::<P>(bytes)?;
Some(Self { seed_ek, s })
}
}
pub fn keygen<P: HqcParams>(seed_pke: &[u8; SEED_BYTES]) -> (EncryptionKey<P>, DecryptionKey) {
let (seed_dk, seed_ek) = hash::i_pke_seed(seed_pke);
let mut xof_dk = hash::xof(&seed_dk[..]);
let y = sample_fixed_weight::<P>(&mut xof_dk, P::OMEGA);
let x = sample_fixed_weight::<P>(&mut xof_dk, P::OMEGA);
let mut xof_ek = hash::xof(&seed_ek[..]);
let h = sample_uniform::<P>(&mut xof_ek);
let hy = mul_sparse_dense::<P>(&y, &h);
let s = x.add(&hy);
(
EncryptionKey { seed_ek, s },
DecryptionKey { seed_dk: *seed_dk },
)
}
pub fn encrypt<P: HqcParams>(ek: &EncryptionKey<P>, m: &[u8], theta: &[u8]) -> (Poly<P>, Poly<P>) {
debug_assert_eq!(m.len(), P::K, "message must be exactly K bytes");
let mut xof_ek = hash::xof(&ek.seed_ek[..]);
let h = sample_uniform::<P>(&mut xof_ek);
let mut xof_th = hash::xof(theta);
let r2 = sample_fixed_weight_mod::<P>(&mut xof_th, P::OMEGA_R);
let e = sample_fixed_weight_mod::<P>(&mut xof_th, P::OMEGA_R);
let r1 = sample_fixed_weight_mod::<P>(&mut xof_th, P::OMEGA_R);
let hr2 = mul_sparse_dense::<P>(&r2, &h);
let u = r1.add(&hr2);
let cm = codes::encode::<P>(m);
let sr2 = mul_sparse_dense::<P>(&r2, &ek.s);
let mut v = cm.add(&sr2);
v.add_assign(&e);
truncate_to_bits::<P>(&mut v, P::N1 * P::N2);
(u, v)
}
pub fn decrypt<P: HqcParams>(dk: &DecryptionKey, u: &Poly<P>, v: &Poly<P>) -> Option<Vec<u8>> {
let mut xof_dk = hash::xof(&dk.seed_dk[..]);
let y = sample_fixed_weight::<P>(&mut xof_dk, P::OMEGA);
let uy = mul_dense_ct::<P>(u, &y);
let tmp = v.add(&uy);
codes::decode::<P>(&tmp)
}
fn truncate_to_bits<P: HqcParams>(p: &mut Poly<P>, nbits: usize) {
let full_words = nbits / 64;
let rem = nbits % 64;
if rem != 0 {
p.words[full_words] &= (1u64 << rem) - 1;
for w in &mut p.words[full_words + 1..P::N_WORDS] {
*w = 0;
}
} else {
for w in &mut p.words[full_words..P::N_WORDS] {
*w = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::{Hqc128, Hqc192, Hqc256};
fn test_msg<P: HqcParams>() -> Vec<u8> {
(0..P::K)
.map(|i| (i.wrapping_mul(7).wrapping_add(1)) as u8)
.collect()
}
fn pke_roundtrip<P: HqcParams>() {
let seed_pke = [0x42u8; SEED_BYTES];
let (ek, dk) = keygen::<P>(&seed_pke);
let m = test_msg::<P>();
let theta = [0x17u8; 32];
let (u, v) = encrypt::<P>(&ek, &m, &theta);
let recovered = decrypt::<P>(&dk, &u, &v).expect("decrypt returned None");
assert_eq!(recovered, m, "round-trip mismatch");
}
#[test]
fn pke_roundtrip_128() {
pke_roundtrip::<Hqc128>();
}
#[test]
fn pke_roundtrip_192() {
pke_roundtrip::<Hqc192>();
}
#[test]
fn pke_roundtrip_256() {
pke_roundtrip::<Hqc256>();
}
#[test]
fn pke_roundtrip_many_128() {
for t in 0u8..8 {
let seed_pke = [t.wrapping_mul(31).wrapping_add(3); SEED_BYTES];
let (ek, dk) = keygen::<Hqc128>(&seed_pke);
let m: Vec<u8> = (0..Hqc128::K)
.map(|i| (i as u8) ^ t.wrapping_mul(5))
.collect();
let theta = [t ^ 0xA5; 32];
let (u, v) = encrypt::<Hqc128>(&ek, &m, &theta);
let got = decrypt::<Hqc128>(&dk, &u, &v).expect("decrypt None");
assert_eq!(got, m, "mismatch at t={t}");
}
}
#[test]
fn keygen_is_deterministic() {
let seed = [0x5Au8; SEED_BYTES];
let (ek1, dk1) = keygen::<Hqc128>(&seed);
let (ek2, dk2) = keygen::<Hqc128>(&seed);
assert_eq!(ek1.seed_ek, ek2.seed_ek);
assert_eq!(ek1.s, ek2.s);
assert_eq!(dk1.seed_dk, dk2.seed_dk);
}
#[test]
fn encrypt_is_deterministic() {
let (ek, _dk) = keygen::<Hqc128>(&[1u8; SEED_BYTES]);
let m = test_msg::<Hqc128>();
let theta = [9u8; 32];
let (u1, v1) = encrypt::<Hqc128>(&ek, &m, &theta);
let (u2, v2) = encrypt::<Hqc128>(&ek, &m, &theta);
assert_eq!(u1, u2);
assert_eq!(v1, v2);
}
#[test]
fn different_theta_changes_ciphertext() {
let (ek, dk) = keygen::<Hqc128>(&[2u8; SEED_BYTES]);
let m = test_msg::<Hqc128>();
let (u1, v1) = encrypt::<Hqc128>(&ek, &m, &[1u8; 32]);
let (u2, v2) = encrypt::<Hqc128>(&ek, &m, &[2u8; 32]);
assert!(u1 != u2 || v1 != v2, "ciphertext should depend on θ");
assert_eq!(decrypt::<Hqc128>(&dk, &u1, &v1).unwrap(), m);
assert_eq!(decrypt::<Hqc128>(&dk, &u2, &v2).unwrap(), m);
}
#[test]
fn encrypt_v_tail_is_zero() {
let (ek, _dk) = keygen::<Hqc128>(&[3u8; SEED_BYTES]);
let (_, v) = encrypt::<Hqc128>(&ek, &test_msg::<Hqc128>(), &[4u8; 32]);
for i in (Hqc128::N1 * Hqc128::N2)..Hqc128::N {
assert_eq!(v.get_bit(i), 0, "v bit {i} above codeword must be zero");
}
}
#[test]
fn encryption_key_byte_roundtrip() {
let (ek, _dk) = keygen::<Hqc192>(&[7u8; SEED_BYTES]);
let bytes = ek.to_bytes();
assert_eq!(bytes.len(), Hqc192::PK_BYTES);
let back = EncryptionKey::<Hqc192>::from_bytes(&bytes).expect("valid length");
assert_eq!(back.seed_ek, ek.seed_ek);
assert_eq!(back.s, ek.s);
}
}