use std::path::PathBuf;
use ml_dsa::{
EncodedSignature, EncodedVerifyingKey, Keypair, MlDsa65, Signature as MlDsaSignature,
SigningKey as MlDsaSigningKey, VerifyingKey as MlDsaVerifyingKey, B32 as DsaB32,
};
use ml_kem::kem::Decapsulate;
use ml_kem::{
Ciphertext, EncapsulateDeterministic, Encoded, EncodedSizeUser, KemCore, MlKem768, B32,
};
use serde_json::Value;
type EncapKey = <MlKem768 as KemCore>::EncapsulationKey;
type DecapKey = <MlKem768 as KemCore>::DecapsulationKey;
fn load(name: &str) -> Value {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/nist_kat")
.join(name);
let bytes = std::fs::read(&path).unwrap_or_else(|e| panic!("read {name}: {e}"));
serde_json::from_slice(&bytes).unwrap_or_else(|e| panic!("parse {name}: {e}"))
}
fn cases(doc: &Value) -> Vec<(&Value, &Value)> {
let mut out = Vec::new();
for g in doc["testGroups"].as_array().expect("testGroups") {
for t in g["tests"].as_array().expect("tests") {
out.push((g, t));
}
}
out
}
fn hx(v: &Value, key: &str) -> Vec<u8> {
hex::decode(
v[key]
.as_str()
.unwrap_or_else(|| panic!("missing hex field {key}")),
)
.unwrap_or_else(|e| panic!("bad hex in {key}: {e}"))
}
fn b32(bytes: &[u8]) -> B32 {
B32::try_from(bytes).expect("expected 32 bytes")
}
fn dsa_b32(bytes: &[u8]) -> DsaB32 {
DsaB32::try_from(bytes).expect("expected 32 bytes")
}
#[test]
fn ml_kem_768_keygen_kat() {
let doc = load("ml_kem_768_keygen.json");
let mut n = 0;
for (_g, t) in cases(&doc) {
let (d, z) = (hx(t, "d"), hx(t, "z"));
let (dk, ek) = MlKem768::generate_deterministic(&b32(&d), &b32(&z));
assert_eq!(
ek.as_bytes().as_slice(),
hx(t, "ek").as_slice(),
"ML-KEM-768 keyGen ek mismatch (tcId {})",
t["tcId"]
);
assert_eq!(
dk.as_bytes().as_slice(),
hx(t, "dk").as_slice(),
"ML-KEM-768 keyGen dk mismatch (tcId {})",
t["tcId"]
);
n += 1;
}
assert!(n >= 3, "expected several keyGen cases, got {n}");
}
#[test]
fn ml_kem_768_encaps_kat() {
let doc = load("ml_kem_768_encap_decap.json");
let mut n = 0;
for (g, t) in cases(&doc) {
if g["function"] != "encapsulation" {
continue;
}
let ek_enc = Encoded::<EncapKey>::try_from(hx(t, "ek").as_slice()).expect("ek size");
let ek = EncapKey::from_bytes(&ek_enc);
let (c, k) = ek
.encapsulate_deterministic(&b32(&hx(t, "m")))
.expect("ML-KEM-768 deterministic encaps");
assert_eq!(
c.as_slice(),
hx(t, "c").as_slice(),
"encaps c (tcId {})",
t["tcId"]
);
assert_eq!(
k.as_slice(),
hx(t, "k").as_slice(),
"encaps k (tcId {})",
t["tcId"]
);
n += 1;
}
assert!(n >= 1, "expected encaps cases");
}
#[test]
fn ml_kem_768_decaps_kat() {
let doc = load("ml_kem_768_encap_decap.json");
let mut n = 0;
for (g, t) in cases(&doc) {
if g["function"] != "decapsulation" {
continue;
}
let dk_enc = Encoded::<DecapKey>::try_from(hx(t, "dk").as_slice()).expect("dk size");
let dk = DecapKey::from_bytes(&dk_enc);
let ct = Ciphertext::<MlKem768>::try_from(hx(t, "c").as_slice()).expect("ct size");
let k = dk.decapsulate(&ct).expect("ML-KEM-768 decaps");
assert_eq!(
k.as_slice(),
hx(t, "k").as_slice(),
"decaps k (tcId {})",
t["tcId"]
);
n += 1;
}
assert!(n >= 1, "expected decaps VAL cases");
}
#[test]
fn ml_kem_768_kat_tamper_is_caught() {
let doc = load("ml_kem_768_keygen.json");
let (_g, t) = cases(&doc)[0];
let (dk, _ek) = MlKem768::generate_deterministic(&b32(&hx(t, "d")), &b32(&hx(t, "z")));
let mut tampered = hx(t, "dk");
tampered[0] ^= 0x01;
assert_ne!(
dk.as_bytes().as_slice(),
tampered.as_slice(),
"a tampered expected dk must not match"
);
}
#[test]
fn ml_dsa_65_keygen_pk_kat() {
let doc = load("ml_dsa_65_keygen.json");
let mut n = 0;
for (_g, t) in cases(&doc) {
let sk = MlDsaSigningKey::<MlDsa65>::from_seed(&dsa_b32(&hx(t, "seed")));
assert_eq!(
sk.verifying_key().encode().as_slice(),
hx(t, "pk").as_slice(),
"ML-DSA-65 keyGen pk mismatch (tcId {})",
t["tcId"]
);
n += 1;
}
assert!(n >= 3, "expected several keyGen cases, got {n}");
}
#[test]
fn ml_dsa_65_verify_kat() {
let doc = load("ml_dsa_65_siggen.json");
let mut n = 0;
for (_g, t) in cases(&doc) {
let pk_enc =
EncodedVerifyingKey::<MlDsa65>::try_from(hx(t, "pk").as_slice()).expect("pk size");
let vk = MlDsaVerifyingKey::<MlDsa65>::decode(&pk_enc);
let msg = hx(t, "message");
let ctx = hx(t, "context");
let sig_bytes = hx(t, "signature");
let sig_enc =
EncodedSignature::<MlDsa65>::try_from(sig_bytes.as_slice()).expect("sig size");
let sig = MlDsaSignature::<MlDsa65>::decode(&sig_enc).expect("decode NIST signature");
assert!(
vk.verify_with_context(&msg, &ctx, &sig),
"ML-DSA-65 must accept the real NIST signature (tcId {})",
t["tcId"]
);
let mut bad = sig_bytes.clone();
bad[0] ^= 0x01;
let accepted = EncodedSignature::<MlDsa65>::try_from(bad.as_slice())
.ok()
.and_then(|e| MlDsaSignature::<MlDsa65>::decode(&e))
.is_some_and(|s| vk.verify_with_context(&msg, &ctx, &s));
assert!(
!accepted,
"a tampered NIST signature must be rejected (tcId {})",
t["tcId"]
);
n += 1;
}
assert!(n >= 1, "expected sigGen verify cases");
}