1mod compress;
3mod reduce;
4
5use compress::{compr_10bit, compr_1bit, compr_4bit, decompr_10bit, decompr_1bit, decompr_4bit};
6use core::{
7 array,
8 fmt::Display,
9 hint::black_box,
10 mem::{self, transmute, MaybeUninit},
11 ops::{AddAssign, Mul, SubAssign},
12};
13use rand_core::CryptoRngCore;
14
15use crate::hash;
16
17const N: usize = 256;
18const K: usize = 3;
19const Q: i16 = 3329;
20const DU: usize = 10;
21const DV: usize = 4;
22
23const COEFFICIENT_BITSIZE: usize = 12;
24
25const ZETAS: [i16; 128] = {
29 const ZETA1: i16 = reduce::to_mont(17);
30
31 let mut zetas = [0; 128];
32 zetas[0] = reduce::R_MOD_Q as i16;
33
34 let mut i = 1;
35 while i < 128 {
36 zetas[i] = reduce::mont_mul(zetas[i - 1], ZETA1);
37
38 i += 1
39 }
40
41 let mut zetas_bitrev = [0; 128];
42
43 i = 0;
44 while i < 128 {
45 let idx = (i as u8).reverse_bits() >> 1;
46
47 zetas_bitrev[i] = match zetas[idx as usize] {
48 z if z > Q / 2 => z - Q,
49 z if z < -Q / 2 => z + Q,
50 z => z,
51 };
52
53 i += 1;
54 }
55
56 zetas_bitrev
57};
58
59#[derive(Debug, PartialEq)]
60struct Poly {
61 f: [i16; N],
62}
63
64impl Poly {
65 const ENCODED_BYTES: usize = (COEFFICIENT_BITSIZE * N) / 8;
66 const COMPRESSED_BYTES: usize = (N * DV) / 8;
67
68 const fn zero() -> Self {
69 Self { f: [0; N] }
70 }
71
72 fn ntt(&mut self) {
74 let f = &mut self.f;
75
76 let mut k = 1;
77
78 for len in (0..7).map(|n| 128 >> n) {
79 for start in (0..256).step_by(len << 1) {
80 let zeta = ZETAS[k];
81 k += 1;
82 for j in start..start + len {
83 let t = reduce::mont_mul(zeta, f[j + len]);
84 f[j + len] = f[j] - t;
85 f[j] += t;
86 }
87 }
88 }
89
90 self.reduce();
91 }
92
93 fn invntt(&mut self) {
95 let f = &mut self.f;
96
97 let mut k = 127;
98
99 for len in (0..7).map(|n| 2 << n) {
100 for start in (0..256).step_by(len << 1) {
101 let zeta = ZETAS[k];
102 k -= 1;
103 for j in start..start + len {
104 let t = f[j];
105 f[j] = reduce::barrett_reduce(t + f[j + len]);
106 f[j + len] -= t;
107 f[j + len] = reduce::mont_mul(zeta, f[j + len]);
108 }
109 }
110 }
111
112 const DIV_128_MONT: i16 = ((1 << 25) % Q as i32) as i16;
114
115 for a in f.iter_mut() {
116 *a = reduce::mont_mul(*a, DIV_128_MONT);
118 }
119 }
120
121 fn sample_ntt(xof: &mut hash::Shake128) -> Self {
123 let mut f: [MaybeUninit<i16>; N] = [MaybeUninit::uninit(); N];
124 let mut idx = 0;
125
126 while idx < N {
127 let bytes = xof.squeezeblock();
128
129 for d in bytes
130 .chunks_exact(3)
131 .flat_map(|b| {
132 let (b0, b1, b2) = (b[0] as u16, b[1] as u16, b[2] as u16);
133 let d1 = b0 | (b1 & 0xF) << 8;
134 let d2 = b1 >> 4 | b2 << 4;
135
136 [d1, d2]
137 })
138 .filter(|d| *d < Q as u16)
139 {
140 f[idx].write(d as i16);
141 idx += 1;
142
143 if idx == N {
144 break;
145 }
146 }
147 }
148
149 Self {
150 f: unsafe { transmute::<[MaybeUninit<i16>; N], [i16; N]>(f) },
151 }
152 }
153
154 fn sample_poly_cbd2(&mut self, bytes: &[u8; 128]) {
156 let f = &mut self.f;
157
158 for (i, bytes) in (0..N).step_by(8).zip(bytes.chunks_exact(4)) {
159 let t = u32::from_le_bytes(bytes.try_into().unwrap());
160
161 let d = (t & 0x55555555) + ((t >> 1) & 0x55555555);
163
164 for j in 0..8 {
165 let x = (d >> (j << 2)) & 3;
167 let y = (d >> ((j << 2) + 2)) & 3;
168 f[i + j] = x as i16 - y as i16;
169 }
170 }
171 }
172
173 fn multiply_ntts_acc(&mut self, f: &Poly, g: &Poly) {
175 let h = &mut self.f;
176 let f = &f.f;
177 let g = &g.f;
178
179 for i in (0..N).step_by(4) {
180 let zeta_idx = 64 + (i >> 2);
181
182 let a = basemul(f[i], f[i + 1], g[i], g[i + 1], ZETAS[zeta_idx]);
183 let b = basemul(f[i + 2], f[i + 3], g[i + 2], g[i + 3], -ZETAS[zeta_idx]);
184
185 h[i] += a.0;
186 h[i + 1] += a.1;
187 h[i + 2] += b.0;
188 h[i + 3] += b.1;
189 }
190 }
191
192 fn multiply_acc(&mut self, a: &PolyVec, b: &PolyVec) {
193 for (f, g) in a.vec.iter().zip(b.vec.iter()) {
194 self.multiply_ntts_acc(f, g);
195 }
196
197 self.reduce();
198 }
199
200 fn montgomery_form(&mut self) {
201 for a in self.f.iter_mut() {
202 *a = reduce::to_mont(*a);
203 }
204 }
205
206 fn reduce(&mut self) {
207 for a in self.f.iter_mut() {
208 *a = reduce::barrett_reduce(*a);
209 }
210 }
211
212 fn byte_encode(&self, bytes: &mut [u8; Poly::ENCODED_BYTES]) {
214 for (a, b) in self.f.chunks(2).zip(bytes.chunks_mut(3)) {
215 let (b0, b1, b2) = coeffs2bytes(a[0], a[1]);
216
217 b[0] = b0;
218 b[1] = b1;
219 b[2] = b2;
220 }
221 }
222
223 fn byte_decode(bytes: &[u8; Self::ENCODED_BYTES]) -> Self {
224 let mut coeffs: [MaybeUninit<i16>; N] = [MaybeUninit::uninit(); N];
225
226 for (a, b) in coeffs.chunks_exact_mut(2).zip(bytes.chunks_exact(3)) {
227 let (t0, t1) = bytes2coeffs(b[0], b[1], b[2]);
228
229 a[0].write(t0);
230 a[1].write(t1);
231 }
232
233 Self {
234 f: unsafe { mem::transmute::<[MaybeUninit<i16>; N], [i16; N]>(coeffs) },
235 }
236 }
237
238 fn compress(&self, bytes: &mut [u8; Self::COMPRESSED_BYTES]) {
239 for (b, a) in bytes.iter_mut().zip(self.f.chunks_exact(2)) {
240 let c: [u8; 2] = array::from_fn(|i| compr_4bit(a[i]));
241
242 *b = c[0] | c[1] << 4;
243 }
244 }
245
246 fn decompress(bytes: &[u8; Self::COMPRESSED_BYTES]) -> Self {
247 const MOD_MASK: u8 = (1 << DV) - 1;
248
249 let mut poly = Poly::zero();
250
251 for (a, b) in poly.f.chunks_exact_mut(2).zip(bytes.iter()) {
252 a[0] = decompr_4bit(b & MOD_MASK);
253 a[1] = decompr_4bit(b >> DV);
254 }
255
256 poly
257 }
258
259 fn generate_eta2<I>(r: &[u8; 32], nonce: &mut I) -> Self
260 where
261 I: Iterator<Item = usize>,
262 {
263 let mut poly = Poly::zero();
264
265 let mut prf = hash::Shake256::init();
266 prf.absorb_multi(&[r, &[nonce.next().unwrap() as u8]]);
267 let block = prf.squeezeblock();
268 poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
269 prf.reset();
270 poly
271 }
272
273 fn from_msg(m: &[u8; 32]) -> Self {
274 let mut poly = Poly::zero();
275
276 for (coeffs, byte) in poly.f.chunks_exact_mut(8).zip(m.iter()) {
277 for (a, bit) in coeffs.iter_mut().zip((0..8).map(|n| *byte >> n)) {
278 *a = decompr_1bit(bit);
279 }
280 }
281
282 poly
283 }
284
285 fn to_msg(&self, m: &mut [u8; 32]) {
286 for (byte, coeffs) in m.iter_mut().zip(self.f.chunks_exact(8)) {
287 for (i, a) in coeffs.iter().enumerate() {
288 *byte |= compr_1bit(*a) << i;
289 }
290 }
291 }
292}
293
294impl AddAssign<&Poly> for Poly {
295 fn add_assign(&mut self, rhs: &Poly) {
296 for (a, b) in self.f.iter_mut().zip(rhs.f.iter()) {
297 *a += b;
298 }
299 }
300}
301
302impl SubAssign<&Poly> for Poly {
303 fn sub_assign(&mut self, rhs: &Poly) {
304 for (a, b) in self.f.iter_mut().zip(rhs.f.iter()) {
305 *a -= b;
306 }
307 }
308}
309
310impl Display for Poly {
311 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
312 let mut coeffs = self.f.iter().enumerate().filter(|(_, &a)| a != 0);
313
314 match coeffs.next() {
315 Some((_, a)) => write!(f, "f(X) = {}", a)?,
316 None => return write!(f, "f(X) = 0"),
317 };
318
319 for (i, a) in coeffs {
320 write!(f, " + {}X^{}", a, i)?;
321 }
322
323 Ok(())
324 }
325}
326
327const fn coeffs2bytes(a: i16, b: i16) -> (u8, u8, u8) {
329 let t0 = a + ((a >> 15) & Q);
330 let t1 = b + ((b >> 15) & Q);
331
332 (t0 as u8, ((t0 >> 8) | (t1 << 4)) as u8, (t1 >> 4) as u8)
334}
335
336const fn bytes2coeffs(b0: u8, b1: u8, b2: u8) -> (i16, i16) {
338 let t0 = ((b0 as u16) | (b1 as u16) << 8) & 0xFFF;
339 let t1 = (((b1 as u16) >> 4) | (b2 as u16) << 4) & 0xFFF;
340
341 (t0 as i16, t1 as i16)
342}
343
344const fn basemul(a0: i16, a1: i16, b0: i16, b1: i16, zeta: i16) -> (i16, i16) {
349 let c0 = reduce::mont_mul(a0, b0) + reduce::mont_mul(reduce::mont_mul(a1, b1), zeta);
350 let c1 = reduce::mont_mul(a0, b1) + reduce::mont_mul(a1, b0);
351
352 (c0, c1)
353}
354
355#[derive(Debug, PartialEq)]
356struct PolyVec {
357 vec: [Poly; K],
358}
359
360impl PolyVec {
361 const BYTE_SIZE: usize = K * Poly::ENCODED_BYTES;
362 const COMPRESSED_POLY_BYTES: usize = (N * DU) / 8;
363 const COMPRESSED_BYTES: usize = K * Self::COMPRESSED_POLY_BYTES;
364
365 const fn zero() -> Self {
366 Self {
367 vec: [const { Poly::zero() }; K],
368 }
369 }
370
371 fn reduce(&mut self) {
372 for p in self.vec.iter_mut() {
373 p.reduce();
374 }
375 }
376
377 fn ntt(&mut self) {
378 for p in self.vec.iter_mut() {
379 p.ntt();
380 }
381 }
382
383 fn invntt(&mut self) {
384 for p in self.vec.iter_mut() {
385 p.invntt();
386 }
387 }
388
389 fn byte_encode<const BYTE_SIZE: usize>(&self, bytes: &mut [u8; BYTE_SIZE]) {
390 for (p, buf) in self
391 .vec
392 .iter()
393 .zip(bytes.chunks_exact_mut(Poly::ENCODED_BYTES))
394 {
395 p.byte_encode(buf.try_into().unwrap());
396 }
397 }
398
399 fn from_bytes(bytes: &[u8; K * Poly::ENCODED_BYTES]) -> Self {
400 let mut vec = [const { Poly::zero() }; K];
401
402 for (v, b) in vec.iter_mut().zip(bytes.chunks_exact(Poly::ENCODED_BYTES)) {
403 *v = Poly::byte_decode(unsafe { b.try_into().unwrap_unchecked() });
404 }
405
406 Self { vec }
407 }
408
409 fn compress(&self, bytes: &mut [u8; Self::COMPRESSED_BYTES]) {
410 for (p, b) in self
411 .vec
412 .iter()
413 .zip(bytes.chunks_exact_mut(Self::COMPRESSED_POLY_BYTES))
414 {
415 for (b, a) in b.chunks_exact_mut(5).zip(p.f.chunks_exact(4)) {
416 let t: [u16; 4] = array::from_fn(|i| compr_10bit(a[i]));
417
418 b[0] = t[0] as u8;
419 b[1] = ((t[0] >> 8) | (t[1] << 2)) as u8;
420 b[2] = ((t[1] >> 6) | (t[2] << 4)) as u8;
421 b[3] = ((t[2] >> 4) | (t[3] << 6)) as u8;
422 b[4] = (t[3] >> 2) as u8;
423 }
424 }
425 }
426
427 fn decompress(bytes: &[u8; Self::COMPRESSED_BYTES]) -> Self {
428 let mut pvec = PolyVec::zero();
429 for (p, b) in pvec
430 .vec
431 .iter_mut()
432 .zip(bytes.chunks_exact(Self::COMPRESSED_POLY_BYTES))
433 {
434 for (a, b) in p.f.chunks_exact_mut(4).zip(b.chunks_exact(5)) {
435 let mut t: [u16; 5] = array::from_fn(|i| b[i] as u16);
436 t[0] |= t[1] << 8;
437 t[1] = t[1] >> 2 | t[2] << 6;
438 t[2] = t[2] >> 4 | t[3] << 4;
439 t[3] = t[3] >> 6 | (t[4] << 2);
440
441 for (a, n) in a.iter_mut().zip(&t[..4]) {
442 *a = decompr_10bit(n & 0x3FF);
443 }
444 }
445 }
446
447 pvec
448 }
449
450 fn generate_eta2<I>(r: &[u8; 32], nonce: &mut I) -> Self
451 where
452 I: Iterator<Item = usize>,
453 {
454 let mut pvec = PolyVec::zero();
455
456 let mut prf = hash::Shake256::init();
457
458 for (poly, nonce) in pvec.vec.iter_mut().zip(nonce) {
459 prf.absorb_multi(&[r, &[nonce as u8]]);
460 let block = prf.squeezeblock();
461 poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
462 prf.reset();
463 }
464
465 pvec
466 }
467}
468
469impl AddAssign<&PolyVec> for PolyVec {
470 fn add_assign(&mut self, rhs: &PolyVec) {
471 for (f, g) in self.vec.iter_mut().zip(rhs.vec.iter()) {
472 f.add_assign(g);
473 }
474 }
475}
476
477impl Mul<&PolyVec> for &PolyVec {
478 type Output = Poly;
479
480 fn mul(self, rhs: &PolyVec) -> Self::Output {
481 let mut out = Poly::zero();
482
483 for (f, g) in self.vec.iter().zip(rhs.vec.iter()) {
484 out.multiply_ntts_acc(f, g);
485 }
486
487 out.reduce();
488
489 out
490 }
491}
492
493#[derive(Debug)]
494struct PolyMatrix {
495 m: [PolyVec; K],
496}
497
498impl PolyMatrix {
499 fn generate(xof: &mut hash::Shake128, rho: &[u8; 32]) -> Self {
500 let mut m: [MaybeUninit<PolyVec>; K] = [const { MaybeUninit::uninit() }; K];
501
502 for (i, pvec) in m.iter_mut().enumerate() {
503 let mut v: [MaybeUninit<Poly>; K] = [const { MaybeUninit::uninit() }; K];
504
505 for (j, poly) in v.iter_mut().enumerate() {
506 xof.absorb_multi(&[rho, &u16::to_le_bytes((j | (i << 8)) as u16)]);
507 poly.write(Poly::sample_ntt(xof));
508 xof.reset();
509 }
510
511 pvec.write(PolyVec {
512 vec: unsafe { transmute::<[MaybeUninit<Poly>; 3], [Poly; 3]>(v) },
513 });
514 }
515
516 Self {
517 m: unsafe { transmute::<[MaybeUninit<PolyVec>; 3], [PolyVec; 3]>(m) },
518 }
519 }
520
521 fn generate_transposed(xof: &mut hash::Shake128, rho: &[u8; 32]) -> Self {
522 let mut m: [MaybeUninit<PolyVec>; K] = [const { MaybeUninit::uninit() }; K];
523
524 for (i, pvec) in m.iter_mut().enumerate() {
525 let mut v: [MaybeUninit<Poly>; K] = [const { MaybeUninit::uninit() }; K];
526
527 for (j, poly) in v.iter_mut().enumerate() {
528 xof.absorb_multi(&[rho, &u16::to_le_bytes((i | (j << 8)) as u16)]);
529 poly.write(Poly::sample_ntt(xof));
530 xof.reset();
531 }
532
533 pvec.write(PolyVec {
534 vec: unsafe { transmute::<[MaybeUninit<Poly>; 3], [Poly; 3]>(v) },
535 });
536 }
537
538 Self {
539 m: unsafe { transmute::<[MaybeUninit<PolyVec>; 3], [PolyVec; 3]>(m) },
540 }
541 }
542}
543
544impl Mul<&PolyVec> for &PolyMatrix {
545 type Output = PolyVec;
546
547 fn mul(self, rhs: &PolyVec) -> Self::Output {
548 let mut out = PolyVec::zero();
549
550 for (poly, rowvec) in out.vec.iter_mut().zip(&self.m) {
551 poly.multiply_acc(rowvec, rhs);
552 }
553
554 out
555 }
556}
557
558fn generate_se(prf: &mut hash::Shake256, sigma: &[u8; 32]) -> (PolyVec, PolyVec) {
559 let mut s = PolyVec::zero();
560 let mut e = PolyVec::zero();
561
562 for (nonce, poly) in s.vec.iter_mut().chain(e.vec.iter_mut()).enumerate() {
563 prf.absorb_multi(&[sigma, &[nonce as u8]]);
564
565 let block = prf.squeezeblock();
566 poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
567
568 prf.reset();
569 poly.ntt();
570 }
571
572 (s, e)
573}
574
575struct PkeEncKey {
576 t: PolyVec,
577 rho: [u8; 32],
578}
579
580impl PkeEncKey {
581 const BYTE_SIZE: usize = PolyVec::BYTE_SIZE + 32;
582 const CIPHERTEXT_SIZE: usize = PolyVec::COMPRESSED_BYTES + Poly::COMPRESSED_BYTES;
583
584 fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
585 self.t.byte_encode(bytes);
586 bytes[PolyVec::BYTE_SIZE..].copy_from_slice(&self.rho);
587 }
588
589 fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
590 let (t_bytes, bytes) = bytes.split_first_chunk().unwrap();
591 let (rho, _) = bytes.split_first_chunk().unwrap();
592
593 let mut t = PolyVec::from_bytes(t_bytes);
594 t.reduce();
595
596 Self { t, rho: *rho }
597 }
598
599 fn encrypt(&self, c: &mut [u8; Self::CIPHERTEXT_SIZE], m: &[u8; 32], r: &[u8; 32]) {
601 let mut xof = hash::Shake128::init();
602 let at = PolyMatrix::generate_transposed(&mut xof, &self.rho);
603
604 let mut nonces = 0..(2 * K + 1);
605
606 let mut y = PolyVec::generate_eta2(r, &mut nonces);
607 let e1 = PolyVec::generate_eta2(r, &mut nonces);
608
609 let e2 = Poly::generate_eta2(r, &mut nonces);
610 y.ntt();
611
612 let mut u = &at * &y;
614 u.invntt();
615 u += &e1;
616 u.reduce();
617
618 let mu = Poly::from_msg(m);
619
620 let mut v = &self.t * &y;
622 v.invntt();
623 v += &e2;
624 v += μ
625 v.reduce();
626
627 let (c1, c2) = c.split_first_chunk_mut().unwrap();
628 let (c2, _) = c2.split_first_chunk_mut().unwrap();
629
630 u.compress(c1);
631 v.compress(c2);
632 }
633}
634
635struct PkeDecKey {
636 s: PolyVec,
637}
638
639impl PkeDecKey {
640 const BYTE_SIZE: usize = K * Poly::ENCODED_BYTES;
641
642 fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
643 self.s.byte_encode(bytes);
644 }
645
646 fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
647 let mut s = PolyVec::from_bytes(bytes);
648 s.reduce();
649
650 Self { s }
651 }
652
653 fn decrypt(&self, m: &mut [u8; 32], c: &[u8; PkeEncKey::CIPHERTEXT_SIZE]) {
655 let (c1, c2) = c.split_first_chunk().unwrap();
656 let (c2, _) = c2.split_first_chunk().unwrap();
657
658 let mut u_prime = PolyVec::decompress(c1);
659 let mut v_prime = Poly::decompress(c2);
660
661 u_prime.ntt();
662 let mut w = &self.s * &u_prime;
663 w.invntt();
664
665 v_prime -= &w;
666 v_prime.reduce();
667
668 v_prime.to_msg(m);
669 }
670}
671
672fn pke_keygen(d: &[u8; 32]) -> (PkeEncKey, PkeDecKey) {
674 let (rho, sigma) = hash::sha3_512_split(&[d, &[K as u8]]);
675
676 let mut xof = hash::Shake128::init();
677 let a = PolyMatrix::generate(&mut xof, &rho);
678
679 let mut prf = hash::Shake256::init();
680
681 let (s, e) = generate_se(&mut prf, &sigma);
682
683 let mut t: PolyVec = PolyVec::zero();
684
685 for i in 0..K {
686 t.vec[i].multiply_acc(&a.m[i], &s);
687 t.vec[i].montgomery_form();
688 }
689
690 t += &e;
691 t.reduce();
692
693 (PkeEncKey { t, rho }, PkeDecKey { s })
694}
695
696pub struct EncapsKey {
698 ek_pke: PkeEncKey,
699}
700
701impl EncapsKey {
702 pub const BYTE_SIZE: usize = PkeEncKey::BYTE_SIZE;
704
705 pub const CIPHERTEXT_SIZE: usize = PkeEncKey::CIPHERTEXT_SIZE;
707
708 #[inline]
710 pub fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
711 self.ek_pke.to_bytes(bytes);
712 }
713
714 #[inline]
716 pub fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
717 let ek_pke = PkeEncKey::from_bytes(bytes);
718
719 Self { ek_pke }
720 }
721
722 fn encaps_internal(
724 &self,
725 c: &mut [u8; PkeEncKey::CIPHERTEXT_SIZE],
726 k: &mut [u8; 32],
727 m: &[u8; 32],
728 ) {
729 let mut bytes = [0u8; Self::BYTE_SIZE];
730 self.to_bytes(&mut bytes);
731 let h = hash::sha3_256(&[&bytes]);
732
733 let (key, r) = hash::sha3_512_split(&[m, &h]);
734
735 self.ek_pke.encrypt(c, m, &r);
736
737 k.copy_from_slice(&key);
738 }
739
740 #[inline]
742 pub fn encaps(
743 &self,
744 c: &mut [u8; Self::CIPHERTEXT_SIZE],
745 k: &mut [u8; 32],
746 rng: &mut impl CryptoRngCore,
747 ) {
748 let mut m = [0u8; 32];
749 rng.fill_bytes(&mut m);
750 self.encaps_internal(c, k, &m);
751 }
752}
753
754pub struct DecapsKey {
756 dk_pke: PkeDecKey,
757 h: [u8; 32],
758 z: [u8; 32],
759}
760
761impl DecapsKey {
762 pub const BYTE_SIZE: usize = PkeDecKey::BYTE_SIZE + PkeEncKey::BYTE_SIZE + 32 + 32;
764
765 #[inline]
767 pub fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE], ek: &EncapsKey) {
768 let (dk_bytes, bytes) = bytes.split_first_chunk_mut().unwrap();
769 let (ek_bytes, bytes) = bytes.split_first_chunk_mut().unwrap();
770 let (ek_hash, bytes): (&mut [u8; 32], _) = bytes.split_first_chunk_mut().unwrap();
771 let (z, _): (&mut [u8; 32], _) = bytes.split_first_chunk_mut().unwrap();
772
773 self.dk_pke.to_bytes(dk_bytes);
774 ek.ek_pke.to_bytes(ek_bytes);
775 hash::sha3_256_into(ek_hash, &[ek_bytes]);
776 z.copy_from_slice(&self.z);
777 }
778
779 #[inline]
781 pub fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
782 let (dk_bytes, bytes) = bytes.split_first_chunk().unwrap();
783 let (_ek_bytes, bytes): (&[u8; PkeEncKey::BYTE_SIZE], _) =
784 bytes.split_first_chunk().unwrap();
785 let (h, bytes) = bytes.split_first_chunk().unwrap();
786 let (z_bytes, _) = bytes.split_first_chunk().unwrap();
787
788 let dk_pke = PkeDecKey::from_bytes(dk_bytes);
789
790 Self {
791 dk_pke,
792 h: *h,
793 z: *z_bytes,
794 }
795 }
796
797 #[inline]
800 pub fn decaps(&self, k: &mut [u8; 32], ek: &EncapsKey, c: &[u8; EncapsKey::CIPHERTEXT_SIZE]) {
801 let mut m_prime = [0u8; 32];
802 self.dk_pke.decrypt(&mut m_prime, c);
803
804 let (k_prime, r_prime) = hash::sha3_512_split(&[&m_prime, &self.h]);
805
806 let mut j = hash::Shake256::init();
807 j.absorb_multi(&[&self.z, c]);
808 k.copy_from_slice(&j.squeezeblock()[..32]);
809
810 let mut c_prime = [0u8; EncapsKey::CIPHERTEXT_SIZE];
811 ek.ek_pke.encrypt(&mut c_prime, &m_prime, &r_prime);
812
813 cmov(k, &k_prime, bytes_is_eq(c, &c_prime));
814 }
815}
816
817const fn bytes_is_eq<const N: usize>(a: &[u8; N], b: &[u8; N]) -> u32 {
820 let mut i = 0;
821 let mut cond = 0;
822
823 while i < N {
824 cond |= (a[i] ^ b[i]) as u32;
825
826 i += 1;
827 }
828
829 (cond.wrapping_neg() >> 31) ^ 1
830}
831
832fn cmov<const N: usize>(dst: &mut [u8; N], src: &[u8; N], cond: u32) {
834 let cond = black_box(cond).wrapping_neg() as u8;
835
836 for (a, b) in dst.iter_mut().zip(src.iter()) {
837 *a ^= cond & (*a ^ *b);
838 }
839}
840
841#[inline]
843pub fn keygen(rng: &mut impl CryptoRngCore) -> (EncapsKey, DecapsKey) {
844 let mut d = [0u8; 32];
845 rng.fill_bytes(&mut d);
846
847 let mut z = [0u8; 32];
848 rng.fill_bytes(&mut z);
849
850 keygen_deterministic(d, z)
851}
852
853fn keygen_deterministic(d: [u8; 32], z: [u8; 32]) -> (EncapsKey, DecapsKey) {
854 let (ek_pke, dk_pke) = pke_keygen(&d);
855
856 let ek = EncapsKey { ek_pke };
857
858 let mut ek_bytes = [0u8; EncapsKey::BYTE_SIZE];
859 ek.to_bytes(&mut ek_bytes);
860
861 let h = hash::sha3_256(&[&ek_bytes]);
862
863 (ek, DecapsKey { dk_pke, h, z })
864}
865
866#[cfg(test)]
867mod tests {
868 use rand_core::OsRng;
869 use serde::Deserialize;
870 use std::{fs::read_to_string, path::PathBuf};
871
872 use super::*;
873
874 #[test]
875 fn test_keygen() {
876 let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
877 test_data_path.push("tests/kyber-keygen.json");
878
879 let test_data = read_to_string(&test_data_path).unwrap();
880 let test_data: Tests<KeyGenTestGroup> = serde_json::from_str(&test_data).unwrap();
881
882 for test_group in test_data
883 .test_groups
884 .iter()
885 .filter(|g| g.parameter_set == "ML-KEM-768")
886 {
887 for test in test_group.tests.iter() {
888 assert_eq!(test.ek.len(), EncapsKey::BYTE_SIZE);
890 assert_eq!(test.dk.len(), DecapsKey::BYTE_SIZE);
891
892 let (ek, dk) = keygen_deterministic(test.d, test.z);
893
894 let test_dk = DecapsKey::from_bytes(test.dk.as_slice().try_into().unwrap());
896 assert_eq!(test_dk.z, test.z);
897 assert_eq!(dk.z, test.z);
898 assert_eq!(test_dk.dk_pke.s, dk.dk_pke.s);
899
900 let test_ek = EncapsKey::from_bytes(test.ek.as_slice().try_into().unwrap());
902 assert_eq!(test_ek.ek_pke.rho, ek.ek_pke.rho);
903 assert_eq!(test_ek.ek_pke.t.vec, ek.ek_pke.t.vec);
904
905 let mut ek_bytes = [0u8; EncapsKey::BYTE_SIZE];
907 ek.to_bytes(&mut ek_bytes);
908 assert_eq!(ek_bytes, test.ek.as_slice());
909
910 let mut dk_bytes = [0u8; DecapsKey::BYTE_SIZE];
912 dk.to_bytes(&mut dk_bytes, &ek);
913 assert_eq!(dk_bytes, test.dk.as_slice());
914 }
915 }
916 }
917
918 #[test]
919 fn test_kem() {
920 let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
921 test_data_path.push("tests/kyber-kem.json");
922
923 let test_data = read_to_string(&test_data_path).unwrap();
924 let test_data: Tests<KemTestGroup> = serde_json::from_str(&test_data).unwrap();
925
926 for test_group in test_data
927 .test_groups
928 .iter()
929 .filter(|g| g.parameter_set == "ML-KEM-768")
930 {
931 match &test_group.params {
932 KemTestGroupKind::Aft { tests } => {
933 for test in tests.iter() {
934 assert_eq!(test.c.len(), EncapsKey::CIPHERTEXT_SIZE);
935 let ek = EncapsKey::from_bytes(test.ek.as_slice().try_into().unwrap());
936 let dk = DecapsKey::from_bytes(test.dk.as_slice().try_into().unwrap());
937
938 let mut c = [0u8; EncapsKey::CIPHERTEXT_SIZE];
939 let mut k = [0u8; 32];
940 ek.encaps_internal(&mut c, &mut k, test.m.as_slice().try_into().unwrap());
941
942 assert_eq!(c, test.c.as_slice());
943 assert_eq!(k, test.k.as_slice());
944
945 let mut k_prime = [0u8; 32];
946 dk.decaps(&mut k_prime, &ek, &c);
947 assert_eq!(&k, &k_prime);
948 }
949 }
950 KemTestGroupKind::Val { tests, dk, ek } => {
951 let ek = EncapsKey::from_bytes(ek.as_slice().try_into().unwrap());
952 let dk = DecapsKey::from_bytes(dk.as_slice().try_into().unwrap());
953 for test in tests.iter() {
954 assert_eq!(test.c.len(), EncapsKey::CIPHERTEXT_SIZE);
955
956 let mut k = [0u8; 32];
957 dk.decaps(&mut k, &ek, test.c[..].try_into().unwrap());
958
959 assert_eq!(&k, &test.k[..]);
960 }
961 }
962 }
963 }
964 }
965
966 #[test]
967 fn test_kem_random() {
968 let (ek, dk) = keygen(&mut OsRng);
969 let mut c = [0u8; EncapsKey::CIPHERTEXT_SIZE];
970 let mut k = [0u8; 32];
971 ek.encaps(&mut c, &mut k, &mut OsRng);
972
973 let mut k_prime = [0u8; 32];
974 dk.decaps(&mut k_prime, &ek, &c);
975
976 assert_eq!(&k, &k_prime);
977 }
978
979 fn gen_rand_bytes<const N: usize>(rng: &mut impl CryptoRngCore) -> [u8; N] {
980 let mut bytes = [0; N];
981 rng.fill_bytes(&mut bytes);
982 bytes
983 }
984
985 #[test]
986 fn test_compress() {
987 let compr_pvec = gen_rand_bytes(&mut OsRng);
988 let mut compr_pvec_prime = [0; PolyVec::COMPRESSED_BYTES];
989 let pvec = PolyVec::decompress(&compr_pvec);
990 pvec.compress(&mut compr_pvec_prime);
991 assert_eq!(&compr_pvec, &compr_pvec_prime);
992
993 let compr_poly = gen_rand_bytes(&mut OsRng);
994 let mut compr_poly_prime = [0; Poly::COMPRESSED_BYTES];
995 let poly = Poly::decompress(&compr_poly);
996 poly.compress(&mut compr_poly_prime);
997 assert_eq!(&compr_poly, &compr_poly_prime)
998 }
999
1000 #[derive(Deserialize)]
1001 struct Tests<T> {
1002 #[serde(rename = "isSample")]
1003 _is_sample: bool,
1004
1005 #[serde(rename = "testGroups")]
1006 test_groups: Vec<T>,
1007
1008 #[serde(rename = "vsId")]
1009 _vs_id: i64,
1010 }
1011
1012 #[derive(Deserialize)]
1013 struct KeyGenTestGroup {
1014 #[serde(rename = "parameterSet")]
1015 parameter_set: String,
1016
1017 #[serde(rename = "testType")]
1018 _test_type: String,
1019
1020 tests: Vec<KeyGenTestVector>,
1021
1022 #[serde(rename = "tgId")]
1023 _tg_id: i64,
1024 }
1025
1026 #[derive(Deserialize)]
1027 struct KeyGenTestVector {
1028 #[serde(with = "hex")]
1029 d: [u8; 32],
1030
1031 #[serde(with = "hex")]
1032 z: [u8; 32],
1033
1034 #[serde(with = "hex")]
1035 dk: Vec<u8>,
1036
1037 #[serde(with = "hex")]
1038 ek: Vec<u8>,
1039
1040 #[serde(rename = "tcId")]
1041 _tc_id: i64,
1042 }
1043
1044 #[derive(Deserialize)]
1045 struct KemTestVectorAft {
1046 #[serde(with = "hex")]
1047 c: Vec<u8>,
1048
1049 #[serde(with = "hex")]
1050 dk: Vec<u8>,
1051
1052 #[serde(with = "hex")]
1053 ek: Vec<u8>,
1054
1055 #[serde(with = "hex")]
1056 k: Vec<u8>,
1057
1058 #[serde(with = "hex")]
1059 m: Vec<u8>,
1060
1061 #[serde(rename = "tcId")]
1062 _tc_id: i64,
1063 }
1064
1065 #[derive(Deserialize)]
1066 struct KemTestVectorVal {
1067 #[serde(with = "hex")]
1068 c: Vec<u8>,
1069
1070 #[serde(with = "hex")]
1071 k: Vec<u8>,
1072
1073 #[serde(rename = "tcId")]
1074 _tc_id: i64,
1075 }
1076
1077 #[derive(Deserialize)]
1078 struct KemTestGroup {
1079 #[serde(rename = "parameterSet")]
1080 parameter_set: String,
1081
1082 #[serde(rename = "tgId")]
1083 _tg_id: i64,
1084
1085 #[serde(flatten)]
1086 params: KemTestGroupKind,
1087 }
1088
1089 #[derive(Deserialize)]
1090 #[serde(tag = "testType")]
1091 enum KemTestGroupKind {
1092 #[serde(rename = "AFT")]
1093 Aft { tests: Vec<KemTestVectorAft> },
1094 #[serde(rename = "VAL")]
1095 Val {
1096 tests: Vec<KemTestVectorVal>,
1097
1098 #[serde(with = "hex")]
1099 dk: Vec<u8>,
1100
1101 #[serde(with = "hex")]
1102 ek: Vec<u8>,
1103 },
1104 }
1105}