1use crate::hash;
4use core::{
5 array,
6 mem::{transmute, transmute_copy, MaybeUninit},
7 ops::{AddAssign, Mul, MulAssign, SubAssign},
8};
9use rand_core::CryptoRngCore;
10use thiserror::Error;
11use zeroize::Zeroize;
12
13pub mod mldsa44;
14pub mod mldsa65;
15pub mod mldsa87;
16
17mod coeff;
18mod reduce;
19
20const Q: i32 = 8380417;
21const N: usize = 256;
22const ZETA: i32 = 1753;
23const D: usize = 13;
24
25const ZETAS: [i32; N] = {
29 let mut zetas = [0; N];
30 zetas[0] = reduce::R_MOD_Q;
31
32 let mut i = 1;
33 while i < N {
34 zetas[i] = reduce::mont_mul(zetas[i - 1], reduce::to_mont(ZETA));
35
36 i += 1
37 }
38
39 let mut zetas_bitrev = [0; N];
40
41 i = 0;
42 while i < N {
43 let idx = (i as u8).reverse_bits();
44
45 zetas_bitrev[i] = match zetas[idx as usize] {
46 z if z > Q / 2 => z - Q,
47 z if z < -Q / 2 => z + Q,
48 z => z,
49 };
50
51 i += 1;
52 }
53
54 zetas_bitrev
55};
56
57trait SigningKeyInternal<
58 const K: usize,
59 const L: usize,
60 const ETA: usize,
61 const TAU: usize,
62 const GAMMA1: usize,
63 const GAMMA2: usize,
64 const BETA: usize,
65 const OMEGA: usize,
66 const CT_BYTES: usize,
67 const W1_BYTES: usize,
68 const Z_BYTES: usize,
69>: From<PrivateKey<K, L, ETA>>
70{
71 fn privkey(&self) -> &PrivateKey<K, L, ETA>;
72 fn expand_mask(pvec: &mut PolyVec<L>, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256);
73 fn bitpack_z(pvec: &PolyVec<L>, dst: &mut [u8; Z_BYTES]);
74 fn pack_simple(w1: &PolyVec<K>, z: &mut [u8; W1_BYTES]);
75 fn decompose(x: &PolyVec<K>, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>);
76
77 fn sign_internal(&self, dst: &mut [u8], m: &[u8], rnd: &[u8; 32]) {
78 let (c_tilde, buf) = dst.split_first_chunk_mut::<CT_BYTES>().unwrap();
79 let (w1_bytes, buf) = buf.split_first_chunk_mut::<W1_BYTES>().unwrap();
80 let (mu, buf) = buf.split_first_chunk_mut::<64>().unwrap();
81 let rho_prime2: &mut [u8; 64] = buf.first_chunk_mut().unwrap();
82
83 let mut h = hash::Shake256::init();
84 let privkey = self.privkey();
85
86 h.absorb_and_squeeze(mu, &[&privkey.tr, m]);
87
88 h.absorb_and_squeeze(rho_prime2, &[&privkey.k, rnd, mu]);
89
90 let mut y = PolyVec::zero();
91 let mut y_hat = PolyVec::zero();
92 let mut w = PolyVec::zero();
93 let mut w1 = PolyVec::zero();
94 let mut w0 = PolyVec::zero();
95 let mut z = PolyVec::zero();
96 let mut hint = PolyVec::zero();
97 let mut c_hat = Poly::zero();
98
99 for nonce in (0..).step_by(L) {
100 Self::expand_mask(&mut y, rho_prime2, nonce, &mut h);
101
102 y_hat.ntt(&y);
103
104 w.multiply_matvec_ntt(&privkey.a_hat, &y_hat);
105 w.reduce_invntt_tomont_inplace();
106
107 Self::decompose(&w, &mut w0, &mut w1);
108 Self::pack_simple(&w1, w1_bytes);
109 h.absorb_and_squeeze(c_tilde, &[mu, w1_bytes]);
110
111 h.absorb(c_tilde);
112 h.finalize();
113 c_hat.f.fill(0);
114 c_hat.sample_in_ball(&mut h, TAU);
115 h.reset();
116 c_hat.ntt_inplace();
117
118 z.multiply_poly_ntt(&c_hat, &privkey.s1_hat);
119 z.invntt_tomont_inplace();
120 z += &y;
121 z.reduce();
122
123 if !z.norm_in_bound(GAMMA1 - BETA) {
124 continue;
125 }
126
127 hint.multiply_poly_ntt(&c_hat, &privkey.s2_hat);
128 hint.invntt_tomont_inplace();
129 w0 -= &hint;
130 w0.reduce();
131
132 if !w0.norm_in_bound(GAMMA2 - BETA) {
133 continue;
134 }
135
136 hint.multiply_poly_ntt(&c_hat, &privkey.t0_hat);
137 hint.invntt_tomont_inplace();
138 hint.reduce();
139
140 if !hint.norm_in_bound(GAMMA2) {
141 continue;
142 }
143
144 w0 += &hint;
145
146 let count = hint.make_hint(&w0, &w1, GAMMA2);
147
148 if count >= OMEGA {
149 continue;
150 }
151
152 break;
153 }
154
155 let (z_buf, buf) = dst[CT_BYTES..].split_first_chunk_mut().unwrap();
156
157 Self::bitpack_z(&z, z_buf);
158
159 hint.hint_bitpack::<OMEGA>(buf);
160 }
161
162 fn keygen_internal(vk: &mut [u8], ksi: &[u8; 32]) -> Self {
163 let mut h = hash::Shake256::init();
164 h.absorb_multi(&[ksi, &[K as u8], &[L as u8]]);
165 let rho: [u8; 32] = h.squeeze_array();
166 let rho_prime: [u8; 64] = h.squeeze_array();
167 let k: [u8; 32] = h.squeeze_array();
168
169 let mut s1_hat = PolyVec::zero();
170 let mut s2_hat = PolyVec::zero();
171 let mut t0_hat = PolyVec::zero();
172 let a_hat = PolyMat::expand_a(&rho);
173
174 expand_s::<K, L, ETA>(&mut s1_hat, &mut s2_hat, &rho_prime);
175
176 s1_hat.ntt_inplace();
177
178 let mut t = PolyVec::zero();
179 let mut t1 = PolyVec::zero();
180
181 t.multiply_matvec_ntt(&a_hat, &s1_hat);
182 t.reduce_invntt_tomont_inplace();
183
184 t += &s2_hat;
185
186 t.power2round(&mut t1, &mut t0_hat);
187
188 vk_encode(vk, &rho, &t1);
189
190 h.reset();
191 h.absorb(vk);
192 h.finalize();
193 let tr: [u8; 64] = h.squeeze_array();
194
195 s2_hat.ntt_inplace();
196 t0_hat.ntt_inplace();
197
198 PrivateKey {
199 rho,
200 k,
201 tr,
202 s1_hat,
203 s2_hat,
204 t0_hat,
205 a_hat,
206 }
207 .into()
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum VerifyError {
213 #[error("z is out of bound")]
214 ZoutOfBound,
215
216 #[error("signature mismatch")]
217 Mismatch,
218
219 #[error("too many hints in signature")]
220 TooManyHints,
221}
222
223trait VerifyingKeyInternal<
224 const K: usize,
225 const L: usize,
226 const CT_BYTES: usize,
227 const Z_BYTES: usize,
228 const H_BYTES: usize,
229 const W1_BYTES: usize,
230 const SIG_SIZE: usize,
231>
232{
233 const OMEGA: usize;
234 const TAU: usize;
235 const GAMMA1: usize;
236 const GAMMA2: usize;
237 const BETA: usize;
238
239 fn bitunpack_z_hat(b: &[u8; Z_BYTES]) -> PolyVec<L>;
240
241 fn w1encode(w1: &PolyVec<K>) -> [u8; W1_BYTES];
242
243 fn use_hint(w1: &mut PolyVec<K>, h: &PolyVec<K>);
244
245 fn pk(&self) -> &PublicKey<K, L>;
246
247 fn verify_internal(&self, m: &[u8], sig: &[u8; SIG_SIZE]) -> Result<(), VerifyError> {
248 let (c_tilde, sig) = sig.split_first_chunk::<CT_BYTES>().unwrap();
249 let (z_bytes, sig) = sig.split_first_chunk::<Z_BYTES>().unwrap();
250 let h_bytes: &[u8; H_BYTES] = sig.try_into().unwrap();
251
252 let hint = PolyVec::hint_bitunpack(h_bytes, Self::OMEGA)?;
253
254 let mut z_hat = Self::bitunpack_z_hat(z_bytes);
255
256 if !z_hat.norm_in_bound(Self::GAMMA1 - Self::BETA) {
257 return Err(VerifyError::ZoutOfBound);
258 }
259
260 let pk = self.pk();
261
262 let mut h = hash::Shake256::init();
263
264 h.absorb_multi(&[&pk.tr, m]);
265 let mu: [u8; 64] = h.squeeze_array();
266 h.reset();
267
268 let mut c_hat = Poly::zero();
269 h.absorb(c_tilde);
270 h.finalize();
271 c_hat.sample_in_ball(&mut h, Self::TAU);
272 h.reset();
273
274 z_hat.ntt_inplace();
275
276 let mut w1 = PolyVec::zero();
277 w1.multiply_matvec_ntt(&pk.a_hat, &z_hat);
278
279 c_hat.ntt_inplace();
280
281 let mut t1 = pk.t1.shifted_left(D);
282 t1.ntt_inplace();
283 t1 *= &c_hat;
284
285 w1 -= &t1;
286 w1.reduce_invntt_tomont_inplace();
287 Self::use_hint(&mut w1, &hint);
288
289 let w1_bytes = Self::w1encode(&w1);
290
291 h.absorb_multi(&[&mu, &w1_bytes]);
292 let c_tilde_prime = h.squeeze_array();
293
294 if c_tilde == &c_tilde_prime {
295 Ok(())
296 } else {
297 Err(VerifyError::Mismatch)
298 }
299 }
300}
301
302pub trait VerifyingKey<
303 const K: usize,
304 const L: usize,
305 const CT_BYTES: usize,
306 const Z_BYTES: usize,
307 const H_BYTES: usize,
308 const W1_BYTES: usize,
309 const SIG_SIZE: usize,
310>
311{
312 fn verify(&self, m: &[u8], sig: &[u8]) -> Result<(), VerifyError>;
313 fn encode(&self, dst: &mut [u8]);
314 fn decode(src: &[u8]) -> Self;
315}
316
317impl<
318 T,
319 const K: usize,
320 const L: usize,
321 const CT_BYTES: usize,
322 const Z_BYTES: usize,
323 const H_BYTES: usize,
324 const W1_BYTES: usize,
325 const SIG_SIZE: usize,
326 > VerifyingKey<K, L, CT_BYTES, Z_BYTES, H_BYTES, W1_BYTES, SIG_SIZE> for T
327where
328 T: VerifyingKeyInternal<K, L, CT_BYTES, Z_BYTES, H_BYTES, W1_BYTES, SIG_SIZE>
329 + From<PublicKey<K, L>>,
330{
331 fn verify(&self, m: &[u8], sig: &[u8]) -> Result<(), VerifyError> {
332 assert!(sig.len() >= SIG_SIZE);
333 self.verify_internal(m, sig.first_chunk().unwrap())
334 }
335
336 fn encode(&self, dst: &mut [u8]) {
337 assert!(dst.len() >= pubkey_size(K));
338 let key = self.pk();
339 vk_encode(&mut dst[..pubkey_size(K)], &key.rho, &key.t1)
340 }
341
342 fn decode(src: &[u8]) -> Self {
343 assert!(src.len() >= pubkey_size(K));
344 PublicKey::decode(src).into()
345 }
346}
347
348pub trait SigningKey<
350 const K: usize,
351 const L: usize,
352 const ETA: usize,
353 const TAU: usize,
354 const GAMMA1: usize,
355 const GAMMA2: usize,
356 const BETA: usize,
357 const OMEGA: usize,
358 const CT_BYTES: usize,
359 const W1_BYTES: usize,
360 const Z_BYTES: usize,
361>
362{
363 fn sign(&self, sig: &mut [u8], rng: &mut impl CryptoRngCore, m: &[u8]);
365 fn encode(&self, dst: &mut [u8]);
366 fn decode(src: &[u8]) -> Self;
367
368 fn keygen(vk: &mut [u8], rng: &mut impl CryptoRngCore) -> Self;
370}
371
372impl<
373 T,
374 const K: usize,
375 const L: usize,
376 const ETA: usize,
377 const TAU: usize,
378 const GAMMA1: usize,
379 const GAMMA2: usize,
380 const BETA: usize,
381 const OMEGA: usize,
382 const CT_BYTES: usize,
383 const W1_BYTES: usize,
384 const Z_BYTES: usize,
385 > SigningKey<K, L, ETA, TAU, GAMMA1, GAMMA2, BETA, OMEGA, CT_BYTES, W1_BYTES, Z_BYTES> for T
386where
387 T: SigningKeyInternal<K, L, ETA, TAU, GAMMA1, GAMMA2, BETA, OMEGA, CT_BYTES, W1_BYTES, Z_BYTES>,
388{
389 fn sign(&self, sig: &mut [u8], rng: &mut impl CryptoRngCore, m: &[u8]) {
390 let mut rnd = [0u8; 32];
391 rng.fill_bytes(&mut rnd);
392
393 self.sign_internal(sig, m.as_ref(), &rnd)
394 }
395
396 fn encode(&self, dst: &mut [u8]) {
397 assert!(dst.len() >= privkey_size(K, L, ETA));
398 self.privkey().encode(dst);
399 }
400
401 fn decode(src: &[u8]) -> Self {
402 assert!(src.len() >= privkey_size(K, L, ETA));
403 PrivateKey::decode(src).into()
404 }
405
406 fn keygen(pk: &mut [u8], rng: &mut impl CryptoRngCore) -> Self {
408 debug_assert!(pk.len() >= pubkey_size(K));
409
410 let mut ksi = [0u8; 32];
411 rng.fill_bytes(&mut ksi);
412
413 let sk = Self::keygen_internal(pk, &ksi);
414
415 ksi.zeroize();
416
417 sk
418 }
419}
420
421const fn pubkey_size(k: usize) -> usize {
422 k * Poly::PACKED_10BIT + 32
423}
424
425const fn privkey_size(k: usize, l: usize, eta: usize) -> usize {
426 match eta {
427 2 => 32 + 32 + 64 + l * Poly::PACKED_3BIT + k * (Poly::PACKED_3BIT + Poly::PACKED_13BIT),
428 4 => 32 + 32 + 64 + l * Poly::PACKED_4BIT + k * (Poly::PACKED_4BIT + Poly::PACKED_13BIT),
429 _ => unreachable!(),
430 }
431}
432
433const fn bitlen(n: usize) -> usize {
434 n.ilog2() as usize + 1
435}
436
437const fn sig_size(k: usize, l: usize, lambda: usize, gamma1: usize, omega: usize) -> usize {
438 lambda / 4 + l * 32 * (1 + bitlen(gamma1 - 1)) + omega + k
439}
440
441fn vk_encode<const K: usize>(dst: &mut [u8], rho: &[u8; 32], t1: &PolyVec<K>) {
442 dst[..32].copy_from_slice(rho);
443 for (xi, z) in
444 t1.v.iter()
445 .zip(dst[32..].chunks_exact_mut(Poly::PACKED_10BIT))
446 {
447 xi.pack_simple_10bit(z.try_into().unwrap())
448 }
449}
450
451struct PublicKey<const K: usize, const L: usize> {
453 rho: [u8; 32],
454 tr: [u8; 64],
455 t1: PolyVec<K>,
456 a_hat: PolyMat<K, L>,
457}
458
459impl<const K: usize, const L: usize> PublicKey<K, L> {
460 fn decode(pk: &[u8]) -> Self {
462 let rho = array::from_fn(|i| pk[i]);
463 let mut t1 = PolyVec::zero();
464
465 for (xi, z) in
466 t1.v.iter_mut()
467 .zip(pk[32..].chunks_exact(Poly::PACKED_10BIT))
468 {
469 xi.unpack_simple_10bit(z.try_into().unwrap())
470 }
471
472 let a_hat = PolyMat::expand_a(&rho);
473
474 let mut h = hash::Shake256::init();
475 h.absorb(pk);
476 h.finalize();
477 let tr = h.squeeze_array();
478
479 Self { rho, tr, t1, a_hat }
480 }
481}
482
483pub(crate) struct PrivateKey<const K: usize, const L: usize, const ETA: usize> {
485 rho: [u8; 32],
486 k: [u8; 32],
487 tr: [u8; 64],
488 s1_hat: PolyVec<L>,
489 s2_hat: PolyVec<K>,
490 t0_hat: PolyVec<K>,
491 a_hat: PolyMat<K, L>,
492}
493
494impl<const K: usize, const L: usize, const ETA: usize> Drop for PrivateKey<K, L, ETA> {
495 fn drop(&mut self) {
496 self.k.zeroize();
497 self.tr.zeroize();
498 }
499}
500
501impl<const K: usize, const L: usize, const ETA: usize> PrivateKey<K, L, ETA> {
502 pub fn encode(&self, dst: &mut [u8]) {
504 let s1 = self.s1_hat.invntt();
505 let s2 = self.s2_hat.invntt();
506 let t0 = self.t0_hat.invntt();
507
508 dst[..32].copy_from_slice(&self.rho);
509 dst[32..64].copy_from_slice(&self.k);
510 dst[64..128].copy_from_slice(&self.tr);
511
512 let buf = &mut dst[128..];
513
514 match ETA {
515 2 => {
516 s1.pack_eta2(&mut buf[..L * Poly::PACKED_3BIT]);
517
518 let buf = &mut buf[L * Poly::PACKED_3BIT..];
519 s2.pack_eta2(&mut buf[..K * Poly::PACKED_3BIT]);
520
521 let buf = &mut buf[K * Poly::PACKED_3BIT..];
522 t0.pack_eta_2powdm1(buf)
523 }
524 4 => {
525 s1.pack_eta4(&mut buf[..L * Poly::PACKED_4BIT]);
526
527 let buf = &mut buf[L * Poly::PACKED_4BIT..];
528 s2.pack_eta4(buf);
529
530 let buf = &mut buf[K * Poly::PACKED_4BIT..];
531 t0.pack_eta_2powdm1(buf)
532 }
533 _ => unreachable!(),
534 }
535 }
536
537 pub fn decode(src: &[u8]) -> Self {
539 let mut rho: MaybeUninit<[u8; 32]> = MaybeUninit::uninit();
540 let mut k: MaybeUninit<[u8; 32]> = MaybeUninit::uninit();
541 let mut tr: MaybeUninit<[u8; 64]> = MaybeUninit::uninit();
542
543 rho.write(src[..32].try_into().unwrap());
544 k.write(src[32..64].try_into().unwrap());
545 tr.write(src[64..128].try_into().unwrap());
546
547 let (rho, k, tr) = unsafe { (rho.assume_init(), k.assume_init(), tr.assume_init()) };
548
549 let mut s1_hat = PolyVec::zero();
550 let mut s2_hat = PolyVec::zero();
551 let mut t0_hat = PolyVec::zero();
552
553 match ETA {
554 2 => {
555 let z = &src[128..];
556 s1_hat.unpack_eta2(&z[..L * Poly::PACKED_3BIT]);
557
558 let z = &z[L * Poly::PACKED_3BIT..];
559 s2_hat.unpack_eta2(&z[..K * Poly::PACKED_3BIT]);
560
561 let z = &z[K * Poly::PACKED_3BIT..];
562 t0_hat.unpack_eta_2powdm1(z)
563 }
564 4 => {
565 let z = &src[128..];
566 s1_hat.unpack_eta4(&z[..L * Poly::PACKED_4BIT]);
567
568 let z = &z[L * Poly::PACKED_4BIT..];
569 s2_hat.unpack_eta4(&z[..K * Poly::PACKED_4BIT]);
570
571 let z = &z[K * Poly::PACKED_4BIT..];
572 t0_hat.unpack_eta_2powdm1(z)
573 }
574 _ => unreachable!(),
575 }
576
577 let a_hat = PolyMat::expand_a(&rho);
578
579 s1_hat.ntt_inplace();
580 s2_hat.ntt_inplace();
581 t0_hat.ntt_inplace();
582
583 Self {
584 rho,
585 k,
586 tr,
587 s1_hat,
588 s2_hat,
589 t0_hat,
590 a_hat,
591 }
592 }
593}
594
595#[repr(transparent)]
596struct Poly {
597 f: [i32; N],
598}
599
600impl Drop for Poly {
601 fn drop(&mut self) {
602 self.f.zeroize();
603 }
604}
605
606impl Poly {
607 const fn zero() -> Self {
608 Self { f: [0; N] }
609 }
610
611 const fn packed_bytesize(bitlen: usize) -> usize {
612 (N * bitlen) / 8
613 }
614
615 fn ntt_inplace(&mut self) {
617 let w = &mut self.f;
618
619 let mut m = 1;
620
621 for len in (0..8).map(|n| 128 >> n) {
622 for start in (0..256).step_by(len << 1) {
623 let zeta = ZETAS[m];
624 m += 1;
625
626 for j in start..start + len {
627 let t = reduce::mont_mul(zeta, w[j + len]);
628 w[j + len] = w[j] - t;
629 w[j] += t;
630 }
631 }
632 }
633 }
634
635 fn ntt(&mut self, f: &Self) {
637 let w_hat = &mut self.f;
638 let w = &f.f;
639
640 w_hat.copy_from_slice(w);
641
642 let mut m = 1;
643
644 for len in (0..8).map(|n| 128 >> n) {
645 for start in (0..256).step_by(len << 1) {
646 let zeta = ZETAS[m];
647 m += 1;
648
649 for j in start..start + len {
650 let t = reduce::mont_mul(zeta, w_hat[j + len]);
651 w_hat[j + len] = w_hat[j] - t;
652 w_hat[j] += t;
653 }
654 }
655 }
656 }
657
658 fn invntt(&self) -> Self {
660 let mut w_hat = self.f;
661
662 let mut m = 255;
663
664 for len in (0..8).map(|n| 1 << n) {
665 for start in (0..256).step_by(len << 1) {
666 let zeta = -ZETAS[m];
667 m -= 1;
668 for j in start..start + len {
669 let t = w_hat[j];
670 w_hat[j] = t + w_hat[j + len];
671 w_hat[j + len] = t - w_hat[j + len];
672 w_hat[j + len] = reduce::mont_mul(zeta, w_hat[j + len]);
673 }
674 }
675 }
676
677 const DIV_256: i32 = ((1 << 24) % Q as i64) as i32;
679
680 for a in w_hat.iter_mut() {
681 *a = reduce::mont_mul(*a, DIV_256);
682 }
683
684 Self { f: w_hat }
685 }
686
687 fn invntt_tomont_inplace(&mut self) {
689 let w = &mut self.f;
690
691 let mut m = 255;
692
693 for len in (0..8).map(|n| 1 << n) {
694 for start in (0..256).step_by(len << 1) {
695 let zeta = -ZETAS[m];
696 m -= 1;
697 for j in start..start + len {
698 let t = w[j];
699 w[j] = t + w[j + len];
700 w[j + len] = t - w[j + len];
701 w[j + len] = reduce::mont_mul(zeta, w[j + len]);
702 }
703 }
704 }
705
706 const DIV_256_MONT: i32 = ((1 << 56) % Q as i64) as i32;
708
709 for a in w.iter_mut() {
710 *a = reduce::mont_mul(*a, DIV_256_MONT);
711 }
712 }
713
714 fn rej_ntt(g: &mut hash::Shake128) -> Self {
716 let mut f: [MaybeUninit<i32>; N] = [MaybeUninit::uninit(); N];
717 let mut idx = 0;
718
719 while idx < N {
720 let bytes = g.squeezeblock();
721
722 for b in bytes.chunks_exact(3) {
723 if let Some(a) = coeff::from_three_bytes(b[0], b[1], b[2]) {
724 f[idx].write(a);
725 idx += 1;
726 }
727
728 if idx == N {
729 break;
730 }
731 }
732 }
733
734 Self {
735 f: unsafe { transmute::<[MaybeUninit<i32>; N], [i32; N]>(f) },
736 }
737 }
738
739 fn rej_bounded<const ETA: usize>(&mut self, h: &mut hash::Shake256) {
741 let mut idx = 0;
742
743 while idx < N {
744 let bytes = h.squeezeblock();
745
746 for z in bytes
747 .iter()
748 .flat_map(|b| {
749 let (z0, z1) = coeff::from_halfbytes::<ETA>(*b);
750 [z0, z1]
751 })
752 .flatten()
753 {
754 self.f[idx] = z;
755 idx += 1;
756
757 if idx == N {
758 break;
759 }
760 }
761 }
762 }
763
764 fn sample_in_ball(&mut self, h: &mut hash::Shake256, tau: usize) {
766 let mut block = h.squeezeblock();
767
768 let mut hash = u64::from_le_bytes(block[..8].try_into().unwrap());
769
770 let mut iter = block[8..].iter();
771
772 let mut i = N - tau;
773
774 while i < N {
775 let j = if let Some(j) = iter.by_ref().find(|b| (**b as usize) <= i) {
776 *j as usize
777 } else {
778 block = h.squeezeblock();
779 iter = block.iter();
780 continue;
781 };
782
783 self.f[i] = self.f[j];
784 self.f[j] = 1 - ((hash & 1) << 1) as i32;
785
786 hash >>= 1;
787 i += 1;
788 }
789 }
790
791 fn multiply_ntt_acc(&mut self, a: &Self, b: &Self) {
792 for i in 0..N {
793 self.f[i] += reduce::mont_mul(a.f[i], b.f[i])
794 }
795 }
796
797 fn multiply_ntt(&mut self, a: &Self, b: &Self) {
798 for i in 0..N {
799 self.f[i] = reduce::mont_mul(a.f[i], b.f[i])
800 }
801 }
802
803 fn dot_prod_ntt<const K: usize>(&mut self, u: &PolyVec<K>, v: &PolyVec<K>) {
804 self.multiply_ntt(&u.v[0], &v.v[0]);
805
806 for i in 1..K {
807 self.multiply_ntt_acc(&u.v[i], &v.v[i]);
808 }
809 }
810
811 fn reduce(&mut self) {
812 for a in self.f.iter_mut() {
813 *a = reduce::barrett_reduce(*a);
814 }
815 }
816
817 fn power2round(&self, f: &mut Self, g: &mut Self) {
818 for i in 0..N {
819 let (r1, r0) = coeff::power2round(self.f[i]);
820 f.f[i] = r1;
821 g.f[i] = r0;
822 }
823 }
824
825 fn decompose_32(&self, p0: &mut Self, p1: &mut Self) {
826 for i in 0..N {
827 let (r1, r0) = coeff::decompose_32(self.f[i]);
828 p0.f[i] = r0;
829 p1.f[i] = r1;
830 }
831 }
832
833 fn decompose_88(&self, p0: &mut Self, p1: &mut Self) {
834 for i in 0..N {
835 let (r1, r0) = coeff::decompose_88(self.f[i]);
836 p0.f[i] = r0;
837 p1.f[i] = r1;
838 }
839 }
840
841 const PACKED_10BIT: usize = (N * 10) / 8;
842
843 fn pack_simple_10bit(&self, z: &mut [u8; Self::PACKED_10BIT]) {
844 for (b, a) in z.chunks_exact_mut(5).zip(self.f.chunks_exact(4)) {
845 b[0] = a[0] as u8;
846 b[1] = (a[0] >> 8) as u8 | (a[1] << 2) as u8;
847 b[2] = (a[1] >> 6) as u8 | (a[2] << 4) as u8;
848 b[3] = (a[2] >> 4) as u8 | (a[3] << 6) as u8;
849 b[4] = (a[3] >> 2) as u8;
850 }
851 }
852
853 fn unpack_simple_10bit(&mut self, z: &[u8; Self::PACKED_10BIT]) {
854 for (a, b) in self.f.chunks_exact_mut(4).zip(z.chunks_exact(5)) {
855 let b: [i32; 5] = array::from_fn(|i| b[i] as i32);
856 a[0] = (b[0] | (b[1] << 8)) & 0x3FF;
857 a[1] = ((b[1] >> 2) | (b[2] << 6)) & 0x3FF;
858 a[2] = ((b[2] >> 4) | (b[3] << 4)) & 0x3FF;
859 a[3] = ((b[3] >> 6) | (b[4] << 2)) & 0x3FF;
860 }
861 }
862
863 fn pack_simple_4bit(&self, z: &mut [u8; Self::packed_bytesize(4)]) {
864 for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
865 *b = (a[0] | a[1] << 4) as u8;
866 }
867 }
868
869 fn pack_simple_uninit_4bit(&self, z: &mut [MaybeUninit<u8>; Self::packed_bytesize(4)]) {
870 for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
871 b.write((a[0] | a[1] << 4) as u8);
872 }
873 }
874
875 fn pack_simple_6bit(&self, z: &mut [u8; Self::packed_bytesize(6)]) {
876 for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(4)) {
877 b[0] = ((a[0] >> 0) | (a[1] << 6)) as u8;
878 b[1] = ((a[1] >> 2) | (a[2] << 4)) as u8;
879 b[2] = ((a[2] >> 4) | (a[3] << 2)) as u8;
880 }
881 }
882
883 fn pack_simple_uninit_6bit(&self, z: &mut [MaybeUninit<u8>; Self::packed_bytesize(6)]) {
884 for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(4)) {
885 b[0].write(((a[0] >> 0) | (a[1] << 6)) as u8);
886 b[1].write(((a[1] >> 2) | (a[2] << 4)) as u8);
887 b[2].write(((a[2] >> 4) | (a[3] << 2)) as u8);
888 }
889 }
890
891 const PACKED_4BIT: usize = (N * 4) / 8;
892
893 fn pack_eta4(&self, z: &mut [u8; Self::PACKED_4BIT]) {
894 for (b, a) in z.iter_mut().zip(self.f.chunks_exact(2)) {
895 let t0 = (4 - a[0]) as u8;
896 let t1 = (4 - a[1]) as u8;
897 *b = t0 | (t1 << 4);
898 }
899 }
900
901 fn unpack_eta4(&mut self, z: &[u8; Self::PACKED_4BIT]) {
902 for (a, b) in self.f.chunks_exact_mut(2).zip(z) {
903 a[0] = 4 - (b & 0xF) as i32;
904 a[1] = 4 - (b >> 4) as i32;
905 }
906 }
907
908 const PACKED_3BIT: usize = (N * 3) / 8;
909
910 fn pack_eta2(&self, z: &mut [u8; Self::PACKED_3BIT]) {
911 for (b, a) in z.chunks_exact_mut(3).zip(self.f.chunks_exact(8)) {
912 let t: [u8; 8] = array::from_fn(|i| (2 - a[i]) as u8);
913
914 b[0] = t[0] | (t[1] << 3) | (t[2] << 6);
915 b[1] = (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7);
916 b[2] = (t[5] >> 1) | (t[6] << 2) | (t[7] << 5);
917 }
918 }
919
920 fn unpack_eta2(&mut self, z: &[u8; Self::PACKED_3BIT]) {
921 for (a, b) in self.f.chunks_exact_mut(8).zip(z.chunks_exact(3)) {
922 a[0] = 2 - (b[0] & 7) as i32;
923 a[1] = 2 - ((b[0] >> 3) & 7) as i32;
924 a[2] = 2 - ((b[0] >> 6) | (b[1] << 2) & 7) as i32;
925 a[3] = 2 - ((b[1] >> 1) & 7) as i32;
926 a[4] = 2 - ((b[1] >> 4) & 7) as i32;
927 a[5] = 2 - (((b[1] >> 7) | (b[2] << 1)) & 7) as i32;
928 a[6] = 2 - ((b[2] >> 2) & 7) as i32;
929 a[7] = 2 - (b[2] >> 5) as i32
930 }
931 }
932
933 const PACKED_13BIT: usize = (N * 13) / 8;
934
935 fn pack_eta_2powdm1(&self, z: &mut [u8; Self::PACKED_13BIT]) {
936 const ETA: i32 = 1 << (D - 1);
937
938 for (b, a) in z.chunks_exact_mut(13).zip(self.f.chunks_exact(8)) {
939 let a: [u16; 8] = array::from_fn(|i| (ETA - a[i]) as u16);
940
941 b[0] = a[0] as u8;
942 b[1] = ((a[0] >> 8) | a[1] << 5) as u8;
943 b[2] = (a[1] >> 3) as u8;
944 b[3] = ((a[1] >> 11) | a[2] << 2) as u8;
945 b[4] = ((a[2] >> 6) | (a[3] << 7)) as u8;
946 b[5] = (a[3] >> 1) as u8;
947 b[6] = ((a[3] >> 9) | a[4] << 4) as u8;
948 b[7] = (a[4] >> 4) as u8;
949 b[8] = ((a[4] >> 12) | a[5] << 1) as u8;
950 b[9] = ((a[5] >> 7) | a[6] << 6) as u8;
951 b[10] = (a[6] >> 2) as u8;
952 b[11] = ((a[6] >> 10) | a[7] << 3) as u8;
953 b[12] = (a[7] >> 5) as u8;
954 }
955 }
956
957 fn unpack_eta_2powdm1(&mut self, z: &[u8; Self::PACKED_13BIT]) {
958 const ETA: i32 = 1 << (D - 1);
959
960 for (a, b) in self.f.chunks_exact_mut(8).zip(z.chunks_exact(13)) {
961 let b: [i32; 13] = array::from_fn(|i| b[i] as i32);
962
963 a[0] = ETA - ((b[0] | (b[1] << 8)) & 0x1FFF);
964 a[1] = ETA - (((b[1] >> 5) | (b[2] << 3) | (b[3] << 11)) & 0x1FFF);
965 a[2] = ETA - (((b[3] >> 2) | (b[4] << 6)) & 0x1FFF);
966 a[3] = ETA - (((b[4] >> 7) | (b[5] << 1) | (b[6] << 9)) & 0x1FFF);
967 a[4] = ETA - (((b[6] >> 4) | (b[7] << 4) | (b[8] << 12)) & 0x1FFF);
968 a[5] = ETA - (((b[8] >> 1) | (b[9] << 7)) & 0x1FFF);
969 a[6] = ETA - (((b[9] >> 6) | (b[10] << 2) | (b[11] << 10)) & 0x1FFF);
970 a[7] = ETA - (((b[11] >> 3) | (b[12] << 5)) & 0x1FFF);
971 }
972 }
973
974 fn bitpack_2pow17(&self, z: &mut [u8; Self::packed_bytesize(18)]) {
975 const B: i32 = 1 << 17;
976
977 for (b, a) in z.chunks_exact_mut(9).zip(self.f.chunks_exact(4)) {
978 let a0 = B - a[0];
979 let a1 = B - a[1];
980 let a2 = B - a[2];
981 let a3 = B - a[3];
982
983 b[0] = (a0 >> 0) as u8;
984 b[1] = (a0 >> 8) as u8;
985 b[2] = ((a0 >> 16) | (a1 << 2)) as u8;
986 b[3] = (a1 >> 6) as u8;
987 b[4] = ((a1 >> 14) | (a2 << 4)) as u8;
988 b[5] = (a2 >> 4) as u8;
989 b[6] = ((a2 >> 12) | (a3 << 6)) as u8;
990 b[7] = (a3 >> 2) as u8;
991 b[8] = (a3 >> 10) as u8;
992 }
993 }
994
995 fn bitunpack_2pow17(&mut self, z: &[u8; Self::packed_bytesize(18)]) {
996 const B: i32 = 1 << 17;
997 const BITMASK: i32 = 0x3ffff;
998
999 for (a, b) in self.f.chunks_exact_mut(4).zip(z.chunks_exact(9)) {
1000 let b: [i32; 9] = array::from_fn(|i| b[i] as i32);
1001
1002 a[0] = B - (((b[0] >> 0) | (b[1] << 8) | (b[2] << 16)) & BITMASK);
1003 a[1] = B - (((b[2] >> 2) | (b[3] << 6) | (b[4] << 14)) & BITMASK);
1004 a[2] = B - (((b[4] >> 4) | (b[5] << 4) | (b[6] << 12)) & BITMASK);
1005 a[3] = B - ((b[6] >> 6) | (b[7] << 2) | (b[8] << 10));
1006 }
1007 }
1008
1009 fn bitpack_2pow19(&self, z: &mut [u8; Self::packed_bytesize(20)]) {
1010 const B: i32 = 1 << 19;
1011
1012 for (b, a) in z.chunks_exact_mut(5).zip(self.f.chunks_exact(2)) {
1013 let a0 = B - a[0];
1014 let a1 = B - a[1];
1015
1016 b[0] = (a0 >> 0) as u8;
1017 b[1] = (a0 >> 8) as u8;
1018 b[2] = ((a0 >> 16) | (a1 << 4)) as u8;
1019 b[3] = (a1 >> 4) as u8;
1020 b[4] = (a1 >> 12) as u8;
1021 }
1022 }
1023
1024 fn bitunpack_2pow19(&mut self, z: &[u8; Self::packed_bytesize(20)]) {
1025 const B: i32 = 1 << 19;
1026 const BITMASK: i32 = 0xfffff;
1027
1028 for (a, b) in self.f.chunks_exact_mut(2).zip(z.chunks_exact(5)) {
1029 let b: [i32; 5] = array::from_fn(|i| b[i] as i32);
1030
1031 a[0] = B - (((b[0] >> 0) | (b[1] << 8) | (b[2] << 16)) & BITMASK);
1032 a[1] = B - ((b[2] >> 4) | (b[3] << 4) | (b[4] << 12));
1033 }
1034 }
1035
1036 const fn norm_in_bound(&self, bound: usize) -> bool {
1037 let mut i = 0;
1038 while i < N {
1039 if coeff::norm(self.f[i]) >= bound {
1040 return false;
1041 }
1042
1043 i += 1;
1044 }
1045
1046 true
1047 }
1048
1049 fn make_hint(&mut self, p0: &Poly, p1: &Poly, gamma2: usize) -> usize {
1050 let mut sum = 0;
1051
1052 for i in 0..N {
1053 let h = coeff::make_hint(p0.f[i], p1.f[i], gamma2 as i32);
1054
1055 self.f[i] = h as i32;
1056 sum += h;
1057 }
1058
1059 sum
1060 }
1061
1062 fn shifted_left(&self, d: usize) -> Self {
1063 let mut f = [MaybeUninit::uninit(); N];
1064
1065 for (i, a) in f.iter_mut().enumerate() {
1066 a.write(self.f[i] << d);
1067 }
1068
1069 Self {
1070 f: unsafe { transmute::<[MaybeUninit<i32>; N], [i32; N]>(f) },
1071 }
1072 }
1073}
1074
1075impl AddAssign<&Self> for Poly {
1076 fn add_assign(&mut self, rhs: &Self) {
1077 for i in 0..N {
1078 self.f[i] += rhs.f[i];
1079 }
1080 }
1081}
1082
1083impl SubAssign<&Self> for Poly {
1084 fn sub_assign(&mut self, rhs: &Self) {
1085 for i in 0..N {
1086 self.f[i] -= rhs.f[i];
1087 }
1088 }
1089}
1090
1091impl MulAssign<&Self> for Poly {
1092 fn mul_assign(&mut self, rhs: &Self) {
1093 for (i, a) in self.f.iter_mut().enumerate() {
1094 *a = reduce::mont_mul(*a, rhs.f[i]);
1095 }
1096 }
1097}
1098
1099#[repr(transparent)]
1100struct PolyMat<const K: usize, const L: usize> {
1101 m: [PolyVec<L>; K],
1102}
1103
1104impl<const K: usize, const L: usize> PolyMat<K, L> {
1105 fn expand_a(rho: &[u8; 32]) -> Self {
1107 let mut g = hash::Shake128::init();
1108 let mut m: [MaybeUninit<PolyVec<L>>; K] = [const { MaybeUninit::uninit() }; K];
1109
1110 for (r, pvec) in m.iter_mut().enumerate() {
1111 let mut v: [MaybeUninit<Poly>; L] = [const { MaybeUninit::uninit() }; L];
1112
1113 for (s, poly) in v.iter_mut().enumerate() {
1114 g.absorb_multi(&[rho, &u16::to_le_bytes(((r << 8) | s) as u16)]);
1115
1116 poly.write(Poly::rej_ntt(&mut g));
1117
1118 g.reset();
1119 }
1120
1121 pvec.write(PolyVec {
1122 v: unsafe { transmute_copy(&v) },
1123 });
1124 }
1125
1126 Self {
1127 m: unsafe { transmute_copy(&m) },
1128 }
1129 }
1130}
1131
1132#[repr(transparent)]
1133struct PolyVec<const K: usize> {
1134 v: [Poly; K],
1135}
1136
1137impl<const K: usize> PolyVec<K> {
1138 const fn zero() -> Self {
1139 Self {
1140 v: [const { Poly::zero() }; K],
1141 }
1142 }
1143
1144 fn ntt_inplace(&mut self) {
1145 for p in self.v.iter_mut() {
1146 p.ntt_inplace();
1147 }
1148 }
1149
1150 fn ntt(&mut self, v_hat: &Self) {
1151 for (p_hat, p) in self.v.iter_mut().zip(&v_hat.v) {
1152 p_hat.ntt(p);
1153 }
1154 }
1155
1156 fn invntt(&self) -> Self {
1157 let mut v = [const { MaybeUninit::uninit() }; K];
1158
1159 for (i, p) in v.iter_mut().enumerate() {
1160 p.write(self.v[i].invntt());
1161 }
1162
1163 Self {
1164 v: unsafe { transmute_copy(&v) },
1165 }
1166 }
1167
1168 fn reduce(&mut self) {
1169 for p in self.v.iter_mut() {
1170 p.reduce();
1171 }
1172 }
1173
1174 fn reduce_invntt_tomont_inplace(&mut self) {
1175 for p in self.v.iter_mut() {
1176 p.reduce();
1177 p.invntt_tomont_inplace();
1178 }
1179 }
1180
1181 fn invntt_tomont_inplace(&mut self) {
1182 for p in self.v.iter_mut() {
1183 p.invntt_tomont_inplace();
1184 }
1185 }
1186
1187 fn power2round(&self, t1: &mut PolyVec<K>, t0: &mut PolyVec<K>) {
1188 for i in 0..K {
1189 self.v[i].power2round(&mut t1.v[i], &mut t0.v[i]);
1190 }
1191 }
1192
1193 fn decompose_32(&self, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>) {
1194 for i in 0..K {
1195 self.v[i].decompose_32(&mut x0.v[i], &mut x1.v[i]);
1196 }
1197 }
1198
1199 fn decompose_88(&self, x0: &mut PolyVec<K>, x1: &mut PolyVec<K>) {
1200 for i in 0..K {
1201 self.v[i].decompose_88(&mut x0.v[i], &mut x1.v[i]);
1202 }
1203 }
1204
1205 fn pack_eta4(&self, z: &mut [u8]) {
1206 for (buf, p) in z.chunks_exact_mut(Poly::PACKED_4BIT).zip(self.v.iter()) {
1207 p.pack_eta4(buf.try_into().unwrap());
1208 }
1209 }
1210
1211 fn unpack_eta4(&mut self, z: &[u8]) {
1212 for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_4BIT)) {
1213 p.unpack_eta4(buf.try_into().unwrap());
1214 }
1215 }
1216
1217 fn pack_eta2(&self, z: &mut [u8]) {
1218 for (buf, p) in z.chunks_exact_mut(Poly::PACKED_3BIT).zip(self.v.iter()) {
1219 p.pack_eta2(buf.try_into().unwrap());
1220 }
1221 }
1222
1223 fn unpack_eta2(&mut self, z: &[u8]) {
1224 for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_3BIT)) {
1225 p.unpack_eta2(buf.try_into().unwrap());
1226 }
1227 }
1228
1229 fn pack_eta_2powdm1(&self, z: &mut [u8]) {
1230 for (buf, p) in z.chunks_exact_mut(Poly::PACKED_13BIT).zip(self.v.iter()) {
1231 p.pack_eta_2powdm1(buf.try_into().unwrap());
1232 }
1233 }
1234
1235 fn unpack_eta_2powdm1(&mut self, z: &[u8]) {
1236 for (p, buf) in self.v.iter_mut().zip(z.chunks_exact(Poly::PACKED_13BIT)) {
1237 p.unpack_eta_2powdm1(buf.try_into().unwrap());
1238 }
1239 }
1240
1241 fn pack_simple_4bit<const BZ: usize>(&self, z: &mut [u8; BZ]) {
1242 for (chunk, p) in z
1243 .chunks_exact_mut(Poly::packed_bytesize(4))
1244 .zip(self.v.iter())
1245 {
1246 p.pack_simple_4bit(chunk.try_into().unwrap());
1247 }
1248 }
1249
1250 fn pack_simple_6bit(&self, z: &mut [u8]) {
1251 for (chunk, p) in z
1252 .chunks_exact_mut(Poly::packed_bytesize(6))
1253 .zip(self.v.iter())
1254 {
1255 p.pack_simple_6bit(chunk.try_into().unwrap());
1256 }
1257 }
1258
1259 fn multiply_matvec_ntt<const L: usize>(&mut self, m: &PolyMat<K, L>, v: &PolyVec<L>) {
1260 for i in 0..K {
1261 self.v[i].dot_prod_ntt(&m.m[i], v)
1262 }
1263 }
1264
1265 fn multiply_poly_ntt(&mut self, p: &Poly, v: &PolyVec<K>) {
1266 for i in 0..K {
1267 self.v[i].multiply_ntt(p, &v.v[i]);
1268 }
1269 }
1270
1271 fn hint_bitpack<const OMEGA: usize>(&self, dst: &mut [u8]) {
1272 let mut idx = 0;
1273
1274 for i in 0..K {
1275 for j in 0..N {
1276 let h = self.v[i].f[j] as usize;
1277 dst[idx] = (j & h.wrapping_neg()) as u8;
1278 idx += h;
1279 }
1280
1281 dst[OMEGA + i] = idx as u8;
1282 }
1283 }
1284
1285 fn hint_bitunpack(y: &[u8], omega: usize) -> Result<PolyVec<K>, VerifyError> {
1286 let mut h = PolyVec::zero();
1287
1288 let mut idx = 0;
1289
1290 for i in 0..K {
1291 let num_hints = y[omega + i] as usize;
1292
1293 if num_hints < idx || num_hints > omega {
1294 return Err(VerifyError::TooManyHints);
1295 }
1296
1297 if idx >= num_hints {
1298 continue;
1299 }
1300
1301 h.v[i].f[y[idx] as usize] = 1;
1302 idx += 1;
1303
1304 for j in idx..num_hints {
1305 if y[idx - 1] >= y[j] {
1306 return Err(VerifyError::TooManyHints);
1307 }
1308
1309 h.v[i].f[y[j] as usize] = 1;
1310 }
1311 idx = num_hints;
1312 }
1313
1314 if y[idx..omega].iter().any(|x| *x != 0) {
1315 return Err(VerifyError::TooManyHints);
1316 }
1317
1318 Ok(h)
1319 }
1320
1321 fn bitpack_2pow17(&self, dst: &mut [u8]) {
1322 for (buf, p) in dst
1323 .chunks_exact_mut(Poly::packed_bytesize(18))
1324 .zip(self.v.iter())
1325 {
1326 p.bitpack_2pow17(buf.try_into().unwrap());
1327 }
1328 }
1329
1330 fn bitpack_2pow19(&self, dst: &mut [u8]) {
1331 for (buf, p) in dst
1332 .chunks_exact_mut(Poly::packed_bytesize(20))
1333 .zip(self.v.iter())
1334 {
1335 p.bitpack_2pow19(buf.try_into().unwrap());
1336 }
1337 }
1338
1339 fn expand_mask_2pow17(&mut self, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256) {
1340 let mut blocks = [0u8; 5 * hash::SHAKE_256_RATE];
1341
1342 for (r, p) in self.v.iter_mut().enumerate() {
1343 h.absorb_multi(&[rho, &u16::to_le_bytes((mu + r) as u16)]);
1344 h.squeezeblocks(&mut blocks);
1345
1346 p.bitunpack_2pow17(blocks.first_chunk_mut().unwrap());
1347 }
1348 }
1349
1350 fn expand_mask_2pow19(&mut self, rho: &[u8; 64], mu: usize, h: &mut hash::Shake256) {
1351 let mut blocks = [0u8; 5 * hash::SHAKE_256_RATE];
1352
1353 for (r, p) in self.v.iter_mut().enumerate() {
1354 h.absorb_multi(&[rho, &u16::to_le_bytes((mu + r) as u16)]);
1355 h.squeezeblocks(&mut blocks);
1356
1357 p.bitunpack_2pow19(blocks.first_chunk_mut().unwrap());
1358 }
1359 }
1360
1361 const fn norm_in_bound(&self, bound: usize) -> bool {
1362 let mut i = 0;
1363
1364 while i < K {
1365 if !self.v[i].norm_in_bound(bound) {
1366 return false;
1367 }
1368
1369 i += 1;
1370 }
1371
1372 true
1373 }
1374
1375 fn make_hint(&mut self, x0: &PolyVec<K>, x1: &PolyVec<K>, gamma2: usize) -> usize {
1376 let mut sum = 0;
1377 for i in 0..K {
1378 sum += self.v[i].make_hint(&x0.v[i], &x1.v[i], gamma2);
1379 }
1380
1381 sum
1382 }
1383
1384 fn shifted_left(&self, d: usize) -> Self {
1385 let mut v = [const { MaybeUninit::uninit() }; K];
1386
1387 for (i, poly) in v.iter_mut().enumerate() {
1388 poly.write(self.v[i].shifted_left(d));
1389 }
1390
1391 Self {
1392 v: unsafe { transmute_copy(&v) },
1393 }
1394 }
1395}
1396
1397impl<const K: usize> Mul<&Poly> for &PolyVec<K> {
1398 type Output = PolyVec<K>;
1399
1400 fn mul(self, rhs: &Poly) -> Self::Output {
1401 let mut v = PolyVec::zero();
1402
1403 for i in 0..K {
1404 v.v[i].multiply_ntt(&self.v[i], rhs);
1405 }
1406
1407 v
1408 }
1409}
1410
1411impl<const K: usize> MulAssign<&Poly> for PolyVec<K> {
1412 fn mul_assign(&mut self, rhs: &Poly) {
1413 for poly in self.v.iter_mut() {
1414 *poly *= rhs;
1415 }
1416 }
1417}
1418
1419impl<const K: usize> AddAssign<&Self> for PolyVec<K> {
1420 fn add_assign(&mut self, rhs: &Self) {
1421 for i in 0..K {
1422 self.v[i] += &rhs.v[i];
1423 }
1424 }
1425}
1426
1427impl<const K: usize> SubAssign<&Self> for PolyVec<K> {
1428 fn sub_assign(&mut self, rhs: &Self) {
1429 for i in 0..K {
1430 self.v[i] -= &rhs.v[i];
1431 }
1432 }
1433}
1434
1435fn expand_s<const K: usize, const L: usize, const ETA: usize>(
1437 s1: &mut PolyVec<L>,
1438 s2: &mut PolyVec<K>,
1439 rho: &[u8; 64],
1440) {
1441 let mut h = hash::Shake256::init();
1442
1443 for (nonce, poly) in s1.v.iter_mut().chain(s2.v.iter_mut()).enumerate() {
1444 h.absorb_multi(&[rho, &u16::to_le_bytes(nonce as u16)]);
1445 poly.rej_bounded::<ETA>(&mut h);
1446 h.reset();
1447 }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452 use rand::RngCore;
1453 use rand_core::OsRng;
1454 use serde::Deserialize;
1455 use std::{fs::read_to_string, path::PathBuf};
1456
1457 use super::*;
1458
1459 #[test]
1460 fn test_gen_sign_verify() {
1461 let mut pk = [0u8; mldsa44::PUBKEY_SIZE];
1462 let sk = mldsa44::PrivateKey::keygen(&mut pk, &mut OsRng);
1463 let vk = mldsa44::PublicKey::decode(&pk);
1464 let mut token = [0u8; 32];
1465 OsRng.fill_bytes(&mut token);
1466 let mut sig = [0u8; mldsa44::SIG_SIZE];
1467 sk.sign(&mut sig, &mut OsRng, &token);
1468 vk.verify(&token, &sig).unwrap();
1469
1470 let mut pk = [0u8; mldsa65::PUBKEY_SIZE];
1471 let sk = mldsa65::PrivateKey::keygen(&mut pk, &mut OsRng);
1472 let vk = mldsa65::PublicKey::decode(&pk);
1473 let mut token = [0u8; 32];
1474 OsRng.fill_bytes(&mut token);
1475 let mut sig = [0u8; mldsa65::SIG_SIZE];
1476 sk.sign(&mut sig, &mut OsRng, &token);
1477 vk.verify(&token, &sig).unwrap();
1478
1479 let mut pk = [0u8; mldsa87::PUBKEY_SIZE];
1480 let sk = mldsa87::PrivateKey::keygen(&mut pk, &mut OsRng);
1481 let vk = mldsa87::PublicKey::decode(&pk);
1482 let mut token = [0u8; 32];
1483 OsRng.fill_bytes(&mut token);
1484 let mut sig = [0u8; mldsa87::SIG_SIZE];
1485 sk.sign(&mut sig, &mut OsRng, &token);
1486 vk.verify(&token, &sig).unwrap();
1487 }
1488
1489 #[test]
1490 fn test_keygen() {
1491 let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1492 test_data_path.push("tests/mldsa-keygen.json");
1493
1494 let test_data = read_to_string(&test_data_path).unwrap();
1495 let test_data: Tests<KeyGenTg> = serde_json::from_str(&test_data).unwrap();
1496
1497 for tg in test_data.test_groups.iter() {
1498 match tg.parameter_set.as_str() {
1499 "ML-DSA-44" => {
1500 for test in &tg.tests {
1501 let mut vk_bytes = [0u8; mldsa44::PUBKEY_SIZE];
1502 let mut sk_bytes = [0u8; mldsa44::PRIVKEY_SIZE];
1503
1504 let sk = mldsa44::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1505 sk.encode(&mut sk_bytes);
1506
1507 assert_eq!(vk_bytes, test.pk[..]);
1508 assert_eq!(sk_bytes, test.sk[..]);
1509
1510 let sk_prime = mldsa44::PrivateKey::decode(&test.sk);
1511
1512 sk_bytes.fill(0);
1513 sk_prime.encode(&mut sk_bytes);
1514
1515 assert_eq!(sk_bytes, test.sk[..]);
1516
1517 let vk_prime = mldsa44::PublicKey::decode(&test.pk);
1518
1519 vk_bytes.fill(0);
1520 vk_prime.encode(&mut vk_bytes);
1521
1522 assert_eq!(vk_bytes, test.pk[..]);
1523 }
1524 }
1525 "ML-DSA-65" => {
1526 for test in &tg.tests {
1527 let mut vk_bytes = [0u8; mldsa65::PUBKEY_SIZE];
1528 let mut sk_bytes = [0u8; mldsa65::PRIVKEY_SIZE];
1529
1530 let sk = mldsa65::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1531 sk.encode(&mut sk_bytes);
1532
1533 assert_eq!(vk_bytes, test.pk[..]);
1534 assert_eq!(sk_bytes, test.sk[..]);
1535
1536 let sk_prime = mldsa65::PrivateKey::decode(&test.sk);
1537
1538 sk_bytes.fill(0);
1539 sk_prime.encode(&mut sk_bytes);
1540
1541 assert_eq!(sk_bytes, test.sk[..]);
1542
1543 let vk_prime = mldsa65::PublicKey::decode(&test.pk);
1544
1545 vk_bytes.fill(0);
1546 vk_prime.encode(&mut vk_bytes);
1547
1548 assert_eq!(vk_bytes, test.pk[..]);
1549 }
1550 }
1551 "ML-DSA-87" => {
1552 for test in &tg.tests {
1553 let mut vk_bytes = [0u8; mldsa87::PUBKEY_SIZE];
1554 let mut sk_bytes = [0u8; mldsa87::PRIVKEY_SIZE];
1555
1556 let sk = mldsa87::PrivateKey::keygen_internal(&mut vk_bytes, &test.seed);
1557 sk.encode(&mut sk_bytes);
1558
1559 assert_eq!(vk_bytes, test.pk[..]);
1560 assert_eq!(sk_bytes, test.sk[..]);
1561
1562 let sk_prime = mldsa87::PrivateKey::decode(&test.sk);
1563
1564 sk_bytes.fill(0);
1565 sk_prime.encode(&mut sk_bytes);
1566
1567 assert_eq!(sk_bytes, test.sk[..]);
1568
1569 let vk_prime = mldsa87::PublicKey::decode(&test.pk);
1570
1571 vk_bytes.fill(0);
1572 vk_prime.encode(&mut vk_bytes);
1573
1574 assert_eq!(vk_bytes, test.pk[..]);
1575 }
1576 }
1577 _ => panic!("invalid paramter set"),
1578 };
1579 }
1580 }
1581
1582 #[test]
1583 fn test_siggen() {
1584 let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1585 test_data_path.push("tests/mldsa-sign.json");
1586
1587 let test_data = read_to_string(&test_data_path).unwrap();
1588 let test_data: Tests<SigGenTg> = serde_json::from_str(&test_data).unwrap();
1589
1590 for tg in test_data.test_groups.iter() {
1591 match tg.parameter_set.as_str() {
1592 "ML-DSA-44" => {
1593 let mut sig = [0u8; mldsa44::SIG_SIZE];
1594
1595 for test in tg.tests.iter() {
1596 sig.fill(0);
1597 let sk = mldsa44::PrivateKey::decode(&test.sk);
1598 let rnd = match &test.rnd {
1599 Some(rnd) => rnd.rnd,
1600 None => [0; 32],
1601 };
1602 sk.sign_internal(&mut sig, &test.message, &rnd);
1603 assert_eq!(&sig, &test.signature[..]);
1604 }
1605 }
1606 "ML-DSA-65" => {
1607 let mut sig = [0u8; mldsa65::SIG_SIZE];
1608
1609 for test in tg.tests.iter() {
1610 sig.fill(0);
1611 let sk = mldsa65::PrivateKey::decode(&test.sk);
1612 let rnd = match &test.rnd {
1613 Some(rnd) => rnd.rnd,
1614 None => [0; 32],
1615 };
1616 sk.sign_internal(&mut sig, &test.message, &rnd);
1617 assert_eq!(&sig, &test.signature[..]);
1618 }
1619 }
1620 "ML-DSA-87" => {
1621 let mut sig = [0u8; mldsa87::SIG_SIZE];
1622
1623 for test in tg.tests.iter() {
1624 sig.fill(0);
1625 let sk = mldsa87::PrivateKey::decode(&test.sk);
1626 let rnd = match &test.rnd {
1627 Some(rnd) => rnd.rnd,
1628 None => [0; 32],
1629 };
1630 sk.sign_internal(&mut sig, &test.message, &rnd);
1631 assert_eq!(&sig, &test.signature[..]);
1632 }
1633 }
1634 _ => panic!("invalid paramter set"),
1635 };
1636 }
1637 }
1638
1639 #[test]
1640 fn test_sigver() {
1641 let mut test_data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1642 test_data_path.push("tests/mldsa-verify.json");
1643
1644 let test_data = read_to_string(&test_data_path).unwrap();
1645 let test_data: Tests<SigVerTg> = serde_json::from_str(&test_data).unwrap();
1646
1647 for tg in test_data.test_groups.iter() {
1648 match tg.parameter_set.as_str() {
1649 "ML-DSA-44" => {
1650 let pk = mldsa44::PublicKey::decode(&tg.pk);
1651
1652 for test in tg.tests.iter() {
1653 match pk
1654 .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1655 {
1656 Ok(_) => assert!(test.test_passed),
1657 Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1658 Err(VerifyError::Mismatch) => {
1659 assert!(!test.test_passed)
1660 }
1661 Err(VerifyError::TooManyHints) => {
1662 assert_eq!(test.reason, "too many hints")
1663 }
1664 }
1665 }
1666 }
1667 "ML-DSA-65" => {
1668 let pk = mldsa65::PublicKey::decode(&tg.pk);
1669
1670 for test in tg.tests.iter() {
1671 match pk
1672 .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1673 {
1674 Ok(_) => assert!(test.test_passed),
1675 Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1676 Err(VerifyError::Mismatch) => {
1677 assert!(!test.test_passed)
1678 }
1679 Err(VerifyError::TooManyHints) => {
1680 assert_eq!(test.reason, "too many hints")
1681 }
1682 }
1683 }
1684 }
1685 "ML-DSA-87" => {
1686 let pk = mldsa87::PublicKey::decode(&tg.pk);
1687
1688 for test in tg.tests.iter() {
1689 match pk
1690 .verify_internal(&test.message, test.signature[..].try_into().unwrap())
1691 {
1692 Ok(_) => assert!(test.test_passed),
1693 Err(VerifyError::ZoutOfBound) => assert_eq!(test.reason, "z too large"),
1694 Err(VerifyError::Mismatch) => {
1695 assert!(!test.test_passed)
1696 }
1697 Err(VerifyError::TooManyHints) => {
1698 assert_eq!(test.reason, "too many hints")
1699 }
1700 }
1701 }
1702 }
1703 _ => panic!("invalid paramter set"),
1704 };
1705 }
1706 }
1707
1708 #[derive(Deserialize)]
1709 struct KeyGenTV {
1710 #[serde(with = "hex")]
1711 pk: Vec<u8>,
1712
1713 #[serde(with = "hex")]
1714 seed: [u8; 32],
1715
1716 #[serde(with = "hex")]
1717 sk: Vec<u8>,
1718 }
1719 #[derive(Deserialize)]
1720 struct KeyGenTg {
1721 #[serde(rename = "parameterSet")]
1722 parameter_set: String,
1723
1724 tests: Vec<KeyGenTV>,
1725 }
1726 #[derive(Deserialize)]
1727 struct Tests<T> {
1728 #[serde(rename = "testGroups")]
1729 test_groups: Vec<T>,
1730 }
1731
1732 #[derive(Deserialize)]
1733 struct Rnd {
1734 #[serde(with = "hex")]
1735 rnd: [u8; 32],
1736 }
1737
1738 #[derive(Deserialize)]
1739 struct SigGenTV {
1740 #[serde(with = "hex")]
1741 message: Vec<u8>,
1742
1743 #[serde(with = "hex")]
1744 signature: Vec<u8>,
1745
1746 #[serde(with = "hex")]
1747 sk: Vec<u8>,
1748
1749 #[serde(flatten)]
1750 rnd: Option<Rnd>,
1751 }
1752 #[derive(Deserialize)]
1753 struct SigGenTg {
1754 #[serde(rename = "parameterSet")]
1755 parameter_set: String,
1756
1757 tests: Vec<SigGenTV>,
1758 }
1759
1760 #[derive(Deserialize)]
1761 struct SigVerTV {
1762 #[serde(with = "hex")]
1763 message: Vec<u8>,
1764
1765 reason: String,
1766
1767 #[serde(with = "hex")]
1768 signature: Vec<u8>,
1769
1770 #[serde(rename = "testPassed")]
1771 test_passed: bool,
1772 }
1773
1774 #[derive(Deserialize)]
1775 struct SigVerTg {
1776 #[serde(rename = "parameterSet")]
1777 parameter_set: String,
1778
1779 #[serde(with = "hex")]
1780 pk: Vec<u8>,
1781
1782 tests: Vec<SigVerTV>,
1783 }
1784}