Skip to main content

pqcrypto_std/mlkem/
mod.rs

1//! Implementation of ML-KEM (FIPS-203)
2mod compress;
3mod reduce;
4
5use compress::{compr_10bit, compr_1bit, compr_4bit, decompr_10bit, decompr_1bit, decompr_4bit};
6use core::{
7    array,
8    fmt::Display,
9    hint::black_box,
10    mem::{self, transmute, MaybeUninit},
11    ops::{AddAssign, Mul, SubAssign},
12};
13use rand_core::CryptoRngCore;
14
15use crate::hash;
16
17const N: usize = 256;
18const K: usize = 3;
19const Q: i16 = 3329;
20const DU: usize = 10;
21const DV: usize = 4;
22
23const COEFFICIENT_BITSIZE: usize = 12;
24
25/// pre-computed zetas in montgomery form
26/// ordered by ZETAS\[i\] = z^BitRev7(i)
27/// zeta -> zeta * R (mod Q)
28const ZETAS: [i16; 128] = {
29    const ZETA1: i16 = reduce::to_mont(17);
30
31    let mut zetas = [0; 128];
32    zetas[0] = reduce::R_MOD_Q as i16;
33
34    let mut i = 1;
35    while i < 128 {
36        zetas[i] = reduce::mont_mul(zetas[i - 1], ZETA1);
37
38        i += 1
39    }
40
41    let mut zetas_bitrev = [0; 128];
42
43    i = 0;
44    while i < 128 {
45        let idx = (i as u8).reverse_bits() >> 1;
46
47        zetas_bitrev[i] = match zetas[idx as usize] {
48            z if z > Q / 2 => z - Q,
49            z if z < -Q / 2 => z + Q,
50            z => z,
51        };
52
53        i += 1;
54    }
55
56    zetas_bitrev
57};
58
59#[derive(Debug, PartialEq)]
60struct Poly {
61    f: [i16; N],
62}
63
64impl Poly {
65    const ENCODED_BYTES: usize = (COEFFICIENT_BITSIZE * N) / 8;
66    const COMPRESSED_BYTES: usize = (N * DV) / 8;
67
68    const fn zero() -> Self {
69        Self { f: [0; N] }
70    }
71
72    /// Algorithm 9 NTT(f)
73    fn ntt(&mut self) {
74        let f = &mut self.f;
75
76        let mut k = 1;
77
78        for len in (0..7).map(|n| 128 >> n) {
79            for start in (0..256).step_by(len << 1) {
80                let zeta = ZETAS[k];
81                k += 1;
82                for j in start..start + len {
83                    let t = reduce::mont_mul(zeta, f[j + len]);
84                    f[j + len] = f[j] - t;
85                    f[j] += t;
86                }
87            }
88        }
89
90        self.reduce();
91    }
92
93    /// Algorithm 10 NTT^-1 (f_hat)
94    fn invntt(&mut self) {
95        let f = &mut self.f;
96
97        let mut k = 127;
98
99        for len in (0..7).map(|n| 2 << n) {
100            for start in (0..256).step_by(len << 1) {
101                let zeta = ZETAS[k];
102                k -= 1;
103                for j in start..start + len {
104                    let t = f[j];
105                    f[j] = reduce::barrett_reduce(t + f[j + len]);
106                    f[j + len] -= t;
107                    f[j + len] = reduce::mont_mul(zeta, f[j + len]);
108                }
109            }
110        }
111
112        // (2^16)^2 / 128 = 2^{25}
113        const DIV_128_MONT: i16 = ((1 << 25) % Q as i32) as i16;
114
115        for a in f.iter_mut() {
116            // a = (a * R) / 128 (mod Q)
117            *a = reduce::mont_mul(*a, DIV_128_MONT);
118        }
119    }
120
121    /// Algorithm 7 SampleNTT(B)
122    fn sample_ntt(xof: &mut hash::Shake128) -> Self {
123        let mut f: [MaybeUninit<i16>; N] = [MaybeUninit::uninit(); N];
124        let mut idx = 0;
125
126        while idx < N {
127            let bytes = xof.squeezeblock();
128
129            for d in bytes
130                .chunks_exact(3)
131                .flat_map(|b| {
132                    let (b0, b1, b2) = (b[0] as u16, b[1] as u16, b[2] as u16);
133                    let d1 = b0 | (b1 & 0xF) << 8;
134                    let d2 = b1 >> 4 | b2 << 4;
135
136                    [d1, d2]
137                })
138                .filter(|d| *d < Q as u16)
139            {
140                f[idx].write(d as i16);
141                idx += 1;
142
143                if idx == N {
144                    break;
145                }
146            }
147        }
148
149        Self {
150            f: unsafe { transmute::<[MaybeUninit<i16>; N], [i16; N]>(f) },
151        }
152    }
153
154    /// Algorithm 8 SamplePolyCBD_2 (B)
155    fn sample_poly_cbd2(&mut self, bytes: &[u8; 128]) {
156        let f = &mut self.f;
157
158        for (i, bytes) in (0..N).step_by(8).zip(bytes.chunks_exact(4)) {
159            let t = u32::from_le_bytes(bytes.try_into().unwrap());
160
161            // add bits to each other in groups of two
162            let d = (t & 0x55555555) + ((t >> 1) & 0x55555555);
163
164            for j in 0..8 {
165                // extract two 2-bit numbers
166                let x = (d >> (j << 2)) & 3;
167                let y = (d >> ((j << 2) + 2)) & 3;
168                f[i + j] = x as i16 - y as i16;
169            }
170        }
171    }
172
173    /// Algorithm 11 MultiplyNTTs(f_hat, g_hat)
174    fn multiply_ntts_acc(&mut self, f: &Poly, g: &Poly) {
175        let h = &mut self.f;
176        let f = &f.f;
177        let g = &g.f;
178
179        for i in (0..N).step_by(4) {
180            let zeta_idx = 64 + (i >> 2);
181
182            let a = basemul(f[i], f[i + 1], g[i], g[i + 1], ZETAS[zeta_idx]);
183            let b = basemul(f[i + 2], f[i + 3], g[i + 2], g[i + 3], -ZETAS[zeta_idx]);
184
185            h[i] += a.0;
186            h[i + 1] += a.1;
187            h[i + 2] += b.0;
188            h[i + 3] += b.1;
189        }
190    }
191
192    fn multiply_acc(&mut self, a: &PolyVec, b: &PolyVec) {
193        for (f, g) in a.vec.iter().zip(b.vec.iter()) {
194            self.multiply_ntts_acc(f, g);
195        }
196
197        self.reduce();
198    }
199
200    fn montgomery_form(&mut self) {
201        for a in self.f.iter_mut() {
202            *a = reduce::to_mont(*a);
203        }
204    }
205
206    fn reduce(&mut self) {
207        for a in self.f.iter_mut() {
208            *a = reduce::barrett_reduce(*a);
209        }
210    }
211
212    /// Encode coefficients as bits in little endian order.
213    fn byte_encode(&self, bytes: &mut [u8; Poly::ENCODED_BYTES]) {
214        for (a, b) in self.f.chunks(2).zip(bytes.chunks_mut(3)) {
215            let (b0, b1, b2) = coeffs2bytes(a[0], a[1]);
216
217            b[0] = b0;
218            b[1] = b1;
219            b[2] = b2;
220        }
221    }
222
223    fn byte_decode(bytes: &[u8; Self::ENCODED_BYTES]) -> Self {
224        let mut coeffs: [MaybeUninit<i16>; N] = [MaybeUninit::uninit(); N];
225
226        for (a, b) in coeffs.chunks_exact_mut(2).zip(bytes.chunks_exact(3)) {
227            let (t0, t1) = bytes2coeffs(b[0], b[1], b[2]);
228
229            a[0].write(t0);
230            a[1].write(t1);
231        }
232
233        Self {
234            f: unsafe { mem::transmute::<[MaybeUninit<i16>; N], [i16; N]>(coeffs) },
235        }
236    }
237
238    fn compress(&self, bytes: &mut [u8; Self::COMPRESSED_BYTES]) {
239        for (b, a) in bytes.iter_mut().zip(self.f.chunks_exact(2)) {
240            let c: [u8; 2] = array::from_fn(|i| compr_4bit(a[i]));
241
242            *b = c[0] | c[1] << 4;
243        }
244    }
245
246    fn decompress(bytes: &[u8; Self::COMPRESSED_BYTES]) -> Self {
247        const MOD_MASK: u8 = (1 << DV) - 1;
248
249        let mut poly = Poly::zero();
250
251        for (a, b) in poly.f.chunks_exact_mut(2).zip(bytes.iter()) {
252            a[0] = decompr_4bit(b & MOD_MASK);
253            a[1] = decompr_4bit(b >> DV);
254        }
255
256        poly
257    }
258
259    fn generate_eta2<I>(r: &[u8; 32], nonce: &mut I) -> Self
260    where
261        I: Iterator<Item = usize>,
262    {
263        let mut poly = Poly::zero();
264
265        let mut prf = hash::Shake256::init();
266        prf.absorb_multi(&[r, &[nonce.next().unwrap() as u8]]);
267        let block = prf.squeezeblock();
268        poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
269        prf.reset();
270        poly
271    }
272
273    fn from_msg(m: &[u8; 32]) -> Self {
274        let mut poly = Poly::zero();
275
276        for (coeffs, byte) in poly.f.chunks_exact_mut(8).zip(m.iter()) {
277            for (a, bit) in coeffs.iter_mut().zip((0..8).map(|n| *byte >> n)) {
278                *a = decompr_1bit(bit);
279            }
280        }
281
282        poly
283    }
284
285    fn to_msg(&self, m: &mut [u8; 32]) {
286        for (byte, coeffs) in m.iter_mut().zip(self.f.chunks_exact(8)) {
287            for (i, a) in coeffs.iter().enumerate() {
288                *byte |= compr_1bit(*a) << i;
289            }
290        }
291    }
292}
293
294impl AddAssign<&Poly> for Poly {
295    fn add_assign(&mut self, rhs: &Poly) {
296        for (a, b) in self.f.iter_mut().zip(rhs.f.iter()) {
297            *a += b;
298        }
299    }
300}
301
302impl SubAssign<&Poly> for Poly {
303    fn sub_assign(&mut self, rhs: &Poly) {
304        for (a, b) in self.f.iter_mut().zip(rhs.f.iter()) {
305            *a -= b;
306        }
307    }
308}
309
310impl Display for Poly {
311    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
312        let mut coeffs = self.f.iter().enumerate().filter(|(_, &a)| a != 0);
313
314        match coeffs.next() {
315            Some((_, a)) => write!(f, "f(X) = {}", a)?,
316            None => return write!(f, "f(X) = 0"),
317        };
318
319        for (i, a) in coeffs {
320            write!(f, " + {}X^{}", a, i)?;
321        }
322
323        Ok(())
324    }
325}
326
327/// Convert 2 integers mod Q into 3 bytes.
328const fn coeffs2bytes(a: i16, b: i16) -> (u8, u8, u8) {
329    let t0 = a + ((a >> 15) & Q);
330    let t1 = b + ((b >> 15) & Q);
331
332    // encode 2 coeficcients (12 bit) into 24 bits
333    (t0 as u8, ((t0 >> 8) | (t1 << 4)) as u8, (t1 >> 4) as u8)
334}
335
336/// Convert 3 bytes into 2 integers mod Q.
337const fn bytes2coeffs(b0: u8, b1: u8, b2: u8) -> (i16, i16) {
338    let t0 = ((b0 as u16) | (b1 as u16) << 8) & 0xFFF;
339    let t1 = (((b1 as u16) >> 4) | (b2 as u16) << 4) & 0xFFF;
340
341    (t0 as i16, t1 as i16)
342}
343
344/// Algorithm 12 BaseCaseMultiply(a_0, a_1, b_0, b_1, zeta)
345/// Compute:
346/// - c0 = (a0*b0 + a1*b1*zeta)R^-1 (mod Q)
347/// - c1 = (a0*b1 + a1*b0)R^-1 (mod Q)
348const fn basemul(a0: i16, a1: i16, b0: i16, b1: i16, zeta: i16) -> (i16, i16) {
349    let c0 = reduce::mont_mul(a0, b0) + reduce::mont_mul(reduce::mont_mul(a1, b1), zeta);
350    let c1 = reduce::mont_mul(a0, b1) + reduce::mont_mul(a1, b0);
351
352    (c0, c1)
353}
354
355#[derive(Debug, PartialEq)]
356struct PolyVec {
357    vec: [Poly; K],
358}
359
360impl PolyVec {
361    const BYTE_SIZE: usize = K * Poly::ENCODED_BYTES;
362    const COMPRESSED_POLY_BYTES: usize = (N * DU) / 8;
363    const COMPRESSED_BYTES: usize = K * Self::COMPRESSED_POLY_BYTES;
364
365    const fn zero() -> Self {
366        Self {
367            vec: [const { Poly::zero() }; K],
368        }
369    }
370
371    fn reduce(&mut self) {
372        for p in self.vec.iter_mut() {
373            p.reduce();
374        }
375    }
376
377    fn ntt(&mut self) {
378        for p in self.vec.iter_mut() {
379            p.ntt();
380        }
381    }
382
383    fn invntt(&mut self) {
384        for p in self.vec.iter_mut() {
385            p.invntt();
386        }
387    }
388
389    fn byte_encode<const BYTE_SIZE: usize>(&self, bytes: &mut [u8; BYTE_SIZE]) {
390        for (p, buf) in self
391            .vec
392            .iter()
393            .zip(bytes.chunks_exact_mut(Poly::ENCODED_BYTES))
394        {
395            p.byte_encode(buf.try_into().unwrap());
396        }
397    }
398
399    fn from_bytes(bytes: &[u8; K * Poly::ENCODED_BYTES]) -> Self {
400        let mut vec = [const { Poly::zero() }; K];
401
402        for (v, b) in vec.iter_mut().zip(bytes.chunks_exact(Poly::ENCODED_BYTES)) {
403            *v = Poly::byte_decode(unsafe { b.try_into().unwrap_unchecked() });
404        }
405
406        Self { vec }
407    }
408
409    fn compress(&self, bytes: &mut [u8; Self::COMPRESSED_BYTES]) {
410        for (p, b) in self
411            .vec
412            .iter()
413            .zip(bytes.chunks_exact_mut(Self::COMPRESSED_POLY_BYTES))
414        {
415            for (b, a) in b.chunks_exact_mut(5).zip(p.f.chunks_exact(4)) {
416                let t: [u16; 4] = array::from_fn(|i| compr_10bit(a[i]));
417
418                b[0] = t[0] as u8;
419                b[1] = ((t[0] >> 8) | (t[1] << 2)) as u8;
420                b[2] = ((t[1] >> 6) | (t[2] << 4)) as u8;
421                b[3] = ((t[2] >> 4) | (t[3] << 6)) as u8;
422                b[4] = (t[3] >> 2) as u8;
423            }
424        }
425    }
426
427    fn decompress(bytes: &[u8; Self::COMPRESSED_BYTES]) -> Self {
428        let mut pvec = PolyVec::zero();
429        for (p, b) in pvec
430            .vec
431            .iter_mut()
432            .zip(bytes.chunks_exact(Self::COMPRESSED_POLY_BYTES))
433        {
434            for (a, b) in p.f.chunks_exact_mut(4).zip(b.chunks_exact(5)) {
435                let mut t: [u16; 5] = array::from_fn(|i| b[i] as u16);
436                t[0] |= t[1] << 8;
437                t[1] = t[1] >> 2 | t[2] << 6;
438                t[2] = t[2] >> 4 | t[3] << 4;
439                t[3] = t[3] >> 6 | (t[4] << 2);
440
441                for (a, n) in a.iter_mut().zip(&t[..4]) {
442                    *a = decompr_10bit(n & 0x3FF);
443                }
444            }
445        }
446
447        pvec
448    }
449
450    fn generate_eta2<I>(r: &[u8; 32], nonce: &mut I) -> Self
451    where
452        I: Iterator<Item = usize>,
453    {
454        let mut pvec = PolyVec::zero();
455
456        let mut prf = hash::Shake256::init();
457
458        for (poly, nonce) in pvec.vec.iter_mut().zip(nonce) {
459            prf.absorb_multi(&[r, &[nonce as u8]]);
460            let block = prf.squeezeblock();
461            poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
462            prf.reset();
463        }
464
465        pvec
466    }
467}
468
469impl AddAssign<&PolyVec> for PolyVec {
470    fn add_assign(&mut self, rhs: &PolyVec) {
471        for (f, g) in self.vec.iter_mut().zip(rhs.vec.iter()) {
472            f.add_assign(g);
473        }
474    }
475}
476
477impl Mul<&PolyVec> for &PolyVec {
478    type Output = Poly;
479
480    fn mul(self, rhs: &PolyVec) -> Self::Output {
481        let mut out = Poly::zero();
482
483        for (f, g) in self.vec.iter().zip(rhs.vec.iter()) {
484            out.multiply_ntts_acc(f, g);
485        }
486
487        out.reduce();
488
489        out
490    }
491}
492
493#[derive(Debug)]
494struct PolyMatrix {
495    m: [PolyVec; K],
496}
497
498impl PolyMatrix {
499    fn generate(xof: &mut hash::Shake128, rho: &[u8; 32]) -> Self {
500        let mut m: [MaybeUninit<PolyVec>; K] = [const { MaybeUninit::uninit() }; K];
501
502        for (i, pvec) in m.iter_mut().enumerate() {
503            let mut v: [MaybeUninit<Poly>; K] = [const { MaybeUninit::uninit() }; K];
504
505            for (j, poly) in v.iter_mut().enumerate() {
506                xof.absorb_multi(&[rho, &u16::to_le_bytes((j | (i << 8)) as u16)]);
507                poly.write(Poly::sample_ntt(xof));
508                xof.reset();
509            }
510
511            pvec.write(PolyVec {
512                vec: unsafe { transmute::<[MaybeUninit<Poly>; 3], [Poly; 3]>(v) },
513            });
514        }
515
516        Self {
517            m: unsafe { transmute::<[MaybeUninit<PolyVec>; 3], [PolyVec; 3]>(m) },
518        }
519    }
520
521    fn generate_transposed(xof: &mut hash::Shake128, rho: &[u8; 32]) -> Self {
522        let mut m: [MaybeUninit<PolyVec>; K] = [const { MaybeUninit::uninit() }; K];
523
524        for (i, pvec) in m.iter_mut().enumerate() {
525            let mut v: [MaybeUninit<Poly>; K] = [const { MaybeUninit::uninit() }; K];
526
527            for (j, poly) in v.iter_mut().enumerate() {
528                xof.absorb_multi(&[rho, &u16::to_le_bytes((i | (j << 8)) as u16)]);
529                poly.write(Poly::sample_ntt(xof));
530                xof.reset();
531            }
532
533            pvec.write(PolyVec {
534                vec: unsafe { transmute::<[MaybeUninit<Poly>; 3], [Poly; 3]>(v) },
535            });
536        }
537
538        Self {
539            m: unsafe { transmute::<[MaybeUninit<PolyVec>; 3], [PolyVec; 3]>(m) },
540        }
541    }
542}
543
544impl Mul<&PolyVec> for &PolyMatrix {
545    type Output = PolyVec;
546
547    fn mul(self, rhs: &PolyVec) -> Self::Output {
548        let mut out = PolyVec::zero();
549
550        for (poly, rowvec) in out.vec.iter_mut().zip(&self.m) {
551            poly.multiply_acc(rowvec, rhs);
552        }
553
554        out
555    }
556}
557
558fn generate_se(prf: &mut hash::Shake256, sigma: &[u8; 32]) -> (PolyVec, PolyVec) {
559    let mut s = PolyVec::zero();
560    let mut e = PolyVec::zero();
561
562    for (nonce, poly) in s.vec.iter_mut().chain(e.vec.iter_mut()).enumerate() {
563        prf.absorb_multi(&[sigma, &[nonce as u8]]);
564
565        let block = prf.squeezeblock();
566        poly.sample_poly_cbd2(&block[..128].try_into().unwrap());
567
568        prf.reset();
569        poly.ntt();
570    }
571
572    (s, e)
573}
574
575struct PkeEncKey {
576    t: PolyVec,
577    rho: [u8; 32],
578}
579
580impl PkeEncKey {
581    const BYTE_SIZE: usize = PolyVec::BYTE_SIZE + 32;
582    const CIPHERTEXT_SIZE: usize = PolyVec::COMPRESSED_BYTES + Poly::COMPRESSED_BYTES;
583
584    fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
585        self.t.byte_encode(bytes);
586        bytes[PolyVec::BYTE_SIZE..].copy_from_slice(&self.rho);
587    }
588
589    fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
590        let (t_bytes, bytes) = bytes.split_first_chunk().unwrap();
591        let (rho, _) = bytes.split_first_chunk().unwrap();
592
593        let mut t = PolyVec::from_bytes(t_bytes);
594        t.reduce();
595
596        Self { t, rho: *rho }
597    }
598
599    /// Algorithm 14 K-PKE.Encrypt(ek_PKE, m, r)
600    fn encrypt(&self, c: &mut [u8; Self::CIPHERTEXT_SIZE], m: &[u8; 32], r: &[u8; 32]) {
601        let mut xof = hash::Shake128::init();
602        let at = PolyMatrix::generate_transposed(&mut xof, &self.rho);
603
604        let mut nonces = 0..(2 * K + 1);
605
606        let mut y = PolyVec::generate_eta2(r, &mut nonces);
607        let e1 = PolyVec::generate_eta2(r, &mut nonces);
608
609        let e2 = Poly::generate_eta2(r, &mut nonces);
610        y.ntt();
611
612        // u <- NTT^-1(A^T * y) + e_1
613        let mut u = &at * &y;
614        u.invntt();
615        u += &e1;
616        u.reduce();
617
618        let mu = Poly::from_msg(m);
619
620        // v <- NTT^-1(t^T * y) + e_2 + mu
621        let mut v = &self.t * &y;
622        v.invntt();
623        v += &e2;
624        v += &mu;
625        v.reduce();
626
627        let (c1, c2) = c.split_first_chunk_mut().unwrap();
628        let (c2, _) = c2.split_first_chunk_mut().unwrap();
629
630        u.compress(c1);
631        v.compress(c2);
632    }
633}
634
635struct PkeDecKey {
636    s: PolyVec,
637}
638
639impl PkeDecKey {
640    const BYTE_SIZE: usize = K * Poly::ENCODED_BYTES;
641
642    fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
643        self.s.byte_encode(bytes);
644    }
645
646    fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
647        let mut s = PolyVec::from_bytes(bytes);
648        s.reduce();
649
650        Self { s }
651    }
652
653    /// Algorithm 15 K-PKE.Encrypt(dk_PKE, c)
654    fn decrypt(&self, m: &mut [u8; 32], c: &[u8; PkeEncKey::CIPHERTEXT_SIZE]) {
655        let (c1, c2) = c.split_first_chunk().unwrap();
656        let (c2, _) = c2.split_first_chunk().unwrap();
657
658        let mut u_prime = PolyVec::decompress(c1);
659        let mut v_prime = Poly::decompress(c2);
660
661        u_prime.ntt();
662        let mut w = &self.s * &u_prime;
663        w.invntt();
664
665        v_prime -= &w;
666        v_prime.reduce();
667
668        v_prime.to_msg(m);
669    }
670}
671
672/// Algorithm 13 K-PKE.KeyGen(d)
673fn pke_keygen(d: &[u8; 32]) -> (PkeEncKey, PkeDecKey) {
674    let (rho, sigma) = hash::sha3_512_split(&[d, &[K as u8]]);
675
676    let mut xof = hash::Shake128::init();
677    let a = PolyMatrix::generate(&mut xof, &rho);
678
679    let mut prf = hash::Shake256::init();
680
681    let (s, e) = generate_se(&mut prf, &sigma);
682
683    let mut t: PolyVec = PolyVec::zero();
684
685    for i in 0..K {
686        t.vec[i].multiply_acc(&a.m[i], &s);
687        t.vec[i].montgomery_form();
688    }
689
690    t += &e;
691    t.reduce();
692
693    (PkeEncKey { t, rho }, PkeDecKey { s })
694}
695
696/// ML-KEM encapsulation key (public key).
697pub struct EncapsKey {
698    ek_pke: PkeEncKey,
699}
700
701impl EncapsKey {
702    /// Byte size of the bit-packed key.
703    pub const BYTE_SIZE: usize = PkeEncKey::BYTE_SIZE;
704
705    /// Byte size of the produced ciphertext.
706    pub const CIPHERTEXT_SIZE: usize = PkeEncKey::CIPHERTEXT_SIZE;
707
708    /// Encode key to bytes.
709    #[inline]
710    pub fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE]) {
711        self.ek_pke.to_bytes(bytes);
712    }
713
714    /// Decode key from bytes.
715    #[inline]
716    pub fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
717        let ek_pke = PkeEncKey::from_bytes(bytes);
718
719        Self { ek_pke }
720    }
721
722    /// Algorithm 17 ML-KEM.Encaps_internal(ek, m)
723    fn encaps_internal(
724        &self,
725        c: &mut [u8; PkeEncKey::CIPHERTEXT_SIZE],
726        k: &mut [u8; 32],
727        m: &[u8; 32],
728    ) {
729        let mut bytes = [0u8; Self::BYTE_SIZE];
730        self.to_bytes(&mut bytes);
731        let h = hash::sha3_256(&[&bytes]);
732
733        let (key, r) = hash::sha3_512_split(&[m, &h]);
734
735        self.ek_pke.encrypt(c, m, &r);
736
737        k.copy_from_slice(&key);
738    }
739
740    /// Algorithm 20 ML-KEM.Encaps(ek)
741    #[inline]
742    pub fn encaps(
743        &self,
744        c: &mut [u8; Self::CIPHERTEXT_SIZE],
745        k: &mut [u8; 32],
746        rng: &mut impl CryptoRngCore,
747    ) {
748        let mut m = [0u8; 32];
749        rng.fill_bytes(&mut m);
750        self.encaps_internal(c, k, &m);
751    }
752}
753
754/// ML-KEM decapsulation key (secret key).
755pub struct DecapsKey {
756    dk_pke: PkeDecKey,
757    h: [u8; 32],
758    z: [u8; 32],
759}
760
761impl DecapsKey {
762    /// Byte size of the bit-packed key.
763    pub const BYTE_SIZE: usize = PkeDecKey::BYTE_SIZE + PkeEncKey::BYTE_SIZE + 32 + 32;
764
765    /// Encode key to bytes.
766    #[inline]
767    pub fn to_bytes(&self, bytes: &mut [u8; Self::BYTE_SIZE], ek: &EncapsKey) {
768        let (dk_bytes, bytes) = bytes.split_first_chunk_mut().unwrap();
769        let (ek_bytes, bytes) = bytes.split_first_chunk_mut().unwrap();
770        let (ek_hash, bytes): (&mut [u8; 32], _) = bytes.split_first_chunk_mut().unwrap();
771        let (z, _): (&mut [u8; 32], _) = bytes.split_first_chunk_mut().unwrap();
772
773        self.dk_pke.to_bytes(dk_bytes);
774        ek.ek_pke.to_bytes(ek_bytes);
775        hash::sha3_256_into(ek_hash, &[ek_bytes]);
776        z.copy_from_slice(&self.z);
777    }
778
779    /// Decode key from bytes.
780    #[inline]
781    pub fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
782        let (dk_bytes, bytes) = bytes.split_first_chunk().unwrap();
783        let (_ek_bytes, bytes): (&[u8; PkeEncKey::BYTE_SIZE], _) =
784            bytes.split_first_chunk().unwrap();
785        let (h, bytes) = bytes.split_first_chunk().unwrap();
786        let (z_bytes, _) = bytes.split_first_chunk().unwrap();
787
788        let dk_pke = PkeDecKey::from_bytes(dk_bytes);
789
790        Self {
791            dk_pke,
792            h: *h,
793            z: *z_bytes,
794        }
795    }
796
797    /// Algorithm 21 ML-KEM.Decaps(dk, c)
798    /// Algorithm 18 ML-KEM.Decaps_internal(dk, c)
799    #[inline]
800    pub fn decaps(&self, k: &mut [u8; 32], ek: &EncapsKey, c: &[u8; EncapsKey::CIPHERTEXT_SIZE]) {
801        let mut m_prime = [0u8; 32];
802        self.dk_pke.decrypt(&mut m_prime, c);
803
804        let (k_prime, r_prime) = hash::sha3_512_split(&[&m_prime, &self.h]);
805
806        let mut j = hash::Shake256::init();
807        j.absorb_multi(&[&self.z, c]);
808        k.copy_from_slice(&j.squeezeblock()[..32]);
809
810        let mut c_prime = [0u8; EncapsKey::CIPHERTEXT_SIZE];
811        ek.ek_pke.encrypt(&mut c_prime, &m_prime, &r_prime);
812
813        cmov(k, &k_prime, bytes_is_eq(c, &c_prime));
814    }
815}
816
817/// Compare two byte arrays in constant time.
818/// Returns 1 if equal and 0 if not equal.
819const fn bytes_is_eq<const N: usize>(a: &[u8; N], b: &[u8; N]) -> u32 {
820    let mut i = 0;
821    let mut cond = 0;
822
823    while i < N {
824        cond |= (a[i] ^ b[i]) as u32;
825
826        i += 1;
827    }
828
829    (cond.wrapping_neg() >> 31) ^ 1
830}
831
832/// Move `src` into `dst` if `cond == 1` in constant time.
833fn cmov<const N: usize>(dst: &mut [u8; N], src: &[u8; N], cond: u32) {
834    let cond = black_box(cond).wrapping_neg() as u8;
835
836    for (a, b) in dst.iter_mut().zip(src.iter()) {
837        *a ^= cond & (*a ^ *b);
838    }
839}
840
841/// Algorithm 19 ML-KEM.KeyGen()
842#[inline]
843pub fn keygen(rng: &mut impl CryptoRngCore) -> (EncapsKey, DecapsKey) {
844    let mut d = [0u8; 32];
845    rng.fill_bytes(&mut d);
846
847    let mut z = [0u8; 32];
848    rng.fill_bytes(&mut z);
849
850    keygen_deterministic(d, z)
851}
852
853fn keygen_deterministic(d: [u8; 32], z: [u8; 32]) -> (EncapsKey, DecapsKey) {
854    let (ek_pke, dk_pke) = pke_keygen(&d);
855
856    let ek = EncapsKey { ek_pke };
857
858    let mut ek_bytes = [0u8; EncapsKey::BYTE_SIZE];
859    ek.to_bytes(&mut ek_bytes);
860
861    let h = hash::sha3_256(&[&ek_bytes]);
862
863    (ek, DecapsKey { dk_pke, h, z })
864}
865
866#[cfg(test)]
867mod tests {
868    use rand_core::OsRng;
869    use serde::Deserialize;
870    use std::{fs::read_to_string, path::PathBuf};
871
872    use super::*;
873
874    #[test]
875    fn test_keygen() {
876        let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
877        test_data_path.push("tests/kyber-keygen.json");
878
879        let test_data = read_to_string(&test_data_path).unwrap();
880        let test_data: Tests<KeyGenTestGroup> = serde_json::from_str(&test_data).unwrap();
881
882        for test_group in test_data
883            .test_groups
884            .iter()
885            .filter(|g| g.parameter_set == "ML-KEM-768")
886        {
887            for test in test_group.tests.iter() {
888                // validate sizes
889                assert_eq!(test.ek.len(), EncapsKey::BYTE_SIZE);
890                assert_eq!(test.dk.len(), DecapsKey::BYTE_SIZE);
891
892                let (ek, dk) = keygen_deterministic(test.d, test.z);
893
894                // Test decoded decaps key
895                let test_dk = DecapsKey::from_bytes(test.dk.as_slice().try_into().unwrap());
896                assert_eq!(test_dk.z, test.z);
897                assert_eq!(dk.z, test.z);
898                assert_eq!(test_dk.dk_pke.s, dk.dk_pke.s);
899
900                // Test decoded encaps key
901                let test_ek = EncapsKey::from_bytes(test.ek.as_slice().try_into().unwrap());
902                assert_eq!(test_ek.ek_pke.rho, ek.ek_pke.rho);
903                assert_eq!(test_ek.ek_pke.t.vec, ek.ek_pke.t.vec);
904
905                // Test encoding of encaps key
906                let mut ek_bytes = [0u8; EncapsKey::BYTE_SIZE];
907                ek.to_bytes(&mut ek_bytes);
908                assert_eq!(ek_bytes, test.ek.as_slice());
909
910                // Test encoding of decaps key
911                let mut dk_bytes = [0u8; DecapsKey::BYTE_SIZE];
912                dk.to_bytes(&mut dk_bytes, &ek);
913                assert_eq!(dk_bytes, test.dk.as_slice());
914            }
915        }
916    }
917
918    #[test]
919    fn test_kem() {
920        let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
921        test_data_path.push("tests/kyber-kem.json");
922
923        let test_data = read_to_string(&test_data_path).unwrap();
924        let test_data: Tests<KemTestGroup> = serde_json::from_str(&test_data).unwrap();
925
926        for test_group in test_data
927            .test_groups
928            .iter()
929            .filter(|g| g.parameter_set == "ML-KEM-768")
930        {
931            match &test_group.params {
932                KemTestGroupKind::Aft { tests } => {
933                    for test in tests.iter() {
934                        assert_eq!(test.c.len(), EncapsKey::CIPHERTEXT_SIZE);
935                        let ek = EncapsKey::from_bytes(test.ek.as_slice().try_into().unwrap());
936                        let dk = DecapsKey::from_bytes(test.dk.as_slice().try_into().unwrap());
937
938                        let mut c = [0u8; EncapsKey::CIPHERTEXT_SIZE];
939                        let mut k = [0u8; 32];
940                        ek.encaps_internal(&mut c, &mut k, test.m.as_slice().try_into().unwrap());
941
942                        assert_eq!(c, test.c.as_slice());
943                        assert_eq!(k, test.k.as_slice());
944
945                        let mut k_prime = [0u8; 32];
946                        dk.decaps(&mut k_prime, &ek, &c);
947                        assert_eq!(&k, &k_prime);
948                    }
949                }
950                KemTestGroupKind::Val { tests, dk, ek } => {
951                    let ek = EncapsKey::from_bytes(ek.as_slice().try_into().unwrap());
952                    let dk = DecapsKey::from_bytes(dk.as_slice().try_into().unwrap());
953                    for test in tests.iter() {
954                        assert_eq!(test.c.len(), EncapsKey::CIPHERTEXT_SIZE);
955
956                        let mut k = [0u8; 32];
957                        dk.decaps(&mut k, &ek, test.c[..].try_into().unwrap());
958
959                        assert_eq!(&k, &test.k[..]);
960                    }
961                }
962            }
963        }
964    }
965
966    #[test]
967    fn test_kem_random() {
968        let (ek, dk) = keygen(&mut OsRng);
969        let mut c = [0u8; EncapsKey::CIPHERTEXT_SIZE];
970        let mut k = [0u8; 32];
971        ek.encaps(&mut c, &mut k, &mut OsRng);
972
973        let mut k_prime = [0u8; 32];
974        dk.decaps(&mut k_prime, &ek, &c);
975
976        assert_eq!(&k, &k_prime);
977    }
978
979    fn gen_rand_bytes<const N: usize>(rng: &mut impl CryptoRngCore) -> [u8; N] {
980        let mut bytes = [0; N];
981        rng.fill_bytes(&mut bytes);
982        bytes
983    }
984
985    #[test]
986    fn test_compress() {
987        let compr_pvec = gen_rand_bytes(&mut OsRng);
988        let mut compr_pvec_prime = [0; PolyVec::COMPRESSED_BYTES];
989        let pvec = PolyVec::decompress(&compr_pvec);
990        pvec.compress(&mut compr_pvec_prime);
991        assert_eq!(&compr_pvec, &compr_pvec_prime);
992
993        let compr_poly = gen_rand_bytes(&mut OsRng);
994        let mut compr_poly_prime = [0; Poly::COMPRESSED_BYTES];
995        let poly = Poly::decompress(&compr_poly);
996        poly.compress(&mut compr_poly_prime);
997        assert_eq!(&compr_poly, &compr_poly_prime)
998    }
999
1000    #[derive(Deserialize)]
1001    struct Tests<T> {
1002        #[serde(rename = "isSample")]
1003        _is_sample: bool,
1004
1005        #[serde(rename = "testGroups")]
1006        test_groups: Vec<T>,
1007
1008        #[serde(rename = "vsId")]
1009        _vs_id: i64,
1010    }
1011
1012    #[derive(Deserialize)]
1013    struct KeyGenTestGroup {
1014        #[serde(rename = "parameterSet")]
1015        parameter_set: String,
1016
1017        #[serde(rename = "testType")]
1018        _test_type: String,
1019
1020        tests: Vec<KeyGenTestVector>,
1021
1022        #[serde(rename = "tgId")]
1023        _tg_id: i64,
1024    }
1025
1026    #[derive(Deserialize)]
1027    struct KeyGenTestVector {
1028        #[serde(with = "hex")]
1029        d: [u8; 32],
1030
1031        #[serde(with = "hex")]
1032        z: [u8; 32],
1033
1034        #[serde(with = "hex")]
1035        dk: Vec<u8>,
1036
1037        #[serde(with = "hex")]
1038        ek: Vec<u8>,
1039
1040        #[serde(rename = "tcId")]
1041        _tc_id: i64,
1042    }
1043
1044    #[derive(Deserialize)]
1045    struct KemTestVectorAft {
1046        #[serde(with = "hex")]
1047        c: Vec<u8>,
1048
1049        #[serde(with = "hex")]
1050        dk: Vec<u8>,
1051
1052        #[serde(with = "hex")]
1053        ek: Vec<u8>,
1054
1055        #[serde(with = "hex")]
1056        k: Vec<u8>,
1057
1058        #[serde(with = "hex")]
1059        m: Vec<u8>,
1060
1061        #[serde(rename = "tcId")]
1062        _tc_id: i64,
1063    }
1064
1065    #[derive(Deserialize)]
1066    struct KemTestVectorVal {
1067        #[serde(with = "hex")]
1068        c: Vec<u8>,
1069
1070        #[serde(with = "hex")]
1071        k: Vec<u8>,
1072
1073        #[serde(rename = "tcId")]
1074        _tc_id: i64,
1075    }
1076
1077    #[derive(Deserialize)]
1078    struct KemTestGroup {
1079        #[serde(rename = "parameterSet")]
1080        parameter_set: String,
1081
1082        #[serde(rename = "tgId")]
1083        _tg_id: i64,
1084
1085        #[serde(flatten)]
1086        params: KemTestGroupKind,
1087    }
1088
1089    #[derive(Deserialize)]
1090    #[serde(tag = "testType")]
1091    enum KemTestGroupKind {
1092        #[serde(rename = "AFT")]
1093        Aft { tests: Vec<KemTestVectorAft> },
1094        #[serde(rename = "VAL")]
1095        Val {
1096            tests: Vec<KemTestVectorVal>,
1097
1098            #[serde(with = "hex")]
1099            dk: Vec<u8>,
1100
1101            #[serde(with = "hex")]
1102            ek: Vec<u8>,
1103        },
1104    }
1105}