Skip to main content

pqcrypto_std/mldsa/
mod.rs

1//! Implementation of ML-DSA (FIPS-204).
2
3use crate::hash;
4use core::{
5    array,
6    mem::{transmute, transmute_copy, MaybeUninit},
7    ops::{AddAssign, Mul, MulAssign, SubAssign},
8};
9use rand_core::CryptoRngCore;
10use thiserror::Error;
11use zeroize::Zeroize;
12
13pub mod mldsa44;
14pub mod mldsa65;
15pub mod mldsa87;
16
17mod coeff;
18mod reduce;
19
20const Q: i32 = 8380417;
21const N: usize = 256;
22const ZETA: i32 = 1753;
23const D: usize = 13;
24
25/// pre-computed zetas in montgomery form
26/// ordered by ZETAS\[i\] = z^BitRev8(i)
27/// zeta -> zeta * 2^32 (mod Q)
28const ZETAS: [i32; N] = {
29    let mut zetas = [0; N];
30    zetas[0] = reduce::R_MOD_Q;
31
32    let mut i = 1;
33    while i < N {
34        zetas[i] = reduce::mont_mul(zetas[i - 1], reduce::to_mont(ZETA));
35
36        i += 1
37    }
38
39    let mut zetas_bitrev = [0; N];
40
41    i = 0;
42    while i < N {
43        let idx = (i as u8).reverse_bits();
44
45        zetas_bitrev[i] = match zetas[idx as usize] {
46            z if z > Q / 2 => z - Q,
47            z if z < -Q / 2 => z + Q,
48            z => z,
49        };
50
51        i += 1;
52    }
53
54    zetas_bitrev
55};
56
57trait SigningKeyInternal<
58    const K: usize,
59    const L: usize,
60    const ETA: usize,
61    const TAU: usize,
62    const GAMMA1: usize,
63    const GAMMA2: usize,
64    const BETA: usize,
65    const OMEGA: usize,
66    const CT_BYTES: usize,
67    const W1_BYTES: usize,
68    const Z_BYTES: usize,
69>: From<PrivateKey<K, L, ETA>>
70{
71    fn privkey(&self) -> &PrivateKey<K, L, ETA>;
72    fn expand_mask(pvec: &mut PolyVec<L>, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256);
73    fn bitpack_z(pvec: &PolyVec<L>, dst: &mut [u8; Z_BYTES]);
74    fn pack_simple(w1: &PolyVec<K>, z: &mut [u8; W1_BYTES]);
75    fn decompose(x: &PolyVec<K>, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>);
76
77    fn sign_internal(&self, dst: &mut [u8], m: &[u8], rnd: &[u8; 32]) {
78        let (c_tilde, buf) = dst.split_first_chunk_mut::<CT_BYTES>().unwrap();
79        let (w1_bytes, buf) = buf.split_first_chunk_mut::<W1_BYTES>().unwrap();
80        let (mu, buf) = buf.split_first_chunk_mut::<64>().unwrap();
81        let rho_prime2: &mut [u8; 64] = buf.first_chunk_mut().unwrap();
82
83        let mut h = hash::Shake256::init();
84        let privkey = self.privkey();
85
86        h.absorb_and_squeeze(mu, &[&privkey.tr, m]);
87
88        h.absorb_and_squeeze(rho_prime2, &[&privkey.k, rnd, mu]);
89
90        let mut y = PolyVec::zero();
91        let mut y_hat = PolyVec::zero();
92        let mut w = PolyVec::zero();
93        let mut w1 = PolyVec::zero();
94        let mut w0 = PolyVec::zero();
95        let mut z = PolyVec::zero();
96        let mut hint = PolyVec::zero();
97        let mut c_hat = Poly::zero();
98
99        for nonce in (0..).step_by(L) {
100            Self::expand_mask(&mut y, rho_prime2, nonce, &mut h);
101
102            y_hat.ntt(&y);
103
104            w.multiply_matvec_ntt(&privkey.a_hat, &y_hat);
105            w.reduce_invntt_tomont_inplace();
106
107            Self::decompose(&w, &mut w0, &mut w1);
108            Self::pack_simple(&w1, w1_bytes);
109            h.absorb_and_squeeze(c_tilde, &[mu, w1_bytes]);
110
111            h.absorb(c_tilde);
112            h.finalize();
113            c_hat.f.fill(0);
114            c_hat.sample_in_ball(&mut h, TAU);
115            h.reset();
116            c_hat.ntt_inplace();
117
118            z.multiply_poly_ntt(&c_hat, &privkey.s1_hat);
119            z.invntt_tomont_inplace();
120            z += &y;
121            z.reduce();
122
123            if !z.norm_in_bound(GAMMA1 - BETA) {
124                continue;
125            }
126
127            hint.multiply_poly_ntt(&c_hat, &privkey.s2_hat);
128            hint.invntt_tomont_inplace();
129            w0 -= &hint;
130            w0.reduce();
131
132            if !w0.norm_in_bound(GAMMA2 - BETA) {
133                continue;
134            }
135
136            hint.multiply_poly_ntt(&c_hat, &privkey.t0_hat);
137            hint.invntt_tomont_inplace();
138            hint.reduce();
139
140            if !hint.norm_in_bound(GAMMA2) {
141                continue;
142            }
143
144            w0 += &hint;
145
146            let count = hint.make_hint(&w0, &w1, GAMMA2);
147
148            if count >= OMEGA {
149                continue;
150            }
151
152            break;
153        }
154
155        let (z_buf, buf) = dst[CT_BYTES..].split_first_chunk_mut().unwrap();
156
157        Self::bitpack_z(&z, z_buf);
158
159        hint.hint_bitpack::<OMEGA>(buf);
160    }
161
162    fn keygen_internal(vk: &mut [u8], ksi: &[u8; 32]) -> Self {
163        let mut h = hash::Shake256::init();
164        h.absorb_multi(&[ksi, &[K as u8], &[L as u8]]);
165        let rho: [u8; 32] = h.squeeze_array();
166        let rho_prime: [u8; 64] = h.squeeze_array();
167        let k: [u8; 32] = h.squeeze_array();
168
169        let mut s1_hat = PolyVec::zero();
170        let mut s2_hat = PolyVec::zero();
171        let mut t0_hat = PolyVec::zero();
172        let a_hat = PolyMat::expand_a(&rho);
173
174        expand_s::<K, L, ETA>(&mut s1_hat, &mut s2_hat, &rho_prime);
175
176        s1_hat.ntt_inplace();
177
178        let mut t = PolyVec::zero();
179        let mut t1 = PolyVec::zero();
180
181        t.multiply_matvec_ntt(&a_hat, &s1_hat);
182        t.reduce_invntt_tomont_inplace();
183
184        t += &s2_hat;
185
186        t.power2round(&mut t1, &mut t0_hat);
187
188        vk_encode(vk, &rho, &t1);
189
190        h.reset();
191        h.absorb(vk);
192        h.finalize();
193        let tr: [u8; 64] = h.squeeze_array();
194
195        s2_hat.ntt_inplace();
196        t0_hat.ntt_inplace();
197
198        PrivateKey {
199            rho,
200            k,
201            tr,
202            s1_hat,
203            s2_hat,
204            t0_hat,
205            a_hat,
206        }
207        .into()
208    }
209}
210
211#[derive(Debug, Error)]
212pub enum VerifyError {
213    #[error("z is out of bound")]
214    ZoutOfBound,
215
216    #[error("signature mismatch")]
217    Mismatch,
218
219    #[error("too many hints in signature")]
220    TooManyHints,
221}
222
223trait VerifyingKeyInternal<
224    const K: usize,
225    const L: usize,
226    const CT_BYTES: usize,
227    const Z_BYTES: usize,
228    const H_BYTES: usize,
229    const W1_BYTES: usize,
230    const SIG_SIZE: usize,
231>
232{
233    const OMEGA: usize;
234    const TAU: usize;
235    const GAMMA1: usize;
236    const GAMMA2: usize;
237    const BETA: usize;
238
239    fn bitunpack_z_hat(b: &[u8; Z_BYTES]) -> PolyVec<L>;
240
241    fn w1encode(w1: &PolyVec<K>) -> [u8; W1_BYTES];
242
243    fn use_hint(w1: &mut PolyVec<K>, h: &PolyVec<K>);
244
245    fn pk(&self) -> &PublicKey<K, L>;
246
247    fn verify_internal(&self, m: &[u8], sig: &[u8; SIG_SIZE]) -> Result<(), VerifyError> {
248        let (c_tilde, sig) = sig.split_first_chunk::<CT_BYTES>().unwrap();
249        let (z_bytes, sig) = sig.split_first_chunk::<Z_BYTES>().unwrap();
250        let h_bytes: &[u8; H_BYTES] = sig.try_into().unwrap();
251
252        let hint = PolyVec::hint_bitunpack(h_bytes, Self::OMEGA)?;
253
254        let mut z_hat = Self::bitunpack_z_hat(z_bytes);
255
256        if !z_hat.norm_in_bound(Self::GAMMA1 - Self::BETA) {
257            return Err(VerifyError::ZoutOfBound);
258        }
259
260        let pk = self.pk();
261
262        let mut h = hash::Shake256::init();
263
264        h.absorb_multi(&[&pk.tr, m]);
265        let mu: [u8; 64] = h.squeeze_array();
266        h.reset();
267
268        let mut c_hat = Poly::zero();
269        h.absorb(c_tilde);
270        h.finalize();
271        c_hat.sample_in_ball(&mut h, Self::TAU);
272        h.reset();
273
274        z_hat.ntt_inplace();
275
276        let mut w1 = PolyVec::zero();
277        w1.multiply_matvec_ntt(&pk.a_hat, &z_hat);
278
279        c_hat.ntt_inplace();
280
281        let mut t1 = pk.t1.shifted_left(D);
282        t1.ntt_inplace();
283        t1 *= &c_hat;
284
285        w1 -= &t1;
286        w1.reduce_invntt_tomont_inplace();
287        Self::use_hint(&mut w1, &hint);
288
289        let w1_bytes = Self::w1encode(&w1);
290
291        h.absorb_multi(&[&mu, &w1_bytes]);
292        let c_tilde_prime = h.squeeze_array();
293
294        if c_tilde == &c_tilde_prime {
295            Ok(())
296        } else {
297            Err(VerifyError::Mismatch)
298        }
299    }
300}
301
302pub trait VerifyingKey<
303    const K: usize,
304    const L: usize,
305    const CT_BYTES: usize,
306    const Z_BYTES: usize,
307    const H_BYTES: usize,
308    const W1_BYTES: usize,
309    const SIG_SIZE: usize,
310>
311{
312    fn verify(&self, m: &[u8], sig: &[u8]) -> Result<(), VerifyError>;
313    fn encode(&self, dst: &mut [u8]);
314    fn decode(src: &[u8]) -> Self;
315}
316
317impl<
318        T,
319        const K: usize,
320        const L: usize,
321        const CT_BYTES: usize,
322        const Z_BYTES: usize,
323        const H_BYTES: usize,
324        const W1_BYTES: usize,
325        const SIG_SIZE: usize,
326    > VerifyingKey<K, L, CT_BYTES, Z_BYTES, H_BYTES, W1_BYTES, SIG_SIZE> for T
327where
328    T: VerifyingKeyInternal<K, L, CT_BYTES, Z_BYTES, H_BYTES, W1_BYTES, SIG_SIZE>
329        + From<PublicKey<K, L>>,
330{
331    fn verify(&self, m: &[u8], sig: &[u8]) -> Result<(), VerifyError> {
332        assert!(sig.len() >= SIG_SIZE);
333        self.verify_internal(m, sig.first_chunk().unwrap())
334    }
335
336    fn encode(&self, dst: &mut [u8]) {
337        assert!(dst.len() >= pubkey_size(K));
338        let key = self.pk();
339        vk_encode(&mut dst[..pubkey_size(K)], &key.rho, &key.t1)
340    }
341
342    fn decode(src: &[u8]) -> Self {
343        assert!(src.len() >= pubkey_size(K));
344        PublicKey::decode(src).into()
345    }
346}
347
348/// Signatory in ML-DSA.
349pub trait SigningKey<
350    const K: usize,
351    const L: usize,
352    const ETA: usize,
353    const TAU: usize,
354    const GAMMA1: usize,
355    const GAMMA2: usize,
356    const BETA: usize,
357    const OMEGA: usize,
358    const CT_BYTES: usize,
359    const W1_BYTES: usize,
360    const Z_BYTES: usize,
361>
362{
363    /// Sign message `m` using randomness from `rng`.
364    fn sign(&self, sig: &mut [u8], rng: &mut impl CryptoRngCore, m: &[u8]);
365    fn encode(&self, dst: &mut [u8]);
366    fn decode(src: &[u8]) -> Self;
367
368    /// Private key generation.
369    fn keygen(vk: &mut [u8], rng: &mut impl CryptoRngCore) -> Self;
370}
371
372impl<
373        T,
374        const K: usize,
375        const L: usize,
376        const ETA: usize,
377        const TAU: usize,
378        const GAMMA1: usize,
379        const GAMMA2: usize,
380        const BETA: usize,
381        const OMEGA: usize,
382        const CT_BYTES: usize,
383        const W1_BYTES: usize,
384        const Z_BYTES: usize,
385    > SigningKey<K, L, ETA, TAU, GAMMA1, GAMMA2, BETA, OMEGA, CT_BYTES, W1_BYTES, Z_BYTES> for T
386where
387    T: SigningKeyInternal<K, L, ETA, TAU, GAMMA1, GAMMA2, BETA, OMEGA, CT_BYTES, W1_BYTES, Z_BYTES>,
388{
389    fn sign(&self, sig: &mut [u8], rng: &mut impl CryptoRngCore, m: &[u8]) {
390        let mut rnd = [0u8; 32];
391        rng.fill_bytes(&mut rnd);
392
393        self.sign_internal(sig, m.as_ref(), &rnd)
394    }
395
396    fn encode(&self, dst: &mut [u8]) {
397        assert!(dst.len() >= privkey_size(K, L, ETA));
398        self.privkey().encode(dst);
399    }
400
401    fn decode(src: &[u8]) -> Self {
402        assert!(src.len() >= privkey_size(K, L, ETA));
403        PrivateKey::decode(src).into()
404    }
405
406    /// Private key generation.
407    fn keygen(pk: &mut [u8], rng: &mut impl CryptoRngCore) -> Self {
408        debug_assert!(pk.len() >= pubkey_size(K));
409
410        let mut ksi = [0u8; 32];
411        rng.fill_bytes(&mut ksi);
412
413        let sk = Self::keygen_internal(pk, &ksi);
414
415        ksi.zeroize();
416
417        sk
418    }
419}
420
421const fn pubkey_size(k: usize) -> usize {
422    k * Poly::PACKED_10BIT + 32
423}
424
425const fn privkey_size(k: usize, l: usize, eta: usize) -> usize {
426    match eta {
427        2 => 32 + 32 + 64 + l * Poly::PACKED_3BIT + k * (Poly::PACKED_3BIT + Poly::PACKED_13BIT),
428        4 => 32 + 32 + 64 + l * Poly::PACKED_4BIT + k * (Poly::PACKED_4BIT + Poly::PACKED_13BIT),
429        _ => unreachable!(),
430    }
431}
432
433const fn bitlen(n: usize) -> usize {
434    n.ilog2() as usize + 1
435}
436
437const fn sig_size(k: usize, l: usize, lambda: usize, gamma1: usize, omega: usize) -> usize {
438    lambda / 4 + l * 32 * (1 + bitlen(gamma1 - 1)) + omega + k
439}
440
441fn vk_encode<const K: usize>(dst: &mut [u8], rho: &[u8; 32], t1: &PolyVec<K>) {
442    dst[..32].copy_from_slice(rho);
443    for (xi, z) in
444        t1.v.iter()
445            .zip(dst[32..].chunks_exact_mut(Poly::PACKED_10BIT))
446    {
447        xi.pack_simple_10bit(z.try_into().unwrap())
448    }
449}
450
451/// Public key used for verifying.
452struct PublicKey<const K: usize, const L: usize> {
453    rho: [u8; 32],
454    tr: [u8; 64],
455    t1: PolyVec<K>,
456    a_hat: PolyMat<K, L>,
457}
458
459impl<const K: usize, const L: usize> PublicKey<K, L> {
460    /// Decode public key from bytes.
461    fn decode(pk: &[u8]) -> Self {
462        let rho = array::from_fn(|i| pk[i]);
463        let mut t1 = PolyVec::zero();
464
465        for (xi, z) in
466            t1.v.iter_mut()
467                .zip(pk[32..].chunks_exact(Poly::PACKED_10BIT))
468        {
469            xi.unpack_simple_10bit(z.try_into().unwrap())
470        }
471
472        let a_hat = PolyMat::expand_a(&rho);
473
474        let mut h = hash::Shake256::init();
475        h.absorb(pk);
476        h.finalize();
477        let tr = h.squeeze_array();
478
479        Self { rho, tr, t1, a_hat }
480    }
481}
482
483/// Private key used for signing.
484pub(crate) struct PrivateKey<const K: usize, const L: usize, const ETA: usize> {
485    rho: [u8; 32],
486    k: [u8; 32],
487    tr: [u8; 64],
488    s1_hat: PolyVec<L>,
489    s2_hat: PolyVec<K>,
490    t0_hat: PolyVec<K>,
491    a_hat: PolyMat<K, L>,
492}
493
494impl<const K: usize, const L: usize, const ETA: usize> Drop for PrivateKey<K, L, ETA> {
495    fn drop(&mut self) {
496        self.k.zeroize();
497        self.tr.zeroize();
498    }
499}
500
501impl<const K: usize, const L: usize, const ETA: usize> PrivateKey<K, L, ETA> {
502    /// Encode private key to bytes.
503    pub fn encode(&self, dst: &mut [u8]) {
504        let s1 = self.s1_hat.invntt();
505        let s2 = self.s2_hat.invntt();
506        let t0 = self.t0_hat.invntt();
507
508        dst[..32].copy_from_slice(&self.rho);
509        dst[32..64].copy_from_slice(&self.k);
510        dst[64..128].copy_from_slice(&self.tr);
511
512        let buf = &mut dst[128..];
513
514        match ETA {
515            2 => {
516                s1.pack_eta2(&mut buf[..L * Poly::PACKED_3BIT]);
517
518                let buf = &mut buf[L * Poly::PACKED_3BIT..];
519                s2.pack_eta2(&mut buf[..K * Poly::PACKED_3BIT]);
520
521                let buf = &mut buf[K * Poly::PACKED_3BIT..];
522                t0.pack_eta_2powdm1(buf)
523            }
524            4 => {
525                s1.pack_eta4(&mut buf[..L * Poly::PACKED_4BIT]);
526
527                let buf = &mut buf[L * Poly::PACKED_4BIT..];
528                s2.pack_eta4(buf);
529
530                let buf = &mut buf[K * Poly::PACKED_4BIT..];
531                t0.pack_eta_2powdm1(buf)
532            }
533            _ => unreachable!(),
534        }
535    }
536
537    /// Decode private key from bytes.
538    pub fn decode(src: &[u8]) -> Self {
539        let mut rho: MaybeUninit<[u8; 32]> = MaybeUninit::uninit();
540        let mut k: MaybeUninit<[u8; 32]> = MaybeUninit::uninit();
541        let mut tr: MaybeUninit<[u8; 64]> = MaybeUninit::uninit();
542
543        rho.write(src[..32].try_into().unwrap());
544        k.write(src[32..64].try_into().unwrap());
545        tr.write(src[64..128].try_into().unwrap());
546
547        let (rho, k, tr) = unsafe { (rho.assume_init(), k.assume_init(), tr.assume_init()) };
548
549        let mut s1_hat = PolyVec::zero();
550        let mut s2_hat = PolyVec::zero();
551        let mut t0_hat = PolyVec::zero();
552
553        match ETA {
554            2 => {
555                let z = &src[128..];
556                s1_hat.unpack_eta2(&z[..L * Poly::PACKED_3BIT]);
557
558                let z = &z[L * Poly::PACKED_3BIT..];
559                s2_hat.unpack_eta2(&z[..K * Poly::PACKED_3BIT]);
560
561                let z = &z[K * Poly::PACKED_3BIT..];
562                t0_hat.unpack_eta_2powdm1(z)
563            }
564            4 => {
565                let z = &src[128..];
566                s1_hat.unpack_eta4(&z[..L * Poly::PACKED_4BIT]);
567
568                let z = &z[L * Poly::PACKED_4BIT..];
569                s2_hat.unpack_eta4(&z[..K * Poly::PACKED_4BIT]);
570
571                let z = &z[K * Poly::PACKED_4BIT..];
572                t0_hat.unpack_eta_2powdm1(z)
573            }
574            _ => unreachable!(),
575        }
576
577        let a_hat = PolyMat::expand_a(&rho);
578
579        s1_hat.ntt_inplace();
580        s2_hat.ntt_inplace();
581        t0_hat.ntt_inplace();
582
583        Self {
584            rho,
585            k,
586            tr,
587            s1_hat,
588            s2_hat,
589            t0_hat,
590            a_hat,
591        }
592    }
593}
594
595#[repr(transparent)]
596struct Poly {
597    f: [i32; N],
598}
599
600impl Drop for Poly {
601    fn drop(&mut self) {
602        self.f.zeroize();
603    }
604}
605
606impl Poly {
607    const fn zero() -> Self {
608        Self { f: [0; N] }
609    }
610
611    const fn packed_bytesize(bitlen: usize) -> usize {
612        (N * bitlen) / 8
613    }
614
615    /// NTT(w)
616    fn ntt_inplace(&mut self) {
617        let w = &mut self.f;
618
619        let mut m = 1;
620
621        for len in (0..8).map(|n| 128 >> n) {
622            for start in (0..256).step_by(len << 1) {
623                let zeta = ZETAS[m];
624                m += 1;
625
626                for j in start..start + len {
627                    let t = reduce::mont_mul(zeta, w[j + len]);
628                    w[j + len] = w[j] - t;
629                    w[j] += t;
630                }
631            }
632        }
633    }
634
635    /// NTT(w)
636    fn ntt(&mut self, f: &Self) {
637        let w_hat = &mut self.f;
638        let w = &f.f;
639
640        w_hat.copy_from_slice(w);
641
642        let mut m = 1;
643
644        for len in (0..8).map(|n| 128 >> n) {
645            for start in (0..256).step_by(len << 1) {
646                let zeta = ZETAS[m];
647                m += 1;
648
649                for j in start..start + len {
650                    let t = reduce::mont_mul(zeta, w_hat[j + len]);
651                    w_hat[j + len] = w_hat[j] - t;
652                    w_hat[j] += t;
653                }
654            }
655        }
656    }
657
658    /// NTT^-1 (w_hat)
659    fn invntt(&self) -> Self {
660        let mut w_hat = self.f;
661
662        let mut m = 255;
663
664        for len in (0..8).map(|n| 1 << n) {
665            for start in (0..256).step_by(len << 1) {
666                let zeta = -ZETAS[m];
667                m -= 1;
668                for j in start..start + len {
669                    let t = w_hat[j];
670                    w_hat[j] = t + w_hat[j + len];
671                    w_hat[j + len] = t - w_hat[j + len];
672                    w_hat[j + len] = reduce::mont_mul(zeta, w_hat[j + len]);
673                }
674            }
675        }
676
677        // 2^32 / 256 = 2^{24}
678        const DIV_256: i32 = ((1 << 24) % Q as i64) as i32;
679
680        for a in w_hat.iter_mut() {
681            *a = reduce::mont_mul(*a, DIV_256);
682        }
683
684        Self { f: w_hat }
685    }
686
687    /// NTT^-1 (w_hat)
688    fn invntt_tomont_inplace(&mut self) {
689        let w = &mut self.f;
690
691        let mut m = 255;
692
693        for len in (0..8).map(|n| 1 << n) {
694            for start in (0..256).step_by(len << 1) {
695                let zeta = -ZETAS[m];
696                m -= 1;
697                for j in start..start + len {
698                    let t = w[j];
699                    w[j] = t + w[j + len];
700                    w[j + len] = t - w[j + len];
701                    w[j + len] = reduce::mont_mul(zeta, w[j + len]);
702                }
703            }
704        }
705
706        // (2^32)^2 / 256 = 2^{56}
707        const DIV_256_MONT: i32 = ((1 << 56) % Q as i64) as i32;
708
709        for a in w.iter_mut() {
710            *a = reduce::mont_mul(*a, DIV_256_MONT);
711        }
712    }
713
714    /// RejNTTPoly(rho)
715    fn rej_ntt(g: &mut hash::Shake128) -> Self {
716        let mut f: [MaybeUninit<i32>; N] = [MaybeUninit::uninit(); N];
717        let mut idx = 0;
718
719        while idx < N {
720            let bytes = g.squeezeblock();
721
722            for b in bytes.chunks_exact(3) {
723                if let Some(a) = coeff::from_three_bytes(b[0], b[1], b[2]) {
724                    f[idx].write(a);
725                    idx += 1;
726                }
727
728                if idx == N {
729                    break;
730                }
731            }
732        }
733
734        Self {
735            f: unsafe { transmute::<[MaybeUninit<i32>; N], [i32; N]>(f) },
736        }
737    }
738
739    /// RejBoundedPoly(rho)
740    fn rej_bounded<const ETA: usize>(&mut self, h: &mut hash::Shake256) {
741        let mut idx = 0;
742
743        while idx < N {
744            let bytes = h.squeezeblock();
745
746            for z in bytes
747                .iter()
748                .flat_map(|b| {
749                    let (z0, z1) = coeff::from_halfbytes::<ETA>(*b);
750                    [z0, z1]
751                })
752                .flatten()
753            {
754                self.f[idx] = z;
755                idx += 1;
756
757                if idx == N {
758                    break;
759                }
760            }
761        }
762    }
763
764    /// SampleInBall(rho)
765    fn sample_in_ball(&mut self, h: &mut hash::Shake256, tau: usize) {
766        let mut block = h.squeezeblock();
767
768        let mut hash = u64::from_le_bytes(block[..8].try_into().unwrap());
769
770        let mut iter = block[8..].iter();
771
772        let mut i = N - tau;
773
774        while i < N {
775            let j = if let Some(j) = iter.by_ref().find(|b| (**b as usize) <= i) {
776                *j as usize
777            } else {
778                block = h.squeezeblock();
779                iter = block.iter();
780                continue;
781            };
782
783            self.f[i] = self.f[j];
784            self.f[j] = 1 - ((hash & 1) << 1) as i32;
785
786            hash >>= 1;
787            i += 1;
788        }
789    }
790
791    fn multiply_ntt_acc(&mut self, a: &Self, b: &Self) {
792        for i in 0..N {
793            self.f[i] += reduce::mont_mul(a.f[i], b.f[i])
794        }
795    }
796
797    fn multiply_ntt(&mut self, a: &Self, b: &Self) {
798        for i in 0..N {
799            self.f[i] = reduce::mont_mul(a.f[i], b.f[i])
800        }
801    }
802
803    fn dot_prod_ntt<const K: usize>(&mut self, u: &PolyVec<K>, v: &PolyVec<K>) {
804        self.multiply_ntt(&u.v[0], &v.v[0]);
805
806        for i in 1..K {
807            self.multiply_ntt_acc(&u.v[i], &v.v[i]);
808        }
809    }
810
811    fn reduce(&mut self) {
812        for a in self.f.iter_mut() {
813            *a = reduce::barrett_reduce(*a);
814        }
815    }
816
817    fn power2round(&self, f: &mut Self, g: &mut Self) {
818        for i in 0..N {
819            let (r1, r0) = coeff::power2round(self.f[i]);
820            f.f[i] = r1;
821            g.f[i] = r0;
822        }
823    }
824
825    fn decompose_32(&self, p0: &mut Self, p1: &mut Self) {
826        for i in 0..N {
827            let (r1, r0) = coeff::decompose_32(self.f[i]);
828            p0.f[i] = r0;
829            p1.f[i] = r1;
830        }
831    }
832
833    fn decompose_88(&self, p0: &mut Self, p1: &mut Self) {
834        for i in 0..N {
835            let (r1, r0) = coeff::decompose_88(self.f[i]);
836            p0.f[i] = r0;
837            p1.f[i] = r1;
838        }
839    }
840
841    const PACKED_10BIT: usize = (N * 10) / 8;
842
843    fn pack_simple_10bit(&self, z: &mut [u8; Self::PACKED_10BIT]) {
844        for (b, a) in z.chunks_exact_mut(5).zip(self.f.chunks_exact(4)) {
845            b[0] = a[0] as u8;
846            b[1] = (a[0] >> 8) as u8 | (a[1] << 2) as u8;
847            b[2] = (a[1] >> 6) as u8 | (a[2] << 4) as u8;
848            b[3] = (a[2] >> 4) as u8 | (a[3] << 6) as u8;
849            b[4] = (a[3] >> 2) as u8;
850        }
851    }
852
853    fn unpack_simple_10bit(&mut self, z: &[u8; Self::PACKED_10BIT]) {
854        for (a, b) in self.f.chunks_exact_mut(4).zip(z.chunks_exact(5)) {
855            let b: [i32; 5] = array::from_fn(|i| b[i] as i32);
856            a[0] = (b[0] | (b[1] << 8)) & 0x3FF;
857            a[1] = ((b[1] >> 2) | (b[2] << 6)) & 0x3FF;
858            a[2] = ((b[2] >> 4) | (b[3] << 4)) & 0x3FF;
859            a[3] = ((b[3] >> 6) | (b[4] << 2)) & 0x3FF;
860        }
861    }
862
863    fn pack_simple_4bit(&self, z: &mut [u8; Self::packed_bytesize(4)]) {
864        for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
865            *b = (a[0] | a[1] << 4) as u8;
866        }
867    }
868
869    fn pack_simple_uninit_4bit(&self, z: &mut [MaybeUninit<u8>; Self::packed_bytesize(4)]) {
870        for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
871            b.write((a[0] | a[1] << 4) as u8);
872        }
873    }
874
875    fn pack_simple_6bit(&self, z: &mut [u8; Self::packed_bytesize(6)]) {
876        for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(4)) {
877            b[0] = ((a[0] >> 0) | (a[1] << 6)) as u8;
878            b[1] = ((a[1] >> 2) | (a[2] << 4)) as u8;
879            b[2] = ((a[2] >> 4) | (a[3] << 2)) as u8;
880        }
881    }
882
883    fn pack_simple_uninit_6bit(&self, z: &mut [MaybeUninit<u8>; Self::packed_bytesize(6)]) {
884        for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(4)) {
885            b[0].write(((a[0] >> 0) | (a[1] << 6)) as u8);
886            b[1].write(((a[1] >> 2) | (a[2] << 4)) as u8);
887            b[2].write(((a[2] >> 4) | (a[3] << 2)) as u8);
888        }
889    }
890
891    const PACKED_4BIT: usize = (N * 4) / 8;
892
893    fn pack_eta4(&self, z: &mut [u8; Self::PACKED_4BIT]) {
894        for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
895            let t0 = (4 - a[0]) as u8;
896            let t1 = (4 - a[1]) as u8;
897            *b = t0 | (t1 << 4);
898        }
899    }
900
901    fn unpack_eta4(&mut self, z: &[u8; Self::PACKED_4BIT]) {
902        for (a, b) in self.f.chunks_exact_mut(2).zip(z) {
903            a[0] = 4 - (b & 0xF) as i32;
904            a[1] = 4 - (b >> 4) as i32;
905        }
906    }
907
908    const PACKED_3BIT: usize = (N * 3) / 8;
909
910    fn pack_eta2(&self, z: &mut [u8; Self::PACKED_3BIT]) {
911        for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(8)) {
912            let t: [u8; 8] = array::from_fn(|i| (2 - a[i]) as u8);
913
914            b[0] = t[0] | (t[1] << 3) | (t[2] << 6);
915            b[1] = (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7);
916            b[2] = (t[5] >> 1) | (t[6] << 2) | (t[7] << 5);
917        }
918    }
919
920    fn unpack_eta2(&mut self, z: &[u8; Self::PACKED_3BIT]) {
921        for (a, b) in self.f.chunks_exact_mut(8).zip(z.chunks_exact(3)) {
922            a[0] = 2 - (b[0] & 7) as i32;
923            a[1] = 2 - ((b[0] >> 3) & 7) as i32;
924            a[2] = 2 - ((b[0] >> 6) | (b[1] << 2) & 7) as i32;
925            a[3] = 2 - ((b[1] >> 1) & 7) as i32;
926            a[4] = 2 - ((b[1] >> 4) & 7) as i32;
927            a[5] = 2 - (((b[1] >> 7) | (b[2] << 1)) & 7) as i32;
928            a[6] = 2 - ((b[2] >> 2) & 7) as i32;
929            a[7] = 2 - (b[2] >> 5) as i32
930        }
931    }
932
933    const PACKED_13BIT: usize = (N * 13) / 8;
934
935    fn pack_eta_2powdm1(&self, z: &mut [u8; Self::PACKED_13BIT]) {
936        const ETA: i32 = 1 << (D - 1);
937
938        for (b, a) in z.chunks_exact_mut(13).zip(self.f.chunks_exact(8)) {
939            let a: [u16; 8] = array::from_fn(|i| (ETA - a[i]) as u16);
940
941            b[0] = a[0] as u8;
942            b[1] = ((a[0] >> 8) | a[1] << 5) as u8;
943            b[2] = (a[1] >> 3) as u8;
944            b[3] = ((a[1] >> 11) | a[2] << 2) as u8;
945            b[4] = ((a[2] >> 6) | (a[3] << 7)) as u8;
946            b[5] = (a[3] >> 1) as u8;
947            b[6] = ((a[3] >> 9) | a[4] << 4) as u8;
948            b[7] = (a[4] >> 4) as u8;
949            b[8] = ((a[4] >> 12) | a[5] << 1) as u8;
950            b[9] = ((a[5] >> 7) | a[6] << 6) as u8;
951            b[10] = (a[6] >> 2) as u8;
952            b[11] = ((a[6] >> 10) | a[7] << 3) as u8;
953            b[12] = (a[7] >> 5) as u8;
954        }
955    }
956
957    fn unpack_eta_2powdm1(&mut self, z: &[u8; Self::PACKED_13BIT]) {
958        const ETA: i32 = 1 << (D - 1);
959
960        for (a, b) in self.f.chunks_exact_mut(8).zip(z.chunks_exact(13)) {
961            let b: [i32; 13] = array::from_fn(|i| b[i] as i32);
962
963            a[0] = ETA - ((b[0] | (b[1] << 8)) & 0x1FFF);
964            a[1] = ETA - (((b[1] >> 5) | (b[2] << 3) | (b[3] << 11)) & 0x1FFF);
965            a[2] = ETA - (((b[3] >> 2) | (b[4] << 6)) & 0x1FFF);
966            a[3] = ETA - (((b[4] >> 7) | (b[5] << 1) | (b[6] << 9)) & 0x1FFF);
967            a[4] = ETA - (((b[6] >> 4) | (b[7] << 4) | (b[8] << 12)) & 0x1FFF);
968            a[5] = ETA - (((b[8] >> 1) | (b[9] << 7)) & 0x1FFF);
969            a[6] = ETA - (((b[9] >> 6) | (b[10] << 2) | (b[11] << 10)) & 0x1FFF);
970            a[7] = ETA - (((b[11] >> 3) | (b[12] << 5)) & 0x1FFF);
971        }
972    }
973
974    fn bitpack_2pow17(&self, z: &mut [u8; Self::packed_bytesize(18)]) {
975        const B: i32 = 1 << 17;
976
977        for (b, a) in z.chunks_exact_mut(9).zip(self.f.chunks_exact(4)) {
978            let a0 = B - a[0];
979            let a1 = B - a[1];
980            let a2 = B - a[2];
981            let a3 = B - a[3];
982
983            b[0] = (a0 >> 0) as u8;
984            b[1] = (a0 >> 8) as u8;
985            b[2] = ((a0 >> 16) | (a1 << 2)) as u8;
986            b[3] = (a1 >> 6) as u8;
987            b[4] = ((a1 >> 14) | (a2 << 4)) as u8;
988            b[5] = (a2 >> 4) as u8;
989            b[6] = ((a2 >> 12) | (a3 << 6)) as u8;
990            b[7] = (a3 >> 2) as u8;
991            b[8] = (a3 >> 10) as u8;
992        }
993    }
994
995    fn bitunpack_2pow17(&mut self, z: &[u8; Self::packed_bytesize(18)]) {
996        const B: i32 = 1 << 17;
997        const BITMASK: i32 = 0x3ffff;
998
999        for (a, b) in self.f.chunks_exact_mut(4).zip(z.chunks_exact(9)) {
1000            let b: [i32; 9] = array::from_fn(|i| b[i] as i32);
1001
1002            a[0] = B - (((b[0] >> 0) | (b[1] << 8) | (b[2] << 16)) & BITMASK);
1003            a[1] = B - (((b[2] >> 2) | (b[3] << 6) | (b[4] << 14)) & BITMASK);
1004            a[2] = B - (((b[4] >> 4) | (b[5] << 4) | (b[6] << 12)) & BITMASK);
1005            a[3] = B - ((b[6] >> 6) | (b[7] << 2) | (b[8] << 10));
1006        }
1007    }
1008
1009    fn bitpack_2pow19(&self, z: &mut [u8; Self::packed_bytesize(20)]) {
1010        const B: i32 = 1 << 19;
1011
1012        for (b, a) in z.chunks_exact_mut(5).zip(self.f.chunks_exact(2)) {
1013            let a0 = B - a[0];
1014            let a1 = B - a[1];
1015
1016            b[0] = (a0 >> 0) as u8;
1017            b[1] = (a0 >> 8) as u8;
1018            b[2] = ((a0 >> 16) | (a1 << 4)) as u8;
1019            b[3] = (a1 >> 4) as u8;
1020            b[4] = (a1 >> 12) as u8;
1021        }
1022    }
1023
1024    fn bitunpack_2pow19(&mut self, z: &[u8; Self::packed_bytesize(20)]) {
1025        const B: i32 = 1 << 19;
1026        const BITMASK: i32 = 0xfffff;
1027
1028        for (a, b) in self.f.chunks_exact_mut(2).zip(z.chunks_exact(5)) {
1029            let b: [i32; 5] = array::from_fn(|i| b[i] as i32);
1030
1031            a[0] = B - (((b[0] >> 0) | (b[1] << 8) | (b[2] << 16)) & BITMASK);
1032            a[1] = B - ((b[2] >> 4) | (b[3] << 4) | (b[4] << 12));
1033        }
1034    }
1035
1036    const fn norm_in_bound(&self, bound: usize) -> bool {
1037        let mut i = 0;
1038        while i < N {
1039            if coeff::norm(self.f[i]) >= bound {
1040                return false;
1041            }
1042
1043            i += 1;
1044        }
1045
1046        true
1047    }
1048
1049    fn make_hint(&mut self, p0: &Poly, p1: &Poly, gamma2: usize) -> usize {
1050        let mut sum = 0;
1051
1052        for i in 0..N {
1053            let h = coeff::make_hint(p0.f[i], p1.f[i], gamma2 as i32);
1054
1055            self.f[i] = h as i32;
1056            sum += h;
1057        }
1058
1059        sum
1060    }
1061
1062    fn shifted_left(&self, d: usize) -> Self {
1063        let mut f = [MaybeUninit::uninit(); N];
1064
1065        for (i, a) in f.iter_mut().enumerate() {
1066            a.write(self.f[i] << d);
1067        }
1068
1069        Self {
1070            f: unsafe { transmute::<[MaybeUninit<i32>; N], [i32; N]>(f) },
1071        }
1072    }
1073}
1074
1075impl AddAssign<&Self> for Poly {
1076    fn add_assign(&mut self, rhs: &Self) {
1077        for i in 0..N {
1078            self.f[i] += rhs.f[i];
1079        }
1080    }
1081}
1082
1083impl SubAssign<&Self> for Poly {
1084    fn sub_assign(&mut self, rhs: &Self) {
1085        for i in 0..N {
1086            self.f[i] -= rhs.f[i];
1087        }
1088    }
1089}
1090
1091impl MulAssign<&Self> for Poly {
1092    fn mul_assign(&mut self, rhs: &Self) {
1093        for (i, a) in self.f.iter_mut().enumerate() {
1094            *a = reduce::mont_mul(*a, rhs.f[i]);
1095        }
1096    }
1097}
1098
1099#[repr(transparent)]
1100struct PolyMat<const K: usize, const L: usize> {
1101    m: [PolyVec<L>; K],
1102}
1103
1104impl<const K: usize, const L: usize> PolyMat<K, L> {
1105    /// ExpandA(rho)
1106    fn expand_a(rho: &[u8; 32]) -> Self {
1107        let mut g = hash::Shake128::init();
1108        let mut m: [MaybeUninit<PolyVec<L>>; K] = [const { MaybeUninit::uninit() }; K];
1109
1110        for (r, pvec) in m.iter_mut().enumerate() {
1111            let mut v: [MaybeUninit<Poly>; L] = [const { MaybeUninit::uninit() }; L];
1112
1113            for (s, poly) in v.iter_mut().enumerate() {
1114                g.absorb_multi(&[rho, &u16::to_le_bytes(((r << 8) | s) as u16)]);
1115
1116                poly.write(Poly::rej_ntt(&mut g));
1117
1118                g.reset();
1119            }
1120
1121            pvec.write(PolyVec {
1122                v: unsafe { transmute_copy(&v) },
1123            });
1124        }
1125
1126        Self {
1127            m: unsafe { transmute_copy(&m) },
1128        }
1129    }
1130}
1131
1132#[repr(transparent)]
1133struct PolyVec<const K: usize> {
1134    v: [Poly; K],
1135}
1136
1137impl<const K: usize> PolyVec<K> {
1138    const fn zero() -> Self {
1139        Self {
1140            v: [const { Poly::zero() }; K],
1141        }
1142    }
1143
1144    fn ntt_inplace(&mut self) {
1145        for p in self.v.iter_mut() {
1146            p.ntt_inplace();
1147        }
1148    }
1149
1150    fn ntt(&mut self, v_hat: &Self) {
1151        for (p_hat, p) in self.v.iter_mut().zip(&v_hat.v) {
1152            p_hat.ntt(p);
1153        }
1154    }
1155
1156    fn invntt(&self) -> Self {
1157        let mut v = [const { MaybeUninit::uninit() }; K];
1158
1159        for (i, p) in v.iter_mut().enumerate() {
1160            p.write(self.v[i].invntt());
1161        }
1162
1163        Self {
1164            v: unsafe { transmute_copy(&v) },
1165        }
1166    }
1167
1168    fn reduce(&mut self) {
1169        for p in self.v.iter_mut() {
1170            p.reduce();
1171        }
1172    }
1173
1174    fn reduce_invntt_tomont_inplace(&mut self) {
1175        for p in self.v.iter_mut() {
1176            p.reduce();
1177            p.invntt_tomont_inplace();
1178        }
1179    }
1180
1181    fn invntt_tomont_inplace(&mut self) {
1182        for p in self.v.iter_mut() {
1183            p.invntt_tomont_inplace();
1184        }
1185    }
1186
1187    fn power2round(&self, t1: &mut PolyVec<K>, t0: &mut PolyVec<K>) {
1188        for i in 0..K {
1189            self.v[i].power2round(&mut t1.v[i], &mut t0.v[i]);
1190        }
1191    }
1192
1193    fn decompose_32(&self, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>) {
1194        for i in 0..K {
1195            self.v[i].decompose_32(&mut x0.v[i], &mut x1.v[i]);
1196        }
1197    }
1198
1199    fn decompose_88(&self, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>) {
1200        for i in 0..K {
1201            self.v[i].decompose_88(&mut x0.v[i], &mut x1.v[i]);
1202        }
1203    }
1204
1205    fn pack_eta4(&self, z: &mut [u8]) {
1206        for (buf, p) in z.chunks_exact_mut(Poly::PACKED_4BIT).zip(self.v.iter()) {
1207            p.pack_eta4(buf.try_into().unwrap());
1208        }
1209    }
1210
1211    fn unpack_eta4(&mut self, z: &[u8]) {
1212        for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_4BIT)) {
1213            p.unpack_eta4(buf.try_into().unwrap());
1214        }
1215    }
1216
1217    fn pack_eta2(&self, z: &mut [u8]) {
1218        for (buf, p) in z.chunks_exact_mut(Poly::PACKED_3BIT).zip(self.v.iter()) {
1219            p.pack_eta2(buf.try_into().unwrap());
1220        }
1221    }
1222
1223    fn unpack_eta2(&mut self, z: &[u8]) {
1224        for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_3BIT)) {
1225            p.unpack_eta2(buf.try_into().unwrap());
1226        }
1227    }
1228
1229    fn pack_eta_2powdm1(&self, z: &mut [u8]) {
1230        for (buf, p) in z.chunks_exact_mut(Poly::PACKED_13BIT).zip(self.v.iter()) {
1231            p.pack_eta_2powdm1(buf.try_into().unwrap());
1232        }
1233    }
1234
1235    fn unpack_eta_2powdm1(&mut self, z: &[u8]) {
1236        for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_13BIT)) {
1237            p.unpack_eta_2powdm1(buf.try_into().unwrap());
1238        }
1239    }
1240
1241    fn pack_simple_4bit<const BZ: usize>(&self, z: &mut [u8; BZ]) {
1242        for (chunk, p) in z
1243            .chunks_exact_mut(Poly::packed_bytesize(4))
1244            .zip(self.v.iter())
1245        {
1246            p.pack_simple_4bit(chunk.try_into().unwrap());
1247        }
1248    }
1249
1250    fn pack_simple_6bit(&self, z: &mut [u8]) {
1251        for (chunk, p) in z
1252            .chunks_exact_mut(Poly::packed_bytesize(6))
1253            .zip(self.v.iter())
1254        {
1255            p.pack_simple_6bit(chunk.try_into().unwrap());
1256        }
1257    }
1258
1259    fn multiply_matvec_ntt<const L: usize>(&mut self, m: &PolyMat<K, L>, v: &PolyVec<L>) {
1260        for i in 0..K {
1261            self.v[i].dot_prod_ntt(&m.m[i], v)
1262        }
1263    }
1264
1265    fn multiply_poly_ntt(&mut self, p: &Poly, v: &PolyVec<K>) {
1266        for i in 0..K {
1267            self.v[i].multiply_ntt(p, &v.v[i]);
1268        }
1269    }
1270
1271    fn hint_bitpack<const OMEGA: usize>(&self, dst: &mut [u8]) {
1272        let mut idx = 0;
1273
1274        for i in 0..K {
1275            for j in 0..N {
1276                let h = self.v[i].f[j] as usize;
1277                dst[idx] = (j & h.wrapping_neg()) as u8;
1278                idx += h;
1279            }
1280
1281            dst[OMEGA + i] = idx as u8;
1282        }
1283    }
1284
1285    fn hint_bitunpack(y: &[u8], omega: usize) -> Result<PolyVec<K>, VerifyError> {
1286        let mut h = PolyVec::zero();
1287
1288        let mut idx = 0;
1289
1290        for i in 0..K {
1291            let num_hints = y[omega + i] as usize;
1292
1293            if num_hints < idx || num_hints > omega {
1294                return Err(VerifyError::TooManyHints);
1295            }
1296
1297            if idx >= num_hints {
1298                continue;
1299            }
1300
1301            h.v[i].f[y[idx] as usize] = 1;
1302            idx += 1;
1303
1304            for j in idx..num_hints {
1305                if y[idx - 1] >= y[j] {
1306                    return Err(VerifyError::TooManyHints);
1307                }
1308
1309                h.v[i].f[y[j] as usize] = 1;
1310            }
1311            idx = num_hints;
1312        }
1313
1314        if y[idx..omega].iter().any(|x| *x != 0) {
1315            return Err(VerifyError::TooManyHints);
1316        }
1317
1318        Ok(h)
1319    }
1320
1321    fn bitpack_2pow17(&self, dst: &mut [u8]) {
1322        for (buf, p) in dst
1323            .chunks_exact_mut(Poly::packed_bytesize(18))
1324            .zip(self.v.iter())
1325        {
1326            p.bitpack_2pow17(buf.try_into().unwrap());
1327        }
1328    }
1329
1330    fn bitpack_2pow19(&self, dst: &mut [u8]) {
1331        for (buf, p) in dst
1332            .chunks_exact_mut(Poly::packed_bytesize(20))
1333            .zip(self.v.iter())
1334        {
1335            p.bitpack_2pow19(buf.try_into().unwrap());
1336        }
1337    }
1338
1339    fn expand_mask_2pow17(&mut self, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256) {
1340        let mut blocks = [0u8; 5 * hash::SHAKE_256_RATE];
1341
1342        for (r, p) in self.v.iter_mut().enumerate() {
1343            h.absorb_multi(&[rho, &u16::to_le_bytes((mu + r) as u16)]);
1344            h.squeezeblocks(&mut blocks);
1345
1346            p.bitunpack_2pow17(blocks.first_chunk_mut().unwrap());
1347        }
1348    }
1349
1350    fn expand_mask_2pow19(&mut self, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256) {
1351        let mut blocks = [0u8; 5 * hash::SHAKE_256_RATE];
1352
1353        for (r, p) in self.v.iter_mut().enumerate() {
1354            h.absorb_multi(&[rho, &u16::to_le_bytes((mu + r) as u16)]);
1355            h.squeezeblocks(&mut blocks);
1356
1357            p.bitunpack_2pow19(blocks.first_chunk_mut().unwrap());
1358        }
1359    }
1360
1361    const fn norm_in_bound(&self, bound: usize) -> bool {
1362        let mut i = 0;
1363
1364        while i < K {
1365            if !self.v[i].norm_in_bound(bound) {
1366                return false;
1367            }
1368
1369            i += 1;
1370        }
1371
1372        true
1373    }
1374
1375    fn make_hint(&mut self, x0: &PolyVec<K>, x1: &PolyVec<K>, gamma2: usize) -> usize {
1376        let mut sum = 0;
1377        for i in 0..K {
1378            sum += self.v[i].make_hint(&x0.v[i], &x1.v[i], gamma2);
1379        }
1380
1381        sum
1382    }
1383
1384    fn shifted_left(&self, d: usize) -> Self {
1385        let mut v = [const { MaybeUninit::uninit() }; K];
1386
1387        for (i, poly) in v.iter_mut().enumerate() {
1388            poly.write(self.v[i].shifted_left(d));
1389        }
1390
1391        Self {
1392            v: unsafe { transmute_copy(&v) },
1393        }
1394    }
1395}
1396
1397impl<const K: usize> Mul<&Poly> for &PolyVec<K> {
1398    type Output = PolyVec<K>;
1399
1400    fn mul(self, rhs: &Poly) -> Self::Output {
1401        let mut v = PolyVec::zero();
1402
1403        for i in 0..K {
1404            v.v[i].multiply_ntt(&self.v[i], rhs);
1405        }
1406
1407        v
1408    }
1409}
1410
1411impl<const K: usize> MulAssign<&Poly> for PolyVec<K> {
1412    fn mul_assign(&mut self, rhs: &Poly) {
1413        for poly in self.v.iter_mut() {
1414            *poly *= rhs;
1415        }
1416    }
1417}
1418
1419impl<const K: usize> AddAssign<&Self> for PolyVec<K> {
1420    fn add_assign(&mut self, rhs: &Self) {
1421        for i in 0..K {
1422            self.v[i] += &rhs.v[i];
1423        }
1424    }
1425}
1426
1427impl<const K: usize> SubAssign<&Self> for PolyVec<K> {
1428    fn sub_assign(&mut self, rhs: &Self) {
1429        for i in 0..K {
1430            self.v[i] -= &rhs.v[i];
1431        }
1432    }
1433}
1434
1435/// ExpandS(rho)
1436fn expand_s<const K: usize, const L: usize, const ETA: usize>(
1437    s1: &mut PolyVec<L>,
1438    s2: &mut PolyVec<K>,
1439    rho: &[u8; 64],
1440) {
1441    let mut h = hash::Shake256::init();
1442
1443    for (nonce, poly) in s1.v.iter_mut().chain(s2.v.iter_mut()).enumerate() {
1444        h.absorb_multi(&[rho, &u16::to_le_bytes(nonce as u16)]);
1445        poly.rej_bounded::<ETA>(&mut h);
1446        h.reset();
1447    }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452    use rand::RngCore;
1453    use rand_core::OsRng;
1454    use serde::Deserialize;
1455    use std::{fs::read_to_string, path::PathBuf};
1456
1457    use super::*;
1458
1459    #[test]
1460    fn test_gen_sign_verify() {
1461        let mut pk = [0u8; mldsa44::PUBKEY_SIZE];
1462        let sk = mldsa44::PrivateKey::keygen(&mut pk, &mut OsRng);
1463        let vk = mldsa44::PublicKey::decode(&pk);
1464        let mut token = [0u8; 32];
1465        OsRng.fill_bytes(&mut token);
1466        let mut sig = [0u8; mldsa44::SIG_SIZE];
1467        sk.sign(&mut sig, &mut OsRng, &token);
1468        vk.verify(&token, &sig).unwrap();
1469
1470        let mut pk = [0u8; mldsa65::PUBKEY_SIZE];
1471        let sk = mldsa65::PrivateKey::keygen(&mut pk, &mut OsRng);
1472        let vk = mldsa65::PublicKey::decode(&pk);
1473        let mut token = [0u8; 32];
1474        OsRng.fill_bytes(&mut token);
1475        let mut sig = [0u8; mldsa65::SIG_SIZE];
1476        sk.sign(&mut sig, &mut OsRng, &token);
1477        vk.verify(&token, &sig).unwrap();
1478
1479        let mut pk = [0u8; mldsa87::PUBKEY_SIZE];
1480        let sk = mldsa87::PrivateKey::keygen(&mut pk, &mut OsRng);
1481        let vk = mldsa87::PublicKey::decode(&pk);
1482        let mut token = [0u8; 32];
1483        OsRng.fill_bytes(&mut token);
1484        let mut sig = [0u8; mldsa87::SIG_SIZE];
1485        sk.sign(&mut sig, &mut OsRng, &token);
1486        vk.verify(&token, &sig).unwrap();
1487    }
1488
1489    #[test]
1490    fn test_keygen() {
1491        let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1492        test_data_path.push("tests/mldsa-keygen.json");
1493
1494        let test_data = read_to_string(&test_data_path).unwrap();
1495        let test_data: Tests<KeyGenTg> = serde_json::from_str(&test_data).unwrap();
1496
1497        for tg in test_data.test_groups.iter() {
1498            match tg.parameter_set.as_str() {
1499                "ML-DSA-44" => {
1500                    for test in &tg.tests {
1501                        let mut vk_bytes = [0u8; mldsa44::PUBKEY_SIZE];
1502                        let mut sk_bytes = [0u8; mldsa44::PRIVKEY_SIZE];
1503
1504                        let sk = mldsa44::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1505                        sk.encode(&mut sk_bytes);
1506
1507                        assert_eq!(vk_bytes, test.pk[..]);
1508                        assert_eq!(sk_bytes, test.sk[..]);
1509
1510                        let sk_prime = mldsa44::PrivateKey::decode(&test.sk);
1511
1512                        sk_bytes.fill(0);
1513                        sk_prime.encode(&mut sk_bytes);
1514
1515                        assert_eq!(sk_bytes, test.sk[..]);
1516
1517                        let vk_prime = mldsa44::PublicKey::decode(&test.pk);
1518
1519                        vk_bytes.fill(0);
1520                        vk_prime.encode(&mut vk_bytes);
1521
1522                        assert_eq!(vk_bytes, test.pk[..]);
1523                    }
1524                }
1525                "ML-DSA-65" => {
1526                    for test in &tg.tests {
1527                        let mut vk_bytes = [0u8; mldsa65::PUBKEY_SIZE];
1528                        let mut sk_bytes = [0u8; mldsa65::PRIVKEY_SIZE];
1529
1530                        let sk = mldsa65::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1531                        sk.encode(&mut sk_bytes);
1532
1533                        assert_eq!(vk_bytes, test.pk[..]);
1534                        assert_eq!(sk_bytes, test.sk[..]);
1535
1536                        let sk_prime = mldsa65::PrivateKey::decode(&test.sk);
1537
1538                        sk_bytes.fill(0);
1539                        sk_prime.encode(&mut sk_bytes);
1540
1541                        assert_eq!(sk_bytes, test.sk[..]);
1542
1543                        let vk_prime = mldsa65::PublicKey::decode(&test.pk);
1544
1545                        vk_bytes.fill(0);
1546                        vk_prime.encode(&mut vk_bytes);
1547
1548                        assert_eq!(vk_bytes, test.pk[..]);
1549                    }
1550                }
1551                "ML-DSA-87" => {
1552                    for test in &tg.tests {
1553                        let mut vk_bytes = [0u8; mldsa87::PUBKEY_SIZE];
1554                        let mut sk_bytes = [0u8; mldsa87::PRIVKEY_SIZE];
1555
1556                        let sk = mldsa87::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1557                        sk.encode(&mut sk_bytes);
1558
1559                        assert_eq!(vk_bytes, test.pk[..]);
1560                        assert_eq!(sk_bytes, test.sk[..]);
1561
1562                        let sk_prime = mldsa87::PrivateKey::decode(&test.sk);
1563
1564                        sk_bytes.fill(0);
1565                        sk_prime.encode(&mut sk_bytes);
1566
1567                        assert_eq!(sk_bytes, test.sk[..]);
1568
1569                        let vk_prime = mldsa87::PublicKey::decode(&test.pk);
1570
1571                        vk_bytes.fill(0);
1572                        vk_prime.encode(&mut vk_bytes);
1573
1574                        assert_eq!(vk_bytes, test.pk[..]);
1575                    }
1576                }
1577                _ => panic!("invalid paramter set"),
1578            };
1579        }
1580    }
1581
1582    #[test]
1583    fn test_siggen() {
1584        let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1585        test_data_path.push("tests/mldsa-sign.json");
1586
1587        let test_data = read_to_string(&test_data_path).unwrap();
1588        let test_data: Tests<SigGenTg> = serde_json::from_str(&test_data).unwrap();
1589
1590        for tg in test_data.test_groups.iter() {
1591            match tg.parameter_set.as_str() {
1592                "ML-DSA-44" => {
1593                    let mut sig = [0u8; mldsa44::SIG_SIZE];
1594
1595                    for test in tg.tests.iter() {
1596                        sig.fill(0);
1597                        let sk = mldsa44::PrivateKey::decode(&test.sk);
1598                        let rnd = match &test.rnd {
1599                            Some(rnd) => rnd.rnd,
1600                            None => [0; 32],
1601                        };
1602                        sk.sign_internal(&mut sig, &test.message, &rnd);
1603                        assert_eq!(&sig, &test.signature[..]);
1604                    }
1605                }
1606                "ML-DSA-65" => {
1607                    let mut sig = [0u8; mldsa65::SIG_SIZE];
1608
1609                    for test in tg.tests.iter() {
1610                        sig.fill(0);
1611                        let sk = mldsa65::PrivateKey::decode(&test.sk);
1612                        let rnd = match &test.rnd {
1613                            Some(rnd) => rnd.rnd,
1614                            None => [0; 32],
1615                        };
1616                        sk.sign_internal(&mut sig, &test.message, &rnd);
1617                        assert_eq!(&sig, &test.signature[..]);
1618                    }
1619                }
1620                "ML-DSA-87" => {
1621                    let mut sig = [0u8; mldsa87::SIG_SIZE];
1622
1623                    for test in tg.tests.iter() {
1624                        sig.fill(0);
1625                        let sk = mldsa87::PrivateKey::decode(&test.sk);
1626                        let rnd = match &test.rnd {
1627                            Some(rnd) => rnd.rnd,
1628                            None => [0; 32],
1629                        };
1630                        sk.sign_internal(&mut sig, &test.message, &rnd);
1631                        assert_eq!(&sig, &test.signature[..]);
1632                    }
1633                }
1634                _ => panic!("invalid paramter set"),
1635            };
1636        }
1637    }
1638
1639    #[test]
1640    fn test_sigver() {
1641        let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1642        test_data_path.push("tests/mldsa-verify.json");
1643
1644        let test_data = read_to_string(&test_data_path).unwrap();
1645        let test_data: Tests<SigVerTg> = serde_json::from_str(&test_data).unwrap();
1646
1647        for tg in test_data.test_groups.iter() {
1648            match tg.parameter_set.as_str() {
1649                "ML-DSA-44" => {
1650                    let pk = mldsa44::PublicKey::decode(&tg.pk);
1651
1652                    for test in tg.tests.iter() {
1653                        match pk
1654                            .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1655                        {
1656                            Ok(_) => assert!(test.test_passed),
1657                            Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1658                            Err(VerifyError::Mismatch) => {
1659                                assert!(!test.test_passed)
1660                            }
1661                            Err(VerifyError::TooManyHints) => {
1662                                assert_eq!(test.reason, "too many hints")
1663                            }
1664                        }
1665                    }
1666                }
1667                "ML-DSA-65" => {
1668                    let pk = mldsa65::PublicKey::decode(&tg.pk);
1669
1670                    for test in tg.tests.iter() {
1671                        match pk
1672                            .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1673                        {
1674                            Ok(_) => assert!(test.test_passed),
1675                            Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1676                            Err(VerifyError::Mismatch) => {
1677                                assert!(!test.test_passed)
1678                            }
1679                            Err(VerifyError::TooManyHints) => {
1680                                assert_eq!(test.reason, "too many hints")
1681                            }
1682                        }
1683                    }
1684                }
1685                "ML-DSA-87" => {
1686                    let pk = mldsa87::PublicKey::decode(&tg.pk);
1687
1688                    for test in tg.tests.iter() {
1689                        match pk
1690                            .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1691                        {
1692                            Ok(_) => assert!(test.test_passed),
1693                            Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1694                            Err(VerifyError::Mismatch) => {
1695                                assert!(!test.test_passed)
1696                            }
1697                            Err(VerifyError::TooManyHints) => {
1698                                assert_eq!(test.reason, "too many hints")
1699                            }
1700                        }
1701                    }
1702                }
1703                _ => panic!("invalid paramter set"),
1704            };
1705        }
1706    }
1707
1708    #[derive(Deserialize)]
1709    struct KeyGenTV {
1710        #[serde(with = "hex")]
1711        pk: Vec<u8>,
1712
1713        #[serde(with = "hex")]
1714        seed: [u8; 32],
1715
1716        #[serde(with = "hex")]
1717        sk: Vec<u8>,
1718    }
1719    #[derive(Deserialize)]
1720    struct KeyGenTg {
1721        #[serde(rename = "parameterSet")]
1722        parameter_set: String,
1723
1724        tests: Vec<KeyGenTV>,
1725    }
1726    #[derive(Deserialize)]
1727    struct Tests<T> {
1728        #[serde(rename = "testGroups")]
1729        test_groups: Vec<T>,
1730    }
1731
1732    #[derive(Deserialize)]
1733    struct Rnd {
1734        #[serde(with = "hex")]
1735        rnd: [u8; 32],
1736    }
1737
1738    #[derive(Deserialize)]
1739    struct SigGenTV {
1740        #[serde(with = "hex")]
1741        message: Vec<u8>,
1742
1743        #[serde(with = "hex")]
1744        signature: Vec<u8>,
1745
1746        #[serde(with = "hex")]
1747        sk: Vec<u8>,
1748
1749        #[serde(flatten)]
1750        rnd: Option<Rnd>,
1751    }
1752    #[derive(Deserialize)]
1753    struct SigGenTg {
1754        #[serde(rename = "parameterSet")]
1755        parameter_set: String,
1756
1757        tests: Vec<SigGenTV>,
1758    }
1759
1760    #[derive(Deserialize)]
1761    struct SigVerTV {
1762        #[serde(with = "hex")]
1763        message: Vec<u8>,
1764
1765        reason: String,
1766
1767        #[serde(with = "hex")]
1768        signature: Vec<u8>,
1769
1770        #[serde(rename = "testPassed")]
1771        test_passed: bool,
1772    }
1773
1774    #[derive(Deserialize)]
1775    struct SigVerTg {
1776        #[serde(rename = "parameterSet")]
1777        parameter_set: String,
1778
1779        #[serde(with = "hex")]
1780        pk: Vec<u8>,
1781
1782        tests: Vec<SigVerTV>,
1783    }
1784}