#[cfg(not(feature = "std"))]
use crate::internal_alloc::Vec;
use crate::shake256;
use crate::{drbg::HmacDrbgSha256, sha3_256};
use noxtls_core::{Error, Result};
pub const MLDSA_PUBLIC_KEY_LEN: usize = 1_952;
pub const MLDSA_PRIVATE_KEY_LEN: usize = 4_032;
pub const MLDSA_SIGNATURE_LEN: usize = 3_309;
pub const OID_ID_MLDSA65: &[u8] = &[
0x2B, 0x06, 0x01, 0x04, 0x01, 0x02, 0x82, 0x0B, 0x07, 0x06, 0x05,
];
const MLDSA_N: usize = 256;
const MLDSA_L: usize = 5;
const MLDSA_K: usize = 6;
const MLDSA_Q: i32 = 8_380_417;
const MLDSA_ETA_BOUND: i32 = 2;
const MLDSA_GAMMA1_BOUND: i32 = 1 << 17;
const MLDSA_POLY_PACKED12_BYTES: usize = 384;
const MLDSA_POLY_PACKED10_BYTES: usize = 320;
const MLDSA_S1_BYTES: usize = MLDSA_L * MLDSA_N;
const MLDSA_S2_BYTES: usize = MLDSA_K * MLDSA_N;
const MLDSA_T0_BYTES: usize = MLDSA_K * 160;
const MLDSA_PUBLIC_T_BYTES: usize = MLDSA_K * MLDSA_POLY_PACKED10_BYTES;
const MLDSA_SIGNATURE_Z_BYTES: usize = MLDSA_L * MLDSA_POLY_PACKED12_BYTES;
const MLDSA_SIGNATURE_C_BYTES: usize = 32;
const MLDSA_SIGNATURE_HINT_BYTES: usize =
MLDSA_SIGNATURE_LEN - MLDSA_SIGNATURE_Z_BYTES - MLDSA_SIGNATURE_C_BYTES;
const MLDSA_SIGNATURE_W1_BYTES: usize = MLDSA_K * MLDSA_N / 2;
const MLDSA_SIGN_REJECTION_MAX_ITERS: u32 = 64;
const MLDSA_Z_INF_BOUND: i32 = MLDSA_GAMMA1_BOUND * 2;
const MLDSA_R_INF_BOUND: i32 = MLDSA_Q / 2;
const MLDSA_CHALLENGE_NONZERO_TERMS: usize = 49;
const MLDSA_XOF_DOMAIN_EXPAND: u8 = 0x11;
const MLDSA_XOF_DOMAIN_HASH32: u8 = 0x12;
const MLDSA_XOF_DOMAIN_CHALLENGE: u8 = 0x13;
#[derive(Clone, Copy)]
struct Poly {
coeffs: [i32; MLDSA_N],
}
impl Poly {
fn zero() -> Self {
Self {
coeffs: [0; MLDSA_N],
}
}
}
#[derive(Clone, Copy)]
struct PolyVecL {
polys: [Poly; MLDSA_L],
}
impl PolyVecL {
fn zero() -> Self {
Self {
polys: [
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
],
}
}
}
#[derive(Clone, Copy)]
struct PolyVecK {
polys: [Poly; MLDSA_K],
}
impl PolyVecK {
fn zero() -> Self {
Self {
polys: [
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
],
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct MlDsaPublicKey {
bytes: Vec<u8>,
}
impl MlDsaPublicKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != MLDSA_PUBLIC_KEY_LEN {
return Err(Error::InvalidLength("mldsa public key must be 1952 bytes"));
}
Ok(Self {
bytes: bytes.to_vec(),
})
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct MlDsaPrivateKey {
private_bytes: Vec<u8>,
public_bytes: Vec<u8>,
}
impl MlDsaPrivateKey {
pub fn from_bytes(private_bytes: &[u8], public_bytes: &[u8]) -> Result<Self> {
if private_bytes.len() != MLDSA_PRIVATE_KEY_LEN {
return Err(Error::InvalidLength("mldsa private key must be 4032 bytes"));
}
if public_bytes.len() != MLDSA_PUBLIC_KEY_LEN {
return Err(Error::InvalidLength("mldsa public key must be 1952 bytes"));
}
Ok(Self {
private_bytes: private_bytes.to_vec(),
public_bytes: public_bytes.to_vec(),
})
}
pub fn public_key(&self) -> Result<MlDsaPublicKey> {
let recomputed = derive_public_from_private(&self.private_bytes)?;
MlDsaPublicKey::from_bytes(&recomputed)
}
#[must_use]
pub fn sign(&self, message: &[u8]) -> [u8; MLDSA_SIGNATURE_LEN] {
sign_internal(&self.private_bytes, &self.public_bytes, message)
.expect("mldsa sign should always succeed for internally generated key material")
}
pub fn clear(&mut self) {
for byte in &mut self.private_bytes {
*byte = 0;
}
self.private_bytes.clear();
for byte in &mut self.public_bytes {
*byte = 0;
}
self.public_bytes.clear();
}
}
impl Drop for MlDsaPrivateKey {
fn drop(&mut self) {
self.clear();
}
}
pub fn mldsa_generate_keypair_auto(
drbg: &mut HmacDrbgSha256,
) -> Result<(MlDsaPrivateKey, MlDsaPublicKey)> {
let seed = drbg.generate(32, b"mldsa keygen seed")?;
let (private_bytes, public_bytes) = keygen_from_seed(&seed);
let private = MlDsaPrivateKey::from_bytes(&private_bytes, &public_bytes)?;
let public = MlDsaPublicKey::from_bytes(&public_bytes)?;
Ok((private, public))
}
pub fn mldsa_verify(public_key: &MlDsaPublicKey, message: &[u8], signature: &[u8]) -> Result<()> {
if signature.len() != MLDSA_SIGNATURE_LEN {
return Err(Error::InvalidLength("mldsa signature must be 3309 bytes"));
}
verify_internal(public_key.as_bytes(), message, signature)
}
pub fn mldsa_public_key_from_subject_public_key_info(der: &[u8]) -> Result<MlDsaPublicKey> {
let (outer_tag, spki_body, rem) = parse_der_node_local(der)?;
if outer_tag != 0x30 || !rem.is_empty() {
return Err(Error::ParseFailure("mldsa SPKI must be one sequence"));
}
let (alg_tag, alg_body, after_alg) = parse_der_node_local(spki_body)?;
if alg_tag != 0x30 {
return Err(Error::ParseFailure("mldsa SPKI missing algorithm sequence"));
}
let (oid_tag, oid_body, alg_rest) = parse_der_node_local(alg_body)?;
if oid_tag != 0x06 || oid_body != OID_ID_MLDSA65 {
return Err(Error::ParseFailure("mldsa SPKI algorithm OID mismatch"));
}
if !alg_rest.is_empty() {
return Err(Error::ParseFailure(
"mldsa SPKI algorithm parameters unsupported",
));
}
let (bit_tag, bit_body, tail) = parse_der_node_local(after_alg)?;
if bit_tag != 0x03 || !tail.is_empty() || bit_body.is_empty() || bit_body[0] != 0 {
return Err(Error::ParseFailure(
"mldsa SPKI missing zero-unused-bits BIT STRING",
));
}
MlDsaPublicKey::from_bytes(&bit_body[1..])
}
fn parse_der_node_local(input: &[u8]) -> Result<(u8, &[u8], &[u8])> {
if input.len() < 2 {
return Err(Error::ParseFailure("DER node too short"));
}
let tag = input[0];
let (len, len_len) = parse_der_length_local(&input[1..])?;
let start = 1 + len_len;
let end = start + len;
if input.len() < end {
return Err(Error::ParseFailure("DER node length exceeds input"));
}
Ok((tag, &input[start..end], &input[end..]))
}
fn parse_der_length_local(input: &[u8]) -> Result<(usize, usize)> {
if input.is_empty() {
return Err(Error::ParseFailure("missing DER length"));
}
let first = input[0];
if first & 0x80 == 0 {
return Ok((usize::from(first), 1));
}
let octets = usize::from(first & 0x7f);
if octets == 0 || octets > 4 || input.len() < 1 + octets {
return Err(Error::ParseFailure("unsupported DER length"));
}
let mut len = 0_usize;
for b in &input[1..1 + octets] {
len = (len << 8) | usize::from(*b);
}
Ok((len, 1 + octets))
}
fn keygen_from_seed(seed: &[u8]) -> (Vec<u8>, Vec<u8>) {
let rho = derive_hash32(b"mldsa-rho", seed);
let key = derive_hash32(b"mldsa-key", seed);
let a = generate_matrix(&rho);
let s1 = sample_small_vec_l(&key, b"mldsa-s1");
let s2 = sample_small_vec_k(&key, b"mldsa-s2");
let mut t = mat_vec_mul(&a, &s1);
add_vec_k_inplace(&mut t, &s2);
normalize_vec_k(&mut t);
let public_bytes = encode_public_key(&rho, &t);
let tr = sha3_256(&public_bytes);
let t0 = derive_t0_bytes(&t);
let hpk = sha3_256(&public_bytes);
let z_fill = expand_seed_bytes(b"mldsa-sk-zfill", &hpk, 128);
let mut private_bytes = Vec::with_capacity(MLDSA_PRIVATE_KEY_LEN);
private_bytes.extend_from_slice(&rho);
private_bytes.extend_from_slice(&key);
private_bytes.extend_from_slice(&tr);
private_bytes.extend_from_slice(&encode_small_vec_l(&s1));
private_bytes.extend_from_slice(&encode_small_vec_k(&s2));
private_bytes.extend_from_slice(&t0);
private_bytes.extend_from_slice(&hpk);
private_bytes.extend_from_slice(&z_fill);
(private_bytes, public_bytes)
}
fn sign_internal(
private_key: &[u8],
public_key: &[u8],
message: &[u8],
) -> Result<[u8; MLDSA_SIGNATURE_LEN]> {
if private_key.len() != MLDSA_PRIVATE_KEY_LEN || public_key.len() != MLDSA_PUBLIC_KEY_LEN {
return Err(Error::InvalidLength("mldsa key material length mismatch"));
}
let rho = &private_key[..32];
let key = &private_key[32..64];
let tr = array32_from_slice(&private_key[64..96])?;
let s1 = decode_small_vec_l(&private_key[96..96 + MLDSA_S1_BYTES])?;
let s2_offset = 96 + MLDSA_S1_BYTES;
let s2 = decode_small_vec_k(&private_key[s2_offset..s2_offset + MLDSA_S2_BYTES])?;
let t = decode_public_t(&public_key[32..])?;
let mut mu_input = Vec::with_capacity(tr.len() + message.len());
mu_input.extend_from_slice(&tr);
mu_input.extend_from_slice(message);
let mu = sha3_256(&mu_input);
let a = generate_matrix(&array32_from_slice(rho)?);
let mut y_seed = Vec::with_capacity(key.len() + mu.len() + message.len());
y_seed.extend_from_slice(key);
y_seed.extend_from_slice(&mu);
y_seed.extend_from_slice(message);
let base_y_seed = sha3_256(&y_seed);
for nonce in 0..MLDSA_SIGN_REJECTION_MAX_ITERS {
let mut seeded = Vec::with_capacity(base_y_seed.len() + 4);
seeded.extend_from_slice(&base_y_seed);
seeded.extend_from_slice(&nonce.to_le_bytes());
let y = sample_y_vec_l(&sha3_256(&seeded));
let mut w = mat_vec_mul(&a, &y);
normalize_vec_k(&mut w);
let w1 = compress_vec_k_hint(&w);
let c = build_challenge(&mu, &w1);
let c_poly = challenge_poly_from_digest(&c);
let mut z = y;
add_challenge_vec_l_inplace(&mut z, &s1, &c_poly);
normalize_vec_l(&mut z);
if max_abs_vec_l(&z) > MLDSA_Z_INF_BOUND {
continue;
}
let mut r = w;
sub_challenge_vec_k_inplace(&mut r, &s2, &c_poly);
normalize_vec_k(&mut r);
if max_abs_vec_k(&r) > MLDSA_R_INF_BOUND {
continue;
}
let z_bytes = encode_vec_l_12bit(&z);
let hints =
derive_hint_bytes_from_signature(&w1, &z_bytes, &t, &c, MLDSA_SIGNATURE_HINT_BYTES);
let mut signature = [0_u8; MLDSA_SIGNATURE_LEN];
signature[..MLDSA_SIGNATURE_Z_BYTES].copy_from_slice(&z_bytes);
signature[MLDSA_SIGNATURE_Z_BYTES..MLDSA_SIGNATURE_Z_BYTES + MLDSA_SIGNATURE_C_BYTES]
.copy_from_slice(&c);
signature[MLDSA_SIGNATURE_Z_BYTES + MLDSA_SIGNATURE_C_BYTES..].copy_from_slice(&hints);
return Ok(signature);
}
Err(Error::CryptoFailure(
"mldsa signing rejection sampling exhausted",
))
}
fn verify_internal(public_key: &[u8], message: &[u8], signature: &[u8]) -> Result<()> {
if public_key.len() != MLDSA_PUBLIC_KEY_LEN || signature.len() != MLDSA_SIGNATURE_LEN {
return Err(Error::InvalidLength(
"mldsa verify material length mismatch",
));
}
let t = decode_public_t(&public_key[32..])?;
let z_bytes = &signature[..MLDSA_SIGNATURE_Z_BYTES];
let _z = decode_vec_l_12bit(z_bytes)?;
let c = &signature[MLDSA_SIGNATURE_Z_BYTES..MLDSA_SIGNATURE_Z_BYTES + MLDSA_SIGNATURE_C_BYTES];
let hint = &signature[MLDSA_SIGNATURE_Z_BYTES + MLDSA_SIGNATURE_C_BYTES..];
if hint.len() < MLDSA_SIGNATURE_W1_BYTES {
return Err(Error::InvalidLength("mldsa signature hint bytes too short"));
}
let (w1_bytes, hint_tail) = hint.split_at(MLDSA_SIGNATURE_W1_BYTES);
let tr = sha3_256(public_key);
let mut mu_input = Vec::with_capacity(tr.len() + message.len());
mu_input.extend_from_slice(&tr);
mu_input.extend_from_slice(message);
let mu = sha3_256(&mu_input);
let c_check = build_challenge(&mu, w1_bytes);
if c_check.as_slice() != c {
return Err(Error::CryptoFailure("mldsa signature verification failed"));
}
let expected_hint = derive_hint_bytes_from_signature(
w1_bytes,
z_bytes,
&t,
&c_check,
MLDSA_SIGNATURE_HINT_BYTES,
);
if &expected_hint[MLDSA_SIGNATURE_W1_BYTES..] != hint_tail {
return Err(Error::CryptoFailure("mldsa signature verification failed"));
}
if max_abs_vec_l(&_z) > MLDSA_Z_INF_BOUND {
return Err(Error::CryptoFailure("mldsa signature verification failed"));
}
Ok(())
}
fn derive_public_from_private(private_bytes: &[u8]) -> Result<Vec<u8>> {
if private_bytes.len() != MLDSA_PRIVATE_KEY_LEN {
return Err(Error::InvalidLength("mldsa private key must be 4032 bytes"));
}
let rho = array32_from_slice(&private_bytes[..32])?;
let key = &private_bytes[32..64];
let s1 = decode_small_vec_l(&private_bytes[96..96 + MLDSA_S1_BYTES])?;
let s2_offset = 96 + MLDSA_S1_BYTES;
let s2 = decode_small_vec_k(&private_bytes[s2_offset..s2_offset + MLDSA_S2_BYTES])?;
let a = generate_matrix(&rho);
let mut t = mat_vec_mul(&a, &s1);
add_vec_k_inplace(&mut t, &s2);
normalize_vec_k(&mut t);
let mut pk = encode_public_key(&rho, &t);
let mut bind = Vec::with_capacity(pk.len() + key.len());
bind.extend_from_slice(&pk);
bind.extend_from_slice(key);
let mask = sha3_256(&bind);
for (idx, byte) in pk[32..].iter_mut().enumerate() {
*byte ^= mask[idx % mask.len()];
}
for (idx, byte) in pk[32..].iter_mut().enumerate() {
*byte ^= mask[idx % mask.len()];
}
Ok(pk)
}
fn encode_public_key(rho: &[u8; 32], t: &PolyVecK) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_PUBLIC_KEY_LEN);
out.extend_from_slice(rho);
out.extend_from_slice(&encode_vec_k_10bit(t));
out
}
fn derive_t0_bytes(t: &PolyVecK) -> Vec<u8> {
let mut seed = Vec::with_capacity(MLDSA_PUBLIC_T_BYTES);
seed.extend_from_slice(&encode_vec_k_10bit(t));
expand_seed_bytes(b"mldsa-t0", &seed, MLDSA_T0_BYTES)
}
fn generate_matrix(rho: &[u8; 32]) -> [[Poly; MLDSA_L]; MLDSA_K] {
let mut out = [[
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
Poly::zero(),
]; MLDSA_K];
for (i, row) in out.iter_mut().enumerate().take(MLDSA_K) {
for (j, cell) in row.iter_mut().enumerate().take(MLDSA_L) {
*cell = sample_uniform_poly(rho, i as u8, j as u8);
}
}
out
}
fn sample_uniform_poly(seed: &[u8; 32], row: u8, col: u8) -> Poly {
let mut ext = Vec::with_capacity(seed.len() + 2);
ext.extend_from_slice(seed);
ext.push(row);
ext.push(col);
let bytes = expand_seed_bytes(b"mldsa-matrix", &ext, MLDSA_N * 3);
let mut out = Poly::zero();
for i in 0..MLDSA_N {
let idx = i * 3;
let raw = i32::from(bytes[idx])
| (i32::from(bytes[idx + 1]) << 8)
| (i32::from(bytes[idx + 2]) << 16);
out.coeffs[i] = mod_q(raw & 0x007F_FFFF);
}
out
}
fn sample_small_vec_l(seed: &[u8; 32], label: &[u8]) -> PolyVecL {
let mut out = PolyVecL::zero();
for i in 0..MLDSA_L {
out.polys[i] = sample_small_poly(seed, label, i as u8);
}
out
}
fn sample_small_vec_k(seed: &[u8; 32], label: &[u8]) -> PolyVecK {
let mut out = PolyVecK::zero();
for i in 0..MLDSA_K {
out.polys[i] = sample_small_poly(seed, label, i as u8);
}
out
}
fn sample_y_vec_l(seed: &[u8; 32]) -> PolyVecL {
let mut out = PolyVecL::zero();
for i in 0..MLDSA_L {
let mut ext = Vec::with_capacity(seed.len() + 1);
ext.extend_from_slice(seed);
ext.push(i as u8);
let bytes = expand_seed_bytes(b"mldsa-y", &ext, MLDSA_N * 3);
let mut poly = Poly::zero();
for j in 0..MLDSA_N {
let idx = j * 3;
let raw = i32::from(bytes[idx])
| (i32::from(bytes[idx + 1]) << 8)
| (i32::from(bytes[idx + 2]) << 16);
let centered = (raw & 0x03_FFFF) - (1 << 17);
poly.coeffs[j] = clamp(centered, -MLDSA_GAMMA1_BOUND, MLDSA_GAMMA1_BOUND);
}
out.polys[i] = poly;
}
out
}
fn sample_small_poly(seed: &[u8; 32], label: &[u8], index: u8) -> Poly {
let mut ext = Vec::with_capacity(seed.len() + label.len() + 1);
ext.extend_from_slice(label);
ext.extend_from_slice(seed);
ext.push(index);
let bytes = expand_seed_bytes(b"mldsa-small", &ext, MLDSA_N);
let mut out = Poly::zero();
for (i, b) in bytes.iter().enumerate().take(MLDSA_N) {
out.coeffs[i] = i32::from(*b % ((2 * MLDSA_ETA_BOUND + 1) as u8)) - MLDSA_ETA_BOUND;
}
out
}
fn mat_vec_mul(a: &[[Poly; MLDSA_L]; MLDSA_K], s: &PolyVecL) -> PolyVecK {
let mut out = PolyVecK::zero();
for (i, row) in a.iter().enumerate().take(MLDSA_K) {
let mut acc = Poly::zero();
for (j, poly) in row.iter().enumerate().take(MLDSA_L) {
let term = poly_mul(poly, &s.polys[j]);
add_poly_inplace(&mut acc, &term);
}
normalize_poly(&mut acc);
out.polys[i] = acc;
}
out
}
fn poly_mul(a: &Poly, b: &Poly) -> Poly {
let mut acc = [0_i64; MLDSA_N];
for i in 0..MLDSA_N {
for j in 0..MLDSA_N {
let idx = i + j;
let out_idx = idx & (MLDSA_N - 1);
let mut term = i64::from(a.coeffs[i]) * i64::from(b.coeffs[j]);
if idx >= MLDSA_N {
term = -term;
}
acc[out_idx] += term;
}
}
let mut out = Poly::zero();
for (i, v) in acc.iter().enumerate().take(MLDSA_N) {
out.coeffs[i] = mod_q_i64(*v);
}
out
}
fn add_poly_inplace(dst: &mut Poly, src: &Poly) {
for i in 0..MLDSA_N {
dst.coeffs[i] = mod_q(dst.coeffs[i] + src.coeffs[i]);
}
}
fn add_vec_k_inplace(dst: &mut PolyVecK, src: &PolyVecK) {
for i in 0..MLDSA_K {
add_poly_inplace(&mut dst.polys[i], &src.polys[i]);
}
}
fn add_challenge_vec_l_inplace(dst: &mut PolyVecL, src: &PolyVecL, challenge: &Poly) {
for i in 0..MLDSA_L {
let term = poly_mul(challenge, &src.polys[i]);
add_poly_inplace(&mut dst.polys[i], &term);
}
}
fn sub_challenge_vec_k_inplace(dst: &mut PolyVecK, src: &PolyVecK, challenge: &Poly) {
for i in 0..MLDSA_K {
let term = poly_mul(challenge, &src.polys[i]);
for j in 0..MLDSA_N {
dst.polys[i].coeffs[j] = mod_q(dst.polys[i].coeffs[j] - term.coeffs[j]);
}
}
}
fn normalize_poly(poly: &mut Poly) {
for c in &mut poly.coeffs {
*c = mod_q(*c);
}
}
fn normalize_vec_l(vec: &mut PolyVecL) {
for poly in &mut vec.polys {
normalize_poly(poly);
}
}
fn normalize_vec_k(vec: &mut PolyVecK) {
for poly in &mut vec.polys {
normalize_poly(poly);
}
}
fn encode_small_vec_l(vec: &PolyVecL) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_S1_BYTES);
for poly in &vec.polys {
for c in &poly.coeffs {
out.push((*c + MLDSA_ETA_BOUND) as u8);
}
}
out
}
fn encode_small_vec_k(vec: &PolyVecK) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_S2_BYTES);
for poly in &vec.polys {
for c in &poly.coeffs {
out.push((*c + MLDSA_ETA_BOUND) as u8);
}
}
out
}
fn decode_small_vec_l(bytes: &[u8]) -> Result<PolyVecL> {
if bytes.len() != MLDSA_S1_BYTES {
return Err(Error::InvalidLength("mldsa s1 bytes length mismatch"));
}
let mut out = PolyVecL::zero();
let mut idx = 0_usize;
for i in 0..MLDSA_L {
for j in 0..MLDSA_N {
out.polys[i].coeffs[j] = i32::from(bytes[idx]) - MLDSA_ETA_BOUND;
idx += 1;
}
}
Ok(out)
}
fn decode_small_vec_k(bytes: &[u8]) -> Result<PolyVecK> {
if bytes.len() != MLDSA_S2_BYTES {
return Err(Error::InvalidLength("mldsa s2 bytes length mismatch"));
}
let mut out = PolyVecK::zero();
let mut idx = 0_usize;
for i in 0..MLDSA_K {
for j in 0..MLDSA_N {
out.polys[i].coeffs[j] = i32::from(bytes[idx]) - MLDSA_ETA_BOUND;
idx += 1;
}
}
Ok(out)
}
fn encode_vec_k_10bit(vec: &PolyVecK) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_PUBLIC_T_BYTES);
for poly in &vec.polys {
out.extend_from_slice(&encode_poly_10bit(poly));
}
out
}
fn decode_public_t(bytes: &[u8]) -> Result<PolyVecK> {
if bytes.len() != MLDSA_PUBLIC_T_BYTES {
return Err(Error::InvalidLength("mldsa public t bytes length mismatch"));
}
let mut out = PolyVecK::zero();
for i in 0..MLDSA_K {
let start = i * MLDSA_POLY_PACKED10_BYTES;
out.polys[i] = decode_poly_10bit(&bytes[start..start + MLDSA_POLY_PACKED10_BYTES])?;
}
Ok(out)
}
fn encode_vec_l_12bit(vec: &PolyVecL) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_SIGNATURE_Z_BYTES);
for poly in &vec.polys {
out.extend_from_slice(&encode_poly_12bit(poly));
}
out
}
fn decode_vec_l_12bit(bytes: &[u8]) -> Result<PolyVecL> {
if bytes.len() != MLDSA_SIGNATURE_Z_BYTES {
return Err(Error::InvalidLength(
"mldsa signature z bytes length mismatch",
));
}
let mut out = PolyVecL::zero();
for i in 0..MLDSA_L {
let start = i * MLDSA_POLY_PACKED12_BYTES;
out.polys[i] = decode_poly_12bit(&bytes[start..start + MLDSA_POLY_PACKED12_BYTES])?;
}
Ok(out)
}
fn encode_poly_10bit(poly: &Poly) -> [u8; MLDSA_POLY_PACKED10_BYTES] {
let mut out = [0_u8; MLDSA_POLY_PACKED10_BYTES];
let mut out_idx = 0_usize;
for chunk in poly.coeffs.chunks_exact(4) {
let mut t = [0_u16; 4];
for i in 0..4 {
t[i] = ((mod_q(chunk[i]) as i64 * 1024 / i64::from(MLDSA_Q)) & 0x03FF) as u16;
}
out[out_idx] = t[0] as u8;
out[out_idx + 1] = ((t[0] >> 8) as u8) | ((t[1] << 2) as u8);
out[out_idx + 2] = ((t[1] >> 6) as u8) | ((t[2] << 4) as u8);
out[out_idx + 3] = ((t[2] >> 4) as u8) | ((t[3] << 6) as u8);
out[out_idx + 4] = (t[3] >> 2) as u8;
out_idx += 5;
}
out
}
fn decode_poly_10bit(bytes: &[u8]) -> Result<Poly> {
if bytes.len() != MLDSA_POLY_PACKED10_BYTES {
return Err(Error::InvalidLength(
"mldsa 10-bit polynomial length mismatch",
));
}
let mut out = Poly::zero();
let mut in_idx = 0_usize;
for i in 0..(MLDSA_N / 4) {
let b0 = u16::from(bytes[in_idx]);
let b1 = u16::from(bytes[in_idx + 1]);
let b2 = u16::from(bytes[in_idx + 2]);
let b3 = u16::from(bytes[in_idx + 3]);
let b4 = u16::from(bytes[in_idx + 4]);
in_idx += 5;
let t0 = b0 | ((b1 & 0x03) << 8);
let t1 = (b1 >> 2) | ((b2 & 0x0F) << 6);
let t2 = (b2 >> 4) | ((b3 & 0x3F) << 4);
let t3 = (b3 >> 6) | (b4 << 2);
out.coeffs[4 * i] = ((i64::from(t0) * i64::from(MLDSA_Q)) / 1024) as i32;
out.coeffs[4 * i + 1] = ((i64::from(t1) * i64::from(MLDSA_Q)) / 1024) as i32;
out.coeffs[4 * i + 2] = ((i64::from(t2) * i64::from(MLDSA_Q)) / 1024) as i32;
out.coeffs[4 * i + 3] = ((i64::from(t3) * i64::from(MLDSA_Q)) / 1024) as i32;
}
Ok(out)
}
fn encode_poly_12bit(poly: &Poly) -> [u8; MLDSA_POLY_PACKED12_BYTES] {
let mut out = [0_u8; MLDSA_POLY_PACKED12_BYTES];
let mut out_idx = 0_usize;
for chunk in poly.coeffs.chunks_exact(2) {
let c0 = (mod_q(chunk[0]) & 0x0FFF) as u16;
let c1 = (mod_q(chunk[1]) & 0x0FFF) as u16;
out[out_idx] = (c0 & 0xFF) as u8;
out[out_idx + 1] = ((c0 >> 8) as u8) | (((c1 & 0x0F) as u8) << 4);
out[out_idx + 2] = (c1 >> 4) as u8;
out_idx += 3;
}
out
}
fn decode_poly_12bit(bytes: &[u8]) -> Result<Poly> {
if bytes.len() != MLDSA_POLY_PACKED12_BYTES {
return Err(Error::InvalidLength(
"mldsa 12-bit polynomial length mismatch",
));
}
let mut out = Poly::zero();
let mut in_idx = 0_usize;
for i in 0..(MLDSA_N / 2) {
let b0 = u16::from(bytes[in_idx]);
let b1 = u16::from(bytes[in_idx + 1]);
let b2 = u16::from(bytes[in_idx + 2]);
in_idx += 3;
out.coeffs[2 * i] = i32::from((b0 | ((b1 & 0x0F) << 8)) & 0x0FFF);
out.coeffs[2 * i + 1] = i32::from(((b1 >> 4) | (b2 << 4)) & 0x0FFF);
}
Ok(out)
}
fn compress_vec_k_hint(vec: &PolyVecK) -> Vec<u8> {
let mut out = Vec::with_capacity(MLDSA_K * MLDSA_N / 2);
for poly in &vec.polys {
for pair in poly.coeffs.chunks_exact(2) {
let lo = (((mod_q(pair[0]) * 16) / MLDSA_Q) & 0x0F) as u8;
let hi = (((mod_q(pair[1]) * 16) / MLDSA_Q) & 0x0F) as u8;
out.push(lo | (hi << 4));
}
}
out
}
fn derive_hint_bytes_from_signature(
w1_bytes: &[u8],
z_bytes: &[u8],
t: &PolyVecK,
c: &[u8],
out_len: usize,
) -> Vec<u8> {
let mut seed = Vec::new();
seed.extend_from_slice(w1_bytes);
seed.extend_from_slice(z_bytes);
seed.extend_from_slice(&encode_vec_k_10bit(t));
seed.extend_from_slice(c);
let mut out = Vec::with_capacity(out_len);
let take = w1_bytes.len().min(out_len);
out.extend_from_slice(&w1_bytes[..take]);
if out_len > take {
out.extend_from_slice(&expand_seed_bytes(b"mldsa-hints", &seed, out_len - take));
}
out
}
fn build_challenge(mu: &[u8; 32], w1_bytes: &[u8]) -> [u8; 32] {
let mut c_input = Vec::with_capacity(1 + mu.len() + w1_bytes.len());
c_input.push(MLDSA_XOF_DOMAIN_CHALLENGE);
c_input.extend_from_slice(mu);
c_input.extend_from_slice(w1_bytes);
let digest = shake256(&c_input, 32);
let mut out = [0_u8; 32];
out.copy_from_slice(&digest);
out
}
fn challenge_poly_from_digest(c: &[u8; 32]) -> Poly {
let mut poly = Poly::zero();
let mut seed = Vec::with_capacity(1 + c.len());
seed.push(MLDSA_XOF_DOMAIN_CHALLENGE);
seed.extend_from_slice(c);
let stream = shake256(&seed, 4 * MLDSA_CHALLENGE_NONZERO_TERMS);
let mut cursor = 0_usize;
let mut placed = 0_usize;
while placed < MLDSA_CHALLENGE_NONZERO_TERMS {
let idx_word = u16::from(stream[cursor]) | (u16::from(stream[cursor + 1]) << 8);
let idx = usize::from(idx_word) % MLDSA_N;
let sign = if (stream[cursor + 2] & 1) == 0 { 1 } else { -1 };
cursor += 3;
if cursor + 3 >= stream.len() {
cursor = 0;
}
if poly.coeffs[idx] == 0 {
poly.coeffs[idx] = sign;
placed += 1;
}
}
poly
}
fn array32_from_slice(bytes: &[u8]) -> Result<[u8; 32]> {
bytes
.try_into()
.map_err(|_| Error::InvalidLength("mldsa expected 32-byte slice"))
}
fn mod_q(value: i32) -> i32 {
let mut v = value % MLDSA_Q;
if v < 0 {
v += MLDSA_Q;
}
v
}
fn mod_q_i64(value: i64) -> i32 {
let q = i64::from(MLDSA_Q);
let mut v = value % q;
if v < 0 {
v += q;
}
v as i32
}
fn clamp(value: i32, min_v: i32, max_v: i32) -> i32 {
if value < min_v {
min_v
} else if value > max_v {
max_v
} else {
value
}
}
fn centered_abs(value: i32) -> i32 {
let reduced = mod_q(value);
let centered = if reduced > (MLDSA_Q / 2) {
reduced - MLDSA_Q
} else {
reduced
};
centered.abs()
}
fn max_abs_vec_l(vec: &PolyVecL) -> i32 {
let mut max_v = 0_i32;
for poly in &vec.polys {
for coeff in &poly.coeffs {
let abs = centered_abs(*coeff);
if abs > max_v {
max_v = abs;
}
}
}
max_v
}
fn max_abs_vec_k(vec: &PolyVecK) -> i32 {
let mut max_v = 0_i32;
for poly in &vec.polys {
for coeff in &poly.coeffs {
let abs = centered_abs(*coeff);
if abs > max_v {
max_v = abs;
}
}
}
max_v
}
fn derive_hash32(label: &[u8], seed: &[u8]) -> [u8; 32] {
let mut input = Vec::with_capacity(1 + label.len() + seed.len());
input.push(MLDSA_XOF_DOMAIN_HASH32);
input.extend_from_slice(label);
input.extend_from_slice(seed);
let digest = shake256(&input, 32);
let mut out = [0_u8; 32];
out.copy_from_slice(&digest);
out
}
fn expand_seed_bytes(label: &[u8], seed: &[u8], out_len: usize) -> Vec<u8> {
let mut input = Vec::with_capacity(1 + label.len() + seed.len());
input.push(MLDSA_XOF_DOMAIN_EXPAND);
input.extend_from_slice(label);
input.extend_from_slice(seed);
shake256(&input, out_len)
}