#![allow(dead_code)]
use super::ntt::reduce32;
use super::params::{MlDsaParams, N, Q};
use super::poly::Poly;
use super::rounding::poly_power2round;
use super::sampling::{expand_a, expand_s};
use arcanum_primitives::shake::Shake256;
#[derive(Clone)]
pub struct KeyPairInternal {
pub rho: [u8; 32],
pub key: [u8; 32],
pub tr: [u8; 64],
pub s1: Vec<Poly>,
pub s2: Vec<Poly>,
pub t1: Vec<Poly>,
pub t0: Vec<Poly>,
}
pub fn generate_keypair_internal<P: MlDsaParams>(seed: &[u8; 32]) -> KeyPairInternal {
let mut inbuf = [0u8; 34];
inbuf[..32].copy_from_slice(seed);
inbuf[32] = P::K as u8;
inbuf[33] = P::L as u8;
let mut shake = Shake256::new();
shake.update(&inbuf);
let mut reader = shake.finalize_xof();
let mut rho = [0u8; 32];
let mut rho_prime = [0u8; 64];
let mut key = [0u8; 32];
reader.squeeze(&mut rho);
reader.squeeze(&mut rho_prime);
reader.squeeze(&mut key);
let a = expand_a::<P>(&rho);
let (s1, s2) = expand_s::<P>(&rho_prime);
let mut s1_ntt = s1.clone();
for poly in &mut s1_ntt {
poly.ntt();
}
let mut t = vec![Poly::zero(); P::K];
for i in 0..P::K {
for j in 0..P::L {
let product = a[i][j].pointwise_mul(&s1_ntt[j]);
t[i] = t[i].add(&product);
}
}
for poly in &mut t {
poly.inv_ntt();
poly.reduce();
}
for i in 0..P::K {
t[i] = t[i].add(&s2[i]);
t[i].reduce();
}
let mut t1 = Vec::with_capacity(P::K);
let mut t0 = Vec::with_capacity(P::K);
for poly in &t {
let (high, low) = poly_power2round(poly);
t1.push(high);
t0.push(low);
}
let pk_bytes = pack_pk::<P>(&rho, &t1);
let mut shake = Shake256::new();
shake.update(&pk_bytes);
let mut reader = shake.finalize_xof();
let mut tr = [0u8; 64];
reader.squeeze(&mut tr);
KeyPairInternal {
rho,
key,
tr,
s1,
s2,
t1,
t0,
}
}
pub fn pack_pk<P: MlDsaParams>(rho: &[u8; 32], t1: &[Poly]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(P::PK_SIZE);
bytes.extend_from_slice(rho);
for poly in t1.iter().take(P::K) {
pack_t1_poly(&mut bytes, poly);
}
bytes
}
fn pack_t1_poly(bytes: &mut Vec<u8>, poly: &Poly) {
for chunk in 0..(N / 4) {
let c0 = poly.coeffs[4 * chunk] as u32;
let c1 = poly.coeffs[4 * chunk + 1] as u32;
let c2 = poly.coeffs[4 * chunk + 2] as u32;
let c3 = poly.coeffs[4 * chunk + 3] as u32;
bytes.push((c0 & 0xFF) as u8);
bytes.push(((c0 >> 8) | (c1 << 2)) as u8);
bytes.push(((c1 >> 6) | (c2 << 4)) as u8);
bytes.push(((c2 >> 4) | (c3 << 6)) as u8);
bytes.push((c3 >> 2) as u8);
}
}
pub fn unpack_pk<P: MlDsaParams>(bytes: &[u8]) -> Option<([u8; 32], Vec<Poly>)> {
if bytes.len() != P::PK_SIZE {
return None;
}
let mut rho = [0u8; 32];
rho.copy_from_slice(&bytes[0..32]);
let mut t1 = Vec::with_capacity(P::K);
let mut offset = 32;
for _ in 0..P::K {
let mut poly = Poly::zero();
unpack_t1_poly(&bytes[offset..offset + 320], &mut poly);
t1.push(poly);
offset += 320;
}
Some((rho, t1))
}
fn unpack_t1_poly(bytes: &[u8], poly: &mut Poly) {
for chunk in 0..(N / 4) {
let b = &bytes[5 * chunk..5 * chunk + 5];
poly.coeffs[4 * chunk] = ((b[0] as i32) | ((b[1] as i32 & 0x03) << 8)) as i32;
poly.coeffs[4 * chunk + 1] = (((b[1] as i32) >> 2) | ((b[2] as i32 & 0x0F) << 6)) as i32;
poly.coeffs[4 * chunk + 2] = (((b[2] as i32) >> 4) | ((b[3] as i32 & 0x3F) << 4)) as i32;
poly.coeffs[4 * chunk + 3] = (((b[3] as i32) >> 6) | ((b[4] as i32) << 2)) as i32;
}
}
pub fn pack_sk<P: MlDsaParams>(
rho: &[u8; 32],
key: &[u8; 32],
tr: &[u8; 64],
s1: &[Poly],
s2: &[Poly],
t0: &[Poly],
) -> Vec<u8> {
let mut bytes = Vec::with_capacity(P::SK_SIZE);
bytes.extend_from_slice(rho);
bytes.extend_from_slice(key);
bytes.extend_from_slice(tr);
for poly in s1.iter().take(P::L) {
pack_eta_poly::<P>(&mut bytes, poly);
}
for poly in s2.iter().take(P::K) {
pack_eta_poly::<P>(&mut bytes, poly);
}
for poly in t0.iter().take(P::K) {
pack_t0_poly(&mut bytes, poly);
}
bytes
}
fn pack_eta_poly<P: MlDsaParams>(bytes: &mut Vec<u8>, poly: &Poly) {
if P::ETA == 2 {
for chunk in 0..(N / 8) {
let mut vals = [0u32; 8];
for i in 0..8 {
vals[i] = (2 - poly.coeffs[8 * chunk + i]) as u32;
}
let bits: u32 = vals[0]
| (vals[1] << 3)
| (vals[2] << 6)
| (vals[3] << 9)
| (vals[4] << 12)
| (vals[5] << 15)
| (vals[6] << 18)
| (vals[7] << 21);
bytes.push((bits & 0xFF) as u8);
bytes.push(((bits >> 8) & 0xFF) as u8);
bytes.push(((bits >> 16) & 0xFF) as u8);
}
} else if P::ETA == 4 {
for chunk in 0..(N / 2) {
let c0 = (4 - poly.coeffs[2 * chunk]) as u8;
let c1 = (4 - poly.coeffs[2 * chunk + 1]) as u8;
bytes.push(c0 | (c1 << 4));
}
}
}
fn pack_t0_poly(bytes: &mut Vec<u8>, poly: &Poly) {
const HALF_MINUS_1: i32 = (1 << 12) - 1;
for chunk in 0..(N / 8) {
let mut vals = [0u32; 8];
for i in 0..8 {
vals[i] = (HALF_MINUS_1 - poly.coeffs[8 * chunk + i]) as u32;
}
bytes.push((vals[0] & 0xFF) as u8);
bytes.push(((vals[0] >> 8) | (vals[1] << 5)) as u8);
bytes.push(((vals[1] >> 3) & 0xFF) as u8);
bytes.push(((vals[1] >> 11) | (vals[2] << 2)) as u8);
bytes.push(((vals[2] >> 6) | (vals[3] << 7)) as u8);
bytes.push(((vals[3] >> 1) & 0xFF) as u8);
bytes.push(((vals[3] >> 9) | (vals[4] << 4)) as u8);
bytes.push(((vals[4] >> 4) & 0xFF) as u8);
bytes.push(((vals[4] >> 12) | (vals[5] << 1)) as u8);
bytes.push(((vals[5] >> 7) | (vals[6] << 6)) as u8);
bytes.push(((vals[6] >> 2) & 0xFF) as u8);
bytes.push(((vals[6] >> 10) | (vals[7] << 3)) as u8);
bytes.push((vals[7] >> 5) as u8);
}
}
#[allow(clippy::type_complexity)]
pub fn unpack_sk<P: MlDsaParams>(
bytes: &[u8],
) -> Option<(
[u8; 32],
[u8; 32],
[u8; 64],
Vec<Poly>,
Vec<Poly>,
Vec<Poly>,
)> {
if bytes.len() != P::SK_SIZE {
return None;
}
let mut offset = 0;
let mut rho = [0u8; 32];
rho.copy_from_slice(&bytes[offset..offset + 32]);
offset += 32;
let mut key = [0u8; 32];
key.copy_from_slice(&bytes[offset..offset + 32]);
offset += 32;
let mut tr = [0u8; 64];
tr.copy_from_slice(&bytes[offset..offset + 64]);
offset += 64;
let eta_poly_size = if P::ETA == 2 { 96 } else { 128 };
let mut s1 = Vec::with_capacity(P::L);
for _ in 0..P::L {
let mut poly = Poly::zero();
unpack_eta_poly::<P>(&bytes[offset..offset + eta_poly_size], &mut poly);
s1.push(poly);
offset += eta_poly_size;
}
let mut s2 = Vec::with_capacity(P::K);
for _ in 0..P::K {
let mut poly = Poly::zero();
unpack_eta_poly::<P>(&bytes[offset..offset + eta_poly_size], &mut poly);
s2.push(poly);
offset += eta_poly_size;
}
let mut t0 = Vec::with_capacity(P::K);
for _ in 0..P::K {
let mut poly = Poly::zero();
unpack_t0_poly(&bytes[offset..offset + 416], &mut poly);
t0.push(poly);
offset += 416;
}
Some((rho, key, tr, s1, s2, t0))
}
fn unpack_eta_poly<P: MlDsaParams>(bytes: &[u8], poly: &mut Poly) {
if P::ETA == 2 {
for chunk in 0..(N / 8) {
let b = &bytes[3 * chunk..3 * chunk + 3];
let mut bits = (b[0] as u32) | ((b[1] as u32) << 8) | ((b[2] as u32) << 16);
for i in 0..8 {
let c = (bits & 0x07) as i32;
poly.coeffs[8 * chunk + i] = 2 - c;
bits >>= 3;
}
}
} else if P::ETA == 4 {
for chunk in 0..(N / 2) {
let b = bytes[chunk];
poly.coeffs[2 * chunk] = 4 - ((b & 0x0F) as i32);
poly.coeffs[2 * chunk + 1] = 4 - ((b >> 4) as i32);
}
}
}
fn unpack_t0_poly(bytes: &[u8], poly: &mut Poly) {
const HALF_MINUS_1: i32 = (1 << 12) - 1;
for chunk in 0..(N / 8) {
let b = &bytes[13 * chunk..13 * chunk + 13];
let v0 = (b[0] as u32) | ((b[1] as u32 & 0x1F) << 8);
let v1 = ((b[1] as u32) >> 5) | ((b[2] as u32) << 3) | ((b[3] as u32 & 0x03) << 11);
let v2 = ((b[3] as u32) >> 2) | ((b[4] as u32 & 0x7F) << 6);
let v3 = ((b[4] as u32) >> 7) | ((b[5] as u32) << 1) | ((b[6] as u32 & 0x0F) << 9);
let v4 = ((b[6] as u32) >> 4) | ((b[7] as u32) << 4) | ((b[8] as u32 & 0x01) << 12);
let v5 = ((b[8] as u32) >> 1) | ((b[9] as u32 & 0x3F) << 7);
let v6 = ((b[9] as u32) >> 6) | ((b[10] as u32) << 2) | ((b[11] as u32 & 0x07) << 10);
let v7 = ((b[11] as u32) >> 3) | ((b[12] as u32) << 5);
poly.coeffs[8 * chunk] = HALF_MINUS_1 - (v0 as i32);
poly.coeffs[8 * chunk + 1] = HALF_MINUS_1 - (v1 as i32);
poly.coeffs[8 * chunk + 2] = HALF_MINUS_1 - (v2 as i32);
poly.coeffs[8 * chunk + 3] = HALF_MINUS_1 - (v3 as i32);
poly.coeffs[8 * chunk + 4] = HALF_MINUS_1 - (v4 as i32);
poly.coeffs[8 * chunk + 5] = HALF_MINUS_1 - (v5 as i32);
poly.coeffs[8 * chunk + 6] = HALF_MINUS_1 - (v6 as i32);
poly.coeffs[8 * chunk + 7] = HALF_MINUS_1 - (v7 as i32);
}
}
#[cfg(test)]
mod tests {
use super::super::params::{Params44, Params65, Params87};
use super::*;
#[test]
fn test_keygen_44_produces_valid_sizes() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params44>(&seed);
assert_eq!(keypair.rho.len(), 32);
assert_eq!(keypair.key.len(), 32);
assert_eq!(keypair.tr.len(), 64);
assert_eq!(keypair.s1.len(), Params44::L);
assert_eq!(keypair.s2.len(), Params44::K);
assert_eq!(keypair.t1.len(), Params44::K);
assert_eq!(keypair.t0.len(), Params44::K);
}
#[test]
fn test_keygen_65_produces_valid_sizes() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params65>(&seed);
assert_eq!(keypair.s1.len(), Params65::L);
assert_eq!(keypair.s2.len(), Params65::K);
assert_eq!(keypair.t1.len(), Params65::K);
assert_eq!(keypair.t0.len(), Params65::K);
}
#[test]
fn test_keygen_87_produces_valid_sizes() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params87>(&seed);
assert_eq!(keypair.s1.len(), Params87::L);
assert_eq!(keypair.s2.len(), Params87::K);
assert_eq!(keypair.t1.len(), Params87::K);
assert_eq!(keypair.t0.len(), Params87::K);
}
#[test]
fn test_keygen_deterministic() {
let seed = [0x42u8; 32];
let kp1 = generate_keypair_internal::<Params65>(&seed);
let kp2 = generate_keypair_internal::<Params65>(&seed);
assert_eq!(kp1.rho, kp2.rho);
assert_eq!(kp1.key, kp2.key);
assert_eq!(kp1.tr, kp2.tr);
for i in 0..Params65::K {
for j in 0..N {
assert_eq!(kp1.t1[i].coeffs[j], kp2.t1[i].coeffs[j]);
}
}
}
#[test]
fn test_keygen_different_seeds() {
let seed1 = [0x42u8; 32];
let seed2 = [0x43u8; 32];
let kp1 = generate_keypair_internal::<Params44>(&seed1);
let kp2 = generate_keypair_internal::<Params44>(&seed2);
assert_ne!(kp1.rho, kp2.rho);
}
#[test]
fn test_t1_coefficients_in_range() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params44>(&seed);
for poly in &keypair.t1 {
for &c in &poly.coeffs {
assert!(c >= 0, "t1 coefficient < 0: {}", c);
assert!(c < 1024, "t1 coefficient >= 2^10: {}", c);
}
}
}
#[test]
fn test_t0_coefficients_in_range() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params44>(&seed);
let bound: i32 = 1 << 12; for poly in &keypair.t0 {
for &c in &poly.coeffs {
assert!(c >= -bound, "t0 coefficient < -2^12: {}", c);
assert!(c < bound, "t0 coefficient >= 2^12: {}", c);
}
}
}
#[test]
fn test_pack_unpack_pk_44() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params44>(&seed);
let pk_bytes = pack_pk::<Params44>(&keypair.rho, &keypair.t1);
assert_eq!(pk_bytes.len(), Params44::PK_SIZE);
let (rho, t1) = unpack_pk::<Params44>(&pk_bytes).unwrap();
assert_eq!(rho, keypair.rho);
for i in 0..Params44::K {
for j in 0..N {
assert_eq!(
t1[i].coeffs[j], keypair.t1[i].coeffs[j],
"t1[{}][{}] mismatch",
i, j
);
}
}
}
#[test]
fn test_pack_unpack_pk_65() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params65>(&seed);
let pk_bytes = pack_pk::<Params65>(&keypair.rho, &keypair.t1);
assert_eq!(pk_bytes.len(), Params65::PK_SIZE);
let (rho, t1) = unpack_pk::<Params65>(&pk_bytes).unwrap();
assert_eq!(rho, keypair.rho);
for i in 0..Params65::K {
for j in 0..N {
assert_eq!(t1[i].coeffs[j], keypair.t1[i].coeffs[j]);
}
}
}
#[test]
fn test_pack_unpack_sk_44() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params44>(&seed);
let sk_bytes = pack_sk::<Params44>(
&keypair.rho,
&keypair.key,
&keypair.tr,
&keypair.s1, &keypair.s2,
&keypair.t0,
);
assert_eq!(sk_bytes.len(), Params44::SK_SIZE);
let (rho, key, tr, _s1, s2, t0) = unpack_sk::<Params44>(&sk_bytes).unwrap();
assert_eq!(rho, keypair.rho);
assert_eq!(key, keypair.key);
assert_eq!(tr, keypair.tr);
for i in 0..Params44::K {
for j in 0..N {
assert_eq!(s2[i].coeffs[j], keypair.s2[i].coeffs[j], "s2[{}][{}]", i, j);
}
}
for i in 0..Params44::K {
for j in 0..N {
assert_eq!(t0[i].coeffs[j], keypair.t0[i].coeffs[j], "t0[{}][{}]", i, j);
}
}
}
#[test]
fn test_pack_unpack_sk_65() {
let seed = [0x42u8; 32];
let keypair = generate_keypair_internal::<Params65>(&seed);
let sk_bytes = pack_sk::<Params65>(
&keypair.rho,
&keypair.key,
&keypair.tr,
&keypair.s1,
&keypair.s2,
&keypair.t0,
);
assert_eq!(sk_bytes.len(), Params65::SK_SIZE);
let (rho, key, tr, _, s2, t0) = unpack_sk::<Params65>(&sk_bytes).unwrap();
assert_eq!(rho, keypair.rho);
assert_eq!(key, keypair.key);
assert_eq!(tr, keypair.tr);
for i in 0..Params65::K {
for j in 0..N {
assert_eq!(s2[i].coeffs[j], keypair.s2[i].coeffs[j]);
assert_eq!(t0[i].coeffs[j], keypair.t0[i].coeffs[j]);
}
}
}
#[test]
fn test_unpack_pk_wrong_size() {
let bytes = vec![0u8; 100];
assert!(unpack_pk::<Params44>(&bytes).is_none());
assert!(unpack_pk::<Params65>(&bytes).is_none());
assert!(unpack_pk::<Params87>(&bytes).is_none());
}
#[test]
fn test_unpack_sk_wrong_size() {
let bytes = vec![0u8; 100];
assert!(unpack_sk::<Params44>(&bytes).is_none());
assert!(unpack_sk::<Params65>(&bytes).is_none());
assert!(unpack_sk::<Params87>(&bytes).is_none());
}
}