#![allow(dead_code)]
use super::keygen::unpack_sk;
use super::ntt::reduce32;
use super::params::{MlDsaParams, N, Q};
use super::poly::Poly;
use super::rounding::{high_bits, make_hint, poly_decompose};
use super::sampling::{expand_a, expand_mask, sample_in_ball};
use arcanum_primitives::shake::Shake256;
const MAX_ATTEMPTS: usize = 1000;
#[derive(Clone)]
pub struct SignatureInternal {
pub c_tilde: Vec<u8>,
pub z: Vec<Poly>,
pub h: Vec<Poly>,
}
pub fn sign_internal<P: MlDsaParams>(sk_bytes: &[u8], message: &[u8]) -> Option<Vec<u8>> {
let (rho, key, tr, s1, s2, t0) = unpack_sk::<P>(sk_bytes)?;
let a = expand_a::<P>(&rho);
let mut s1_ntt: Vec<Poly> = s1.clone();
let mut s2_ntt: Vec<Poly> = s2.clone();
for poly in &mut s1_ntt {
poly.ntt();
}
for poly in &mut s2_ntt {
poly.ntt();
}
let mut t0_ntt: Vec<Poly> = t0.clone();
for poly in &mut t0_ntt {
poly.ntt();
}
let mut shake = Shake256::new();
shake.update(&tr);
shake.update(message);
let mut reader = shake.finalize_xof();
let mut mu = [0u8; 64];
reader.squeeze(&mut mu);
let mut kappa: u16 = 0;
for _ in 0..MAX_ATTEMPTS {
let mut mask_seed = Vec::with_capacity(96);
mask_seed.extend_from_slice(&key);
mask_seed.extend_from_slice(&mu);
let y = expand_mask::<P>(&mask_seed, kappa, P::GAMMA1);
let mut y_ntt = y.clone();
for poly in &mut y_ntt {
poly.ntt();
}
let mut w = vec![Poly::zero(); P::K];
for i in 0..P::K {
for j in 0..P::L {
let product = a[i][j].pointwise_mul(&y_ntt[j]);
w[i] = w[i].add(&product);
}
}
for poly in &mut w {
poly.inv_ntt();
poly.reduce();
}
let mut w1 = vec![Poly::zero(); P::K];
for i in 0..P::K {
for j in 0..N {
w1[i].coeffs[j] = high_bits(w[i].coeffs[j], P::GAMMA2 as i32);
}
}
let c_tilde_len = commitment_hash_len::<P>();
let c_tilde = compute_challenge_hash::<P>(&mu, &w1, c_tilde_len);
let mut c = sample_in_ball(&c_tilde, P::TAU);
c.ntt();
let mut z = vec![Poly::zero(); P::L];
for i in 0..P::L {
let cs1_i = c.pointwise_mul(&s1_ntt[i]);
let mut cs1_i_poly = cs1_i;
cs1_i_poly.inv_ntt();
cs1_i_poly.reduce_centered();
z[i] = y[i].add(&cs1_i_poly);
}
let mut w_minus_cs2 = vec![Poly::zero(); P::K];
for i in 0..P::K {
let cs2_i = c.pointwise_mul(&s2_ntt[i]);
let mut cs2_i_poly = cs2_i;
cs2_i_poly.inv_ntt();
cs2_i_poly.reduce_centered();
w_minus_cs2[i] = w[i].sub(&cs2_i_poly);
}
let mut r0 = vec![Poly::zero(); P::K];
for i in 0..P::K {
let (_, low) = poly_decompose(&w_minus_cs2[i], P::GAMMA2 as i32);
r0[i] = low;
}
let gamma1_minus_beta = P::GAMMA1 - P::BETA;
let gamma2_minus_beta = P::GAMMA2 - P::BETA;
let mut z_norm_ok = true;
for poly in &z {
if poly.infinity_norm() >= gamma1_minus_beta {
z_norm_ok = false;
break;
}
}
let mut r0_norm_ok = true;
for poly in &r0 {
if poly.infinity_norm() >= gamma2_minus_beta {
r0_norm_ok = false;
break;
}
}
if !z_norm_ok || !r0_norm_ok {
kappa = kappa.wrapping_add(P::L as u16);
continue;
}
let mut ct0 = vec![Poly::zero(); P::K];
for i in 0..P::K {
let ct0_i = c.pointwise_mul(&t0_ntt[i]);
let mut ct0_i_poly = ct0_i;
ct0_i_poly.inv_ntt();
ct0_i_poly.reduce_centered();
ct0[i] = ct0_i_poly;
}
let mut ct0_norm_ok = true;
for poly in &ct0 {
if poly.infinity_norm() >= P::GAMMA2 {
ct0_norm_ok = false;
break;
}
}
if !ct0_norm_ok {
kappa = kappa.wrapping_add(P::L as u16);
continue;
}
let mut h = vec![Poly::zero(); P::K];
let mut total_hints = 0usize;
for i in 0..P::K {
let w_cs2_ct0 = w_minus_cs2[i].add(&ct0[i]);
let mut neg_ct0 = Poly::zero();
for j in 0..N {
neg_ct0.coeffs[j] = -ct0[i].coeffs[j];
}
for j in 0..N {
if make_hint(neg_ct0.coeffs[j], w_cs2_ct0.coeffs[j], P::GAMMA2 as i32) {
h[i].coeffs[j] = 1;
total_hints += 1;
}
}
}
if total_hints > P::OMEGA {
kappa = kappa.wrapping_add(P::L as u16);
continue;
}
for poly in &mut z {
poly.reduce_centered();
}
let sig = pack_signature::<P>(&c_tilde, &z, &h);
return Some(sig);
}
None
}
fn compute_challenge_hash<P: MlDsaParams>(mu: &[u8; 64], w1: &[Poly], len: usize) -> Vec<u8> {
let mut shake = Shake256::new();
shake.update(mu);
for poly in w1.iter().take(P::K) {
let packed = pack_w1_poly::<P>(poly);
shake.update(&packed);
}
let mut reader = shake.finalize_xof();
let mut c_tilde = vec![0u8; len];
reader.squeeze(&mut c_tilde);
c_tilde
}
fn pack_w1_poly<P: MlDsaParams>(poly: &Poly) -> Vec<u8> {
if P::GAMMA2 == (Q as u32 - 1) / 88 {
pack_w1_6bits(poly)
} else {
pack_w1_4bits(poly)
}
}
fn pack_w1_6bits(poly: &Poly) -> Vec<u8> {
let mut bytes = Vec::with_capacity(192);
for chunk in 0..(N / 4) {
let c0 = poly.coeffs[4 * chunk] as u32;
let c1 = poly.coeffs[4 * chunk + 1] as u32;
let c2 = poly.coeffs[4 * chunk + 2] as u32;
let c3 = poly.coeffs[4 * chunk + 3] as u32;
bytes.push((c0 | (c1 << 6)) as u8);
bytes.push(((c1 >> 2) | (c2 << 4)) as u8);
bytes.push(((c2 >> 4) | (c3 << 2)) as u8);
}
bytes
}
fn pack_w1_4bits(poly: &Poly) -> Vec<u8> {
let mut bytes = Vec::with_capacity(128);
for chunk in 0..(N / 2) {
let c0 = poly.coeffs[2 * chunk] as u8;
let c1 = poly.coeffs[2 * chunk + 1] as u8;
bytes.push(c0 | (c1 << 4));
}
bytes
}
fn commitment_hash_len<P: MlDsaParams>() -> usize {
P::LAMBDA / 4
}
pub fn pack_signature<P: MlDsaParams>(c_tilde: &[u8], z: &[Poly], h: &[Poly]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(P::SIG_SIZE);
bytes.extend_from_slice(c_tilde);
for poly in z.iter().take(P::L) {
pack_z_poly::<P>(&mut bytes, poly);
}
pack_hint::<P>(&mut bytes, h);
bytes
}
fn pack_z_poly<P: MlDsaParams>(bytes: &mut Vec<u8>, poly: &Poly) {
if P::GAMMA1 == (1 << 17) {
pack_z_18bits(bytes, poly);
} else {
pack_z_20bits(bytes, poly);
}
}
fn pack_z_18bits(bytes: &mut Vec<u8>, poly: &Poly) {
const GAMMA1: i32 = 1 << 17;
for chunk in 0..(N / 4) {
let c0 = (GAMMA1 - 1 - poly.coeffs[4 * chunk]) as u32 & 0x3FFFF;
let c1 = (GAMMA1 - 1 - poly.coeffs[4 * chunk + 1]) as u32 & 0x3FFFF;
let c2 = (GAMMA1 - 1 - poly.coeffs[4 * chunk + 2]) as u32 & 0x3FFFF;
let c3 = (GAMMA1 - 1 - poly.coeffs[4 * chunk + 3]) as u32 & 0x3FFFF;
bytes.push((c0 & 0xFF) as u8);
bytes.push(((c0 >> 8) & 0xFF) as u8);
bytes.push(((c0 >> 16) | (c1 << 2)) as u8);
bytes.push(((c1 >> 6) & 0xFF) as u8);
bytes.push(((c1 >> 14) | (c2 << 4)) as u8);
bytes.push(((c2 >> 4) & 0xFF) as u8);
bytes.push(((c2 >> 12) | (c3 << 6)) as u8);
bytes.push(((c3 >> 2) & 0xFF) as u8);
bytes.push(((c3 >> 10) & 0xFF) as u8);
}
}
fn pack_z_20bits(bytes: &mut Vec<u8>, poly: &Poly) {
const GAMMA1: i32 = 1 << 19;
for chunk in 0..(N / 2) {
let c0 = (GAMMA1 - 1 - poly.coeffs[2 * chunk]) as u32;
let c1 = (GAMMA1 - 1 - poly.coeffs[2 * chunk + 1]) as u32;
bytes.push((c0 & 0xFF) as u8);
bytes.push(((c0 >> 8) & 0xFF) as u8);
bytes.push(((c0 >> 16) | (c1 << 4)) as u8);
bytes.push(((c1 >> 4) & 0xFF) as u8);
bytes.push(((c1 >> 12) & 0xFF) as u8);
}
}
fn pack_hint<P: MlDsaParams>(bytes: &mut Vec<u8>, h: &[Poly]) {
let mut hint_bytes = vec![0u8; P::OMEGA + P::K];
let mut idx = 0;
for i in 0..P::K {
for j in 0..N {
if h[i].coeffs[j] != 0 {
hint_bytes[idx] = j as u8;
idx += 1;
}
}
hint_bytes[P::OMEGA + i] = idx as u8;
}
bytes.extend_from_slice(&hint_bytes);
}
pub fn unpack_signature<P: MlDsaParams>(bytes: &[u8]) -> Option<(Vec<u8>, Vec<Poly>, Vec<Poly>)> {
if bytes.len() != P::SIG_SIZE {
return None;
}
let mut offset = 0;
let c_tilde_len = commitment_hash_len::<P>();
let c_tilde = bytes[offset..offset + c_tilde_len].to_vec();
offset += c_tilde_len;
let z_poly_size = if P::GAMMA1 == (1 << 17) {
576 } else {
640 };
let mut z = Vec::with_capacity(P::L);
for _ in 0..P::L {
let mut poly = Poly::zero();
unpack_z_poly::<P>(&bytes[offset..offset + z_poly_size], &mut poly);
z.push(poly);
offset += z_poly_size;
}
let h = unpack_hint::<P>(&bytes[offset..])?;
Some((c_tilde, z, h))
}
fn unpack_z_poly<P: MlDsaParams>(bytes: &[u8], poly: &mut Poly) {
if P::GAMMA1 == (1 << 17) {
unpack_z_18bits(bytes, poly);
} else {
unpack_z_20bits(bytes, poly);
}
}
fn unpack_z_18bits(bytes: &[u8], poly: &mut Poly) {
const GAMMA1: i32 = 1 << 17;
for chunk in 0..(N / 4) {
let b = &bytes[9 * chunk..9 * chunk + 9];
let c0 = (b[0] as u32) | ((b[1] as u32) << 8) | ((b[2] as u32 & 0x03) << 16);
let c1 = ((b[2] as u32) >> 2) | ((b[3] as u32) << 6) | ((b[4] as u32 & 0x0F) << 14);
let c2 = ((b[4] as u32) >> 4) | ((b[5] as u32) << 4) | ((b[6] as u32 & 0x3F) << 12);
let c3 = ((b[6] as u32) >> 6) | ((b[7] as u32) << 2) | ((b[8] as u32) << 10);
poly.coeffs[4 * chunk] = GAMMA1 - 1 - (c0 as i32);
poly.coeffs[4 * chunk + 1] = GAMMA1 - 1 - (c1 as i32);
poly.coeffs[4 * chunk + 2] = GAMMA1 - 1 - (c2 as i32);
poly.coeffs[4 * chunk + 3] = GAMMA1 - 1 - (c3 as i32);
}
}
fn unpack_z_20bits(bytes: &[u8], poly: &mut Poly) {
const GAMMA1: i32 = 1 << 19;
for chunk in 0..(N / 2) {
let b = &bytes[5 * chunk..5 * chunk + 5];
let c0 = (b[0] as u32) | ((b[1] as u32) << 8) | ((b[2] as u32 & 0x0F) << 16);
let c1 = ((b[2] as u32) >> 4) | ((b[3] as u32) << 4) | ((b[4] as u32) << 12);
poly.coeffs[2 * chunk] = GAMMA1 - 1 - (c0 as i32);
poly.coeffs[2 * chunk + 1] = GAMMA1 - 1 - (c1 as i32);
}
}
fn unpack_hint<P: MlDsaParams>(bytes: &[u8]) -> Option<Vec<Poly>> {
if bytes.len() < P::OMEGA + P::K {
return None;
}
let mut h = vec![Poly::zero(); P::K];
let mut k = 0usize;
for i in 0..P::K {
let limit = bytes[P::OMEGA + i] as usize;
if limit < k || limit > P::OMEGA {
return None;
}
while k < limit {
let j = bytes[k] as usize;
if j >= N {
return None;
}
if k > 0 && i == 0 {
} else if k > 0 {
let prev_limit = bytes[P::OMEGA + i - 1] as usize;
if k > prev_limit && bytes[k] <= bytes[k - 1] && k - 1 >= prev_limit {
}
}
h[i].coeffs[j] = 1;
k += 1;
}
}
while k < P::OMEGA {
if bytes[k] != 0 {
return None;
}
k += 1;
}
Some(h)
}
#[cfg(test)]
mod tests {
use super::super::keygen::{generate_keypair_internal, pack_sk};
use super::super::params::{Params44, Params65, Params87};
use super::*;
fn get_test_sk<P: MlDsaParams>() -> Vec<u8> {
let seed = [0x42u8; 32];
let kp = generate_keypair_internal::<P>(&seed);
pack_sk::<P>(&kp.rho, &kp.key, &kp.tr, &kp.s1, &kp.s2, &kp.t0)
}
#[test]
fn test_sign_44_produces_valid_size() {
let sk = get_test_sk::<Params44>();
let message = b"Test message";
let sig = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
assert_eq!(sig.len(), Params44::SIG_SIZE);
}
#[test]
fn test_sign_65_produces_valid_size() {
let sk = get_test_sk::<Params65>();
let message = b"Test message";
let sig = sign_internal::<Params65>(&sk, message).expect("Signing should succeed");
assert_eq!(sig.len(), Params65::SIG_SIZE);
}
#[test]
fn test_sign_87_produces_valid_size() {
let sk = get_test_sk::<Params87>();
let message = b"Test message";
let sig = sign_internal::<Params87>(&sk, message).expect("Signing should succeed");
assert_eq!(sig.len(), Params87::SIG_SIZE);
}
#[test]
fn test_sign_deterministic_with_same_key() {
let sk = get_test_sk::<Params65>();
let message = b"Hello, ML-DSA!";
let sig1 = sign_internal::<Params65>(&sk, message).expect("First sign should succeed");
let sig2 = sign_internal::<Params65>(&sk, message).expect("Second sign should succeed");
assert_eq!(sig1.len(), Params65::SIG_SIZE);
assert_eq!(sig2.len(), Params65::SIG_SIZE);
}
#[test]
fn test_sign_different_messages() {
let sk = get_test_sk::<Params44>();
let msg1 = b"Message 1";
let msg2 = b"Message 2";
let sig1 = sign_internal::<Params44>(&sk, msg1).expect("Sign msg1 should succeed");
let sig2 = sign_internal::<Params44>(&sk, msg2).expect("Sign msg2 should succeed");
assert_ne!(sig1, sig2);
}
#[test]
fn test_pack_unpack_signature_44() {
let sk = get_test_sk::<Params44>();
let message = b"Test";
let sig_bytes = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
let (c_tilde, z, h) =
unpack_signature::<Params44>(&sig_bytes).expect("Unpack should succeed");
assert_eq!(c_tilde.len(), 32);
assert_eq!(z.len(), Params44::L);
assert_eq!(h.len(), Params44::K);
}
#[test]
fn test_pack_unpack_signature_65() {
let sk = get_test_sk::<Params65>();
let message = b"Test";
let sig_bytes = sign_internal::<Params65>(&sk, message).expect("Signing should succeed");
let (c_tilde, z, h) =
unpack_signature::<Params65>(&sig_bytes).expect("Unpack should succeed");
assert_eq!(c_tilde.len(), 48);
assert_eq!(z.len(), Params65::L);
assert_eq!(h.len(), Params65::K);
}
#[test]
fn test_z_coefficients_in_range() {
let sk = get_test_sk::<Params44>();
let message = b"Test";
let sig_bytes = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
let (_, z, _) = unpack_signature::<Params44>(&sig_bytes).expect("Unpack should succeed");
for poly in &z {
for &c in &poly.coeffs {
assert!(c >= -(Params44::GAMMA1 as i32) && c < Params44::GAMMA1 as i32);
}
}
}
#[test]
fn test_hint_weight_within_bounds() {
let sk = get_test_sk::<Params44>();
let message = b"Test";
let sig_bytes = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
let (_, _, h) = unpack_signature::<Params44>(&sig_bytes).expect("Unpack should succeed");
let mut total = 0;
for poly in &h {
for &c in &poly.coeffs {
if c != 0 {
total += 1;
}
}
}
assert!(
total <= Params44::OMEGA,
"Hint weight {} > ω={}",
total,
Params44::OMEGA
);
}
#[test]
fn test_sign_invalid_sk_size() {
let short_sk = vec![0u8; 100];
let message = b"Test";
let result = sign_internal::<Params44>(&short_sk, message);
assert!(result.is_none());
}
#[test]
fn test_unpack_signature_invalid_size() {
let short_sig = vec![0u8; 100];
let result = unpack_signature::<Params44>(&short_sig);
assert!(result.is_none());
}
}