#![allow(dead_code)]
#![cfg_attr(all(feature = "simd", target_arch = "x86_64"), allow(unsafe_code))]
use super::params::{MlDsaParams, N, Params44, Params65, Params87, Q};
use super::poly::Poly;
use arcanum_primitives::shake::{Shake128, Shake256};
const REJECTION_BOUND: i32 = Q;
pub fn sample_poly_uniform(rho: &[u8; 32], i: u8, j: u8) -> Poly {
let mut poly = Poly::zero();
let mut shake = Shake128::new();
shake.update(rho);
shake.update(&[j, i]); let mut reader = shake.finalize_xof();
let mut idx = 0;
let mut buf = [0u8; 504]; let mut buf_pos = 504;
while idx < N {
if buf_pos >= 504 {
reader.squeeze(&mut buf);
buf_pos = 0;
}
let d1 = (buf[buf_pos] as i32) | ((buf[buf_pos + 1] as i32 & 0x0F) << 8);
let d2 = ((buf[buf_pos + 1] as i32) >> 4) | ((buf[buf_pos + 2] as i32) << 4);
buf_pos += 3;
if d1 < REJECTION_BOUND {
poly.coeffs[idx] = d1;
idx += 1;
}
if idx < N && d2 < REJECTION_BOUND {
poly.coeffs[idx] = d2;
idx += 1;
}
}
poly
}
pub fn expand_a<P: MlDsaParams>(rho: &[u8; 32]) -> Vec<Vec<Poly>> {
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { expand_a_x4::<P>(rho) };
}
}
expand_a_sequential::<P>(rho)
}
pub fn expand_a_sequential<P: MlDsaParams>(rho: &[u8; 32]) -> Vec<Vec<Poly>> {
let mut a = Vec::with_capacity(P::K);
for i in 0..P::K {
let mut row = Vec::with_capacity(P::L);
for j in 0..P::L {
let mut poly = sample_poly_uniform(rho, i as u8, j as u8);
poly.ntt(); row.push(poly);
}
a.push(row);
}
a
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn expand_a_x4<P: MlDsaParams>(rho: &[u8; 32]) -> Vec<Vec<Poly>> {
use arcanum_primitives::keccak_x4::Shake128X4;
let mut a: Vec<Vec<Poly>> = (0..P::K)
.map(|_| (0..P::L).map(|_| Poly::zero()).collect())
.collect();
let indices: Vec<(usize, usize)> = (0..P::K)
.flat_map(|i| (0..P::L).map(move |j| (i, j)))
.collect();
let mut idx = 0;
while idx < indices.len() {
let batch_size = (indices.len() - idx).min(4);
unsafe {
let mut shake_x4 = Shake128X4::new();
for b in 0..batch_size {
let (i, j) = indices[idx + b];
let mut seed = [0u8; 34];
seed[..32].copy_from_slice(rho);
seed[32] = j as u8;
seed[33] = i as u8;
shake_x4.absorb(b, &seed);
}
shake_x4.finalize();
let mut bufs = [[0u8; 840]; 4];
let mut buf_pos = [0usize; 4];
let mut poly_idx = [0usize; 4];
let mut polys = [Poly::zero(), Poly::zero(), Poly::zero(), Poly::zero()];
shake_x4.squeeze_blocks_x4(&mut bufs, 5, batch_size);
for b in 0..batch_size {
while poly_idx[b] < N && buf_pos[b] + 2 < 840 {
let d1 = (bufs[b][buf_pos[b]] as i32)
| ((bufs[b][buf_pos[b] + 1] as i32 & 0x0F) << 8);
let d2 = ((bufs[b][buf_pos[b] + 1] as i32) >> 4)
| ((bufs[b][buf_pos[b] + 2] as i32) << 4);
buf_pos[b] += 3;
if d1 < REJECTION_BOUND && poly_idx[b] < N {
polys[b].coeffs[poly_idx[b]] = d1;
poly_idx[b] += 1;
}
if d2 < REJECTION_BOUND && poly_idx[b] < N {
polys[b].coeffs[poly_idx[b]] = d2;
poly_idx[b] += 1;
}
}
while poly_idx[b] < N {
let mut extra = [0u8; 168];
shake_x4.squeeze_one_block(b, &mut extra);
let mut pos = 0;
while poly_idx[b] < N && pos + 2 < 168 {
let d1 = (extra[pos] as i32) | ((extra[pos + 1] as i32 & 0x0F) << 8);
let d2 = ((extra[pos + 1] as i32) >> 4) | ((extra[pos + 2] as i32) << 4);
pos += 3;
if d1 < REJECTION_BOUND && poly_idx[b] < N {
polys[b].coeffs[poly_idx[b]] = d1;
poly_idx[b] += 1;
}
if d2 < REJECTION_BOUND && poly_idx[b] < N {
polys[b].coeffs[poly_idx[b]] = d2;
poly_idx[b] += 1;
}
}
}
}
for b in 0..batch_size {
let (i, j) = indices[idx + b];
a[i][j] = polys[b].clone();
a[i][j].ntt();
}
}
idx += 4;
}
a
}
#[cfg(feature = "parallel")]
pub fn expand_a_parallel<P: MlDsaParams>(rho: &[u8; 32]) -> Vec<Vec<Poly>> {
use rayon::prelude::*;
let indices: Vec<(usize, usize)> = (0..P::K)
.flat_map(|i| (0..P::L).map(move |j| (i, j)))
.collect();
let flat_polys: Vec<(usize, usize, Poly)> = indices
.into_par_iter()
.map(|(i, j)| {
let mut poly = sample_poly_uniform(rho, i as u8, j as u8);
poly.ntt();
(i, j, poly)
})
.collect();
let mut a: Vec<Vec<Poly>> = (0..P::K).map(|_| Vec::with_capacity(P::L)).collect();
for row in &mut a {
for _ in 0..P::L {
row.push(Poly::zero());
}
}
for (i, j, poly) in flat_polys {
a[i][j] = poly;
}
a
}
pub fn sample_poly_eta(seed: &[u8; 64], nonce: u16, eta: usize) -> Poly {
let mut poly = Poly::zero();
let mut shake = Shake256::new();
shake.update(seed);
shake.update(&nonce.to_le_bytes());
let mut reader = shake.finalize_xof();
let mut idx = 0;
let mut buf = [0u8; 272];
let mut buf_pos = 272;
while idx < N {
if buf_pos >= 272 {
reader.squeeze(&mut buf);
buf_pos = 0;
}
let b = buf[buf_pos];
buf_pos += 1;
if eta == 2 {
let t0 = b & 0x0F;
let t1 = b >> 4;
if t0 < 15 {
let coeff = sample_eta2(t0);
poly.coeffs[idx] = coeff;
idx += 1;
}
if idx < N && t1 < 15 {
let coeff = sample_eta2(t1);
poly.coeffs[idx] = coeff;
idx += 1;
}
} else if eta == 4 {
let t0 = b & 0x0F;
let t1 = b >> 4;
if t0 < 9 {
poly.coeffs[idx] = 4 - t0 as i32;
idx += 1;
}
if idx < N && t1 < 9 {
poly.coeffs[idx] = 4 - t1 as i32;
idx += 1;
}
}
}
poly
}
#[inline]
fn sample_eta2(t: u8) -> i32 {
let t_mod_5 = (t % 5) as i32;
2 - t_mod_5
}
pub fn sample_poly_gamma1(seed: &[u8], nonce: u16, gamma1: u32) -> Poly {
let mut poly = Poly::zero();
let mut shake = Shake256::new();
shake.update(seed);
shake.update(&nonce.to_le_bytes());
let mut reader = shake.finalize_xof();
if gamma1 == (1 << 17) {
sample_gamma1_17(&mut reader, &mut poly);
} else if gamma1 == (1 << 19) {
sample_gamma1_19(&mut reader, &mut poly);
}
poly
}
fn sample_gamma1_17(reader: &mut arcanum_primitives::shake::Shake256Reader, poly: &mut Poly) {
const GAMMA1: i32 = 1 << 17;
let mut buf = [0u8; 576];
reader.squeeze(&mut buf);
for chunk in 0..(N / 4) {
let b = &buf[9 * chunk..9 * chunk + 9];
let r0 = (b[0] as i32) | ((b[1] as i32) << 8) | (((b[2] as i32) & 0x03) << 16);
let r1 = ((b[2] as i32) >> 2) | ((b[3] as i32) << 6) | (((b[4] as i32) & 0x0F) << 14);
let r2 = ((b[4] as i32) >> 4) | ((b[5] as i32) << 4) | (((b[6] as i32) & 0x3F) << 12);
let r3 = ((b[6] as i32) >> 6) | ((b[7] as i32) << 2) | ((b[8] as i32) << 10);
poly.coeffs[4 * chunk] = GAMMA1 - r0;
poly.coeffs[4 * chunk + 1] = GAMMA1 - r1;
poly.coeffs[4 * chunk + 2] = GAMMA1 - r2;
poly.coeffs[4 * chunk + 3] = GAMMA1 - r3;
}
}
fn sample_gamma1_19(reader: &mut arcanum_primitives::shake::Shake256Reader, poly: &mut Poly) {
const GAMMA1: i32 = 1 << 19;
let mut buf = [0u8; 640];
reader.squeeze(&mut buf);
for chunk in 0..(N / 2) {
let b = &buf[5 * chunk..5 * chunk + 5];
let r0 = (b[0] as i32) | ((b[1] as i32) << 8) | (((b[2] as i32) & 0x0F) << 16);
let r1 = ((b[2] as i32) >> 4) | ((b[3] as i32) << 4) | ((b[4] as i32) << 12);
poly.coeffs[2 * chunk] = GAMMA1 - r0;
poly.coeffs[2 * chunk + 1] = GAMMA1 - r1;
}
}
pub fn sample_in_ball(seed: &[u8], tau: usize) -> Poly {
let mut poly = Poly::zero();
let mut shake = Shake256::new();
shake.update(seed);
let mut reader = shake.finalize_xof();
let mut sign_bytes = [0u8; 8];
reader.squeeze(&mut sign_bytes);
let signs = u64::from_le_bytes(sign_bytes);
let mut buf = [0u8; 1];
for i in (N - tau)..N {
loop {
reader.squeeze(&mut buf);
let j = buf[0] as usize;
if j <= i {
poly.coeffs[i] = poly.coeffs[j];
let sign_bit = (signs >> (i - (N - tau))) & 1;
poly.coeffs[j] = 1 - 2 * (sign_bit as i32); break;
}
}
}
poly
}
pub fn expand_s<P: MlDsaParams>(seed: &[u8; 64]) -> (Vec<Poly>, Vec<Poly>) {
let mut s1 = Vec::with_capacity(P::L);
let mut s2 = Vec::with_capacity(P::K);
for i in 0..P::L {
s1.push(sample_poly_eta(seed, i as u16, P::ETA));
}
for i in 0..P::K {
s2.push(sample_poly_eta(seed, (P::L + i) as u16, P::ETA));
}
(s1, s2)
}
pub fn expand_mask<P: MlDsaParams>(seed: &[u8], nonce: u16, gamma1: u32) -> Vec<Poly> {
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
if P::L % 4 == 0 && is_x86_feature_detected!("avx2") {
return unsafe { expand_mask_x4::<P>(seed, nonce, gamma1) };
}
let mut y = Vec::with_capacity(P::L);
for i in 0..P::L {
y.push(sample_poly_gamma1(seed, nonce + i as u16, gamma1));
}
y
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn expand_mask_x4<P: MlDsaParams>(seed: &[u8], nonce: u16, gamma1: u32) -> Vec<Poly> {
use arcanum_primitives::keccak_x4::Shake256X4;
debug_assert!(
gamma1 == (1 << 17) || gamma1 == (1 << 19),
"expand_mask_x4: unsupported gamma1 value {}",
gamma1
);
let mut y = Vec::with_capacity(P::L);
let num_batches = P::L / 4;
for batch in 0..num_batches {
let base_nonce = nonce
.checked_add((batch * 4) as u16)
.expect("expand_mask_x4: nonce overflow");
unsafe {
let mut shake_x4 = Shake256X4::new();
for i in 0..4 {
let n = base_nonce + i as u16;
shake_x4.absorb(i, seed);
shake_x4.absorb(i, &n.to_le_bytes());
}
shake_x4.finalize();
if gamma1 == (1 << 17) {
let mut bufs = [[0u8; 680]; 4];
shake_x4.squeeze_blocks_x4(&mut bufs, 5, 4);
for i in 0..4 {
y.push(decode_gamma1_17(&bufs[i]));
}
} else if gamma1 == (1 << 19) {
let mut bufs = [[0u8; 680]; 4];
shake_x4.squeeze_blocks_x4(&mut bufs, 5, 4);
for i in 0..4 {
y.push(decode_gamma1_19(&bufs[i]));
}
} else {
for i in 0..4 {
let n = base_nonce + i as u16;
y.push(sample_poly_gamma1(seed, n, gamma1));
}
}
}
}
y
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
fn decode_gamma1_17(buf: &[u8]) -> Poly {
const GAMMA1: i32 = 1 << 17;
let mut poly = Poly::zero();
for chunk in 0..(N / 4) {
let b = &buf[9 * chunk..9 * chunk + 9];
let r0 = (b[0] as i32) | ((b[1] as i32) << 8) | (((b[2] as i32) & 0x03) << 16);
let r1 = ((b[2] as i32) >> 2) | ((b[3] as i32) << 6) | (((b[4] as i32) & 0x0F) << 14);
let r2 = ((b[4] as i32) >> 4) | ((b[5] as i32) << 4) | (((b[6] as i32) & 0x3F) << 12);
let r3 = ((b[6] as i32) >> 6) | ((b[7] as i32) << 2) | ((b[8] as i32) << 10);
poly.coeffs[4 * chunk] = GAMMA1 - r0;
poly.coeffs[4 * chunk + 1] = GAMMA1 - r1;
poly.coeffs[4 * chunk + 2] = GAMMA1 - r2;
poly.coeffs[4 * chunk + 3] = GAMMA1 - r3;
}
poly
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
fn decode_gamma1_19(buf: &[u8]) -> Poly {
const GAMMA1: i32 = 1 << 19;
let mut poly = Poly::zero();
for chunk in 0..(N / 2) {
let b = &buf[5 * chunk..5 * chunk + 5];
let r0 = (b[0] as i32) | ((b[1] as i32) << 8) | (((b[2] as i32) & 0x0F) << 16);
let r1 = ((b[2] as i32) >> 4) | ((b[3] as i32) << 4) | ((b[4] as i32) << 12);
poly.coeffs[2 * chunk] = GAMMA1 - r0;
poly.coeffs[2 * chunk + 1] = GAMMA1 - r1;
}
poly
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sample_poly_uniform_deterministic() {
let rho = [0u8; 32];
let p1 = sample_poly_uniform(&rho, 0, 0);
let p2 = sample_poly_uniform(&rho, 0, 0);
for i in 0..N {
assert_eq!(p1.coeffs[i], p2.coeffs[i], "Mismatch at index {}", i);
}
}
#[test]
fn test_sample_poly_uniform_in_range() {
let rho = [0x42u8; 32];
let poly = sample_poly_uniform(&rho, 1, 2);
for i in 0..N {
assert!(
poly.coeffs[i] >= 0,
"Coefficient {} is negative: {}",
i,
poly.coeffs[i]
);
assert!(
poly.coeffs[i] < Q,
"Coefficient {} >= q: {}",
i,
poly.coeffs[i]
);
}
}
#[test]
fn test_sample_poly_uniform_different_indices() {
let rho = [0u8; 32];
let p1 = sample_poly_uniform(&rho, 0, 0);
let p2 = sample_poly_uniform(&rho, 0, 1);
let p3 = sample_poly_uniform(&rho, 1, 0);
let mut same_01 = true;
let mut same_02 = true;
for i in 0..N {
if p1.coeffs[i] != p2.coeffs[i] {
same_01 = false;
}
if p1.coeffs[i] != p3.coeffs[i] {
same_02 = false;
}
}
assert!(!same_01, "p(0,0) should differ from p(0,1)");
assert!(!same_02, "p(0,0) should differ from p(1,0)");
}
#[test]
fn test_sample_eta2() {
assert_eq!(sample_eta2(0), 2); assert_eq!(sample_eta2(1), 1); assert_eq!(sample_eta2(2), 0); assert_eq!(sample_eta2(3), -1); assert_eq!(sample_eta2(4), -2); assert_eq!(sample_eta2(5), 2); assert_eq!(sample_eta2(6), 1); assert_eq!(sample_eta2(7), 0); assert_eq!(sample_eta2(8), -1); assert_eq!(sample_eta2(9), -2); assert_eq!(sample_eta2(10), 2); assert_eq!(sample_eta2(11), 1); assert_eq!(sample_eta2(12), 0); assert_eq!(sample_eta2(13), -1); assert_eq!(sample_eta2(14), -2); }
#[test]
fn test_sample_poly_eta2_in_range() {
let seed = [0x42u8; 64];
let poly = sample_poly_eta(&seed, 0, 2);
for i in 0..N {
assert!(
poly.coeffs[i] >= -2,
"Coefficient {} < -2: {}",
i,
poly.coeffs[i]
);
assert!(
poly.coeffs[i] <= 2,
"Coefficient {} > 2: {}",
i,
poly.coeffs[i]
);
}
}
#[test]
fn test_sample_poly_eta4_in_range() {
let seed = [0x42u8; 64];
let poly = sample_poly_eta(&seed, 0, 4);
for i in 0..N {
assert!(
poly.coeffs[i] >= -4,
"Coefficient {} < -4: {}",
i,
poly.coeffs[i]
);
assert!(
poly.coeffs[i] <= 4,
"Coefficient {} > 4: {}",
i,
poly.coeffs[i]
);
}
}
#[test]
fn test_sample_poly_gamma1_17_in_range() {
let seed = [0x42u8; 64];
let gamma1 = 1u32 << 17;
let poly = sample_poly_gamma1(&seed, 0, gamma1);
for i in 0..N {
let c = poly.coeffs[i];
assert!(c > -(gamma1 as i32), "Coefficient {} <= -gamma1: {}", i, c);
assert!(c <= gamma1 as i32, "Coefficient {} > gamma1: {}", i, c);
}
}
#[test]
fn test_sample_poly_gamma1_19_in_range() {
let seed = [0x42u8; 64];
let gamma1 = 1u32 << 19;
let poly = sample_poly_gamma1(&seed, 0, gamma1);
for i in 0..N {
let c = poly.coeffs[i];
assert!(c > -(gamma1 as i32), "Coefficient {} <= -gamma1: {}", i, c);
assert!(c <= gamma1 as i32, "Coefficient {} > gamma1: {}", i, c);
}
}
#[test]
fn test_sample_in_ball_tau_nonzero() {
let seed = [0x42u8; 32];
let tau = 39; let poly = sample_in_ball(&seed, tau);
let mut count = 0;
for i in 0..N {
if poly.coeffs[i] != 0 {
count += 1;
assert!(
poly.coeffs[i] == 1 || poly.coeffs[i] == -1,
"Non-zero coefficient {} is not +/-1: {}",
i,
poly.coeffs[i]
);
}
}
assert_eq!(
count, tau,
"Expected {} non-zero coefficients, got {}",
tau, count
);
}
#[test]
fn test_sample_in_ball_different_taus() {
let seed = [0x42u8; 32];
for tau in [39usize, 49, 60] {
let poly = sample_in_ball(&seed, tau);
let count: usize = poly.coeffs.iter().filter(|&&c| c != 0).count();
assert_eq!(
count, tau,
"τ={}: expected {} non-zero, got {}",
tau, tau, count
);
}
}
#[test]
fn test_expand_a_dimensions() {
let rho = [0u8; 32];
let a44: Vec<Vec<Poly>> = expand_a::<Params44>(&rho);
assert_eq!(a44.len(), 4, "ML-DSA-44: k should be 4");
assert_eq!(a44[0].len(), 4, "ML-DSA-44: l should be 4");
let a65: Vec<Vec<Poly>> = expand_a::<Params65>(&rho);
assert_eq!(a65.len(), 6, "ML-DSA-65: k should be 6");
assert_eq!(a65[0].len(), 5, "ML-DSA-65: l should be 5");
let a87: Vec<Vec<Poly>> = expand_a::<Params87>(&rho);
assert_eq!(a87.len(), 8, "ML-DSA-87: k should be 8");
assert_eq!(a87[0].len(), 7, "ML-DSA-87: l should be 7");
}
#[test]
fn test_expand_s_dimensions() {
let seed = [0u8; 64];
let (s1_44, s2_44) = expand_s::<Params44>(&seed);
assert_eq!(s1_44.len(), 4, "ML-DSA-44: s1 should have l=4 polynomials");
assert_eq!(s2_44.len(), 4, "ML-DSA-44: s2 should have k=4 polynomials");
let (s1_65, s2_65) = expand_s::<Params65>(&seed);
assert_eq!(s1_65.len(), 5, "ML-DSA-65: s1 should have l=5 polynomials");
assert_eq!(s2_65.len(), 6, "ML-DSA-65: s2 should have k=6 polynomials");
let (s1_87, s2_87) = expand_s::<Params87>(&seed);
assert_eq!(s1_87.len(), 7, "ML-DSA-87: s1 should have l=7 polynomials");
assert_eq!(s2_87.len(), 8, "ML-DSA-87: s2 should have k=8 polynomials");
}
}