1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_cfg))]
8#![warn(clippy::pedantic)] #![warn(clippy::integer_division_remainder_used)] #![warn(clippy::as_conversions)] #![allow(non_snake_case)] #![allow(clippy::similar_names)] #![allow(clippy::many_single_char_names)] #![allow(clippy::clone_on_copy)] #![deny(missing_docs)] #![warn(unreachable_pub)] mod algebra;
40mod crypto;
41mod encode;
42mod hint;
43mod ntt;
44mod param;
45mod pkcs8;
46mod sampling;
47
48pub use crate::param::{EncodedSignature, EncodedVerifyingKey, ExpandedSigningKey, MlDsaParams};
49pub use signature::{self, Error};
50
51use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Vector};
52use crate::crypto::H;
53use crate::hint::Hint;
54use crate::ntt::{Ntt, NttInverse};
55use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
56use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
57use core::convert::{AsRef, TryFrom, TryInto};
58use core::fmt;
59use hybrid_array::{
60 Array,
61 typenum::{
62 Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
63 U75, U80, U88, Unsigned,
64 },
65};
66use module_lattice::Truncate;
67use sha3::Shake256;
68use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
69
70#[cfg(feature = "rand_core")]
71use {
72 rand_core::{CryptoRng, TryCryptoRng},
73 signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner},
74};
75
76#[cfg(feature = "zeroize")]
77use zeroize::{Zeroize, ZeroizeOnDrop};
78
79pub type B32 = Array<u8, U32>;
81
82pub(crate) type B64 = Array<u8, U64>;
84
85pub type Seed = B32;
88
89#[derive(Clone, PartialEq, Debug)]
91pub struct Signature<P: MlDsaParams> {
92 c_tilde: Array<u8, P::Lambda>,
93 z: Vector<P::L>,
94 h: Hint<P>,
95}
96
97impl<P: MlDsaParams> Signature<P> {
98 pub fn encode(&self) -> EncodedSignature<P> {
101 let c_tilde = self.c_tilde.clone();
102 let z = P::encode_z(&self.z);
103 let h = self.h.bit_pack();
104 P::concat_sig(c_tilde, z, h)
105 }
106
107 pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
110 let (c_tilde, z, h) = P::split_sig(enc);
111
112 let c_tilde = c_tilde.clone();
113 let z = P::decode_z(z);
114 let h = Hint::bit_unpack(h)?;
115
116 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
117 return None;
118 }
119
120 Some(Self { c_tilde, z, h })
121 }
122}
123
124impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
125 type Error = Error;
126
127 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
128 let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
129 Self::decode(&enc).ok_or(Error::new())
130 }
131}
132
133impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
134 type Error = Error;
135
136 fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
137 Ok(self.encode())
138 }
139}
140
141impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
142 type Repr = EncodedSignature<P>;
143}
144
145struct MuBuilder(H);
146
147impl MuBuilder {
148 fn new(tr: &[u8], ctx: &[u8]) -> Self {
149 let mut h = H::default();
150 h = h.absorb(tr);
151 h = h.absorb(&[0]);
152 h = h.absorb(&[Truncate::truncate(ctx.len())]);
153 h = h.absorb(ctx);
154
155 Self(h)
156 }
157
158 fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
159 let mut h = H::default().absorb(tr);
160
161 for m in Mp {
162 h = h.absorb(m);
163 }
164
165 h.squeeze_new()
166 }
167
168 fn message(mut self, M: &[&[u8]]) -> B64 {
169 for m in M {
170 self.0 = self.0.absorb(m);
171 }
172
173 self.0.squeeze_new()
174 }
175
176 fn finish(mut self) -> B64 {
177 self.0.squeeze_new()
178 }
179}
180
181impl AsMut<Shake256> for MuBuilder {
182 fn as_mut(&mut self) -> &mut Shake256 {
183 self.0.updatable()
184 }
185}
186
187pub struct KeyPair<P: MlDsaParams> {
189 signing_key: SigningKey<P>,
191
192 verifying_key: VerifyingKey<P>,
194
195 seed: B32,
197}
198
199impl<P: MlDsaParams> KeyPair<P> {
200 pub fn signing_key(&self) -> &SigningKey<P> {
202 &self.signing_key
203 }
204
205 pub fn verifying_key(&self) -> &VerifyingKey<P> {
207 &self.verifying_key
208 }
209
210 #[inline]
217 pub fn to_seed(&self) -> Seed {
218 self.seed
219 }
220}
221
222impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
223 fn as_ref(&self) -> &VerifyingKey<P> {
224 &self.verifying_key
225 }
226}
227
228impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("KeyPair")
231 .field("verifying_key", &self.verifying_key)
232 .finish_non_exhaustive()
233 }
234}
235
236impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
237 type VerifyingKey = VerifyingKey<P>;
238}
239
240impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
243 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
244 self.try_multipart_sign(&[msg])
245 }
246}
247
248impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
251 fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
252 self.signing_key.raw_sign_deterministic(msg, &[])
253 }
254}
255
256impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
259 fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
260 &self,
261 f: F,
262 ) -> Result<Signature<P>, Error> {
263 self.signing_key.try_sign_digest(&f)
264 }
265}
266
267#[derive(Clone, PartialEq)]
269pub struct SigningKey<P: MlDsaParams> {
270 rho: B32,
271 K: B32,
272 tr: B64,
273 s1: Vector<P::L>,
274 s2: Vector<P::K>,
275 t0: Vector<P::K>,
276
277 s1_hat: NttVector<P::L>,
279 s2_hat: NttVector<P::K>,
280 t0_hat: NttVector<P::K>,
281 A_hat: NttMatrix<P::K, P::L>,
282}
283
284impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 f.debug_struct("SigningKey").finish_non_exhaustive()
287 }
288}
289
290#[cfg(feature = "zeroize")]
291impl<P: MlDsaParams> Drop for SigningKey<P> {
292 fn drop(&mut self) {
293 self.rho.zeroize();
294 self.K.zeroize();
295 self.tr.zeroize();
296 self.s1.zeroize();
297 self.s2.zeroize();
298 self.t0.zeroize();
299 }
300}
301
302#[cfg(feature = "zeroize")]
303impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
304
305impl<P: MlDsaParams> SigningKey<P> {
306 fn new(
307 rho: B32,
308 K: B32,
309 tr: B64,
310 s1: Vector<P::L>,
311 s2: Vector<P::K>,
312 t0: Vector<P::K>,
313 A_hat: Option<NttMatrix<P::K, P::L>>,
314 ) -> Self {
315 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
316 let s1_hat = s1.ntt();
317 let s2_hat = s2.ntt();
318 let t0_hat = t0.ntt();
319
320 Self {
321 rho,
322 K,
323 tr,
324 s1,
325 s2,
326 t0,
327
328 s1_hat,
329 s2_hat,
330 t0_hat,
331 A_hat,
332 }
333 }
334
335 #[must_use]
340 pub fn from_seed(seed: &Seed) -> Self {
341 let kp = P::from_seed(seed);
342 kp.signing_key
343 }
344
345 pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
351 where
352 P: MlDsaParams,
353 {
354 let mu = MuBuilder::internal(&self.tr, Mp);
355 self.raw_sign_mu(&mu, rnd)
356 }
357
358 fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
359 where
360 P: MlDsaParams,
361 {
362 let rhopp: B64 = H::default()
364 .absorb(&self.K)
365 .absorb(rnd)
366 .absorb(mu)
367 .squeeze_new();
368
369 for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
371 let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
372 let w = (&self.A_hat * &y.ntt()).ntt_inverse();
373 let w1 = w.high_bits::<P::TwoGamma2>();
374
375 let w1_tilde = P::encode_w1(&w1);
376 let c_tilde = H::default()
377 .absorb(mu)
378 .absorb(&w1_tilde)
379 .squeeze_new::<P::Lambda>();
380 let c = sample_in_ball(&c_tilde, P::TAU);
381 let c_hat = c.ntt();
382
383 let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
384 let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
385
386 let z = &y + &cs1;
387 let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
388
389 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
390 || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
391 {
392 continue;
393 }
394
395 let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
396 let minus_ct0 = -&ct0;
397 let w_cs2_ct0 = &(&w - &cs2) + &ct0;
398 let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
399
400 if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
401 continue;
402 }
403
404 let z = z.mod_plus_minus::<SpecQ>();
405 return Signature { c_tilde, z, h };
406 }
407
408 unreachable!("Rejection sampling failed to find a valid signature");
409 }
410
411 #[cfg(feature = "rand_core")]
419 pub fn sign_randomized<R: TryCryptoRng + ?Sized>(
420 &self,
421 M: &[u8],
422 ctx: &[u8],
423 rng: &mut R,
424 ) -> Result<Signature<P>, Error> {
425 self.raw_sign_randomized(&[M], ctx, rng)
426 }
427
428 #[cfg(feature = "rand_core")]
429 fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
430 &self,
431 Mp: &[&[u8]],
432 ctx: &[u8],
433 rng: &mut R,
434 ) -> Result<Signature<P>, Error> {
435 if ctx.len() > 255 {
436 return Err(Error::new());
437 }
438
439 let mut rnd = B32::default();
440 rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
441
442 let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
443 Ok(self.raw_sign_mu(&mu, &rnd))
444 }
445
446 #[cfg(feature = "rand_core")]
453 pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
454 &self,
455 mu: &B64,
456 rng: &mut R,
457 ) -> Result<Signature<P>, Error> {
458 let mut rnd = B32::default();
459 rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
460
461 Ok(self.raw_sign_mu(mu, &rnd))
462 }
463
464 pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
471 self.raw_sign_deterministic(&[M], ctx)
472 }
473
474 pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
478 let rnd = B32::default();
479 self.raw_sign_mu(mu, &rnd)
480 }
481
482 fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
483 if ctx.len() > 255 {
484 return Err(Error::new());
485 }
486
487 let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
488 Ok(self.sign_mu_deterministic(&mu))
489 }
490
491 pub fn verifying_key(&self) -> VerifyingKey<P> {
503 let kp: &dyn signature::Keypair<VerifyingKey = VerifyingKey<P>> = self;
504 kp.verifying_key()
505 }
506
507 #[deprecated(since = "0.1.0", note = "use `SigningKey::from_seed` instead")]
521 pub fn from_expanded(enc: &ExpandedSigningKey<P>) -> Self
522 where
523 P: MlDsaParams,
524 {
525 let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
526 Self::new(
527 rho.clone(),
528 K.clone(),
529 tr.clone(),
530 P::decode_s1(s1_enc),
531 P::decode_s2(s2_enc),
532 P::decode_t0(t0_enc),
533 None,
534 )
535 }
536
537 #[deprecated(since = "0.1.0", note = "use `KeyPair::to_seed` instead")]
542 pub fn to_expanded(&self) -> ExpandedSigningKey<P>
543 where
544 P: MlDsaParams,
545 {
546 let s1_enc = P::encode_s1(&self.s1);
547 let s2_enc = P::encode_s2(&self.s2);
548 let t0_enc = P::encode_t0(&self.t0);
549 P::concat_sk(
550 self.rho.clone(),
551 self.K.clone(),
552 self.tr.clone(),
553 s1_enc,
554 s2_enc,
555 t0_enc,
556 )
557 }
558}
559
560impl<P: MlDsaParams> Signer<Signature<P>> for SigningKey<P> {
564 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
565 self.try_multipart_sign(&[msg])
566 }
567}
568
569impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SigningKey<P> {
573 fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
574 self.raw_sign_deterministic(msg, &[])
575 }
576}
577
578impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
582 fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
583 &self,
584 f: F,
585 ) -> Result<Signature<P>, Error> {
586 let mut mu = MuBuilder::new(&self.tr, &[]);
587 f(mu.as_mut())?;
588 let mu = mu.finish();
589
590 Ok(self.sign_mu_deterministic(&mu))
591 }
592}
593
594impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
597 type VerifyingKey = VerifyingKey<P>;
598
599 fn verifying_key(&self) -> Self::VerifyingKey {
604 let As1 = &self.A_hat * &self.s1_hat;
605 let t = &As1.ntt_inverse() + &self.s2;
606
607 let (t1, _) = t.power2round();
609
610 VerifyingKey::new(self.rho.clone(), t1, Some(self.A_hat.clone()), None)
611 }
612}
613
614#[cfg(feature = "rand_core")]
618impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
619 fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
620 &self,
621 rng: &mut R,
622 msg: &[u8],
623 ) -> Result<Signature<P>, Error> {
624 self.try_multipart_sign_with_rng(rng, &[msg])
625 }
626}
627
628#[cfg(feature = "rand_core")]
632impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
633 fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
634 &self,
635 rng: &mut R,
636 msg: &[&[u8]],
637 ) -> Result<Signature<P>, Error> {
638 self.raw_sign_randomized(msg, &[], rng)
639 }
640}
641
642#[cfg(feature = "rand_core")]
646impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningKey<P> {
647 fn try_sign_digest_with_rng<
648 R: TryCryptoRng + ?Sized,
649 F: Fn(&mut Shake256) -> Result<(), Error>,
650 >(
651 &self,
652 rng: &mut R,
653 f: F,
654 ) -> Result<Signature<P>, Error> {
655 let mut mu = MuBuilder::new(&self.tr, &[]);
656 f(mu.as_mut())?;
657 let mu = mu.finish();
658
659 self.sign_mu_randomized(&mu, rng)
660 }
661}
662
663#[derive(Clone, Debug, PartialEq)]
665pub struct VerifyingKey<P: ParameterSet> {
666 rho: B32,
667 t1: Vector<P::K>,
668
669 A_hat: NttMatrix<P::K, P::L>,
671 t1_2d_hat: NttVector<P::K>,
672 tr: B64,
673}
674
675impl<P: MlDsaParams> VerifyingKey<P> {
676 fn new(
677 rho: B32,
678 t1: Vector<P::K>,
679 A_hat: Option<NttMatrix<P::K, P::L>>,
680 enc: Option<EncodedVerifyingKey<P>>,
681 ) -> Self {
682 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
683 let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
684
685 let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
686 let tr: B64 = H::default().absorb(&enc).squeeze_new();
687
688 Self {
689 rho,
690 t1,
691 A_hat,
692 t1_2d_hat,
693 tr,
694 }
695 }
696
697 pub fn compute_mu<F: FnOnce(&mut Shake256) -> Result<(), Error>>(
703 &self,
704 Mp: F,
705 ctx: &[u8],
706 ) -> Result<B64, Error> {
707 let mut mu = MuBuilder::new(&self.tr, ctx);
708 Mp(mu.as_mut())?;
709 Ok(mu.finish())
710 }
711
712 pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
717 where
718 P: MlDsaParams,
719 {
720 let mu = MuBuilder::internal(&self.tr, &[M]);
721 self.raw_verify_mu(&mu, sigma)
722 }
723
724 fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
725 where
726 P: MlDsaParams,
727 {
728 let c = sample_in_ball(&sigma.c_tilde, P::TAU);
730
731 let z_hat = sigma.z.ntt();
732 let c_hat = c.ntt();
733 let Az_hat = &self.A_hat * &z_hat;
734 let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
735
736 let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
737 let w1p = sigma.h.use_hint(&wp_approx);
738
739 let w1p_tilde = P::encode_w1(&w1p);
740 let cp_tilde = H::default()
741 .absorb(mu)
742 .absorb(&w1p_tilde)
743 .squeeze_new::<P::Lambda>();
744
745 sigma.c_tilde == cp_tilde
746 }
747
748 pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
751 self.raw_verify_with_context(&[M], ctx, sigma)
752 }
753
754 pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
757 self.raw_verify_mu(mu, sigma)
758 }
759
760 fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
761 if ctx.len() > 255 {
762 return false;
763 }
764
765 let mu = MuBuilder::new(&self.tr, ctx).message(M);
766 self.verify_mu(&mu, sigma)
767 }
768
769 fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
770 let t1_enc = P::encode_t1(t1);
771 P::concat_vk(rho.clone(), t1_enc)
772 }
773
774 pub fn encode(&self) -> EncodedVerifyingKey<P> {
777 Self::encode_internal(&self.rho, &self.t1)
778 }
779
780 pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
783 let (rho, t1_enc) = P::split_vk(enc);
784 let t1 = P::decode_t1(t1_enc);
785 Self::new(rho.clone(), t1, None, Some(enc.clone()))
786 }
787}
788
789impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
790 fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
791 self.multipart_verify(&[msg], signature)
792 }
793}
794
795impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
796 fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
797 self.raw_verify_with_context(msg, &[], signature)
798 .then_some(())
799 .ok_or(Error::new())
800 }
801}
802
803impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
804 fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
805 &self,
806 f: F,
807 signature: &Signature<P>,
808 ) -> Result<(), Error> {
809 let mut mu = MuBuilder::new(&self.tr, &[]);
810 f(mu.as_mut())?;
811 let mu = mu.finish();
812
813 self.raw_verify_mu(&mu, signature)
814 .then_some(())
815 .ok_or(Error::new())
816 }
817}
818
819#[derive(Default, Clone, Debug, PartialEq)]
821pub struct MlDsa44;
822
823impl ParameterSet for MlDsa44 {
824 type K = U4;
825 type L = U4;
826 type Eta = U2;
827 type Gamma1 = Shleft<U1, U17>;
828 type Gamma2 = Quot<QMinus1, U88>;
829 type TwoGamma2 = Prod<U2, Self::Gamma2>;
830 type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
831 type Lambda = U32;
832 type Omega = U80;
833 const TAU: usize = 39;
834}
835
836#[derive(Default, Clone, Debug, PartialEq)]
838pub struct MlDsa65;
839
840impl ParameterSet for MlDsa65 {
841 type K = U6;
842 type L = U5;
843 type Eta = U4;
844 type Gamma1 = Shleft<U1, U19>;
845 type Gamma2 = Quot<QMinus1, U32>;
846 type TwoGamma2 = Prod<U2, Self::Gamma2>;
847 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
848 type Lambda = U48;
849 type Omega = U55;
850 const TAU: usize = 49;
851}
852
853#[derive(Default, Clone, Debug, PartialEq)]
855pub struct MlDsa87;
856
857impl ParameterSet for MlDsa87 {
858 type K = U8;
859 type L = U7;
860 type Eta = U2;
861 type Gamma1 = Shleft<U1, U19>;
862 type Gamma2 = Quot<QMinus1, U32>;
863 type TwoGamma2 = Prod<U2, Self::Gamma2>;
864 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
865 type Lambda = U64;
866 type Omega = U75;
867 const TAU: usize = 60;
868}
869
870pub trait KeyGen: MlDsaParams {
872 type KeyPair: signature::Keypair;
874
875 #[cfg(feature = "rand_core")]
877 fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
878
879 fn from_seed(xi: &B32) -> Self::KeyPair;
883}
884
885impl<P> KeyGen for P
886where
887 P: MlDsaParams,
888{
889 type KeyPair = KeyPair<P>;
890
891 #[cfg(feature = "rand_core")]
894 fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
895 let mut xi = B32::default();
896 rng.fill_bytes(&mut xi);
897 Self::from_seed(&xi)
898 }
899
900 fn from_seed(xi: &Seed) -> KeyPair<P>
905 where
906 P: MlDsaParams,
907 {
908 let mut h = H::default()
910 .absorb(xi)
911 .absorb(&[P::K::U8])
912 .absorb(&[P::L::U8]);
913
914 let rho: B32 = h.squeeze_new();
915 let rhop: B64 = h.squeeze_new();
916 let K: B32 = h.squeeze_new();
917
918 let A_hat = expand_a::<P::K, P::L>(&rho);
920 let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
921 let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
922
923 let As1_hat = &A_hat * &s1.ntt();
925 let t = &As1_hat.ntt_inverse() + &s2;
926
927 let (t1, t0) = t.power2round();
929
930 let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
931 let signing_key =
932 SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
933
934 KeyPair {
935 signing_key,
936 verifying_key,
937 seed: xi.clone(),
938 }
939 }
940}
941
942#[cfg(test)]
943mod test {
944 use super::*;
945 use crate::param::*;
946
947 #[test]
948 fn output_sizes() {
949 assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
954 assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
955 assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
956
957 assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
958 assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
959 assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
960
961 assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
962 assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
963 assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
964 }
965
966 fn encode_decode_round_trip_test<P>()
967 where
968 P: MlDsaParams + PartialEq,
969 {
970 let seed = Array::default();
971 let kp = P::from_seed(&seed);
972 assert_eq!(kp.to_seed(), seed);
973
974 let sk = kp.signing_key;
975 let vk = kp.verifying_key;
976
977 let vk_bytes = vk.encode();
978 let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
979 assert!(vk == vk2);
980
981 #[allow(deprecated)]
982 {
983 let sk_bytes = sk.to_expanded();
984 let sk2 = SigningKey::<P>::from_expanded(&sk_bytes);
985 assert!(sk == sk2);
986
987 let M = b"Hello world";
988 let rnd = Array([0u8; 32]);
989 let sig = sk.sign_internal(&[M], &rnd);
990 let sig_bytes = sig.encode();
991 let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
992 assert!(sig == sig2);
993 }
994 }
995
996 #[test]
997 fn encode_decode_round_trip() {
998 encode_decode_round_trip_test::<MlDsa44>();
999 encode_decode_round_trip_test::<MlDsa65>();
1000 encode_decode_round_trip_test::<MlDsa87>();
1001 }
1002
1003 fn public_from_private_test<P>()
1004 where
1005 P: MlDsaParams + PartialEq,
1006 {
1007 let kp = P::from_seed(&Array::default());
1008 let sk = kp.signing_key;
1009 let vk = kp.verifying_key;
1010 let vk_derived = sk.verifying_key();
1011
1012 assert!(vk == vk_derived);
1013 }
1014
1015 #[test]
1016 fn public_from_private() {
1017 public_from_private_test::<MlDsa44>();
1018 public_from_private_test::<MlDsa65>();
1019 public_from_private_test::<MlDsa87>();
1020 }
1021
1022 fn sign_verify_round_trip_test<P>()
1023 where
1024 P: MlDsaParams,
1025 {
1026 let kp = P::from_seed(&Array::default());
1027 let sk = kp.signing_key;
1028 let vk = kp.verifying_key;
1029
1030 let M = b"Hello world";
1031 let rnd = Array([0u8; 32]);
1032 let sig = sk.sign_internal(&[M], &rnd);
1033
1034 assert!(vk.verify_internal(M, &sig));
1035 }
1036
1037 #[test]
1038 fn sign_verify_round_trip() {
1039 sign_verify_round_trip_test::<MlDsa44>();
1040 sign_verify_round_trip_test::<MlDsa65>();
1041 sign_verify_round_trip_test::<MlDsa87>();
1042 }
1043
1044 #[test]
1045 fn sign_mu_verify_mu_round_trip() {
1046 fn sign_mu_verify_mu<P>()
1047 where
1048 P: MlDsaParams,
1049 {
1050 let kp = P::from_seed(&Array::default());
1051 let sk = kp.signing_key;
1052 let vk = kp.verifying_key;
1053
1054 let M = b"Hello world";
1055 let rnd = Array([0u8; 32]);
1056 let mu = MuBuilder::internal(&sk.tr, &[M]);
1057 let sig = sk.raw_sign_mu(&mu, &rnd);
1058
1059 assert!(vk.raw_verify_mu(&mu, &sig));
1060 }
1061 sign_mu_verify_mu::<MlDsa44>();
1062 sign_mu_verify_mu::<MlDsa65>();
1063 sign_mu_verify_mu::<MlDsa87>();
1064 }
1065
1066 #[test]
1067 fn sign_mu_verify_internal_round_trip() {
1068 fn sign_mu_verify_internal<P>()
1069 where
1070 P: MlDsaParams,
1071 {
1072 let kp = P::from_seed(&Array::default());
1073 let sk = kp.signing_key;
1074 let vk = kp.verifying_key;
1075
1076 let M = b"Hello world";
1077 let rnd = Array([0u8; 32]);
1078 let mu = MuBuilder::internal(&sk.tr, &[M]);
1079 let sig = sk.raw_sign_mu(&mu, &rnd);
1080
1081 assert!(vk.verify_internal(M, &sig));
1082 }
1083 sign_mu_verify_internal::<MlDsa44>();
1084 sign_mu_verify_internal::<MlDsa65>();
1085 sign_mu_verify_internal::<MlDsa87>();
1086 }
1087
1088 #[test]
1089 fn sign_internal_verify_mu_round_trip() {
1090 fn sign_internal_verify_mu<P>()
1091 where
1092 P: MlDsaParams,
1093 {
1094 let kp = P::from_seed(&Array::default());
1095 let sk = kp.signing_key;
1096 let vk = kp.verifying_key;
1097
1098 let M = b"Hello world";
1099 let rnd = Array([0u8; 32]);
1100 let mu = MuBuilder::internal(&sk.tr, &[M]);
1101 let sig = sk.sign_internal(&[M], &rnd);
1102
1103 assert!(vk.raw_verify_mu(&mu, &sig));
1104 }
1105 sign_internal_verify_mu::<MlDsa44>();
1106 sign_internal_verify_mu::<MlDsa65>();
1107 sign_internal_verify_mu::<MlDsa87>();
1108 }
1109
1110 #[test]
1111 fn from_seed_implementations_match() {
1112 fn assert_from_seed_equality<P>()
1113 where
1114 P: MlDsaParams,
1115 {
1116 let seed = Seed::default();
1117 let kp1 = P::from_seed(&seed);
1118 let sk1 = SigningKey::<P>::from_seed(&seed);
1119 let vk1 = sk1.verifying_key();
1120 assert_eq!(kp1.signing_key, sk1);
1121 assert_eq!(kp1.verifying_key, vk1);
1122 }
1123 assert_from_seed_equality::<MlDsa44>();
1124 assert_from_seed_equality::<MlDsa65>();
1125 assert_from_seed_equality::<MlDsa87>();
1126 }
1127
1128 #[test]
1129 fn to_seed_returns_correct_seed() {
1130 fn test_to_seed<P: MlDsaParams>() {
1131 let seed = Array([
1132 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
1133 24, 25, 26, 27, 28, 29, 30, 31, 32,
1134 ]);
1135 let kp = P::from_seed(&seed);
1136 assert_eq!(kp.to_seed(), seed);
1137 }
1138 test_to_seed::<MlDsa44>();
1139 test_to_seed::<MlDsa65>();
1140 test_to_seed::<MlDsa87>();
1141 }
1142
1143 #[test]
1144 fn verification_rejects_invalid_signature() {
1145 fn test_invalid_sig<P: MlDsaParams>() {
1146 let kp = P::from_seed(&Array::default());
1147 let vk = kp.verifying_key();
1148
1149 let msg = b"Hello world";
1150 let rnd = Array([0u8; 32]);
1151 let mut sig = kp.signing_key().sign_internal(&[msg], &rnd);
1152 sig.c_tilde[0] ^= 0xFF;
1153
1154 assert!(!vk.verify_with_context(msg, &[], &sig));
1155 }
1156 test_invalid_sig::<MlDsa44>();
1157 test_invalid_sig::<MlDsa65>();
1158 test_invalid_sig::<MlDsa87>();
1159 }
1160
1161 #[test]
1162 fn verification_rejects_wrong_message() {
1163 fn test_wrong_msg<P: MlDsaParams>() {
1164 let kp = P::from_seed(&Array::default());
1165 let vk = kp.verifying_key();
1166
1167 let msg1 = b"Hello world";
1168 let msg2 = b"Wrong message";
1169 let rnd = Array([0u8; 32]);
1170 let sig = kp.signing_key().sign_internal(&[msg1], &rnd);
1171
1172 assert!(!vk.verify_with_context(msg2, &[], &sig));
1173 }
1174 test_wrong_msg::<MlDsa44>();
1175 test_wrong_msg::<MlDsa65>();
1176 test_wrong_msg::<MlDsa87>();
1177 }
1178
1179 #[test]
1180 fn context_length_validation() {
1181 fn test_ctx_length<P: MlDsaParams>() {
1182 let kp = P::from_seed(&Array::default());
1183 let sk = kp.signing_key();
1184 let vk = kp.verifying_key();
1185
1186 let msg = b"Hello world";
1187 let long_ctx = [0u8; 256];
1188 let short_ctx = [0u8; 255];
1189
1190 assert!(sk.sign_deterministic(msg, &long_ctx).is_err());
1191
1192 let sig = sk.sign_deterministic(msg, &short_ctx).unwrap();
1193 assert!(!vk.verify_with_context(msg, &long_ctx, &sig));
1194 assert!(vk.verify_with_context(msg, &short_ctx, &sig));
1195 }
1196 test_ctx_length::<MlDsa44>();
1197 test_ctx_length::<MlDsa65>();
1198 test_ctx_length::<MlDsa87>();
1199 }
1200
1201 #[test]
1202 fn derived_verifying_key_validates_signatures() {
1203 fn test_derived_vk<P: MlDsaParams>() {
1204 let seed = Array([42u8; 32]);
1205 let kp = P::from_seed(&seed);
1206 let sk = kp.signing_key();
1207 let derived_vk = sk.verifying_key();
1208
1209 let msg = b"Test message for derived key";
1210 let rnd = Array([0u8; 32]);
1211 let sig = sk.sign_internal(&[msg], &rnd);
1212
1213 assert!(derived_vk.verify_internal(msg, &sig));
1214 assert_eq!(derived_vk.encode(), kp.verifying_key().encode());
1215 }
1216 test_derived_vk::<MlDsa44>();
1217 test_derived_vk::<MlDsa65>();
1218 test_derived_vk::<MlDsa87>();
1219 }
1220
1221 #[test]
1222 #[cfg(feature = "alloc")]
1223 fn debug_implementations() {
1224 extern crate alloc;
1225 use core::fmt::Write;
1226
1227 fn test_debug<P: MlDsaParams>() {
1228 let kp = P::from_seed(&Array::default());
1229
1230 let mut kp_debug = alloc::string::String::new();
1231 write!(&mut kp_debug, "{:?}", kp).unwrap();
1232 assert!(kp_debug.contains("KeyPair"));
1233
1234 let mut sk_debug = alloc::string::String::new();
1235 write!(&mut sk_debug, "{:?}", kp.signing_key()).unwrap();
1236 assert!(sk_debug.contains("SigningKey"));
1237 }
1238 test_debug::<MlDsa44>();
1239 test_debug::<MlDsa65>();
1240 test_debug::<MlDsa87>();
1241 }
1242}