1use crate::error::{QrllibError, Result};
32use sha3::{
33 Digest, Sha3_256, Sha3_512,
34 digest::{ExtendableOutput, Update, XofReader},
35};
36use shake::{Shake128, Shake256};
37use zeroize::{Zeroize, Zeroizing};
38
39pub const MLKEM1024_SEED_SIZE: usize = 64;
46
47pub const MLKEM1024_SHARED_KEY_SIZE: usize = 32;
49
50pub const MLKEM1024_CIPHERTEXT_SIZE: usize = K * ENCODING_SIZE_11 + ENCODING_SIZE_5;
52
53pub const MLKEM1024_ENCAPSULATION_KEY_SIZE: usize = K * ENCODING_SIZE_12 + 32;
55
56const N: usize = 256;
62const Q: u16 = 3329;
63
64const K: usize = 4;
66
67const ENCODING_SIZE_1: usize = N / 8;
69const ENCODING_SIZE_5: usize = N * 5 / 8;
70const ENCODING_SIZE_11: usize = N * 11 / 8;
71const ENCODING_SIZE_12: usize = N * 12 / 8;
72
73const MESSAGE_SIZE: usize = ENCODING_SIZE_1;
75
76const D5: u8 = 5;
77const D11: u8 = 11;
78
79const HALF_Q_ROUNDED_UP: u16 = Q.div_ceil(2); const SHAKE128_RATE: usize = 168;
81
82fn field_reduce_once(a: u16) -> u16 {
88 let x = a.wrapping_sub(Q);
89 x.wrapping_add(Q & (x >> 15).wrapping_neg())
91}
92
93fn field_add(a: u16, b: u16) -> u16 {
94 field_reduce_once(a.wrapping_add(b))
95}
96
97fn field_sub(a: u16, b: u16) -> u16 {
98 field_reduce_once(a.wrapping_sub(b).wrapping_add(Q))
99}
100
101const BARRETT_MULTIPLIER: u64 = 5039;
102const BARRETT_SHIFT: u32 = 24;
103const BARRETT_WIDE_MULTIPLIER: u64 = 1_290_167;
104const BARRETT_WIDE_SHIFT: u32 = 32;
105
106fn field_reduce(a: u32) -> u16 {
107 let quotient = ((a as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
108 field_reduce_once(a.wrapping_sub(quotient.wrapping_mul(Q as u32)) as u16)
109}
110
111fn field_reduce_wide(a: u32) -> u16 {
114 let quotient = ((a as u64 * BARRETT_WIDE_MULTIPLIER) >> BARRETT_WIDE_SHIFT) as u32;
115 field_reduce_once(a.wrapping_sub(quotient.wrapping_mul(Q as u32)) as u16)
116}
117
118fn field_mul(a: u16, b: u16) -> u16 {
119 field_reduce(a as u32 * b as u32)
120}
121
122fn field_mul_wide(a: u16, b: u16) -> u16 {
123 field_reduce_wide(a as u32 * b as u32)
124}
125
126fn field_mul_sub(a: u16, b: u16, c: u16) -> u16 {
127 let x = a as u32 * b.wrapping_sub(c).wrapping_add(Q) as u32;
128 field_reduce(x)
129}
130
131const COMPRESS1_LOWER: u32 = (Q as u32).div_ceil(4); const COMPRESS1_UPPER: u32 = (3 * Q as u32) / 4; fn compress1(x: u16) -> u8 {
135 let ux = x as u32;
136 let ge_lower = ((ux.wrapping_sub(COMPRESS1_LOWER)) >> 31) ^ 1;
137 let le_upper = ((COMPRESS1_UPPER.wrapping_sub(ux)) >> 31) ^ 1;
138 (ge_lower & le_upper) as u8
139}
140
141fn compress5(x: u16) -> u16 {
142 let dividend = (x as u32) << D5;
143 let mut quotient = ((dividend as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
144 let remainder = dividend.wrapping_sub(quotient.wrapping_mul(Q as u32));
145 quotient = quotient.wrapping_add(((Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
146 quotient = quotient.wrapping_add(((Q as u32 + Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
147 (quotient & 0x1f) as u16
148}
149
150fn compress11(x: u16) -> u16 {
151 let dividend = (x as u32) << D11;
152 let mut quotient = ((dividend as u64 * BARRETT_MULTIPLIER) >> BARRETT_SHIFT) as u32;
153 let remainder = dividend.wrapping_sub(quotient.wrapping_mul(Q as u32));
154 quotient = quotient.wrapping_add(((Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
155 quotient = quotient.wrapping_add(((Q as u32 + Q as u32 / 2).wrapping_sub(remainder) >> 31) & 1);
156 (quotient & 0x7ff) as u16
157}
158
159fn decompress(y: u16, d: u8) -> u16 {
160 let dividend = (y as u32) * (Q as u32);
161 let mut quotient = dividend >> d;
162 quotient += (dividend >> (d - 1)) & 1;
163 quotient as u16
164}
165
166type RingElement = [u16; N];
172
173fn new_ring() -> RingElement {
174 [0u16; N]
175}
176
177fn ring_decode_and_decompress1(dst: &mut RingElement, src: &[u8]) {
178 for (i, slot) in dst.iter_mut().enumerate() {
179 let b = (src[i / 8] >> (i % 8)) & 1;
182 *slot = (b as u16) * HALF_Q_ROUNDED_UP;
183 }
184}
185
186fn ring_decode_and_decompress5(dst: &mut RingElement, src: &[u8]) {
187 let mut i = 0usize;
188 let mut off = 0usize;
189 while i < N {
190 let b0 = src[off] as u16;
191 let b1 = src[off + 1] as u16;
192 let b2 = src[off + 2] as u16;
193 let b3 = src[off + 3] as u16;
194 let b4 = src[off + 4] as u16;
195
196 dst[i] = decompress(b0 & 0x1f, D5);
197 dst[i + 1] = decompress((b0 >> 5 | b1 << 3) & 0x1f, D5);
198 dst[i + 2] = decompress((b1 >> 2) & 0x1f, D5);
199 dst[i + 3] = decompress((b1 >> 7 | b2 << 1) & 0x1f, D5);
200 dst[i + 4] = decompress((b2 >> 4 | b3 << 4) & 0x1f, D5);
201 dst[i + 5] = decompress((b3 >> 1) & 0x1f, D5);
202 dst[i + 6] = decompress((b3 >> 6 | b4 << 2) & 0x1f, D5);
203 dst[i + 7] = decompress((b4 >> 3) & 0x1f, D5);
204
205 i += 8;
206 off += 5;
207 }
208}
209
210fn ring_decode_and_decompress11(dst: &mut RingElement, src: &[u8]) {
211 let mut i = 0usize;
212 let mut off = 0usize;
213 while i < N {
214 let b0 = src[off] as u32;
215 let b1 = src[off + 1] as u32;
216 let b2 = src[off + 2] as u32;
217 let b3 = src[off + 3] as u32;
218 let b4 = src[off + 4] as u32;
219 let b5 = src[off + 5] as u32;
220 let b6 = src[off + 6] as u32;
221 let b7 = src[off + 7] as u32;
222 let b8 = src[off + 8] as u32;
223 let b9 = src[off + 9] as u32;
224 let b10 = src[off + 10] as u32;
225
226 dst[i] = decompress(((b0 | b1 << 8) & 0x7ff) as u16, D11);
227 dst[i + 1] = decompress(((b1 >> 3 | b2 << 5) & 0x7ff) as u16, D11);
228 dst[i + 2] = decompress(((b2 >> 6 | b3 << 2 | b4 << 10) & 0x7ff) as u16, D11);
229 dst[i + 3] = decompress(((b4 >> 1 | b5 << 7) & 0x7ff) as u16, D11);
230 dst[i + 4] = decompress(((b5 >> 4 | b6 << 4) & 0x7ff) as u16, D11);
231 dst[i + 5] = decompress(((b6 >> 7 | b7 << 1 | b8 << 9) & 0x7ff) as u16, D11);
232 dst[i + 6] = decompress(((b8 >> 2 | b9 << 6) & 0x7ff) as u16, D11);
233 dst[i + 7] = decompress(((b9 >> 5 | b10 << 3) & 0x7ff) as u16, D11);
234
235 i += 8;
236 off += 11;
237 }
238}
239
240fn ring_compress_and_encode1(dst: &mut [u8], src: &RingElement) {
241 let mut i = 0usize;
242 let mut off = 0usize;
243 while i < N {
244 let c0 = compress1(src[i]);
245 let c1 = compress1(src[i + 1]);
246 let c2 = compress1(src[i + 2]);
247 let c3 = compress1(src[i + 3]);
248 let c4 = compress1(src[i + 4]);
249 let c5 = compress1(src[i + 5]);
250 let c6 = compress1(src[i + 6]);
251 let c7 = compress1(src[i + 7]);
252
253 dst[off] = c0 | c1 << 1 | c2 << 2 | c3 << 3 | c4 << 4 | c5 << 5 | c6 << 6 | c7 << 7;
254
255 i += 8;
256 off += 1;
257 }
258}
259
260fn ring_compress_and_encode5(dst: &mut [u8], src: &RingElement) {
261 let mut i = 0usize;
262 let mut off = 0usize;
263 while i < N {
264 let c0 = compress5(src[i]);
265 let c1 = compress5(src[i + 1]);
266 let c2 = compress5(src[i + 2]);
267 let c3 = compress5(src[i + 3]);
268 let c4 = compress5(src[i + 4]);
269 let c5 = compress5(src[i + 5]);
270 let c6 = compress5(src[i + 6]);
271 let c7 = compress5(src[i + 7]);
272
273 dst[off] = (c0 | c1 << 5) as u8;
274 dst[off + 1] = (c1 >> 3 | c2 << 2 | c3 << 7) as u8;
275 dst[off + 2] = (c3 >> 1 | c4 << 4) as u8;
276 dst[off + 3] = (c4 >> 4 | c5 << 1 | c6 << 6) as u8;
277 dst[off + 4] = (c6 >> 2 | c7 << 3) as u8;
278
279 i += 8;
280 off += 5;
281 }
282}
283
284fn ring_compress_and_encode11(dst: &mut [u8], src: &RingElement) {
285 let mut i = 0usize;
286 let mut off = 0usize;
287 while i < N {
288 let c0 = compress11(src[i]) as u32;
289 let c1 = compress11(src[i + 1]) as u32;
290 let c2 = compress11(src[i + 2]) as u32;
291 let c3 = compress11(src[i + 3]) as u32;
292 let c4 = compress11(src[i + 4]) as u32;
293 let c5 = compress11(src[i + 5]) as u32;
294 let c6 = compress11(src[i + 6]) as u32;
295 let c7 = compress11(src[i + 7]) as u32;
296
297 dst[off] = c0 as u8;
298 dst[off + 1] = (c0 >> 8 | c1 << 3) as u8;
299 dst[off + 2] = (c1 >> 5 | c2 << 6) as u8;
300 dst[off + 3] = (c2 >> 2) as u8;
301 dst[off + 4] = (c2 >> 10 | c3 << 1) as u8;
302 dst[off + 5] = (c3 >> 7 | c4 << 4) as u8;
303 dst[off + 6] = (c4 >> 4 | c5 << 7) as u8;
304 dst[off + 7] = (c5 >> 1) as u8;
305 dst[off + 8] = (c5 >> 9 | c6 << 2) as u8;
306 dst[off + 9] = (c6 >> 6 | c7 << 5) as u8;
307 dst[off + 10] = (c7 >> 3) as u8;
308
309 i += 8;
310 off += 11;
311 }
312}
313
314fn byte_encode12(dst: &mut [u8], p: &RingElement) {
315 let mut i = 0usize;
316 let mut off = 0usize;
317 while i < N {
318 let x = (p[i] as u32) | (p[i + 1] as u32) << 12;
319 dst[off] = x as u8;
320 dst[off + 1] = (x >> 8) as u8;
321 dst[off + 2] = (x >> 16) as u8;
322 i += 2;
323 off += 3;
324 }
325}
326
327fn byte_decode12(dst: &mut RingElement, src: &[u8]) -> Result<()> {
328 let mut i = 0usize;
329 let mut off = 0usize;
330 while i < N {
331 let x = (src[off] as u32) | (src[off + 1] as u32) << 8 | (src[off + 2] as u32) << 16;
332 let c0 = (x & 0x0fff) as u16;
333 let c1 = (x >> 12) as u16;
334 if c0 >= Q || c1 >= Q {
335 return Err(QrllibError::InvalidMlKemEncoding);
336 }
337 dst[i] = c0;
338 dst[i + 1] = c1;
339 i += 2;
340 off += 3;
341 }
342 Ok(())
343}
344
345fn sample_ntt(dst: &mut RingElement, rho: &[u8; 32], j_index: u8, i_index: u8) {
351 let mut ctx = Shake128::default();
352 ctx.update(rho);
353 ctx.update(&[j_index, i_index]);
354 let mut reader = ctx.finalize_xof();
355
356 let mut j = 0usize;
357 let mut buf = [0u8; SHAKE128_RATE];
358 let mut off = buf.len();
359
360 loop {
361 if off >= buf.len() {
362 reader.read(&mut buf);
363 off = 0;
364 }
365
366 let x0 = (buf[off] as u16) | (((buf[off + 1] & 0x0f) as u16) << 8);
367 let x1 = ((buf[off + 1] >> 4) as u16) | ((buf[off + 2] as u16) << 4);
368 off += 3;
369
370 if x0 < Q {
371 dst[j] = x0;
372 j += 1;
373 }
374 if j >= N {
375 break;
376 }
377 if x1 < Q {
378 dst[j] = x1;
379 j += 1;
380 }
381 if j >= N {
382 break;
383 }
384 }
385}
386
387fn sample_poly_cbd(dst: &mut RingElement, sigma: &[u8; 32], counter: u8) {
389 let mut prf = Shake256::default();
390 prf.update(sigma);
391 prf.update(&[counter]);
392 let mut reader = prf.finalize_xof();
393 let mut buf = [0u8; 128];
394 reader.read(&mut buf);
395
396 let mut i = 0usize;
397 let mut j = 0usize;
398 while i < buf.len() {
399 let t = u32::from_le_bytes([buf[i], buf[i + 1], buf[i + 2], buf[i + 3]]);
400 let d = (t & 0x5555_5555) + ((t >> 1) & 0x5555_5555);
403
404 dst[j] = cbd2(d, d >> 2);
405 dst[j + 1] = cbd2(d >> 4, d >> 6);
406 dst[j + 2] = cbd2(d >> 8, d >> 10);
407 dst[j + 3] = cbd2(d >> 12, d >> 14);
408 dst[j + 4] = cbd2(d >> 16, d >> 18);
409 dst[j + 5] = cbd2(d >> 20, d >> 22);
410 dst[j + 6] = cbd2(d >> 24, d >> 26);
411 dst[j + 7] = cbd2(d >> 28, d >> 30);
412
413 i += 4;
414 j += 8;
415 }
416}
417
418fn cbd2(a: u32, b: u32) -> u16 {
419 field_reduce_once(Q.wrapping_add((a & 0x3) as u16).wrapping_sub((b & 0x3) as u16))
420}
421
422#[rustfmt::skip]
427const ZETAS: [u16; 128] = [
428 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
429 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
430 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
431 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
432 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
433 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
434 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594,
435 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
436];
437
438fn ntt(f: &mut RingElement) {
439 let mut i = 1usize;
440 let mut length = 128usize;
441 while length >= 2 {
442 let mut start = 0usize;
443 while start < 256 {
444 let zeta = ZETAS[i];
445 i += 1;
446 for j in start..start + length {
447 let t = field_mul_wide(zeta, f[j + length]);
451 let a = f[j];
452 f[j] = a.wrapping_add(t);
453 f[j + length] = a.wrapping_add(Q).wrapping_sub(t);
454 }
455 start += 2 * length;
456 }
457 length /= 2;
458 }
459 for coeff in f.iter_mut() {
460 *coeff = field_reduce(*coeff as u32);
461 }
462}
463
464const INVERSE_NTT_SCALE: u16 = 3303;
465const INVERSE_NTT_FINAL_ZETA: u16 = 1652; fn inverse_ntt(f: &mut RingElement) {
470 let mut i = 127usize;
471 let mut length = 2usize;
472 while length < 128 {
473 let mut start = 0usize;
474 while start < 256 {
475 let zeta = ZETAS[i];
476 i -= 1;
477 for j in start..start + length {
478 let t = f[j];
479 f[j] = field_add(t, f[j + length]);
480 f[j + length] = field_mul_sub(zeta, f[j + length], t);
481 }
482 start += 2 * length;
483 }
484 length *= 2;
485 }
486
487 for j in 0..128 {
488 let t = f[j];
489 f[j] = field_mul(field_add(t, f[j + 128]), INVERSE_NTT_SCALE);
490 f[j + 128] = field_mul_sub(INVERSE_NTT_FINAL_ZETA, f[j + 128], t);
491 }
492}
493
494#[rustfmt::skip]
495const GAMMAS: [u16; 128] = [
496 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229,
497 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279,
498 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
499 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688,
500 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684,
501 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
502 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443,
503 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
504];
505
506#[allow(clippy::too_many_arguments)]
511fn ntt_mul_add4(
512 acc: &mut RingElement,
513 a0: &RingElement,
514 b0: &RingElement,
515 a1: &RingElement,
516 b1: &RingElement,
517 a2: &RingElement,
518 b2: &RingElement,
519 a3: &RingElement,
520 b3: &RingElement,
521) {
522 let mut i = 0usize;
523 while i < N {
524 let gamma = GAMMAS[i / 2] as u32;
525
526 let (a00, a01) = (a0[i], a0[i + 1]);
527 let (b00, b01) = (b0[i], b0[i + 1]);
528 let mut acc0 = acc[i] as u32;
529 acc0 += (a00 as u32) * (b00 as u32) + (field_mul(a01, b01) as u32) * gamma;
530 let mut acc1 = acc[i + 1] as u32;
531 acc1 += (a00 as u32) * (b01 as u32) + (a01 as u32) * (b00 as u32);
532
533 let (a10, a11) = (a1[i], a1[i + 1]);
534 let (b10, b11) = (b1[i], b1[i + 1]);
535 acc0 += (a10 as u32) * (b10 as u32) + (field_mul(a11, b11) as u32) * gamma;
536 acc1 += (a10 as u32) * (b11 as u32) + (a11 as u32) * (b10 as u32);
537
538 let (a20, a21) = (a2[i], a2[i + 1]);
539 let (b20, b21) = (b2[i], b2[i + 1]);
540 acc0 += (a20 as u32) * (b20 as u32) + (field_mul(a21, b21) as u32) * gamma;
541 acc1 += (a20 as u32) * (b21 as u32) + (a21 as u32) * (b20 as u32);
542
543 let (a30, a31) = (a3[i], a3[i + 1]);
544 let (b30, b31) = (b3[i], b3[i + 1]);
545 acc0 += (a30 as u32) * (b30 as u32) + (field_mul(a31, b31) as u32) * gamma;
546 acc1 += (a30 as u32) * (b31 as u32) + (a31 as u32) * (b30 as u32);
547
548 acc[i] = field_reduce_wide(acc0);
549 acc[i + 1] = field_reduce_wide(acc1);
550 i += 2;
551 }
552}
553
554fn poly_add_assign(a: &mut RingElement, b: &RingElement) {
555 for i in 0..N {
556 a[i] = field_add(a[i], b[i]);
557 }
558}
559
560fn poly_sub_assign(a: &mut RingElement, b: &RingElement) {
561 for i in 0..N {
562 a[i] = field_sub(a[i], b[i]);
563 }
564}
565
566fn sha3_512(input: &[u8]) -> [u8; 64] {
571 let out = Sha3_512::digest(input);
572 let mut r = [0u8; 64];
573 r.copy_from_slice(&out);
574 r
575}
576
577fn sha3_256(input: &[u8]) -> [u8; 32] {
578 let out = Sha3_256::digest(input);
579 let mut r = [0u8; 32];
580 r.copy_from_slice(&out);
581 r
582}
583
584fn shake256(inputs: &[&[u8]], out: &mut [u8]) {
585 let mut h = Shake256::default();
586 for input in inputs {
587 h.update(input);
588 }
589 let mut reader = h.finalize_xof();
590 reader.read(out);
591}
592
593fn ct_eq_mask(a: &[u8], b: &[u8]) -> u8 {
600 debug_assert_eq!(a.len(), b.len());
601 let mut diff = 0u8;
602 for (x, y) in a.iter().zip(b.iter()) {
603 diff |= x ^ y;
604 }
605 let nonzero = (diff as u32 | (diff as u32).wrapping_neg()) >> 31;
607 ((nonzero ^ 1) as u8).wrapping_neg()
608}
609
610fn ct_select(mask: u8, dst: &mut [u8], src: &[u8]) {
613 for (d, s) in dst.iter_mut().zip(src.iter()) {
614 *d = (*d & !mask) | (*s & mask);
615 }
616}
617
618#[derive(Clone)]
623struct EncryptionKey {
624 t: [RingElement; K], a: [RingElement; K * K], rho: [u8; 32], encoded: [u8; MLKEM1024_ENCAPSULATION_KEY_SIZE], }
629
630impl EncryptionKey {
631 fn zeroed() -> Self {
632 Self {
633 t: [new_ring(); K],
634 a: [new_ring(); K * K],
635 rho: [0u8; 32],
636 encoded: [0u8; MLKEM1024_ENCAPSULATION_KEY_SIZE],
637 }
638 }
639}
640
641#[derive(Clone)]
642struct DecryptionKey {
643 s: [RingElement; K], }
645
646pub struct DecapsulationKey {
652 d: [u8; 32], z: [u8; 32], h: [u8; 32], encryption_key: EncryptionKey,
656 decryption_key: DecryptionKey,
657}
658
659#[derive(Clone)]
661pub struct EncapsulationKey {
662 h: [u8; 32], encryption_key: EncryptionKey,
664}
665
666impl DecapsulationKey {
667 pub fn generate() -> Result<Self> {
669 let mut d = [0u8; 32];
670 let mut z = [0u8; 32];
671 getrandom::getrandom(&mut d)?;
672 getrandom::getrandom(&mut z)?;
673 let key = Self::from_d_z(&d, &z);
674 d.zeroize();
675 z.zeroize();
676 Ok(key)
677 }
678
679 pub fn from_seed(seed: &[u8]) -> Result<Self> {
682 if seed.len() != MLKEM1024_SEED_SIZE {
683 return Err(QrllibError::InvalidMlKemSeedSize(seed.len(), MLKEM1024_SEED_SIZE));
684 }
685 let mut d = [0u8; 32];
686 let mut z = [0u8; 32];
687 d.copy_from_slice(&seed[..32]);
688 z.copy_from_slice(&seed[32..]);
689 let key = Self::from_d_z(&d, &z);
690 d.zeroize();
691 z.zeroize();
692 Ok(key)
693 }
694
695 fn from_d_z(d: &[u8; 32], z: &[u8; 32]) -> Self {
696 let mut dk = DecapsulationKey {
697 d: *d,
698 z: *z,
699 h: [0u8; 32],
700 encryption_key: EncryptionKey::zeroed(),
701 decryption_key: DecryptionKey { s: [new_ring(); K] },
702 };
703 pke_key_gen(&mut dk, d);
704 dk.h = sha3_256(&dk.encryption_key.encoded);
705 dk
706 }
707
708 pub fn decapsulate(
713 &self,
714 ciphertext: &[u8],
715 ) -> Result<Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>> {
716 if ciphertext.len() != MLKEM1024_CIPHERTEXT_SIZE {
717 return Err(QrllibError::InvalidMlKemCiphertextSize(
718 ciphertext.len(),
719 MLKEM1024_CIPHERTEXT_SIZE,
720 ));
721 }
722 let mut ct = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
723 ct.copy_from_slice(ciphertext);
724 Ok(decapsulate(self, &ct))
725 }
726
727 pub fn encapsulation_key(&self) -> EncapsulationKey {
729 EncapsulationKey { h: self.h, encryption_key: self.encryption_key.clone() }
730 }
731
732 pub fn bytes(&self) -> Zeroizing<[u8; MLKEM1024_SEED_SIZE]> {
734 let mut b = [0u8; MLKEM1024_SEED_SIZE];
735 b[..32].copy_from_slice(&self.d);
736 b[32..].copy_from_slice(&self.z);
737 Zeroizing::new(b)
738 }
739
740 pub fn zeroize(&mut self) {
745 self.d.zeroize();
746 self.z.zeroize();
747 for poly in &mut self.decryption_key.s {
748 poly.zeroize();
749 }
750 }
751}
752
753impl Drop for DecapsulationKey {
754 fn drop(&mut self) {
755 self.zeroize();
756 }
757}
758
759impl EncapsulationKey {
760 pub fn from_bytes(ek_bytes: &[u8]) -> Result<Self> {
764 if ek_bytes.len() != MLKEM1024_ENCAPSULATION_KEY_SIZE {
765 return Err(QrllibError::InvalidMlKemEncapsulationKeySize(
766 ek_bytes.len(),
767 MLKEM1024_ENCAPSULATION_KEY_SIZE,
768 ));
769 }
770
771 let mut ek =
772 EncapsulationKey { h: sha3_256(ek_bytes), encryption_key: EncryptionKey::zeroed() };
773 ek.encryption_key.encoded.copy_from_slice(ek_bytes);
774
775 let mut offset = 0usize;
776 for i in 0..K {
777 byte_decode12(
778 &mut ek.encryption_key.t[i],
779 &ek_bytes[offset..offset + ENCODING_SIZE_12],
780 )?;
781 offset += ENCODING_SIZE_12;
782 }
783 ek.encryption_key.rho.copy_from_slice(&ek_bytes[offset..offset + 32]);
784
785 let rho = ek.encryption_key.rho;
786 for i in 0..K {
787 for j in 0..K {
788 sample_ntt(&mut ek.encryption_key.a[i * K + j], &rho, j as u8, i as u8);
789 }
790 }
791
792 Ok(ek)
793 }
794
795 pub fn encapsulate(
798 &self,
799 ) -> Result<(Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE])> {
800 let mut m = [0u8; MESSAGE_SIZE];
801 getrandom::getrandom(&mut m)?;
802 let (shared_key, ciphertext) = encapsulate_to(&self.encryption_key, &self.h, &m);
803 m.zeroize();
804 Ok((shared_key, ciphertext))
805 }
806
807 pub fn encapsulate_deterministic(
815 &self,
816 m: &[u8; MESSAGE_SIZE],
817 ) -> (Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE]) {
818 encapsulate_to(&self.encryption_key, &self.h, m)
819 }
820
821 pub fn bytes(&self) -> [u8; MLKEM1024_ENCAPSULATION_KEY_SIZE] {
823 self.encryption_key.encoded
824 }
825}
826
827fn pke_key_gen(dk: &mut DecapsulationKey, d: &[u8; 32]) {
833 let mut g_input = [0u8; 33];
834 g_input[..32].copy_from_slice(d);
835 g_input[32] = K as u8;
836 let mut g = sha3_512(&g_input);
837 let mut rho = [0u8; 32];
838 let mut sigma = [0u8; 32];
839 rho.copy_from_slice(&g[..32]);
840 sigma.copy_from_slice(&g[32..]);
841
842 dk.encryption_key.rho = rho;
843 dk.encryption_key.encoded[K * ENCODING_SIZE_12..].copy_from_slice(&rho);
844
845 for i in 0..K {
846 for j in 0..K {
847 sample_ntt(&mut dk.encryption_key.a[i * K + j], &rho, j as u8, i as u8);
848 }
849 }
850
851 let mut counter = 0u8;
852 for i in 0..K {
853 sample_poly_cbd(&mut dk.decryption_key.s[i], &sigma, counter);
854 ntt(&mut dk.decryption_key.s[i]);
855 counter += 1;
856 }
857
858 for i in 0..K {
859 let mut acc = new_ring();
860 ntt_mul_add4(
861 &mut acc,
862 &dk.encryption_key.a[i * K],
863 &dk.decryption_key.s[0],
864 &dk.encryption_key.a[i * K + 1],
865 &dk.decryption_key.s[1],
866 &dk.encryption_key.a[i * K + 2],
867 &dk.decryption_key.s[2],
868 &dk.encryption_key.a[i * K + 3],
869 &dk.decryption_key.s[3],
870 );
871
872 let mut e = new_ring();
873 sample_poly_cbd(&mut e, &sigma, counter);
874 ntt(&mut e);
875 counter += 1;
876 poly_add_assign(&mut acc, &e);
877 e.zeroize(); dk.encryption_key.t[i] = acc;
880 byte_encode12(
881 &mut dk.encryption_key.encoded[i * ENCODING_SIZE_12..(i + 1) * ENCODING_SIZE_12],
882 &dk.encryption_key.t[i],
883 );
884 }
885
886 g_input.zeroize();
889 g.zeroize();
890 sigma.zeroize();
891 rho.zeroize();
892}
893
894fn pke_encrypt(
895 dst: &mut [u8; MLKEM1024_CIPHERTEXT_SIZE],
896 ek: &EncryptionKey,
897 m: &[u8; MESSAGE_SIZE],
898 r: &[u8; 32],
899) {
900 let mut counter = 0u8;
901 let mut y = [new_ring(); K];
902 for poly in &mut y {
903 sample_poly_cbd(poly, r, counter);
904 ntt(poly);
905 counter += 1;
906 }
907
908 let mut off = 0usize;
909 for i in 0..K {
910 let mut acc = new_ring();
911 ntt_mul_add4(
914 &mut acc,
915 &ek.a[i],
916 &y[0],
917 &ek.a[K + i],
918 &y[1],
919 &ek.a[2 * K + i],
920 &y[2],
921 &ek.a[3 * K + i],
922 &y[3],
923 );
924 inverse_ntt(&mut acc);
925
926 let mut e1 = new_ring();
927 sample_poly_cbd(&mut e1, r, counter);
928 counter += 1;
929 poly_add_assign(&mut acc, &e1);
930 e1.zeroize(); ring_compress_and_encode11(&mut dst[off..off + ENCODING_SIZE_11], &acc);
933 off += ENCODING_SIZE_11;
934 }
935
936 let mut e2 = new_ring();
937 sample_poly_cbd(&mut e2, r, counter);
938
939 let mut mu = new_ring();
940 ring_decode_and_decompress1(&mut mu, m);
941
942 let mut v = new_ring();
943 ntt_mul_add4(&mut v, &ek.t[0], &y[0], &ek.t[1], &y[1], &ek.t[2], &y[2], &ek.t[3], &y[3]);
944 inverse_ntt(&mut v);
945 poly_add_assign(&mut v, &e2);
946 poly_add_assign(&mut v, &mu);
947
948 ring_compress_and_encode5(&mut dst[off..off + ENCODING_SIZE_5], &v);
949
950 for poly in &mut y {
955 poly.zeroize();
956 }
957 e2.zeroize();
958 mu.zeroize();
959 v.zeroize();
960}
961
962fn pke_decrypt(
963 dst: &mut [u8; MESSAGE_SIZE],
964 dk: &DecapsulationKey,
965 c: &[u8; MLKEM1024_CIPHERTEXT_SIZE],
966) {
967 let mut u = [new_ring(); K];
968 let mut off = 0usize;
969 for poly in &mut u {
970 ring_decode_and_decompress11(poly, &c[off..off + ENCODING_SIZE_11]);
971 off += ENCODING_SIZE_11;
972 ntt(poly);
973 }
974
975 let mut v = new_ring();
976 ring_decode_and_decompress5(&mut v, &c[off..off + ENCODING_SIZE_5]);
977
978 let s = &dk.decryption_key.s;
979 let mut acc = new_ring();
980 ntt_mul_add4(&mut acc, &s[0], &u[0], &s[1], &u[1], &s[2], &u[2], &s[3], &u[3]);
981 inverse_ntt(&mut acc);
982
983 poly_sub_assign(&mut v, &acc);
984 ring_compress_and_encode1(dst, &v);
985
986 acc.zeroize();
991 v.zeroize();
992}
993
994fn encapsulate_to(
999 ek: &EncryptionKey,
1000 ek_h: &[u8; 32],
1001 m: &[u8; MESSAGE_SIZE],
1002) -> (Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]>, [u8; MLKEM1024_CIPHERTEXT_SIZE]) {
1003 let mut g_input = [0u8; MESSAGE_SIZE + 32];
1004 g_input[..MESSAGE_SIZE].copy_from_slice(m);
1005 g_input[MESSAGE_SIZE..].copy_from_slice(ek_h);
1006 let mut g = sha3_512(&g_input);
1007
1008 let mut shared_key = [0u8; MLKEM1024_SHARED_KEY_SIZE];
1009 shared_key.copy_from_slice(&g[..MLKEM1024_SHARED_KEY_SIZE]);
1010 let mut r = [0u8; 32];
1011 r.copy_from_slice(&g[MLKEM1024_SHARED_KEY_SIZE..]);
1012
1013 let mut ciphertext = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
1014 pke_encrypt(&mut ciphertext, ek, m, &r);
1015
1016 g_input.zeroize();
1019 g.zeroize();
1020 r.zeroize();
1021
1022 (Zeroizing::new(shared_key), ciphertext)
1023}
1024
1025fn decapsulate(
1026 dk: &DecapsulationKey,
1027 ct: &[u8; MLKEM1024_CIPHERTEXT_SIZE],
1028) -> Zeroizing<[u8; MLKEM1024_SHARED_KEY_SIZE]> {
1029 let mut m = [0u8; MESSAGE_SIZE];
1030 pke_decrypt(&mut m, dk, ct);
1031
1032 let mut g_input = [0u8; MESSAGE_SIZE + 32];
1033 g_input[..MESSAGE_SIZE].copy_from_slice(&m);
1034 g_input[MESSAGE_SIZE..].copy_from_slice(&dk.h);
1035 let mut g = sha3_512(&g_input);
1036 let mut r = [0u8; 32];
1037 r.copy_from_slice(&g[MLKEM1024_SHARED_KEY_SIZE..]);
1038
1039 let mut k_out = [0u8; MLKEM1024_SHARED_KEY_SIZE];
1042 shake256(&[&dk.z, ct.as_slice()], &mut k_out);
1043
1044 let mut c = [0u8; MLKEM1024_CIPHERTEXT_SIZE];
1045 pke_encrypt(&mut c, &dk.encryption_key, &m, &r);
1046
1047 let matches = ct_eq_mask(ct.as_slice(), &c);
1051 ct_select(matches, &mut k_out, &g[..MLKEM1024_SHARED_KEY_SIZE]);
1052
1053 m.zeroize();
1054 g_input.zeroize();
1055 g.zeroize();
1056 r.zeroize();
1057
1058 Zeroizing::new(k_out)
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use super::*;
1064
1065 fn seed(byte: u8) -> [u8; MLKEM1024_SEED_SIZE] {
1066 [byte; MLKEM1024_SEED_SIZE]
1067 }
1068
1069 #[test]
1070 fn sizes_match_fips_203_ml_kem_1024() {
1071 assert_eq!(MLKEM1024_SEED_SIZE, 64);
1072 assert_eq!(MLKEM1024_SHARED_KEY_SIZE, 32);
1073 assert_eq!(MLKEM1024_CIPHERTEXT_SIZE, 1568);
1074 assert_eq!(MLKEM1024_ENCAPSULATION_KEY_SIZE, 1568);
1075 }
1076
1077 #[test]
1078 fn encapsulate_then_decapsulate_recovers_shared_secret() {
1079 let dk = DecapsulationKey::from_seed(&seed(0x42)).expect("decap key");
1080 let ek = dk.encapsulation_key();
1081 let (shared_a, ciphertext) = ek.encapsulate().expect("encapsulate");
1082 let shared_b = dk.decapsulate(&ciphertext).expect("decapsulate");
1083 assert_eq!(*shared_a, *shared_b);
1084 assert_eq!(ciphertext.len(), MLKEM1024_CIPHERTEXT_SIZE);
1085 }
1086
1087 #[test]
1088 fn generated_decapsulation_key_round_trips() {
1089 let dk = DecapsulationKey::generate().expect("generated decap key");
1093 let ek = dk.encapsulation_key();
1094 let (shared_a, ciphertext) = ek.encapsulate().expect("encapsulate");
1095 let shared_b = dk.decapsulate(&ciphertext).expect("decapsulate");
1096 assert_eq!(*shared_a, *shared_b);
1097 }
1098
1099 #[test]
1100 fn encapsulation_key_round_trips_through_bytes() {
1101 let dk = DecapsulationKey::from_seed(&seed(7)).expect("decap key");
1102 let ek = dk.encapsulation_key();
1103 let ek_bytes = ek.bytes();
1104 assert_eq!(ek_bytes.len(), MLKEM1024_ENCAPSULATION_KEY_SIZE);
1105
1106 let restored = EncapsulationKey::from_bytes(&ek_bytes).expect("restore ek");
1107 let (shared, ciphertext) = restored.encapsulate().expect("encapsulate");
1108 assert_eq!(*dk.decapsulate(&ciphertext).expect("decapsulate"), *shared);
1109 }
1110
1111 #[test]
1112 fn from_seed_is_deterministic_and_round_trips() {
1113 let dk1 = DecapsulationKey::from_seed(&seed(0x11)).expect("decap key");
1114 let dk2 = DecapsulationKey::from_seed(&seed(0x11)).expect("decap key");
1115 assert_eq!(dk1.encapsulation_key().bytes(), dk2.encapsulation_key().bytes());
1116 assert_eq!(*dk1.bytes(), *dk2.bytes());
1117
1118 let dk3 = DecapsulationKey::from_seed(&seed(0x12)).expect("decap key");
1120 assert_ne!(dk1.encapsulation_key().bytes(), dk3.encapsulation_key().bytes());
1121 }
1122
1123 #[test]
1124 fn deterministic_encapsulation_is_reproducible() {
1125 let dk = DecapsulationKey::from_seed(&seed(0x99)).expect("decap key");
1126 let ek = dk.encapsulation_key();
1127 let m = [0x5a_u8; MESSAGE_SIZE];
1128 let (shared_a, ct_a) = ek.encapsulate_deterministic(&m);
1129 let (shared_b, ct_b) = ek.encapsulate_deterministic(&m);
1130 assert_eq!(*shared_a, *shared_b);
1131 assert_eq!(ct_a, ct_b);
1132 assert_eq!(*dk.decapsulate(&ct_a).expect("decapsulate"), *shared_a);
1133 }
1134
1135 #[test]
1136 fn decapsulate_implicitly_rejects_malformed_ciphertext() {
1137 let dk = DecapsulationKey::from_seed(&seed(0x33)).expect("decap key");
1138 let ek = dk.encapsulation_key();
1139 let (_shared, mut ciphertext) = ek.encapsulate().expect("encapsulate");
1140
1141 ciphertext[0] ^= 0xff;
1144 let rejected = dk.decapsulate(&ciphertext).expect("implicit rejection still succeeds");
1145 let valid = dk.decapsulate(&ek.encapsulate().expect("encapsulate").1).expect("decapsulate");
1146 assert_ne!(*rejected, *valid);
1147 }
1148
1149 #[test]
1150 fn wrong_length_inputs_are_rejected() {
1151 assert!(matches!(
1152 DecapsulationKey::from_seed(&[0u8; 32]),
1153 Err(QrllibError::InvalidMlKemSeedSize(32, 64))
1154 ));
1155 let dk = DecapsulationKey::from_seed(&seed(1)).expect("decap key");
1156 assert!(matches!(
1157 dk.decapsulate(&[0u8; 10]),
1158 Err(QrllibError::InvalidMlKemCiphertextSize(10, 1568))
1159 ));
1160 assert!(matches!(
1161 EncapsulationKey::from_bytes(&[0u8; 100]),
1162 Err(QrllibError::InvalidMlKemEncapsulationKeySize(100, 1568))
1163 ));
1164 }
1165
1166 #[test]
1167 fn from_bytes_rejects_non_canonical_encoding() {
1168 let dk = DecapsulationKey::from_seed(&seed(0x55)).expect("decap key");
1169 let mut ek_bytes = dk.encapsulation_key().bytes();
1170 ek_bytes[0] = 0xff;
1174 ek_bytes[1] |= 0x0f;
1175 assert!(matches!(
1176 EncapsulationKey::from_bytes(&ek_bytes),
1177 Err(QrllibError::InvalidMlKemEncoding)
1178 ));
1179 }
1180
1181 #[test]
1182 fn decapsulation_key_zeroize_clears_secret_material() {
1183 let mut dk = DecapsulationKey::from_seed(&seed(0x77)).expect("decap key");
1184 dk.zeroize();
1185 assert!(dk.d.iter().all(|b| *b == 0));
1186 assert!(dk.z.iter().all(|b| *b == 0));
1187 assert!(dk.decryption_key.s.iter().all(|poly| poly.iter().all(|c| *c == 0)));
1188 }
1189}
1190
1191#[cfg(test)]
1207mod acvp {
1208 use super::*;
1209 use serde::Deserialize;
1210 use std::{
1211 env, fs,
1212 path::{Path, PathBuf},
1213 };
1214
1215 #[derive(Deserialize)]
1216 struct PromptFile {
1217 #[serde(rename = "testGroups")]
1218 test_groups: Vec<PromptGroup>,
1219 }
1220
1221 #[derive(Deserialize)]
1222 struct PromptGroup {
1223 #[serde(rename = "tgId")]
1224 tg_id: u32,
1225 #[serde(rename = "parameterSet")]
1226 parameter_set: String,
1227 #[serde(default)]
1228 function: String,
1229 tests: Vec<PromptTest>,
1230 }
1231
1232 #[derive(Deserialize)]
1233 struct PromptTest {
1234 #[serde(rename = "tcId")]
1235 tc_id: u32,
1236 #[serde(default)]
1237 d: String,
1238 #[serde(default)]
1239 z: String,
1240 #[serde(default)]
1241 ek: String,
1242 #[serde(default)]
1243 dk: String,
1244 #[serde(default)]
1245 m: String,
1246 #[serde(default)]
1247 c: String,
1248 }
1249
1250 #[derive(Deserialize)]
1251 struct ExpectedFile {
1252 #[serde(rename = "testGroups")]
1253 test_groups: Vec<ExpectedGroup>,
1254 }
1255
1256 #[derive(Deserialize)]
1257 struct ExpectedGroup {
1258 #[serde(rename = "tgId")]
1259 tg_id: u32,
1260 tests: Vec<ExpectedTest>,
1261 }
1262
1263 #[derive(Deserialize)]
1264 struct ExpectedTest {
1265 #[serde(rename = "tcId")]
1266 tc_id: u32,
1267 #[serde(default)]
1268 ek: String,
1269 #[serde(default)]
1270 dk: String,
1271 #[serde(default)]
1272 c: String,
1273 #[serde(default)]
1274 k: String,
1275 #[serde(rename = "testPassed", default)]
1276 test_passed: bool,
1277 }
1278
1279 fn vectors_dir() -> Option<PathBuf> {
1280 env::var_os("MLKEM_ACVP_VECTORS_DIR").map(PathBuf::from)
1281 }
1282
1283 fn load<T: serde::de::DeserializeOwned>(dir: &Path, suite: &str, name: &str) -> T {
1284 let path = dir.join(suite).join(name);
1285 let data =
1286 fs::read_to_string(&path).unwrap_or_else(|e| panic!("read {}: {}", path.display(), e));
1287 serde_json::from_str(&data).unwrap_or_else(|e| panic!("parse {}: {}", path.display(), e))
1288 }
1289
1290 fn decode(value: &str) -> Vec<u8> {
1291 hex::decode(value).expect("ACVP hex")
1292 }
1293
1294 fn expected_test(expected: &ExpectedFile, tg_id: u32, tc_id: u32) -> &ExpectedTest {
1295 expected
1296 .test_groups
1297 .iter()
1298 .find(|g| g.tg_id == tg_id)
1299 .unwrap_or_else(|| panic!("missing expected group {tg_id}"))
1300 .tests
1301 .iter()
1302 .find(|t| t.tc_id == tc_id)
1303 .unwrap_or_else(|| panic!("missing expected test {tc_id} in group {tg_id}"))
1304 }
1305
1306 fn to_expanded(dk: &DecapsulationKey) -> Vec<u8> {
1309 let mut out =
1310 Vec::with_capacity(K * ENCODING_SIZE_12 + MLKEM1024_ENCAPSULATION_KEY_SIZE + 64);
1311 let mut encoded = [0u8; ENCODING_SIZE_12];
1312 for poly in &dk.decryption_key.s {
1313 byte_encode12(&mut encoded, poly);
1314 out.extend_from_slice(&encoded);
1315 }
1316 out.extend_from_slice(&dk.encryption_key.encoded);
1317 out.extend_from_slice(&dk.h);
1318 out.extend_from_slice(&dk.z);
1319 out
1320 }
1321
1322 fn from_expanded(b: &[u8]) -> Result<DecapsulationKey> {
1326 const EXPANDED: usize = K * ENCODING_SIZE_12 + MLKEM1024_ENCAPSULATION_KEY_SIZE + 64;
1327 if b.len() != EXPANDED {
1331 return Err(QrllibError::InvalidMlKemEncoding);
1333 }
1334 let mut s = [new_ring(); K];
1335 let mut off = 0usize;
1336 for poly in &mut s {
1337 byte_decode12(poly, &b[off..off + ENCODING_SIZE_12])?;
1338 off += ENCODING_SIZE_12;
1339 }
1340 let ek = EncapsulationKey::from_bytes(&b[off..off + MLKEM1024_ENCAPSULATION_KEY_SIZE])?;
1341 off += MLKEM1024_ENCAPSULATION_KEY_SIZE;
1342 if ek.h[..] != b[off..off + 32] {
1343 return Err(QrllibError::InvalidMlKemEncoding);
1344 }
1345 off += 32;
1346 let mut z = [0u8; 32];
1347 z.copy_from_slice(&b[off..off + 32]);
1348 Ok(DecapsulationKey {
1351 d: [0u8; 32],
1352 z,
1353 h: ek.h,
1354 encryption_key: ek.encryption_key,
1355 decryption_key: DecryptionKey { s },
1356 })
1357 }
1358
1359 #[test]
1360 fn acvp_keygen_matches_nist_vectors() {
1361 let Some(dir) = vectors_dir() else {
1365 eprintln!("MLKEM_ACVP_VECTORS_DIR not set; skipping ML-KEM ACVP keyGen test");
1367 return;
1368 };
1370 let suite = "ML-KEM-keyGen-FIPS203";
1371 let prompt: PromptFile = load(&dir, suite, "prompt.json");
1372 let expected: ExpectedFile = load(&dir, suite, "expectedResults.json");
1373
1374 let mut tested = 0u32;
1375 for group in &prompt.test_groups {
1376 if group.parameter_set != "ML-KEM-1024" {
1377 continue;
1378 }
1379 for test in &group.tests {
1380 tested += 1;
1381 let want = expected_test(&expected, group.tg_id, test.tc_id);
1382 let mut seed = [0u8; MLKEM1024_SEED_SIZE];
1383 seed[..32].copy_from_slice(&decode(&test.d));
1384 seed[32..].copy_from_slice(&decode(&test.z));
1385 let dk = DecapsulationKey::from_seed(&seed).expect("decapsulation key");
1386 assert_eq!(
1387 dk.encapsulation_key().bytes().as_slice(),
1388 decode(&want.ek).as_slice(),
1389 "tc{}: encapsulation key mismatch",
1390 test.tc_id
1391 );
1392 assert_eq!(
1393 to_expanded(&dk),
1394 decode(&want.dk),
1395 "tc{}: expanded decapsulation key mismatch",
1396 test.tc_id
1397 );
1398 }
1399 }
1400 assert!(tested > 0, "no ML-KEM-1024 ACVP keyGen test cases");
1401 eprintln!("ACVP ML-KEM-1024 keyGen: {tested} cases passed");
1402 }
1403
1404 #[test]
1405 fn acvp_encap_decap_matches_nist_vectors() {
1406 let Some(dir) = vectors_dir() else {
1410 eprintln!("MLKEM_ACVP_VECTORS_DIR not set; skipping ML-KEM ACVP encapDecap test");
1412 return;
1413 };
1415 let suite = "ML-KEM-encapDecap-FIPS203";
1416 let prompt: PromptFile = load(&dir, suite, "prompt.json");
1417 let expected: ExpectedFile = load(&dir, suite, "expectedResults.json");
1418
1419 let (mut encap, mut decap, mut decap_check, mut encap_check) = (0u32, 0u32, 0u32, 0u32);
1420 for group in &prompt.test_groups {
1421 if group.parameter_set != "ML-KEM-1024" {
1422 continue;
1423 }
1424 for test in &group.tests {
1425 let want = expected_test(&expected, group.tg_id, test.tc_id);
1426 match group.function.as_str() {
1427 "encapsulation" => {
1428 let ek = EncapsulationKey::from_bytes(&decode(&test.ek))
1429 .expect("encapsulation key");
1430 let m: [u8; 32] = decode(&test.m).try_into().expect("32-byte m");
1431 let (shared, ciphertext) = ek.encapsulate_deterministic(&m);
1432 assert_eq!(ciphertext, decode(&want.c).as_slice(), "tc{}: ct", test.tc_id);
1433 assert_eq!(*shared, decode(&want.k).as_slice(), "tc{}: K", test.tc_id);
1434 encap += 1;
1435 }
1436 "decapsulation" => {
1437 let dk = from_expanded(&decode(&test.dk)).expect("decapsulation key");
1438 let shared = dk.decapsulate(&decode(&test.c)).expect("decapsulate");
1439 assert_eq!(*shared, decode(&want.k).as_slice(), "tc{}: K", test.tc_id);
1440 decap += 1;
1441 }
1442 "decapsulationKeyCheck" => {
1443 let ok = from_expanded(&decode(&test.dk)).is_ok();
1444 assert_eq!(ok, want.test_passed, "tc{}: dk check", test.tc_id);
1445 decap_check += 1;
1446 }
1447 "encapsulationKeyCheck" => {
1448 let ok = EncapsulationKey::from_bytes(&decode(&test.ek)).is_ok();
1449 assert_eq!(ok, want.test_passed, "tc{}: ek check", test.tc_id);
1450 encap_check += 1;
1451 }
1452 other => panic!("unexpected ACVP function {other:?}"),
1457 }
1458 }
1459 }
1460 assert!(
1461 encap > 0 && decap > 0 && decap_check > 0 && encap_check > 0,
1462 "missing an ML-KEM-1024 encapDecap function (encap={encap} decap={decap} \
1463 decapCheck={decap_check} encapCheck={encap_check})"
1464 );
1465 eprintln!(
1466 "ACVP ML-KEM-1024 encapDecap: encap={encap} decap={decap} \
1467 decapKeyCheck={decap_check} encapKeyCheck={encap_check} passed"
1468 );
1469 }
1470}