#![allow(dead_code)]
use super::keygen::unpack_pk;
use super::params::{D, MlDsaParams, N, Q};
use super::poly::Poly;
use super::rounding::use_hint;
use super::sampling::{expand_a, sample_in_ball};
use super::sign::unpack_signature;
use arcanum_primitives::shake::Shake256;
pub fn verify_internal<P: MlDsaParams>(pk_bytes: &[u8], message: &[u8], sig_bytes: &[u8]) -> bool {
let (rho, mut t1) = match unpack_pk::<P>(pk_bytes) {
Some(pk) => pk,
None => return false,
};
let (c_tilde, mut z, h) = match unpack_signature::<P>(sig_bytes) {
Some(sig) => sig,
None => return false,
};
let hint_count: usize = h
.iter()
.map(|poly| poly.coeffs.iter().filter(|&&c| c != 0).count())
.sum();
if hint_count > P::OMEGA {
return false;
}
let gamma1_minus_beta = P::GAMMA1 - P::BETA;
for poly in &z {
let norm = poly.infinity_norm();
if norm >= gamma1_minus_beta {
return false;
}
}
let a = expand_a::<P>(&rho);
let mut shake = Shake256::new();
shake.update(pk_bytes);
let mut reader = shake.finalize_xof();
let mut tr = [0u8; 64];
reader.squeeze(&mut tr);
let mut shake = Shake256::new();
shake.update(&tr);
shake.update(message);
let mut reader = shake.finalize_xof();
let mut mu = [0u8; 64];
reader.squeeze(&mut mu);
let mut c = sample_in_ball(&c_tilde, P::TAU);
c.ntt();
for poly in &mut z {
poly.ntt();
}
for poly in &mut t1 {
poly.ntt();
}
let mut az_ntt = vec![Poly::zero(); P::K];
for i in 0..P::K {
for j in 0..P::L {
let product = a[i][j].pointwise_mul(&z[j]);
az_ntt[i] = az_ntt[i].add(&product);
}
}
let mut ct1_ntt = vec![Poly::zero(); P::K];
for i in 0..P::K {
ct1_ntt[i] = c.pointwise_mul(&t1[i]);
}
let mut w_prime = vec![Poly::zero(); P::K];
for i in 0..P::K {
let mut az_i = az_ntt[i];
az_i.inv_ntt();
az_i.reduce();
let mut ct1_i = ct1_ntt[i];
ct1_i.inv_ntt();
ct1_i.reduce_centered();
for j in 0..N {
let az_val = az_i.coeffs[j] as i64;
let ct1_scaled = (ct1_i.coeffs[j] as i64) * (1i64 << D);
let mut val = az_val - ct1_scaled;
val = ((val % (Q as i64)) + (Q as i64)) % (Q as i64);
w_prime[i].coeffs[j] = val as i32;
}
}
let mut w_prime_1 = vec![Poly::zero(); P::K];
for i in 0..P::K {
for j in 0..N {
let hint_bit = h[i].coeffs[j] != 0;
w_prime_1[i].coeffs[j] = use_hint(hint_bit, w_prime[i].coeffs[j], P::GAMMA2 as i32);
}
}
let c_tilde_prime = compute_challenge_hash::<P>(&mu, &w_prime_1);
c_tilde == c_tilde_prime
}
fn compute_challenge_hash<P: MlDsaParams>(mu: &[u8; 64], w1: &[Poly]) -> Vec<u8> {
let c_tilde_len = P::LAMBDA / 4;
let mut shake = Shake256::new();
shake.update(mu);
for poly in w1.iter().take(P::K) {
let packed = pack_w1_poly::<P>(poly);
shake.update(&packed);
}
let mut reader = shake.finalize_xof();
let mut c_tilde = vec![0u8; c_tilde_len];
reader.squeeze(&mut c_tilde);
c_tilde
}
fn pack_w1_poly<P: MlDsaParams>(poly: &Poly) -> Vec<u8> {
if P::GAMMA2 == (Q as u32 - 1) / 88 {
pack_w1_6bits(poly)
} else {
pack_w1_4bits(poly)
}
}
fn pack_w1_6bits(poly: &Poly) -> Vec<u8> {
let mut bytes = Vec::with_capacity(192);
for chunk in 0..(N / 4) {
let c0 = poly.coeffs[4 * chunk] as u32;
let c1 = poly.coeffs[4 * chunk + 1] as u32;
let c2 = poly.coeffs[4 * chunk + 2] as u32;
let c3 = poly.coeffs[4 * chunk + 3] as u32;
bytes.push((c0 | (c1 << 6)) as u8);
bytes.push(((c1 >> 2) | (c2 << 4)) as u8);
bytes.push(((c2 >> 4) | (c3 << 2)) as u8);
}
bytes
}
fn pack_w1_4bits(poly: &Poly) -> Vec<u8> {
let mut bytes = Vec::with_capacity(128);
for chunk in 0..(N / 2) {
let c0 = poly.coeffs[2 * chunk] as u8;
let c1 = poly.coeffs[2 * chunk + 1] as u8;
bytes.push(c0 | (c1 << 4));
}
bytes
}
#[cfg(test)]
mod tests {
use super::super::keygen::{generate_keypair_internal, pack_pk, pack_sk};
use super::super::params::{Params44, Params65, Params87};
use super::super::sign::sign_internal;
use super::*;
fn get_test_keypair<P: MlDsaParams>() -> (Vec<u8>, Vec<u8>) {
let seed = [0x42u8; 32];
let kp = generate_keypair_internal::<P>(&seed);
let pk = pack_pk::<P>(&kp.rho, &kp.t1);
let sk = pack_sk::<P>(&kp.rho, &kp.key, &kp.tr, &kp.s1, &kp.s2, &kp.t0);
(pk, sk)
}
#[test]
fn test_verify_44_valid_signature() {
let (pk, sk) = get_test_keypair::<Params44>();
let message = b"Test message for ML-DSA-44";
let sig = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
assert!(
verify_internal::<Params44>(&pk, message, &sig),
"Verification should succeed"
);
}
#[test]
fn test_verify_65_valid_signature() {
let (pk, sk) = get_test_keypair::<Params65>();
let message = b"Test message for ML-DSA-65";
let sig = sign_internal::<Params65>(&sk, message).expect("Signing should succeed");
assert!(
verify_internal::<Params65>(&pk, message, &sig),
"Verification should succeed"
);
}
#[test]
fn test_verify_87_valid_signature() {
let (pk, sk) = get_test_keypair::<Params87>();
let message = b"Test message for ML-DSA-87";
let sig = sign_internal::<Params87>(&sk, message).expect("Signing should succeed");
assert!(
verify_internal::<Params87>(&pk, message, &sig),
"Verification should succeed"
);
}
#[test]
fn test_verify_wrong_message_fails() {
let (pk, sk) = get_test_keypair::<Params44>();
let message1 = b"Original message";
let message2 = b"Different message";
let sig = sign_internal::<Params44>(&sk, message1).expect("Signing should succeed");
assert!(
!verify_internal::<Params44>(&pk, message2, &sig),
"Verification should fail for wrong message"
);
}
#[test]
fn test_verify_various_messages() {
let (pk, sk) = get_test_keypair::<Params44>();
let messages: &[&[u8]] = &[
b"", b"A", b"Test message", b"Test message for ML-DSA-44", b"The quick brown fox jumps over the lazy dog. 0123456789", ];
for (idx, message) in messages.iter().enumerate() {
let sig = sign_internal::<Params44>(&sk, message)
.expect(&format!("Signing message {} should succeed", idx));
assert!(
verify_internal::<Params44>(&pk, message, &sig),
"Verification should succeed for message {} ({} bytes)",
idx,
message.len()
);
}
}
#[test]
fn test_verify_wrong_key_fails() {
let (pk1, sk1) = get_test_keypair::<Params44>();
let seed2 = [0x43u8; 32];
let kp2 = generate_keypair_internal::<Params44>(&seed2);
let pk2 = pack_pk::<Params44>(&kp2.rho, &kp2.t1);
let message = b"Test message";
let sig = sign_internal::<Params44>(&sk1, message).expect("Signing should succeed");
assert!(
!verify_internal::<Params44>(&pk2, message, &sig),
"Verification should fail with wrong key"
);
assert!(
verify_internal::<Params44>(&pk1, message, &sig),
"Verification should succeed with correct key"
);
}
#[test]
fn test_verify_corrupted_signature_fails() {
let (pk, sk) = get_test_keypair::<Params44>();
let message = b"Test message";
let mut sig = sign_internal::<Params44>(&sk, message).expect("Signing should succeed");
sig[50] ^= 0xFF;
assert!(
!verify_internal::<Params44>(&pk, message, &sig),
"Verification should fail for corrupted signature"
);
}
#[test]
fn test_verify_invalid_pk_size() {
let message = b"Test";
let sig = vec![0u8; Params44::SIG_SIZE];
let pk = vec![0u8; 100];
assert!(!verify_internal::<Params44>(&pk, message, &sig));
}
#[test]
fn test_verify_invalid_sig_size() {
let (pk, _sk) = get_test_keypair::<Params44>();
let message = b"Test";
let sig = vec![0u8; 100];
assert!(!verify_internal::<Params44>(&pk, message, &sig));
}
#[test]
fn test_sign_verify_empty_message() {
let (pk, sk) = get_test_keypair::<Params65>();
let message = b"";
let sig = sign_internal::<Params65>(&sk, message).expect("Signing should succeed");
assert!(
verify_internal::<Params65>(&pk, message, &sig),
"Verification should succeed for empty message"
);
}
#[test]
fn test_sign_verify_long_message() {
let (pk, sk) = get_test_keypair::<Params65>();
let message = vec![0xABu8; 10000];
let sig = sign_internal::<Params65>(&sk, &message).expect("Signing should succeed");
assert!(
verify_internal::<Params65>(&pk, &message, &sig),
"Verification should succeed for long message"
);
}
}