use crate::params::HqcParams;
use super::{Poly, MAX_N_WORDS};
const WIDE_WORDS: usize = 2 * MAX_N_WORDS + 2;
#[inline(never)] fn shift_xor_wide<P: HqcParams>(acc: &mut [u64; WIDE_WORDS], src: &[u64; MAX_N_WORDS], rot: usize) {
let nw = P::N_WORDS;
let word_shift = rot >> 6; let bit_shift = rot & 63;
if bit_shift == 0 {
for i in 0..nw {
acc[i + word_shift] ^= src[i];
}
} else {
let right_shift = 64 - bit_shift;
for i in 0..nw {
acc[i + word_shift] ^= src[i] << bit_shift;
acc[i + word_shift + 1] ^= src[i] >> right_shift;
}
}
}
fn reduce_wide<P: HqcParams>(acc: &[u64; WIDE_WORDS]) -> Poly<P> {
let nw = P::N_WORDS;
let n = P::N;
let word = n / 64; let off = n % 64;
let mut out = Poly::<P>::zero();
for w in 0..nw {
let folded = if off == 0 {
acc[word + w]
} else {
(acc[word + w] >> off) | (acc[word + w + 1] << (64 - off))
};
out.words[w] = acc[w] ^ folded;
}
if off != 0 {
out.words[nw - 1] &= (1u64 << off) - 1;
}
out
}
pub fn mul_sparse_dense<P: HqcParams>(sparse: &Poly<P>, dense: &Poly<P>) -> Poly<P> {
let mut acc = [0u64; WIDE_WORDS];
for word_idx in 0..P::N_WORDS {
let mut word = sparse.words[word_idx];
while word != 0 {
let lsb = word.trailing_zeros() as usize;
let pos = word_idx * 64 + lsb;
if pos < P::N {
shift_xor_wide::<P>(&mut acc, &dense.words, pos);
}
word &= word - 1; }
}
reduce_wide::<P>(&acc)
}
#[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))]
#[inline]
fn clmul64(a: u64, b: u64) -> (u64, u64) {
use core::arch::x86_64::{
__m128i, _mm_clmulepi64_si128, _mm_set_epi64x, _mm_storeu_si128,
};
unsafe {
let xa = _mm_set_epi64x(0, a as i64);
let xb = _mm_set_epi64x(0, b as i64);
let prod = _mm_clmulepi64_si128(xa, xb, 0x00);
let mut out = [0u64; 2];
_mm_storeu_si128(out.as_mut_ptr() as *mut __m128i, prod);
(out[0], out[1])
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "pclmulqdq")))]
#[inline]
fn clmul64(a: u64, b: u64) -> (u64, u64) {
let mut lo = 0u64;
let mut hi = 0u64;
let mut i = 0;
while i < 64 {
let mask = ((b >> i) & 1).wrapping_neg();
lo ^= (a << i) & mask;
if i != 0 {
hi ^= (a >> (64 - i)) & mask;
}
i += 1;
}
(lo, hi)
}
const KARATSUBA_THRESHOLD: usize = 16;
const KARATSUBA_SCRATCH_WORDS: usize = 4 * MAX_N_WORDS + 256;
fn karatsuba(out: &mut [u64], a: &[u64], b: &[u64], n: usize, scratch: &mut [u64]) {
if n <= KARATSUBA_THRESHOLD {
for slot in out[..2 * n].iter_mut() {
*slot = 0;
}
for i in 0..n {
let ai = a[i];
for j in 0..n {
let (lo, hi) = clmul64(ai, b[j]);
out[i + j] ^= lo;
out[i + j + 1] ^= hi;
}
}
return;
}
let h = (n + 1) / 2; let nhi = n - h;
let (a_lo, a_hi) = (&a[0..h], &a[h..n]);
let (b_lo, b_hi) = (&b[0..h], &b[h..n]);
let (sum_a, rest) = scratch.split_at_mut(h);
let (sum_b, rest) = rest.split_at_mut(h);
let (mid, sub) = rest.split_at_mut(2 * h);
sum_a.copy_from_slice(a_lo);
for i in 0..nhi {
sum_a[i] ^= a_hi[i];
}
sum_b.copy_from_slice(b_lo);
for i in 0..nhi {
sum_b[i] ^= b_hi[i];
}
karatsuba(&mut out[0..2 * h], a_lo, b_lo, h, sub);
karatsuba(&mut out[2 * h..2 * n], a_hi, b_hi, nhi, sub);
karatsuba(mid, sum_a, sum_b, h, sub);
for i in 0..2 * h {
mid[i] ^= out[i]; }
for i in 0..2 * nhi {
mid[i] ^= out[2 * h + i]; }
for i in 0..2 * h {
out[h + i] ^= mid[i];
}
}
pub fn mul_dense_ct<P: HqcParams>(a: &Poly<P>, b: &Poly<P>) -> Poly<P> {
let nw = P::N_WORDS;
debug_assert!(
{
let top = P::N & 63;
if top == 0 {
true
} else {
let junk = !((1u64 << top) - 1);
(a.words[nw - 1] & junk) == 0 && (b.words[nw - 1] & junk) == 0
}
},
"mul_dense_ct operands must have zero bits at/above N"
);
let mut acc = [0u64; WIDE_WORDS];
let mut scratch = [0u64; KARATSUBA_SCRATCH_WORDS];
karatsuba(&mut acc[..2 * nw], &a.words[..nw], &b.words[..nw], nw, &mut scratch);
reduce_wide::<P>(&acc)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::{Hqc128, Hqc192, Hqc256};
fn from_positions<P: HqcParams>(positions: &[usize]) -> Poly<P> {
let mut p = Poly::<P>::zero();
for &i in positions {
p.set_bit(i);
}
p
}
#[test]
fn multiply_by_zero_is_zero() {
let mut a = Poly::<Hqc128>::zero();
a.set_bit(0);
a.set_bit(100);
let zero = Poly::<Hqc128>::zero();
let r = mul_sparse_dense::<Hqc128>(&zero, &a);
assert_eq!(r.hamming_weight(), 0);
let r2 = mul_sparse_dense::<Hqc128>(&a, &zero);
assert_eq!(r2.hamming_weight(), 0);
}
#[test]
fn multiply_by_one_is_identity() {
let one = from_positions::<Hqc128>(&[0]);
let mut a = Poly::<Hqc128>::zero();
a.set_bit(5);
a.set_bit(1000);
a.set_bit(17000);
let r = mul_sparse_dense::<Hqc128>(&one, &a);
assert_eq!(r, a);
let r2 = mul_sparse_dense::<Hqc128>(&a, &one);
assert_eq!(r2, a);
}
#[test]
fn multiply_by_x_is_cyclic_shift() {
let x = from_positions::<Hqc128>(&[1]); let a = from_positions::<Hqc128>(&[0, 100, Hqc128::N - 1]);
let r = mul_sparse_dense::<Hqc128>(&x, &a);
assert_eq!(r.get_bit(1), 1);
assert_eq!(r.get_bit(101), 1);
assert_eq!(r.get_bit(0), 1, "X^(N-1) * X should wrap to bit 0");
assert_eq!(r.hamming_weight(), 3);
}
fn naive_mul<P: HqcParams>(a_pos: &[usize], b_pos: &[usize]) -> Poly<P> {
let mut acc = vec![0u8; P::N];
for &s in a_pos {
for &d in b_pos {
let k = (s + d) % P::N;
acc[k] ^= 1;
}
}
let mut p = Poly::<P>::zero();
for (i, &bit) in acc.iter().enumerate() {
if bit == 1 {
p.set_bit(i);
}
}
p
}
fn check_against_naive<P: HqcParams>(a_pos: &[usize], b_pos: &[usize]) {
let a = from_positions::<P>(a_pos);
let b = from_positions::<P>(b_pos);
let expected = naive_mul::<P>(a_pos, b_pos);
assert_eq!(mul_sparse_dense::<P>(&a, &b), expected, "Mode A vs naive");
assert_eq!(mul_dense_ct::<P>(&b, &a), expected, "Mode B vs naive");
}
#[test]
fn matches_naive_large_rotations_128() {
check_against_naive::<Hqc128>(
&[0, 1, 12_345, Hqc128::N - 1, Hqc128::N - 2],
&[0, 7, 17_000, Hqc128::N - 1],
);
}
#[test]
fn matches_naive_large_rotations_192() {
check_against_naive::<Hqc192>(
&[0, 3, 30_000, Hqc192::N - 1],
&[5, 200, 35_000, Hqc192::N - 5],
);
}
#[test]
fn matches_naive_large_rotations_256() {
check_against_naive::<Hqc256>(
&[0, 9, 50_000, Hqc256::N - 1],
&[1, 64, 57_000, Hqc256::N - 3],
);
}
#[test]
fn matches_naive_single_high_bit_128() {
check_against_naive::<Hqc128>(&[Hqc128::N - 1], &[Hqc128::N - 1]);
}
#[test]
fn sparse_dense_commutativity_hqc128() {
use sha3::{Shake256, digest::{Update, ExtendableOutput}};
use crate::poly::sampling::sample_fixed_weight;
for seed in 0u8..10 {
let mut xof = {
let mut h = Shake256::default();
h.update(&[seed, 0]);
h.finalize_xof()
};
let a = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA);
let mut xof2 = {
let mut h = Shake256::default();
h.update(&[seed, 1]);
h.finalize_xof()
};
let b = sample_fixed_weight::<Hqc128>(&mut xof2, Hqc128::OMEGA_R);
let ab = mul_sparse_dense::<Hqc128>(&a, &b);
let ba = mul_sparse_dense::<Hqc128>(&b, &a);
assert_eq!(ab, ba, "commutativity failed for seed={seed}");
}
}
#[test]
fn sparse_dense_commutativity_hqc256() {
use sha3::{Shake256, digest::{Update, ExtendableOutput}};
use crate::poly::sampling::sample_fixed_weight;
let mut xof = {
let mut h = Shake256::default();
h.update(b"hqc256-seed");
h.finalize_xof()
};
let a = sample_fixed_weight::<Hqc256>(&mut xof, Hqc256::OMEGA);
let mut xof2 = {
let mut h = Shake256::default();
h.update(b"hqc256-seed2");
h.finalize_xof()
};
let b = sample_fixed_weight::<Hqc256>(&mut xof2, Hqc256::OMEGA_R);
let ab = mul_sparse_dense::<Hqc256>(&a, &b);
let ba = mul_sparse_dense::<Hqc256>(&b, &a);
assert_eq!(ab, ba);
}
#[test]
fn sparse_dense_matches_dense_ct() {
use sha3::{Shake256, digest::{Update, ExtendableOutput}};
use crate::poly::sampling::{sample_fixed_weight, sample_uniform};
for seed in 0u8..5 {
let mut xof = {
let mut h = Shake256::default();
h.update(&[seed]);
h.finalize_xof()
};
let sparse = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA);
let dense = sample_uniform::<Hqc128>(&mut xof);
let r_a = mul_sparse_dense::<Hqc128>(&sparse, &dense);
let r_b = mul_dense_ct::<Hqc128>(&dense, &sparse);
assert_eq!(r_a, r_b, "Mode A vs Mode B mismatch for seed={seed}");
}
}
fn mul_dense_ct_bitwise<P: HqcParams>(a: &Poly<P>, b: &Poly<P>) -> Poly<P> {
let mut acc = [0u64; WIDE_WORDS];
let nw = P::N_WORDS;
for pos in 0..P::N {
let word_idx = pos >> 6;
let bit_idx = pos & 63;
let bit = (b.words[word_idx] >> bit_idx) & 1;
let mask = bit.wrapping_neg();
let word_shift = pos >> 6;
let bit_shift = pos & 63;
if bit_shift == 0 {
for i in 0..nw {
acc[i + word_shift] ^= a.words[i] & mask;
}
} else {
let right_shift = 64 - bit_shift;
for i in 0..nw {
acc[i + word_shift] ^= (a.words[i] << bit_shift) & mask;
acc[i + word_shift + 1] ^= (a.words[i] >> right_shift) & mask;
}
}
}
reduce_wide::<P>(&acc)
}
fn karatsuba_matches_bitwise<P: HqcParams>(tag: &[u8]) {
use crate::poly::sampling::sample_uniform;
use sha3::{digest::{ExtendableOutput, Update}, Shake256};
for seed in 0u8..8 {
let mut xof = {
let mut h = Shake256::default();
h.update(tag);
h.update(&[seed]);
h.finalize_xof()
};
let a = sample_uniform::<P>(&mut xof);
let b = sample_uniform::<P>(&mut xof);
let fast = mul_dense_ct::<P>(&a, &b);
let reference = mul_dense_ct_bitwise::<P>(&a, &b);
assert_eq!(fast, reference, "Karatsuba vs bit-level mismatch, seed={seed}");
let fast_rev = mul_dense_ct::<P>(&b, &a);
assert_eq!(fast, fast_rev, "dense product not commutative, seed={seed}");
}
}
#[test]
fn karatsuba_matches_bitwise_128() {
karatsuba_matches_bitwise::<Hqc128>(b"k128");
}
#[test]
fn karatsuba_matches_bitwise_192() {
karatsuba_matches_bitwise::<Hqc192>(b"k192");
}
#[test]
fn karatsuba_matches_bitwise_256() {
karatsuba_matches_bitwise::<Hqc256>(b"k256");
}
#[test]
fn distributivity_add_mul() {
let a = from_positions::<Hqc128>(&[0, 1, 5]);
let b = from_positions::<Hqc128>(&[2, 7]);
let c = from_positions::<Hqc128>(&[3, 8, 100]);
let bc = b.add(&c);
let lhs = mul_sparse_dense::<Hqc128>(&a, &bc);
let ab = mul_sparse_dense::<Hqc128>(&a, &b);
let ac = mul_sparse_dense::<Hqc128>(&a, &c);
let rhs = ab.add(&ac);
assert_eq!(lhs, rhs);
}
#[test]
fn result_has_no_overflow_bits_after_mul() {
use sha3::{Shake256, digest::{Update, ExtendableOutput}};
use crate::poly::sampling::sample_fixed_weight;
let mut xof = {
let mut h = Shake256::default();
h.update(b"overflow-check");
h.finalize_xof()
};
let a = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA);
let b = sample_fixed_weight::<Hqc128>(&mut xof, Hqc128::OMEGA_R);
let r = mul_sparse_dense::<Hqc128>(&a, &b);
let last_bit = Hqc128::N & 63;
let mask = (1u64 << last_bit) - 1;
assert_eq!(r.words[Hqc128::N_WORDS - 1] & !mask, 0,
"overflow bits not cleared by reduce()");
}
#[test]
fn all_three_param_sets_mul() {
let a128 = from_positions::<Hqc128>(&[0, 1]);
let b128 = from_positions::<Hqc128>(&[0, 2]);
let _ = mul_sparse_dense::<Hqc128>(&a128, &b128);
let a192 = from_positions::<Hqc192>(&[0, 1]);
let b192 = from_positions::<Hqc192>(&[0, 2]);
let _ = mul_sparse_dense::<Hqc192>(&a192, &b192);
let a256 = from_positions::<Hqc256>(&[0, 1]);
let b256 = from_positions::<Hqc256>(&[0, 2]);
let _ = mul_sparse_dense::<Hqc256>(&a256, &b256);
}
}