use sha3::digest::{ExtendableOutput, XofReader};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use crate::params::HqcParams;
use super::Poly;
fn xof_get_bytes(xof: &mut impl XofReader, out: &mut [u8]) {
xof.read(out);
let pad = (8 - out.len() % 8) % 8;
if pad != 0 {
let mut discard = [0u8; 8];
xof.read(&mut discard[..pad]);
}
}
#[inline(always)]
fn read_u64(xof: &mut impl XofReader) -> u64 {
let mut buf = [0u8; 8];
xof.read(&mut buf);
u64::from_le_bytes(buf)
}
pub fn sample_fixed_weight<P: HqcParams>(
xof: &mut impl XofReader,
weight: usize,
) -> Poly<P> {
debug_assert!(weight <= 256, "weight {weight} exceeds internal buffer size");
debug_assert!(weight <= P::N, "weight {weight} > N={}", P::N);
let n = P::N as u32;
let threshold = ((1u32 << 24) / n) * n;
let batch = 3 * weight;
let mut buf = [0u8; 3 * 256];
let mut j = batch;
let mut positions = [0u32; 256];
let mut filled: usize = 0;
while filled < weight {
let pos = loop {
if j == batch {
xof_get_bytes(xof, &mut buf[..batch]);
j = 0;
}
let cand = ((buf[j] as u32) << 16) | ((buf[j + 1] as u32) << 8) | (buf[j + 2] as u32);
j += 3;
if cand < threshold {
break cand % n; }
};
positions[filled] = pos;
let mut dup = false;
for k in 0..filled {
if positions[k] == pos {
dup = true;
}
}
if !dup {
filled += 1;
}
}
let mut poly = Poly::<P>::zero();
for j in 0..weight {
poly.set_bit(positions[j] as usize);
}
poly
}
pub fn sample_fixed_weight_mod<P: HqcParams>(
xof: &mut impl XofReader,
weight: usize,
) -> Poly<P> {
debug_assert!(weight <= 256, "weight {weight} exceeds internal buffer size");
debug_assert!(weight <= P::N, "weight {weight} > N={}", P::N);
let mut support = [0u32; 256];
let nbytes = 4 * weight;
let mut buf = [0u8; 4 * 256];
xof_get_bytes(xof, &mut buf[..nbytes]);
for i in 0..weight {
let rand =
u32::from_le_bytes([buf[4 * i], buf[4 * i + 1], buf[4 * i + 2], buf[4 * i + 3]]) as u64;
let n_minus_i = (P::N - i) as u64;
support[i] = (i as u64 + ((rand * n_minus_i) >> 32)) as u32;
}
for i in (0..weight).rev() {
let mut found = Choice::from(0u8);
for j in (i + 1)..weight {
found |= support[j].ct_eq(&support[i]);
}
support[i] = u32::conditional_select(&support[i], &(i as u32), found);
}
let mut word_of = [0u32; 256];
let mut bit_of = [0u64; 256];
for k in 0..weight {
word_of[k] = support[k] >> 6;
bit_of[k] = 1u64 << (support[k] & 63);
}
let mut poly = Poly::<P>::zero();
for word_idx in 0..P::N_WORDS {
let wi = word_idx as u32;
let mut acc = 0u64;
for k in 0..weight {
let eq = word_of[k].ct_eq(&wi);
acc |= u64::conditional_select(&0u64, &bit_of[k], eq);
}
poly.words[word_idx] |= acc;
}
poly
}
pub fn sample_uniform<P: HqcParams>(xof: &mut impl XofReader) -> Poly<P> {
let mut poly = Poly::<P>::zero();
for i in 0..P::N_WORDS {
poly.words[i] = read_u64(xof);
}
let last_bit = P::N & 63; if last_bit != 0 {
let mask = (1u64 << last_bit) - 1;
poly.words[P::N_WORDS - 1] &= mask;
}
poly
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::{Hqc128, Hqc192, Hqc256};
use sha3::{Shake256, digest::Update, digest::XofReader};
fn make_xof(seed: &[u8]) -> impl XofReader {
use sha3::digest::ExtendableOutput;
let mut h = Shake256::default();
h.update(seed);
h.finalize_xof()
}
#[test]
fn fixed_weight_correct_weight_128() {
let mut xof = make_xof(b"test-seed-0");
let p = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA);
assert_eq!(p.hamming_weight(), Hqc128::OMEGA);
}
#[test]
fn fixed_weight_correct_weight_192() {
let mut xof = make_xof(b"test-seed-1");
let p = sample_fixed_weight::<Hqc192>(&mut xof, Hqc192::OMEGA);
assert_eq!(p.hamming_weight(), Hqc192::OMEGA);
}
#[test]
fn fixed_weight_correct_weight_256() {
let mut xof = make_xof(b"test-seed-2");
let p = sample_fixed_weight::<Hqc256>(&mut xof, Hqc256::OMEGA);
assert_eq!(p.hamming_weight(), Hqc256::OMEGA);
}
#[test]
fn fixed_weight_all_bits_in_range() {
let mut xof = make_xof(b"test-seed-range");
let p = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA_R);
for i in 0..Hqc128::N {
let _ = p.get_bit(i); }
let last_bit = Hqc128::N & 63;
let mask = (1u64 << last_bit) - 1;
assert_eq!(p.words[Hqc128::N_WORDS - 1] & !mask, 0);
}
#[test]
fn fixed_weight_no_duplicates() {
for seed in 0u8..20 {
let mut xof = make_xof(&[seed]);
let p = sample_fixed_weight::<Hqc128>(&mut xof, 1);
assert_eq!(p.hamming_weight(), 1, "seed={seed}");
}
}
#[test]
fn fixed_weight_deterministic() {
let mut xof1 = make_xof(b"deterministic");
let mut xof2 = make_xof(b"deterministic");
let p1 = sample_fixed_weight::<Hqc128>(&mut xof1, Hqc128::OMEGA);
let p2 = sample_fixed_weight::<Hqc128>(&mut xof2, Hqc128::OMEGA);
assert_eq!(p1, p2);
}
#[test]
fn fixed_weight_different_seeds_differ() {
let mut xof1 = make_xof(b"seed-A");
let mut xof2 = make_xof(b"seed-B");
let p1 = sample_fixed_weight::<Hqc128>(&mut xof1, Hqc128::OMEGA);
let p2 = sample_fixed_weight::<Hqc128>(&mut xof2, Hqc128::OMEGA);
assert_ne!(p1, p2, "different seeds should almost certainly differ");
}
struct FixedXof {
bytes: Vec<u8>,
pos: usize,
}
impl FixedXof {
fn new(bytes: Vec<u8>) -> Self {
FixedXof { bytes, pos: 0 }
}
}
impl XofReader for FixedXof {
fn read(&mut self, buffer: &mut [u8]) {
for b in buffer.iter_mut() {
*b = self.bytes.get(self.pos).copied().unwrap_or(0);
self.pos += 1;
}
}
}
fn mod_correct_weight<P: HqcParams>(seed: &[u8], weight: usize) {
let mut xof = make_xof(seed);
let p = sample_fixed_weight_mod::<P>(&mut xof, weight);
assert_eq!(p.hamming_weight(), weight, "exact weight (implies distinct positions)");
let last_bit = P::N & 63;
if last_bit != 0 {
let mask = (1u64 << last_bit) - 1;
assert_eq!(p.words[P::N_WORDS - 1] & !mask, 0, "no overflow bits");
}
}
#[test]
fn mod_correct_weight_128() {
mod_correct_weight::<Hqc128>(b"mod-seed-0", Hqc128::OMEGA_R);
}
#[test]
fn mod_correct_weight_192() {
mod_correct_weight::<Hqc192>(b"mod-seed-1", Hqc192::OMEGA_R);
}
#[test]
fn mod_correct_weight_256() {
mod_correct_weight::<Hqc256>(b"mod-seed-2", Hqc256::OMEGA_R);
}
#[test]
fn mod_deterministic() {
let mut xof1 = make_xof(b"mod-det");
let mut xof2 = make_xof(b"mod-det");
let p1 = sample_fixed_weight_mod::<Hqc128>(&mut xof1, Hqc128::OMEGA_R);
let p2 = sample_fixed_weight_mod::<Hqc128>(&mut xof2, Hqc128::OMEGA_R);
assert_eq!(p1, p2);
}
#[test]
fn mod_differs_from_rejection() {
let mut xof_a = make_xof(b"same-seed");
let mut xof_b = make_xof(b"same-seed");
let a = sample_fixed_weight::<Hqc128>(&mut xof_a, Hqc128::OMEGA_R);
let b = sample_fixed_weight_mod::<Hqc128>(&mut xof_b, Hqc128::OMEGA_R);
assert_ne!(a, b, "rejection and mod samplers must produce different vectors");
}
#[test]
fn mod_all_zero_words_gives_identity_positions() {
let weight = 8;
let mut xof = FixedXof::new(vec![0u8; 4 * weight]);
let p = sample_fixed_weight_mod::<Hqc128>(&mut xof, weight);
assert_eq!(p.hamming_weight(), weight);
for i in 0..weight {
assert_eq!(p.get_bit(i), 1, "bit {i} must be set");
}
assert_eq!(p.get_bit(weight), 0, "bit {weight} must be clear");
}
#[test]
fn mod_all_ones_words_exercises_dedup() {
let weight = 8;
let mut xof = FixedXof::new(vec![0xFFu8; 4 * weight]);
let p = sample_fixed_weight_mod::<Hqc128>(&mut xof, weight);
assert_eq!(p.hamming_weight(), weight, "dedup must still yield distinct positions");
for i in 0..(weight - 1) {
assert_eq!(p.get_bit(i), 1, "bit {i} must be set");
}
assert_eq!(p.get_bit(weight - 1), 0, "bit weight−1 must be clear (it became N−1)");
assert_eq!(p.get_bit(Hqc128::N - 1), 1, "bit N−1 must be set");
}
#[test]
fn mod_reads_u32_little_endian() {
let mut xof = FixedXof::new(vec![0x00, 0x00, 0x00, 0x80]);
let p = sample_fixed_weight_mod::<Hqc128>(&mut xof, 1);
assert_eq!(p.hamming_weight(), 1);
assert_eq!(p.get_bit(8834), 1, "expected position floor(N/2) = 8834 (little-endian)");
}
#[test]
fn uniform_no_overflow_bits() {
let mut xof = make_xof(b"uniform-test");
let p = sample_uniform::<Hqc128>(&mut xof);
let last_bit = Hqc128::N & 63;
let mask = (1u64 << last_bit) - 1;
assert_eq!(p.words[Hqc128::N_WORDS - 1] & !mask, 0,
"bits above N-1 must be zero");
}
#[test]
fn uniform_deterministic() {
let mut xof1 = make_xof(b"uniform-det");
let mut xof2 = make_xof(b"uniform-det");
let p1 = sample_uniform::<Hqc128>(&mut xof1);
let p2 = sample_uniform::<Hqc128>(&mut xof2);
assert_eq!(p1, p2);
}
#[test]
fn uniform_256_no_overflow() {
let mut xof = make_xof(b"uniform-256");
let p = sample_uniform::<Hqc256>(&mut xof);
let last_bit = Hqc256::N & 63;
let mask = (1u64 << last_bit) - 1;
assert_eq!(p.words[Hqc256::N_WORDS - 1] & !mask, 0);
}
}