Skip to main content

qrllib/
mlkem.rs

1//! ML-KEM-1024 key encapsulation mechanism (FIPS 203).
2//!
3//! This is a faithful port of `go-qrllib`'s `crypto/internal/mlkem1024`
4//! (the IND-CCA-secure construction over K-PKE) and its `crypto/mlkem1024`
5//! public wrapper. ML-KEM-1024 targets NIST security category 5.
6//!
7//! # API
8//!
9//! - [`DecapsulationKey`] is the private key. Generate one with
10//!   [`DecapsulationKey::generate`] (fresh randomness) or restore it
11//!   deterministically from a 64-byte `d || z` seed with
12//!   [`DecapsulationKey::from_seed`].
13//! - [`EncapsulationKey`] is the public key, obtained via
14//!   [`DecapsulationKey::encapsulation_key`] or decoded from its bytes with
15//!   [`EncapsulationKey::from_bytes`].
16//! - [`EncapsulationKey::encapsulate`] produces a `(shared_key, ciphertext)`
17//!   pair; [`DecapsulationKey::decapsulate`] recovers the shared key from the
18//!   ciphertext.
19//!
20//! Shared secrets are returned as [`Zeroizing`] byte arrays, and the
21//! decapsulation key zeroizes its secret material on drop.
22//!
23//! # Determinism for test vectors
24//!
25//! [`EncapsulationKey::encapsulate_deterministic`] takes the 32-byte message
26//! explicitly instead of drawing it from the system RNG. It exists only to
27//! reproduce ACVP / Wycheproof / cross-verification vectors — production code
28//! must use [`EncapsulationKey::encapsulate`], because reusing a message
29//! across encapsulations breaks IND-CCA security.
30
31use crate::error::{QrllibError, Result};
32use sha3::{
33    Digest, Sha3_256, Sha3_512,
34    digest::{ExtendableOutput, Update, XofReader},
35};
36use shake::{Shake128, Shake256};
37use zeroize::{Zeroize, Zeroizing};
38
39// ---------------------------------------------------------------------------
40// Public parameters
41// ---------------------------------------------------------------------------
42
43/// Size in bytes of the `d || z` seed that deterministically generates an
44/// ML-KEM-1024 decapsulation key.
45pub const MLKEM1024_SEED_SIZE: usize = 64;
46
47/// Size in bytes of an ML-KEM-1024 shared secret.
48pub const MLKEM1024_SHARED_KEY_SIZE: usize = 32;
49
50/// Size in bytes of an ML-KEM-1024 ciphertext.
51pub const MLKEM1024_CIPHERTEXT_SIZE: usize = K * ENCODING_SIZE_11 + ENCODING_SIZE_5;
52
53/// Size in bytes of an encoded ML-KEM-1024 encapsulation (public) key.
54pub const MLKEM1024_ENCAPSULATION_KEY_SIZE: usize = K * ENCODING_SIZE_12 + 32;
55
56// ---------------------------------------------------------------------------
57// Internal parameters
58// ---------------------------------------------------------------------------
59
60// ML-KEM global parameters.
61const N: usize = 256;
62const Q: u16 = 3329;
63
64// ML-KEM-1024 module rank.
65const K: usize = 4;
66
67// Byte lengths of ByteEncode_d(f) output (FIPS 203, Algorithm 5).
68const ENCODING_SIZE_1: usize = N / 8;
69const ENCODING_SIZE_5: usize = N * 5 / 8;
70const ENCODING_SIZE_11: usize = N * 11 / 8;
71const ENCODING_SIZE_12: usize = N * 12 / 8;
72
73// ML-KEM messages are 32-byte values encoded as ByteEncode_1(m).
74const MESSAGE_SIZE: usize = ENCODING_SIZE_1;
75
76const D5: u8 = 5;
77const D11: u8 = 11;
78
79const HALF_Q_ROUNDED_UP: u16 = Q.div_ceil(2); // Decompress_1(1) == (q + 1) / 2
80const SHAKE128_RATE: usize = 168;
81
82// ---------------------------------------------------------------------------
83// Field arithmetic (Z_q, q = 3329)
84// ---------------------------------------------------------------------------
85
86/// Reduces a value in `[0, 2q)` to `[0, q)`.
87fn field_reduce_once(a: u16) -> u16 {
88    let x = a.wrapping_sub(Q);
89    // Add q back iff the subtraction went negative (top bit set after wrap).
90    x.wrapping_add(Q & (x >> 15).wrapping_neg())
91}
92
93fn field_add(a: u16, b: u16) -> u16 {
94    field_reduce_once(a.wrapping_add(b))
95}
96
97fn field_sub(a: u16, b: u16) -> u16 {
98    field_reduce_once(a.wrapping_sub(b).wrapping_add(Q))
99}
100
101const BARRETT_MULTIPLIER: u64 = 5039;
102const BARRETT_SHIFT: u32 = 24;
103const BARRETT_WIDE_MULTIPLIER: u64 = 1_290_167;
104const BARRETT_WIDE_SHIFT: u32 = 32;
105
106fn field_reduce(a: u32) -> u16 {
107    let quotient = ((a as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
108    field_reduce_once(a.wrapping_sub(quotient.wrapping_mul(Q as u32)) as u16)
109}
110
111/// Reduces lazy products and accumulators that do not fit the 24-bit Barrett
112/// reducer. Current callers stay below about `8*q*q`.
113fn field_reduce_wide(a: u32) -> u16 {
114    let quotient = ((a as u64 * BARRETT_WIDE_MULTIPLIER) >> BARRETT_WIDE_SHIFT) as u32;
115    field_reduce_once(a.wrapping_sub(quotient.wrapping_mul(Q as u32)) as u16)
116}
117
118fn field_mul(a: u16, b: u16) -> u16 {
119    field_reduce(a as u32 * b as u32)
120}
121
122fn field_mul_wide(a: u16, b: u16) -> u16 {
123    field_reduce_wide(a as u32 * b as u32)
124}
125
126fn field_mul_sub(a: u16, b: u16, c: u16) -> u16 {
127    let x = a as u32 * b.wrapping_sub(c).wrapping_add(Q) as u32;
128    field_reduce(x)
129}
130
131const COMPRESS1_LOWER: u32 = (Q as u32).div_ceil(4); // ceil(q/4) == (q + 3) / 4
132const COMPRESS1_UPPER: u32 = (3 * Q as u32) / 4; // floor(3q/4)
133
134fn compress1(x: u16) -> u8 {
135    let ux = x as u32;
136    let ge_lower = ((ux.wrapping_sub(COMPRESS1_LOWER)) >> 31) ^ 1;
137    let le_upper = ((COMPRESS1_UPPER.wrapping_sub(ux)) >> 31) ^ 1;
138    (ge_lower & le_upper) as u8
139}
140
141fn compress5(x: u16) -> u16 {
142    let dividend = (x as u32) << D5;
143    let mut quotient = ((dividend as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
144    let remainder = dividend.wrapping_sub(quotient.wrapping_mul(Q as u32));
145    quotient = quotient.wrapping_add(((Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
146    quotient = quotient.wrapping_add(((Q as u32 + Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
147    (quotient & 0x1f) as u16
148}
149
150fn compress11(x: u16) -> u16 {
151    let dividend = (x as u32) << D11;
152    let mut quotient = ((dividend as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
153    let remainder = dividend.wrapping_sub(quotient.wrapping_mul(Q as u32));
154    quotient = quotient.wrapping_add(((Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
155    quotient = quotient.wrapping_add(((Q as u32 + Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
156    (quotient & 0x7ff) as u16
157}
158
159fn decompress(y: u16, d: u8) -> u16 {
160    let dividend = (y as u32) * (Q as u32);
161    let mut quotient = dividend >> d;
162    quotient += (dividend >> (d - 1)) & 1;
163    quotient as u16
164}
165
166// ---------------------------------------------------------------------------
167// Ring elements and (de)serialisation
168// ---------------------------------------------------------------------------
169
170/// A degree-255 polynomial in `Z_q[X]/(X^256 + 1)`.
171type RingElement = [u16; N];
172
173fn new_ring() -> RingElement {
174    [0u16; N]
175}
176
177fn ring_decode_and_decompress1(dst: &mut RingElement, src: &[u8]) {
178    for (i, slot) in dst.iter_mut().enumerate() {
179        // Decode one message bit, so the result is either 0 or 1; since q is
180        // odd, Decompress_1 maps 1 to (q+1)/2, rounding q/2 up.
181        let b = (src[i / 8] >> (i % 8)) & 1;
182        *slot = (b as u16) * HALF_Q_ROUNDED_UP;
183    }
184}
185
186fn ring_decode_and_decompress5(dst: &mut RingElement, src: &[u8]) {
187    let mut i = 0usize;
188    let mut off = 0usize;
189    while i < N {
190        let b0 = src[off] as u16;
191        let b1 = src[off + 1] as u16;
192        let b2 = src[off + 2] as u16;
193        let b3 = src[off + 3] as u16;
194        let b4 = src[off + 4] as u16;
195
196        dst[i] = decompress(b0 & 0x1f, D5);
197        dst[i + 1] = decompress((b0 >> 5 | b1 << 3) & 0x1f, D5);
198        dst[i + 2] = decompress((b1 >> 2) & 0x1f, D5);
199        dst[i + 3] = decompress((b1 >> 7 | b2 << 1) & 0x1f, D5);
200        dst[i + 4] = decompress((b2 >> 4 | b3 << 4) & 0x1f, D5);
201        dst[i + 5] = decompress((b3 >> 1) & 0x1f, D5);
202        dst[i + 6] = decompress((b3 >> 6 | b4 << 2) & 0x1f, D5);
203        dst[i + 7] = decompress((b4 >> 3) & 0x1f, D5);
204
205        i += 8;
206        off += 5;
207    }
208}
209
210fn ring_decode_and_decompress11(dst: &mut RingElement, src: &[u8]) {
211    let mut i = 0usize;
212    let mut off = 0usize;
213    while i < N {
214        let b0 = src[off] as u32;
215        let b1 = src[off + 1] as u32;
216        let b2 = src[off + 2] as u32;
217        let b3 = src[off + 3] as u32;
218        let b4 = src[off + 4] as u32;
219        let b5 = src[off + 5] as u32;
220        let b6 = src[off + 6] as u32;
221        let b7 = src[off + 7] as u32;
222        let b8 = src[off + 8] as u32;
223        let b9 = src[off + 9] as u32;
224        let b10 = src[off + 10] as u32;
225
226        dst[i] = decompress(((b0 | b1 << 8) & 0x7ff) as u16, D11);
227        dst[i + 1] = decompress(((b1 >> 3 | b2 << 5) & 0x7ff) as u16, D11);
228        dst[i + 2] = decompress(((b2 >> 6 | b3 << 2 | b4 << 10) & 0x7ff) as u16, D11);
229        dst[i + 3] = decompress(((b4 >> 1 | b5 << 7) & 0x7ff) as u16, D11);
230        dst[i + 4] = decompress(((b5 >> 4 | b6 << 4) & 0x7ff) as u16, D11);
231        dst[i + 5] = decompress(((b6 >> 7 | b7 << 1 | b8 << 9) & 0x7ff) as u16, D11);
232        dst[i + 6] = decompress(((b8 >> 2 | b9 << 6) & 0x7ff) as u16, D11);
233        dst[i + 7] = decompress(((b9 >> 5 | b10 << 3) & 0x7ff) as u16, D11);
234
235        i += 8;
236        off += 11;
237    }
238}
239
240fn ring_compress_and_encode1(dst: &mut [u8], src: &RingElement) {
241    let mut i = 0usize;
242    let mut off = 0usize;
243    while i < N {
244        let c0 = compress1(src[i]);
245        let c1 = compress1(src[i + 1]);
246        let c2 = compress1(src[i + 2]);
247        let c3 = compress1(src[i + 3]);
248        let c4 = compress1(src[i + 4]);
249        let c5 = compress1(src[i + 5]);
250        let c6 = compress1(src[i + 6]);
251        let c7 = compress1(src[i + 7]);
252
253        dst[off] = c0 | c1 << 1 | c2 << 2 | c3 << 3 | c4 << 4 | c5 << 5 | c6 << 6 | c7 << 7;
254
255        i += 8;
256        off += 1;
257    }
258}
259
260fn ring_compress_and_encode5(dst: &mut [u8], src: &RingElement) {
261    let mut i = 0usize;
262    let mut off = 0usize;
263    while i < N {
264        let c0 = compress5(src[i]);
265        let c1 = compress5(src[i + 1]);
266        let c2 = compress5(src[i + 2]);
267        let c3 = compress5(src[i + 3]);
268        let c4 = compress5(src[i + 4]);
269        let c5 = compress5(src[i + 5]);
270        let c6 = compress5(src[i + 6]);
271        let c7 = compress5(src[i + 7]);
272
273        dst[off] = (c0 | c1 << 5) as u8;
274        dst[off + 1] = (c1 >> 3 | c2 << 2 | c3 << 7) as u8;
275        dst[off + 2] = (c3 >> 1 | c4 << 4) as u8;
276        dst[off + 3] = (c4 >> 4 | c5 << 1 | c6 << 6) as u8;
277        dst[off + 4] = (c6 >> 2 | c7 << 3) as u8;
278
279        i += 8;
280        off += 5;
281    }
282}
283
284fn ring_compress_and_encode11(dst: &mut [u8], src: &RingElement) {
285    let mut i = 0usize;
286    let mut off = 0usize;
287    while i < N {
288        let c0 = compress11(src[i]) as u32;
289        let c1 = compress11(src[i + 1]) as u32;
290        let c2 = compress11(src[i + 2]) as u32;
291        let c3 = compress11(src[i + 3]) as u32;
292        let c4 = compress11(src[i + 4]) as u32;
293        let c5 = compress11(src[i + 5]) as u32;
294        let c6 = compress11(src[i + 6]) as u32;
295        let c7 = compress11(src[i + 7]) as u32;
296
297        dst[off] = c0 as u8;
298        dst[off + 1] = (c0 >> 8 | c1 << 3) as u8;
299        dst[off + 2] = (c1 >> 5 | c2 << 6) as u8;
300        dst[off + 3] = (c2 >> 2) as u8;
301        dst[off + 4] = (c2 >> 10 | c3 << 1) as u8;
302        dst[off + 5] = (c3 >> 7 | c4 << 4) as u8;
303        dst[off + 6] = (c4 >> 4 | c5 << 7) as u8;
304        dst[off + 7] = (c5 >> 1) as u8;
305        dst[off + 8] = (c5 >> 9 | c6 << 2) as u8;
306        dst[off + 9] = (c6 >> 6 | c7 << 5) as u8;
307        dst[off + 10] = (c7 >> 3) as u8;
308
309        i += 8;
310        off += 11;
311    }
312}
313
314fn byte_encode12(dst: &mut [u8], p: &RingElement) {
315    let mut i = 0usize;
316    let mut off = 0usize;
317    while i < N {
318        let x = (p[i] as u32) | (p[i + 1] as u32) << 12;
319        dst[off] = x as u8;
320        dst[off + 1] = (x >> 8) as u8;
321        dst[off + 2] = (x >> 16) as u8;
322        i += 2;
323        off += 3;
324    }
325}
326
327fn byte_decode12(dst: &mut RingElement, src: &[u8]) -> Result<()> {
328    let mut i = 0usize;
329    let mut off = 0usize;
330    while i < N {
331        let x = (src[off] as u32) | (src[off + 1] as u32) << 8 | (src[off + 2] as u32) << 16;
332        let c0 = (x & 0x0fff) as u16;
333        let c1 = (x >> 12) as u16;
334        if c0 >= Q || c1 >= Q {
335            return Err(QrllibError::InvalidMlKemEncoding);
336        }
337        dst[i] = c0;
338        dst[i + 1] = c1;
339        i += 2;
340        off += 3;
341    }
342    Ok(())
343}
344
345// ---------------------------------------------------------------------------
346// Sampling
347// ---------------------------------------------------------------------------
348
349/// Samples the NTT-domain matrix entry `A[i,j]` from `SHAKE128(rho || j || i)`.
350fn sample_ntt(dst: &mut RingElement, rho: &[u8; 32], j_index: u8, i_index: u8) {
351    let mut ctx = Shake128::default();
352    ctx.update(rho);
353    ctx.update(&[j_index, i_index]);
354    let mut reader = ctx.finalize_xof();
355
356    let mut j = 0usize;
357    let mut buf = [0u8; SHAKE128_RATE];
358    let mut off = buf.len();
359
360    loop {
361        if off >= buf.len() {
362            reader.read(&mut buf);
363            off = 0;
364        }
365
366        let x0 = (buf[off] as u16) | (((buf[off + 1] & 0x0f) as u16) << 8);
367        let x1 = ((buf[off + 1] >> 4) as u16) | ((buf[off + 2] as u16) << 4);
368        off += 3;
369
370        if x0 < Q {
371            dst[j] = x0;
372            j += 1;
373        }
374        if j >= N {
375            break;
376        }
377        if x1 < Q {
378            dst[j] = x1;
379            j += 1;
380        }
381        if j >= N {
382            break;
383        }
384    }
385}
386
387/// Samples a noise polynomial with CBD_2 from `SHAKE256(sigma || counter)`.
388fn sample_poly_cbd(dst: &mut RingElement, sigma: &[u8; 32], counter: u8) {
389    let mut prf = Shake256::default();
390    prf.update(sigma);
391    prf.update(&[counter]);
392    let mut reader = prf.finalize_xof();
393    let mut buf = [0u8; 128];
394    reader.read(&mut buf);
395
396    let mut i = 0usize;
397    let mut j = 0usize;
398    while i < buf.len() {
399        let t = u32::from_le_bytes([buf[i], buf[i + 1], buf[i + 2], buf[i + 3]]);
400        // Each two-bit field in d is the Hamming weight of one input bit pair;
401        // CBD_2 maps adjacent weights to one coefficient as a-b mod q.
402        let d = (t & 0x5555_5555) + ((t >> 1) & 0x5555_5555);
403
404        dst[j] = cbd2(d, d >> 2);
405        dst[j + 1] = cbd2(d >> 4, d >> 6);
406        dst[j + 2] = cbd2(d >> 8, d >> 10);
407        dst[j + 3] = cbd2(d >> 12, d >> 14);
408        dst[j + 4] = cbd2(d >> 16, d >> 18);
409        dst[j + 5] = cbd2(d >> 20, d >> 22);
410        dst[j + 6] = cbd2(d >> 24, d >> 26);
411        dst[j + 7] = cbd2(d >> 28, d >> 30);
412
413        i += 4;
414        j += 8;
415    }
416}
417
418fn cbd2(a: u32, b: u32) -> u16 {
419    field_reduce_once(Q.wrapping_add((a & 0x3) as u16).wrapping_sub((b & 0x3) as u16))
420}
421
422// ---------------------------------------------------------------------------
423// Number-theoretic transform
424// ---------------------------------------------------------------------------
425
426#[rustfmt::skip]
427const ZETAS: [u16; 128] = [
428    1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
429    296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
430    289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
431    2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
432    17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
433    1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
434    1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594,
435    2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
436];
437
438fn ntt(f: &mut RingElement) {
439    let mut i = 1usize;
440    let mut length = 128usize;
441    while length >= 2 {
442        let mut start = 0usize;
443        while start < 256 {
444            let zeta = ZETAS[i];
445            i += 1;
446            for j in start..start + length {
447                // Keep butterfly outputs unreduced between layers. Each layer
448                // can grow coefficients by at most q, so across the seven NTT
449                // layers they stay below 8q and are canonicalized at the end.
450                let t = field_mul_wide(zeta, f[j + length]);
451                let a = f[j];
452                f[j] = a.wrapping_add(t);
453                f[j + length] = a.wrapping_add(Q).wrapping_sub(t);
454            }
455            start += 2 * length;
456        }
457        length /= 2;
458    }
459    for coeff in f.iter_mut() {
460        *coeff = field_reduce(*coeff as u32);
461    }
462}
463
464const INVERSE_NTT_SCALE: u16 = 3303;
465// The final inverse NTT layer multiplies lower-half outputs by INVERSE_NTT_SCALE
466// directly and folds the upper-half scaling into its zeta.
467const INVERSE_NTT_FINAL_ZETA: u16 = 1652; // zetas[1] * INVERSE_NTT_SCALE mod q
468
469fn inverse_ntt(f: &mut RingElement) {
470    let mut i = 127usize;
471    let mut length = 2usize;
472    while length < 128 {
473        let mut start = 0usize;
474        while start < 256 {
475            let zeta = ZETAS[i];
476            i -= 1;
477            for j in start..start + length {
478                let t = f[j];
479                f[j] = field_add(t, f[j + length]);
480                f[j + length] = field_mul_sub(zeta, f[j + length], t);
481            }
482            start += 2 * length;
483        }
484        length *= 2;
485    }
486
487    for j in 0..128 {
488        let t = f[j];
489        f[j] = field_mul(field_add(t, f[j + 128]), INVERSE_NTT_SCALE);
490        f[j + 128] = field_mul_sub(INVERSE_NTT_FINAL_ZETA, f[j + 128], t);
491    }
492}
493
494#[rustfmt::skip]
495const GAMMAS: [u16; 128] = [
496    17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229,
497    1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279,
498    1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
499    939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688,
500    1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684,
501    1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
502    2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443,
503    1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
504];
505
506/// Fuses the four multiplication terms in an ML-KEM-1024 NTT dot product. The
507/// repeated lane blocks are intentionally unrolled so each coefficient pair
508/// loads `acc` and `gamma` once, accumulates all four products lazily, and
509/// reduces only once per output coefficient.
510#[allow(clippy::too_many_arguments)]
511fn ntt_mul_add4(
512    acc: &mut RingElement,
513    a0: &RingElement,
514    b0: &RingElement,
515    a1: &RingElement,
516    b1: &RingElement,
517    a2: &RingElement,
518    b2: &RingElement,
519    a3: &RingElement,
520    b3: &RingElement,
521) {
522    let mut i = 0usize;
523    while i < N {
524        let gamma = GAMMAS[i / 2] as u32;
525
526        let (a00, a01) = (a0[i], a0[i + 1]);
527        let (b00, b01) = (b0[i], b0[i + 1]);
528        let mut acc0 = acc[i] as u32;
529        acc0 += (a00 as u32) * (b00 as u32) + (field_mul(a01, b01) as u32) * gamma;
530        let mut acc1 = acc[i + 1] as u32;
531        acc1 += (a00 as u32) * (b01 as u32) + (a01 as u32) * (b00 as u32);
532
533        let (a10, a11) = (a1[i], a1[i + 1]);
534        let (b10, b11) = (b1[i], b1[i + 1]);
535        acc0 += (a10 as u32) * (b10 as u32) + (field_mul(a11, b11) as u32) * gamma;
536        acc1 += (a10 as u32) * (b11 as u32) + (a11 as u32) * (b10 as u32);
537
538        let (a20, a21) = (a2[i], a2[i + 1]);
539        let (b20, b21) = (b2[i], b2[i + 1]);
540        acc0 += (a20 as u32) * (b20 as u32) + (field_mul(a21, b21) as u32) * gamma;
541        acc1 += (a20 as u32) * (b21 as u32) + (a21 as u32) * (b20 as u32);
542
543        let (a30, a31) = (a3[i], a3[i + 1]);
544        let (b30, b31) = (b3[i], b3[i + 1]);
545        acc0 += (a30 as u32) * (b30 as u32) + (field_mul(a31, b31) as u32) * gamma;
546        acc1 += (a30 as u32) * (b31 as u32) + (a31 as u32) * (b30 as u32);
547
548        acc[i] = field_reduce_wide(acc0);
549        acc[i + 1] = field_reduce_wide(acc1);
550        i += 2;
551    }
552}
553
554fn poly_add_assign(a: &mut RingElement, b: &RingElement) {
555    for i in 0..N {
556        a[i] = field_add(a[i], b[i]);
557    }
558}
559
560fn poly_sub_assign(a: &mut RingElement, b: &RingElement) {
561    for i in 0..N {
562        a[i] = field_sub(a[i], b[i]);
563    }
564}
565
566// ---------------------------------------------------------------------------
567// Hash helpers (FIPS 202)
568// ---------------------------------------------------------------------------
569
570fn sha3_512(input: &[u8]) -> [u8; 64] {
571    let out = Sha3_512::digest(input);
572    let mut r = [0u8; 64];
573    r.copy_from_slice(&out);
574    r
575}
576
577fn sha3_256(input: &[u8]) -> [u8; 32] {
578    let out = Sha3_256::digest(input);
579    let mut r = [0u8; 32];
580    r.copy_from_slice(&out);
581    r
582}
583
584fn shake256(inputs: &[&[u8]], out: &mut [u8]) {
585    let mut h = Shake256::default();
586    for input in inputs {
587        h.update(input);
588    }
589    let mut reader = h.finalize_xof();
590    reader.read(out);
591}
592
593// ---------------------------------------------------------------------------
594// Constant-time helpers
595// ---------------------------------------------------------------------------
596
597/// Returns `0xFF` if the slices are equal, `0x00` otherwise, in time
598/// independent of where the first difference (if any) occurs.
599fn ct_eq_mask(a: &[u8], b: &[u8]) -> u8 {
600    debug_assert_eq!(a.len(), b.len());
601    let mut diff = 0u8;
602    for (x, y) in a.iter().zip(b.iter()) {
603        diff |= x ^ y;
604    }
605    // diff != 0  ->  bit 31 set  ->  nonzero == 1
606    let nonzero = (diff as u32 | (diff as u32).wrapping_neg()) >> 31;
607    ((nonzero ^ 1) as u8).wrapping_neg()
608}
609
610/// `dst = src` where `mask == 0xFF`, leaves `dst` unchanged where `mask ==
611/// 0x00`. Data-independent.
612fn ct_select(mask: u8, dst: &mut [u8], src: &[u8]) {
613    for (d, s) in dst.iter_mut().zip(src.iter()) {
614        *d = (*d & !mask) | (*s & mask);
615    }
616}
617
618// ---------------------------------------------------------------------------
619// Key material
620// ---------------------------------------------------------------------------
621
622#[derive(Clone)]
623struct EncryptionKey {
624    t: [RingElement; K],     // public key vector in NTT domain
625    a: [RingElement; K * K], // public matrix A in NTT domain
626    rho: [u8; 32],           // matrix seed
627    encoded: [u8; MLKEM1024_ENCAPSULATION_KEY_SIZE], // encoded t || rho
628}
629
630impl EncryptionKey {
631    fn zeroed() -> Self {
632        Self {
633            t: [new_ring(); K],
634            a: [new_ring(); K * K],
635            rho: [0u8; 32],
636            encoded: [0u8; MLKEM1024_ENCAPSULATION_KEY_SIZE],
637        }
638    }
639}
640
641#[derive(Clone)]
642struct DecryptionKey {
643    s: [RingElement; K], // secret key vector in NTT domain
644}
645
646/// An ML-KEM-1024 decapsulation (private) key.
647///
648/// Holds secret material — the `d`/`z` seeds and the secret vector `s` — which
649/// is zeroized on drop. Deliberately not [`Clone`]: copy the seed via
650/// [`DecapsulationKey::bytes`] and re-derive instead.
651pub struct DecapsulationKey {
652    d: [u8; 32], // decapsulation key seed
653    z: [u8; 32], // implicit-rejection seed
654    h: [u8; 32], // H(ek)
655    encryption_key: EncryptionKey,
656    decryption_key: DecryptionKey,
657}
658
659/// An ML-KEM-1024 encapsulation (public) key.
660#[derive(Clone)]
661pub struct EncapsulationKey {
662    h: [u8; 32], // H(ek)
663    encryption_key: EncryptionKey,
664}
665
666impl DecapsulationKey {
667    /// Generates a fresh decapsulation key from system randomness.
668    pub fn generate() -> Result<Self> {
669        let mut d = [0u8; 32];
670        let mut z = [0u8; 32];
671        getrandom::getrandom(&mut d)?;
672        getrandom::getrandom(&mut z)?;
673        let key = Self::from_d_z(&d, &z);
674        d.zeroize();
675        z.zeroize();
676        Ok(key)
677    }
678
679    /// Deterministically derives a decapsulation key from a
680    /// [`MLKEM1024_SEED_SIZE`]-byte seed in `d || z` form.
681    pub fn from_seed(seed: &[u8]) -> Result<Self> {
682        if seed.len() != MLKEM1024_SEED_SIZE {
683            return Err(QrllibError::InvalidMlKemSeedSize(seed.len(), MLKEM1024_SEED_SIZE));
684        }
685        let mut d = [0u8; 32];
686        let mut z = [0u8; 32];
687        d.copy_from_slice(&seed[..32]);
688        z.copy_from_slice(&seed[32..]);
689        let key = Self::from_d_z(&d, &z);
690        d.zeroize();
691        z.zeroize();
692        Ok(key)
693    }
694
695    fn from_d_z(d: &[u8; 32], z: &[u8; 32]) -> Self {
696        let mut dk = DecapsulationKey {
697            d: *d,
698            z: *z,
699            h: [0u8; 32],
700            encryption_key: EncryptionKey::zeroed(),
701            decryption_key: DecryptionKey { s: [new_ring(); K] },
702        };
703        pke_key_gen(&mut dk, d);
704        dk.h = sha3_256(&dk.encryption_key.encoded);
705        dk
706    }
707
708    /// Recovers the shared secret from an ML-KEM-1024 ciphertext. Implements
709    /// the Fujisaki-Okamoto re-encryption check with constant-time implicit
710    /// rejection: a malformed ciphertext yields a pseudo-random key derived
711    /// from `z`, never an error.
712    pub fn decapsulate(
713        &self,
714        ciphertext: &[u8],
715    ) -> Result<Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>> {
716        if ciphertext.len() != MLKEM1024_CIPHERTEXT_SIZE {
717            return Err(QrllibError::InvalidMlKemCiphertextSize(
718                ciphertext.len(),
719                MLKEM1024_CIPHERTEXT_SIZE,
720            ));
721        }
722        let mut ct = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
723        ct.copy_from_slice(ciphertext);
724        Ok(decapsulate(self, &ct))
725    }
726
727    /// Returns the corresponding encapsulation (public) key.
728    pub fn encapsulation_key(&self) -> EncapsulationKey {
729        EncapsulationKey { h: self.h, encryption_key: self.encryption_key.clone() }
730    }
731
732    /// Returns the decapsulation key seed in `d || z` form.
733    pub fn bytes(&self) -> Zeroizing<[u8; MLKEM1024_SEED_SIZE]> {
734        let mut b = [0u8; MLKEM1024_SEED_SIZE];
735        b[..32].copy_from_slice(&self.d);
736        b[32..].copy_from_slice(&self.z);
737        Zeroizing::new(b)
738    }
739
740    /// Overwrites the secret material — the `d`/`z` seeds and secret vector
741    /// `s` — with zeros. Best-effort under Rust's memory model. Non-secret
742    /// fields (the encapsulation key, matrix seed, and `H(ek)`) are left
743    /// intact.
744    pub fn zeroize(&mut self) {
745        self.d.zeroize();
746        self.z.zeroize();
747        for poly in &mut self.decryption_key.s {
748            poly.zeroize();
749        }
750    }
751}
752
753impl Drop for DecapsulationKey {
754    fn drop(&mut self) {
755        self.zeroize();
756    }
757}
758
759impl EncapsulationKey {
760    /// Constructs an encapsulation key from its
761    /// [`MLKEM1024_ENCAPSULATION_KEY_SIZE`]-byte encoded form. Rejects
762    /// non-canonical encodings (any coefficient `>= q`).
763    pub fn from_bytes(ek_bytes: &[u8]) -> Result<Self> {
764        if ek_bytes.len() != MLKEM1024_ENCAPSULATION_KEY_SIZE {
765            return Err(QrllibError::InvalidMlKemEncapsulationKeySize(
766                ek_bytes.len(),
767                MLKEM1024_ENCAPSULATION_KEY_SIZE,
768            ));
769        }
770
771        let mut ek =
772            EncapsulationKey { h: sha3_256(ek_bytes), encryption_key: EncryptionKey::zeroed() };
773        ek.encryption_key.encoded.copy_from_slice(ek_bytes);
774
775        let mut offset = 0usize;
776        for i in 0..K {
777            byte_decode12(
778                &mut ek.encryption_key.t[i],
779                &ek_bytes[offset..offset + ENCODING_SIZE_12],
780            )?;
781            offset += ENCODING_SIZE_12;
782        }
783        ek.encryption_key.rho.copy_from_slice(&ek_bytes[offset..offset + 32]);
784
785        let rho = ek.encryption_key.rho;
786        for i in 0..K {
787            for j in 0..K {
788                sample_ntt(&mut ek.encryption_key.a[i * K + j], &rho, j as u8, i as u8);
789            }
790        }
791
792        Ok(ek)
793    }
794
795    /// Produces a fresh `(shared_key, ciphertext)` pair using system
796    /// randomness for the encapsulated message.
797    pub fn encapsulate(
798        &self,
799    ) -> Result<(Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE])> {
800        let mut m = [0u8; MESSAGE_SIZE];
801        getrandom::getrandom(&mut m)?;
802        let (shared_key, ciphertext) = encapsulate_to(&self.encryption_key, &self.h, &m);
803        m.zeroize();
804        Ok((shared_key, ciphertext))
805    }
806
807    /// Derandomized counterpart to [`EncapsulationKey::encapsulate`] that takes
808    /// the 32-byte message `m` explicitly.
809    ///
810    /// **Test-only.** Reusing `m` across encapsulations breaks IND-CCA
811    /// security; this exists solely to reproduce ACVP / Wycheproof /
812    /// cross-verification vectors. Production code must call
813    /// [`EncapsulationKey::encapsulate`].
814    pub fn encapsulate_deterministic(
815        &self,
816        m: &[u8; MESSAGE_SIZE],
817    ) -> (Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE]) {
818        encapsulate_to(&self.encryption_key, &self.h, m)
819    }
820
821    /// Returns the encoded form of the encapsulation key.
822    pub fn bytes(&self) -> [u8; MLKEM1024_ENCAPSULATION_KEY_SIZE] {
823        self.encryption_key.encoded
824    }
825}
826
827// ---------------------------------------------------------------------------
828// K-PKE (FIPS 203, Section 5): the IND-CPA-secure PKE that ML-KEM wraps with
829// the FO transform.
830// ---------------------------------------------------------------------------
831
832fn pke_key_gen(dk: &mut DecapsulationKey, d: &[u8; 32]) {
833    let mut g_input = [0u8; 33];
834    g_input[..32].copy_from_slice(d);
835    g_input[32] = K as u8;
836    let mut g = sha3_512(&g_input);
837    let mut rho = [0u8; 32];
838    let mut sigma = [0u8; 32];
839    rho.copy_from_slice(&g[..32]);
840    sigma.copy_from_slice(&g[32..]);
841
842    dk.encryption_key.rho = rho;
843    dk.encryption_key.encoded[K * ENCODING_SIZE_12..].copy_from_slice(&rho);
844
845    for i in 0..K {
846        for j in 0..K {
847            sample_ntt(&mut dk.encryption_key.a[i * K + j], &rho, j as u8, i as u8);
848        }
849    }
850
851    let mut counter = 0u8;
852    for i in 0..K {
853        sample_poly_cbd(&mut dk.decryption_key.s[i], &sigma, counter);
854        ntt(&mut dk.decryption_key.s[i]);
855        counter += 1;
856    }
857
858    for i in 0..K {
859        let mut acc = new_ring();
860        ntt_mul_add4(
861            &mut acc,
862            &dk.encryption_key.a[i * K],
863            &dk.decryption_key.s[0],
864            &dk.encryption_key.a[i * K + 1],
865            &dk.decryption_key.s[1],
866            &dk.encryption_key.a[i * K + 2],
867            &dk.decryption_key.s[2],
868            &dk.encryption_key.a[i * K + 3],
869            &dk.decryption_key.s[3],
870        );
871
872        let mut e = new_ring();
873        sample_poly_cbd(&mut e, &sigma, counter);
874        ntt(&mut e);
875        counter += 1;
876        poly_add_assign(&mut acc, &e);
877        e.zeroize(); // noise secret; no longer needed
878
879        dk.encryption_key.t[i] = acc;
880        byte_encode12(
881            &mut dk.encryption_key.encoded[i * ENCODING_SIZE_12..(i + 1) * ENCODING_SIZE_12],
882            &dk.encryption_key.t[i],
883        );
884    }
885
886    // Wipe key-generation secrets: g_input holds the seed d, g/sigma hold the
887    // CBD sampling seed.
888    g_input.zeroize();
889    g.zeroize();
890    sigma.zeroize();
891    rho.zeroize();
892}
893
894fn pke_encrypt(
895    dst: &mut [u8; MLKEM1024_CIPHERTEXT_SIZE],
896    ek: &EncryptionKey,
897    m: &[u8; MESSAGE_SIZE],
898    r: &[u8; 32],
899) {
900    let mut counter = 0u8;
901    let mut y = [new_ring(); K];
902    for poly in &mut y {
903        sample_poly_cbd(poly, r, counter);
904        ntt(poly);
905        counter += 1;
906    }
907
908    let mut off = 0usize;
909    for i in 0..K {
910        let mut acc = new_ring();
911        // ek.a is stored row-major as A[row*K + column]. K-PKE.Encrypt needs
912        // A^T * y, so this walks one column of A for each output polynomial.
913        ntt_mul_add4(
914            &mut acc,
915            &ek.a[i],
916            &y[0],
917            &ek.a[K + i],
918            &y[1],
919            &ek.a[2 * K + i],
920            &y[2],
921            &ek.a[3 * K + i],
922            &y[3],
923        );
924        inverse_ntt(&mut acc);
925
926        let mut e1 = new_ring();
927        sample_poly_cbd(&mut e1, r, counter);
928        counter += 1;
929        poly_add_assign(&mut acc, &e1);
930        e1.zeroize(); // noise secret; acc (= u_i) is public ciphertext
931
932        ring_compress_and_encode11(&mut dst[off..off + ENCODING_SIZE_11], &acc);
933        off += ENCODING_SIZE_11;
934    }
935
936    let mut e2 = new_ring();
937    sample_poly_cbd(&mut e2, r, counter);
938
939    let mut mu = new_ring();
940    ring_decode_and_decompress1(&mut mu, m);
941
942    let mut v = new_ring();
943    ntt_mul_add4(&mut v, &ek.t[0], &y[0], &ek.t[1], &y[1], &ek.t[2], &y[2], &ek.t[3], &y[3]);
944    inverse_ntt(&mut v);
945    poly_add_assign(&mut v, &e2);
946    poly_add_assign(&mut v, &mu);
947
948    ring_compress_and_encode5(&mut dst[off..off + ENCODING_SIZE_5], &v);
949
950    // Wipe encryption secrets: y is the encryption randomness vector, e2/mu
951    // derive from the message randomness, and full-precision v carries mu
952    // before compression rounding. The u_i accumulators are not wiped — they
953    // are the public ciphertext content.
954    for poly in &mut y {
955        poly.zeroize();
956    }
957    e2.zeroize();
958    mu.zeroize();
959    v.zeroize();
960}
961
962fn pke_decrypt(
963    dst: &mut [u8; MESSAGE_SIZE],
964    dk: &DecapsulationKey,
965    c: &[u8; MLKEM1024_CIPHERTEXT_SIZE],
966) {
967    let mut u = [new_ring(); K];
968    let mut off = 0usize;
969    for poly in &mut u {
970        ring_decode_and_decompress11(poly, &c[off..off + ENCODING_SIZE_11]);
971        off += ENCODING_SIZE_11;
972        ntt(poly);
973    }
974
975    let mut v = new_ring();
976    ring_decode_and_decompress5(&mut v, &c[off..off + ENCODING_SIZE_5]);
977
978    let s = &dk.decryption_key.s;
979    let mut acc = new_ring();
980    ntt_mul_add4(&mut acc, &s[0], &u[0], &s[1], &u[1], &s[2], &u[2], &s[3], &u[3]);
981    inverse_ntt(&mut acc);
982
983    poly_sub_assign(&mut v, &acc);
984    ring_compress_and_encode1(dst, &v);
985
986    // Wipe decryption secrets: acc is s^T·u (secret-key-dependent) and v holds
987    // the noisy plaintext polynomial after the subtraction. The decoded u is
988    // public ciphertext content and is left as is. dst (the decrypted message)
989    // is wiped by decapsulate after the FO re-encryption check.
990    acc.zeroize();
991    v.zeroize();
992}
993
994// ---------------------------------------------------------------------------
995// ML-KEM FO transform
996// ---------------------------------------------------------------------------
997
998fn encapsulate_to(
999    ek: &EncryptionKey,
1000    ek_h: &[u8; 32],
1001    m: &[u8; MESSAGE_SIZE],
1002) -> (Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE]) {
1003    let mut g_input = [0u8; MESSAGE_SIZE + 32];
1004    g_input[..MESSAGE_SIZE].copy_from_slice(m);
1005    g_input[MESSAGE_SIZE..].copy_from_slice(ek_h);
1006    let mut g = sha3_512(&g_input);
1007
1008    let mut shared_key = [0u8; MLKEM1024_SHARED_KEY_SIZE];
1009    shared_key.copy_from_slice(&g[..MLKEM1024_SHARED_KEY_SIZE]);
1010    let mut r = [0u8; 32];
1011    r.copy_from_slice(&g[MLKEM1024_SHARED_KEY_SIZE..]);
1012
1013    let mut ciphertext = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
1014    pke_encrypt(&mut ciphertext, ek, m, &r);
1015
1016    // Wipe transient secret material derived from the message randomness. m is
1017    // owned by the caller and left intact.
1018    g_input.zeroize();
1019    g.zeroize();
1020    r.zeroize();
1021
1022    (Zeroizing::new(shared_key), ciphertext)
1023}
1024
1025fn decapsulate(
1026    dk: &DecapsulationKey,
1027    ct: &[u8; MLKEM1024_CIPHERTEXT_SIZE],
1028) -> Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]> {
1029    let mut m = [0u8; MESSAGE_SIZE];
1030    pke_decrypt(&mut m, dk, ct);
1031
1032    let mut g_input = [0u8; MESSAGE_SIZE + 32];
1033    g_input[..MESSAGE_SIZE].copy_from_slice(&m);
1034    g_input[MESSAGE_SIZE..].copy_from_slice(&dk.h);
1035    let mut g = sha3_512(&g_input);
1036    let mut r = [0u8; 32];
1037    r.copy_from_slice(&g[MLKEM1024_SHARED_KEY_SIZE..]);
1038
1039    // Implicit-rejection key J(z || ct): the default output for a ciphertext
1040    // that fails the re-encryption check.
1041    let mut k_out = [0u8; MLKEM1024_SHARED_KEY_SIZE];
1042    shake256(&[&dk.z, ct.as_slice()], &mut k_out);
1043
1044    let mut c = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
1045    pke_encrypt(&mut c, &dk.encryption_key, &m, &r);
1046
1047    // If the re-encryption matches, replace the implicit-rejection key with the
1048    // real shared key G(m || H(ek))[:32]. Constant-time; data-independent wipes
1049    // below add no timing side channel.
1050    let matches = ct_eq_mask(ct.as_slice(), &c);
1051    ct_select(matches, &mut k_out, &g[..MLKEM1024_SHARED_KEY_SIZE]);
1052
1053    m.zeroize();
1054    g_input.zeroize();
1055    g.zeroize();
1056    r.zeroize();
1057
1058    Zeroizing::new(k_out)
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064
1065    fn seed(byte: u8) -> [u8; MLKEM1024_SEED_SIZE] {
1066        [byte; MLKEM1024_SEED_SIZE]
1067    }
1068
1069    #[test]
1070    fn sizes_match_fips_203_ml_kem_1024() {
1071        assert_eq!(MLKEM1024_SEED_SIZE, 64);
1072        assert_eq!(MLKEM1024_SHARED_KEY_SIZE, 32);
1073        assert_eq!(MLKEM1024_CIPHERTEXT_SIZE, 1568);
1074        assert_eq!(MLKEM1024_ENCAPSULATION_KEY_SIZE, 1568);
1075    }
1076
1077    #[test]
1078    fn encapsulate_then_decapsulate_recovers_shared_secret() {
1079        let dk = DecapsulationKey::from_seed(&seed(0x42)).expect("decap key");
1080        let ek = dk.encapsulation_key();
1081        let (shared_a, ciphertext) = ek.encapsulate().expect("encapsulate");
1082        let shared_b = dk.decapsulate(&ciphertext).expect("decapsulate");
1083        assert_eq!(*shared_a, *shared_b);
1084        assert_eq!(ciphertext.len(), MLKEM1024_CIPHERTEXT_SIZE);
1085    }
1086
1087    #[test]
1088    fn generated_decapsulation_key_round_trips() {
1089        // Exercises the system-randomness `generate()` path (as opposed to the
1090        // deterministic `from_seed`): the resulting keypair must complete a full
1091        // encapsulate -> decapsulate round-trip with matching shared secrets.
1092        let dk = DecapsulationKey::generate().expect("generated decap key");
1093        let ek = dk.encapsulation_key();
1094        let (shared_a, ciphertext) = ek.encapsulate().expect("encapsulate");
1095        let shared_b = dk.decapsulate(&ciphertext).expect("decapsulate");
1096        assert_eq!(*shared_a, *shared_b);
1097    }
1098
1099    #[test]
1100    fn encapsulation_key_round_trips_through_bytes() {
1101        let dk = DecapsulationKey::from_seed(&seed(7)).expect("decap key");
1102        let ek = dk.encapsulation_key();
1103        let ek_bytes = ek.bytes();
1104        assert_eq!(ek_bytes.len(), MLKEM1024_ENCAPSULATION_KEY_SIZE);
1105
1106        let restored = EncapsulationKey::from_bytes(&ek_bytes).expect("restore ek");
1107        let (shared, ciphertext) = restored.encapsulate().expect("encapsulate");
1108        assert_eq!(*dk.decapsulate(&ciphertext).expect("decapsulate"), *shared);
1109    }
1110
1111    #[test]
1112    fn from_seed_is_deterministic_and_round_trips() {
1113        let dk1 = DecapsulationKey::from_seed(&seed(0x11)).expect("decap key");
1114        let dk2 = DecapsulationKey::from_seed(&seed(0x11)).expect("decap key");
1115        assert_eq!(dk1.encapsulation_key().bytes(), dk2.encapsulation_key().bytes());
1116        assert_eq!(*dk1.bytes(), *dk2.bytes());
1117
1118        // A different seed yields a different public key.
1119        let dk3 = DecapsulationKey::from_seed(&seed(0x12)).expect("decap key");
1120        assert_ne!(dk1.encapsulation_key().bytes(), dk3.encapsulation_key().bytes());
1121    }
1122
1123    #[test]
1124    fn deterministic_encapsulation_is_reproducible() {
1125        let dk = DecapsulationKey::from_seed(&seed(0x99)).expect("decap key");
1126        let ek = dk.encapsulation_key();
1127        let m = [0x5a_u8; MESSAGE_SIZE];
1128        let (shared_a, ct_a) = ek.encapsulate_deterministic(&m);
1129        let (shared_b, ct_b) = ek.encapsulate_deterministic(&m);
1130        assert_eq!(*shared_a, *shared_b);
1131        assert_eq!(ct_a, ct_b);
1132        assert_eq!(*dk.decapsulate(&ct_a).expect("decapsulate"), *shared_a);
1133    }
1134
1135    #[test]
1136    fn decapsulate_implicitly_rejects_malformed_ciphertext() {
1137        let dk = DecapsulationKey::from_seed(&seed(0x33)).expect("decap key");
1138        let ek = dk.encapsulation_key();
1139        let (_shared, mut ciphertext) = ek.encapsulate().expect("encapsulate");
1140
1141        // Flip a byte: decapsulation must return a pseudo-random key (derived
1142        // from z), not an error, and not the real shared secret.
1143        ciphertext[0] ^= 0xff;
1144        let rejected = dk.decapsulate(&ciphertext).expect("implicit rejection still succeeds");
1145        let valid = dk.decapsulate(&ek.encapsulate().expect("encapsulate").1).expect("decapsulate");
1146        assert_ne!(*rejected, *valid);
1147    }
1148
1149    #[test]
1150    fn wrong_length_inputs_are_rejected() {
1151        assert!(matches!(
1152            DecapsulationKey::from_seed(&[0u8; 32]),
1153            Err(QrllibError::InvalidMlKemSeedSize(32, 64))
1154        ));
1155        let dk = DecapsulationKey::from_seed(&seed(1)).expect("decap key");
1156        assert!(matches!(
1157            dk.decapsulate(&[0u8; 10]),
1158            Err(QrllibError::InvalidMlKemCiphertextSize(10, 1568))
1159        ));
1160        assert!(matches!(
1161            EncapsulationKey::from_bytes(&[0u8; 100]),
1162            Err(QrllibError::InvalidMlKemEncapsulationKeySize(100, 1568))
1163        ));
1164    }
1165
1166    #[test]
1167    fn from_bytes_rejects_non_canonical_encoding() {
1168        let dk = DecapsulationKey::from_seed(&seed(0x55)).expect("decap key");
1169        let mut ek_bytes = dk.encapsulation_key().bytes();
1170        // Force the first encoded coefficient (12 bits spanning bytes[0..2]) to
1171        // 0xfff = 4095, which is >= q, so byte_decode12 must reject the key
1172        // rather than accept a non-canonical encoding.
1173        ek_bytes[0] = 0xff;
1174        ek_bytes[1] |= 0x0f;
1175        assert!(matches!(
1176            EncapsulationKey::from_bytes(&ek_bytes),
1177            Err(QrllibError::InvalidMlKemEncoding)
1178        ));
1179    }
1180
1181    #[test]
1182    fn decapsulation_key_zeroize_clears_secret_material() {
1183        let mut dk = DecapsulationKey::from_seed(&seed(0x77)).expect("decap key");
1184        dk.zeroize();
1185        assert!(dk.d.iter().all(|b| *b == 0));
1186        assert!(dk.z.iter().all(|b| *b == 0));
1187        assert!(dk.decryption_key.s.iter().all(|poly| poly.iter().all(|c| *c == 0)));
1188    }
1189}
1190
1191/// NIST ACVP known-answer tests for ML-KEM-1024 (FIPS 203), ported from
1192/// go-qrllib's `crypto/internal/mlkem1024/acvp_test.go`.
1193///
1194/// Lives in-module because the ACVP `decapsulation` and `decapsulationKeyCheck`
1195/// functions operate on the **expanded** decapsulation-key encoding
1196/// (`dkPKE || ek || H(ek) || z`, 3168 bytes), which requires private-field
1197/// access that the public API deliberately does not expose (parity with
1198/// go-qrllib, which keeps the seed `d || z` as the canonical key bytes).
1199///
1200/// Vectors are **not** vendored. Point `MLKEM_ACVP_VECTORS_DIR` at a directory
1201/// containing the decompressed NIST ACVP suites
1202/// `ML-KEM-keyGen-FIPS203/{prompt,expectedResults}.json` and
1203/// `ML-KEM-encapDecap-FIPS203/{prompt,expectedResults}.json`. When the variable
1204/// is unset the tests log a skip and pass, so day-to-day `cargo test` does not
1205/// require the vectors. See `.github/acvp/README.md`.
1206#[cfg(test)]
1207mod acvp {
1208    use super::*;
1209    use serde::Deserialize;
1210    use std::{
1211        env, fs,
1212        path::{Path, PathBuf},
1213    };
1214
1215    #[derive(Deserialize)]
1216    struct PromptFile {
1217        #[serde(rename = "testGroups")]
1218        test_groups: Vec<PromptGroup>,
1219    }
1220
1221    #[derive(Deserialize)]
1222    struct PromptGroup {
1223        #[serde(rename = "tgId")]
1224        tg_id: u32,
1225        #[serde(rename = "parameterSet")]
1226        parameter_set: String,
1227        #[serde(default)]
1228        function: String,
1229        tests: Vec<PromptTest>,
1230    }
1231
1232    #[derive(Deserialize)]
1233    struct PromptTest {
1234        #[serde(rename = "tcId")]
1235        tc_id: u32,
1236        #[serde(default)]
1237        d: String,
1238        #[serde(default)]
1239        z: String,
1240        #[serde(default)]
1241        ek: String,
1242        #[serde(default)]
1243        dk: String,
1244        #[serde(default)]
1245        m: String,
1246        #[serde(default)]
1247        c: String,
1248    }
1249
1250    #[derive(Deserialize)]
1251    struct ExpectedFile {
1252        #[serde(rename = "testGroups")]
1253        test_groups: Vec<ExpectedGroup>,
1254    }
1255
1256    #[derive(Deserialize)]
1257    struct ExpectedGroup {
1258        #[serde(rename = "tgId")]
1259        tg_id: u32,
1260        tests: Vec<ExpectedTest>,
1261    }
1262
1263    #[derive(Deserialize)]
1264    struct ExpectedTest {
1265        #[serde(rename = "tcId")]
1266        tc_id: u32,
1267        #[serde(default)]
1268        ek: String,
1269        #[serde(default)]
1270        dk: String,
1271        #[serde(default)]
1272        c: String,
1273        #[serde(default)]
1274        k: String,
1275        #[serde(rename = "testPassed", default)]
1276        test_passed: bool,
1277    }
1278
1279    fn vectors_dir() -> Option<PathBuf> {
1280        env::var_os("MLKEM_ACVP_VECTORS_DIR").map(PathBuf::from)
1281    }
1282
1283    fn load<T: serde::de::DeserializeOwned>(dir: &Path, suite: &str, name: &str) -> T {
1284        let path = dir.join(suite).join(name);
1285        let data =
1286            fs::read_to_string(&path).unwrap_or_else(|e| panic!("read {}: {}", path.display(), e));
1287        serde_json::from_str(&data).unwrap_or_else(|e| panic!("parse {}: {}", path.display(), e))
1288    }
1289
1290    fn decode(value: &str) -> Vec<u8> {
1291        hex::decode(value).expect("ACVP hex")
1292    }
1293
1294    fn expected_test(expected: &ExpectedFile, tg_id: u32, tc_id: u32) -> &ExpectedTest {
1295        expected
1296            .test_groups
1297            .iter()
1298            .find(|g| g.tg_id == tg_id)
1299            .unwrap_or_else(|| panic!("missing expected group {tg_id}"))
1300            .tests
1301            .iter()
1302            .find(|t| t.tc_id == tc_id)
1303            .unwrap_or_else(|| panic!("missing expected test {tc_id} in group {tg_id}"))
1304    }
1305
1306    /// Serialise a decapsulation key into the FIPS 203 / ACVP expanded form
1307    /// `ByteEncode12(s) || ek || H(ek) || z`.
1308    fn to_expanded(dk: &DecapsulationKey) -> Vec<u8> {
1309        let mut out =
1310            Vec::with_capacity(K * ENCODING_SIZE_12 + MLKEM1024_ENCAPSULATION_KEY_SIZE + 64);
1311        let mut encoded = [0u8; ENCODING_SIZE_12];
1312        for poly in &dk.decryption_key.s {
1313            byte_encode12(&mut encoded, poly);
1314            out.extend_from_slice(&encoded);
1315        }
1316        out.extend_from_slice(&dk.encryption_key.encoded);
1317        out.extend_from_slice(&dk.h);
1318        out.extend_from_slice(&dk.z);
1319        out
1320    }
1321
1322    /// Reconstruct a decapsulation key from the ACVP expanded form, validating
1323    /// the secret vector, embedded encapsulation key, and `H(ek)` consistency —
1324    /// the predicate the `decapsulationKeyCheck` function asserts.
1325    fn from_expanded(b: &[u8]) -> Result<DecapsulationKey> {
1326        const EXPANDED: usize = K * ENCODING_SIZE_12 + MLKEM1024_ENCAPSULATION_KEY_SIZE + 64;
1327        // Coverage: every ACVP `decapsulationKeyCheck` vector is exactly EXPANDED
1328        // bytes, so this outer length guard never trips; the check vectors instead
1329        // exercise the content-validation branches below.
1330        if b.len() != EXPANDED {
1331            //coverage:ignore reason=defensively-unreachable
1332            return Err(QrllibError::InvalidMlKemEncoding);
1333        }
1334        let mut s = [new_ring(); K];
1335        let mut off = 0usize;
1336        for poly in &mut s {
1337            byte_decode12(poly, &b[off..off + ENCODING_SIZE_12])?;
1338            off += ENCODING_SIZE_12;
1339        }
1340        let ek = EncapsulationKey::from_bytes(&b[off..off + MLKEM1024_ENCAPSULATION_KEY_SIZE])?;
1341        off += MLKEM1024_ENCAPSULATION_KEY_SIZE;
1342        if ek.h[..] != b[off..off + 32] {
1343            return Err(QrllibError::InvalidMlKemEncoding);
1344        }
1345        off += 32;
1346        let mut z = [0u8; 32];
1347        z.copy_from_slice(&b[off..off + 32]);
1348        // `d` is unused by decapsulation (which consumes s, h, z, and the
1349        // encryption key), so a zero placeholder is sufficient here.
1350        Ok(DecapsulationKey {
1351            d: [0u8; 32],
1352            z,
1353            h: ek.h,
1354            encryption_key: ek.encryption_key,
1355            decryption_key: DecryptionKey { s },
1356        })
1357    }
1358
1359    #[test]
1360    fn acvp_keygen_matches_nist_vectors() {
1361        // Coverage: the coverage run always sets MLKEM_ACVP_VECTORS_DIR, so the
1362        // skip arm (for environments without the vendored NIST vectors) is never
1363        // taken here.
1364        let Some(dir) = vectors_dir() else {
1365            //coverage:ignore start reason=defensively-unreachable
1366            eprintln!("MLKEM_ACVP_VECTORS_DIR not set; skipping ML-KEM ACVP keyGen test");
1367            return;
1368            //coverage:ignore end
1369        };
1370        let suite = "ML-KEM-keyGen-FIPS203";
1371        let prompt: PromptFile = load(&dir, suite, "prompt.json");
1372        let expected: ExpectedFile = load(&dir, suite, "expectedResults.json");
1373
1374        let mut tested = 0u32;
1375        for group in &prompt.test_groups {
1376            if group.parameter_set != "ML-KEM-1024" {
1377                continue;
1378            }
1379            for test in &group.tests {
1380                tested += 1;
1381                let want = expected_test(&expected, group.tg_id, test.tc_id);
1382                let mut seed = [0u8; MLKEM1024_SEED_SIZE];
1383                seed[..32].copy_from_slice(&decode(&test.d));
1384                seed[32..].copy_from_slice(&decode(&test.z));
1385                let dk = DecapsulationKey::from_seed(&seed).expect("decapsulation key");
1386                assert_eq!(
1387                    dk.encapsulation_key().bytes().as_slice(),
1388                    decode(&want.ek).as_slice(),
1389                    "tc{}: encapsulation key mismatch",
1390                    test.tc_id
1391                );
1392                assert_eq!(
1393                    to_expanded(&dk),
1394                    decode(&want.dk),
1395                    "tc{}: expanded decapsulation key mismatch",
1396                    test.tc_id
1397                );
1398            }
1399        }
1400        assert!(tested > 0, "no ML-KEM-1024 ACVP keyGen test cases");
1401        eprintln!("ACVP ML-KEM-1024 keyGen: {tested} cases passed");
1402    }
1403
1404    #[test]
1405    fn acvp_encap_decap_matches_nist_vectors() {
1406        // Coverage: the coverage run always sets MLKEM_ACVP_VECTORS_DIR, so the
1407        // skip arm (for environments without the vendored NIST vectors) is never
1408        // taken here.
1409        let Some(dir) = vectors_dir() else {
1410            //coverage:ignore start reason=defensively-unreachable
1411            eprintln!("MLKEM_ACVP_VECTORS_DIR not set; skipping ML-KEM ACVP encapDecap test");
1412            return;
1413            //coverage:ignore end
1414        };
1415        let suite = "ML-KEM-encapDecap-FIPS203";
1416        let prompt: PromptFile = load(&dir, suite, "prompt.json");
1417        let expected: ExpectedFile = load(&dir, suite, "expectedResults.json");
1418
1419        let (mut encap, mut decap, mut decap_check, mut encap_check) = (0u32, 0u32, 0u32, 0u32);
1420        for group in &prompt.test_groups {
1421            if group.parameter_set != "ML-KEM-1024" {
1422                continue;
1423            }
1424            for test in &group.tests {
1425                let want = expected_test(&expected, group.tg_id, test.tc_id);
1426                match group.function.as_str() {
1427                    "encapsulation" => {
1428                        let ek = EncapsulationKey::from_bytes(&decode(&test.ek))
1429                            .expect("encapsulation key");
1430                        let m: [u8; 32] = decode(&test.m).try_into().expect("32-byte m");
1431                        let (shared, ciphertext) = ek.encapsulate_deterministic(&m);
1432                        assert_eq!(ciphertext, decode(&want.c).as_slice(), "tc{}: ct", test.tc_id);
1433                        assert_eq!(*shared, decode(&want.k).as_slice(), "tc{}: K", test.tc_id);
1434                        encap += 1;
1435                    }
1436                    "decapsulation" => {
1437                        let dk = from_expanded(&decode(&test.dk)).expect("decapsulation key");
1438                        let shared = dk.decapsulate(&decode(&test.c)).expect("decapsulate");
1439                        assert_eq!(*shared, decode(&want.k).as_slice(), "tc{}: K", test.tc_id);
1440                        decap += 1;
1441                    }
1442                    "decapsulationKeyCheck" => {
1443                        let ok = from_expanded(&decode(&test.dk)).is_ok();
1444                        assert_eq!(ok, want.test_passed, "tc{}: dk check", test.tc_id);
1445                        decap_check += 1;
1446                    }
1447                    "encapsulationKeyCheck" => {
1448                        let ok = EncapsulationKey::from_bytes(&decode(&test.ek)).is_ok();
1449                        assert_eq!(ok, want.test_passed, "tc{}: ek check", test.tc_id);
1450                        encap_check += 1;
1451                    }
1452                    // Coverage: the NIST ML-KEM-1024 encapDecap suite only contains
1453                    // the four functions matched above; this guard fires solely if
1454                    // a future vector set introduces a new function tag.
1455                    //coverage:ignore reason=defensively-unreachable
1456                    other => panic!("unexpected ACVP function {other:?}"),
1457                }
1458            }
1459        }
1460        assert!(
1461            encap > 0 && decap > 0 && decap_check > 0 && encap_check > 0,
1462            "missing an ML-KEM-1024 encapDecap function (encap={encap} decap={decap} \
1463             decapCheck={decap_check} encapCheck={encap_check})"
1464        );
1465        eprintln!(
1466            "ACVP ML-KEM-1024 encapDecap: encap={encap} decap={decap} \
1467             decapKeyCheck={decap_check} encapKeyCheck={encap_check} passed"
1468        );
1469    }
1470}