1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8#![warn(clippy::pedantic)] #![warn(clippy::integer_division_remainder_used)] #![warn(clippy::as_conversions)] #![allow(non_snake_case)] #![allow(clippy::similar_names)] #![allow(clippy::many_single_char_names)] #![allow(clippy::clone_on_copy)] #![deny(missing_docs)] mod algebra;
35mod crypto;
36mod encode;
37mod hint;
38mod ntt;
39mod param;
40mod sampling;
41mod util;
42
43mod module_lattice;
45
46use core::convert::{AsRef, TryFrom, TryInto};
47use hybrid_array::{
48 Array,
49 typenum::{
50 Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
51 U75, U80, U88, Unsigned,
52 },
53};
54use signature::digest::Update;
55use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
56
57#[cfg(feature = "rand_core")]
58use signature::RandomizedDigestSigner;
59
60#[cfg(feature = "rand_core")]
61use rand_core::{CryptoRng, TryCryptoRng};
62
63use sha3::Shake256;
64#[cfg(feature = "zeroize")]
65use zeroize::{Zeroize, ZeroizeOnDrop};
66
67#[cfg(feature = "pkcs8")]
68use {
69 const_oid::db::fips204,
70 pkcs8::{
71 AlgorithmIdentifierRef, PrivateKeyInfoRef,
72 der::{self, AnyRef},
73 spki::{
74 self, AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier,
75 SubjectPublicKeyInfoRef,
76 },
77 },
78};
79
80#[cfg(all(feature = "alloc", feature = "pkcs8"))]
81use pkcs8::{
82 EncodePrivateKey, EncodePublicKey,
83 der::asn1::{BitString, BitStringRef, OctetStringRef},
84 spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
85};
86
87use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
88use crate::crypto::H;
89use crate::hint::Hint;
90use crate::ntt::{Ntt, NttInverse};
91use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
92use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
93use crate::util::B64;
94use core::fmt;
95
96pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
97pub use crate::util::B32;
98pub use signature::{self, Error};
99
100#[derive(Clone, PartialEq, Debug)]
102pub struct Signature<P: MlDsaParams> {
103 c_tilde: Array<u8, P::Lambda>,
104 z: Vector<P::L>,
105 h: Hint<P>,
106}
107
108impl<P: MlDsaParams> Signature<P> {
109 pub fn encode(&self) -> EncodedSignature<P> {
112 let c_tilde = self.c_tilde.clone();
113 let z = P::encode_z(&self.z);
114 let h = self.h.bit_pack();
115 P::concat_sig(c_tilde, z, h)
116 }
117
118 pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
121 let (c_tilde, z, h) = P::split_sig(enc);
122
123 let c_tilde = c_tilde.clone();
124 let z = P::decode_z(z);
125 let h = Hint::bit_unpack(h)?;
126
127 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
128 return None;
129 }
130
131 Some(Self { c_tilde, z, h })
132 }
133}
134
135impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
136 type Error = Error;
137
138 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
139 let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
140 Self::decode(&enc).ok_or(Error::new())
141 }
142}
143
144impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
145 type Error = Error;
146
147 fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
148 Ok(self.encode())
149 }
150}
151
152impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
153 type Repr = EncodedSignature<P>;
154}
155
156#[cfg(all(feature = "alloc", feature = "pkcs8"))]
157impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
158 fn to_bitstring(&self) -> der::Result<BitString> {
159 BitString::new(0, self.encode().to_vec())
160 }
161}
162
163#[cfg(feature = "pkcs8")]
164impl<P> AssociatedAlgorithmIdentifier for Signature<P>
165where
166 P: MlDsaParams,
167 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
168{
169 type Params = AnyRef<'static>;
170
171 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
172}
173
174fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
178 let mut h = H::default().absorb(tr);
179
180 for m in Mp.iter().copied().flatten() {
181 h = h.absorb(m);
182 }
183
184 h.squeeze_new()
185}
186
187pub struct KeyPair<P: MlDsaParams> {
189 signing_key: SigningKey<P>,
191
192 verifying_key: VerifyingKey<P>,
194
195 #[cfg(feature = "pkcs8")]
197 seed: B32,
198}
199
200impl<P: MlDsaParams> KeyPair<P> {
201 pub fn signing_key(&self) -> &SigningKey<P> {
203 &self.signing_key
204 }
205
206 pub fn verifying_key(&self) -> &VerifyingKey<P> {
208 &self.verifying_key
209 }
210}
211
212impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
213 fn as_ref(&self) -> &VerifyingKey<P> {
214 &self.verifying_key
215 }
216}
217
218impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
219 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220 f.debug_struct("KeyPair")
221 .field("verifying_key", &self.verifying_key)
222 .finish_non_exhaustive()
223 }
224}
225
226impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
227 type VerifyingKey = VerifyingKey<P>;
228}
229
230#[cfg(feature = "pkcs8")]
231impl<P> TryFrom<PrivateKeyInfoRef<'_>> for KeyPair<P>
232where
233 P: MlDsaParams,
234 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
235{
236 type Error = pkcs8::Error;
237
238 fn try_from(private_key_info: pkcs8::PrivateKeyInfoRef<'_>) -> pkcs8::Result<Self> {
239 match private_key_info.algorithm {
240 alg if alg == P::ALGORITHM_IDENTIFIER => {}
241 other => return Err(spki::Error::OidUnknown { oid: other.oid }.into()),
242 }
243
244 let seed = Array::try_from(private_key_info.private_key.as_bytes())
245 .map_err(|_| pkcs8::Error::KeyMalformed)?;
246 Ok(P::from_seed(&seed))
247 }
248}
249
250impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
253 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
254 self.try_multipart_sign(&[msg])
255 }
256}
257
258impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
261 fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
262 self.signing_key.raw_sign_deterministic(msg, &[])
263 }
264}
265
266impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
269 fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
270 &self,
271 f: F,
272 ) -> Result<Signature<P>, Error> {
273 self.signing_key.try_sign_digest(&f)
274 }
275}
276
277#[cfg(feature = "pkcs8")]
278impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
279where
280 P: MlDsaParams,
281 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
282{
283 type Params = AnyRef<'static>;
284
285 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
286 Signature::<P>::ALGORITHM_IDENTIFIER;
287}
288
289#[cfg(all(feature = "alloc", feature = "pkcs8"))]
290impl<P> EncodePrivateKey for KeyPair<P>
291where
292 P: MlDsaParams,
293 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
294{
295 fn to_pkcs8_der(&self) -> pkcs8::Result<der::SecretDocument> {
296 let pkcs8_key = pkcs8::PrivateKeyInfoRef::new(
297 P::ALGORITHM_IDENTIFIER,
298 OctetStringRef::new(&self.seed)?,
299 );
300 Ok(der::SecretDocument::encode_msg(&pkcs8_key)?)
301 }
302}
303
304#[derive(Clone, PartialEq)]
306pub struct SigningKey<P: MlDsaParams> {
307 rho: B32,
308 K: B32,
309 tr: B64,
310 s1: Vector<P::L>,
311 s2: Vector<P::K>,
312 t0: Vector<P::K>,
313
314 s1_hat: NttVector<P::L>,
316 s2_hat: NttVector<P::K>,
317 t0_hat: NttVector<P::K>,
318 A_hat: NttMatrix<P::K, P::L>,
319}
320
321impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 f.debug_struct("SigningKey").finish_non_exhaustive()
324 }
325}
326
327#[cfg(feature = "zeroize")]
328impl<P: MlDsaParams> Drop for SigningKey<P> {
329 fn drop(&mut self) {
330 self.rho.zeroize();
331 self.K.zeroize();
332 self.tr.zeroize();
333 self.s1.zeroize();
334 self.s2.zeroize();
335 self.t0.zeroize();
336 }
337}
338
339#[cfg(feature = "zeroize")]
340impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
341
342impl<P: MlDsaParams> SigningKey<P> {
343 fn new(
344 rho: B32,
345 K: B32,
346 tr: B64,
347 s1: Vector<P::L>,
348 s2: Vector<P::K>,
349 t0: Vector<P::K>,
350 A_hat: Option<NttMatrix<P::K, P::L>>,
351 ) -> Self {
352 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
353 let s1_hat = s1.ntt();
354 let s2_hat = s2.ntt();
355 let t0_hat = t0.ntt();
356
357 Self {
358 rho,
359 K,
360 tr,
361 s1,
362 s2,
363 t0,
364
365 s1_hat,
366 s2_hat,
367 t0_hat,
368 A_hat,
369 }
370 }
371
372 #[must_use]
377 pub fn from_seed(seed: &B32) -> Self {
378 let kp = P::from_seed(seed);
379 kp.signing_key
380 }
381
382 pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
388 where
389 P: MlDsaParams,
390 {
391 self.raw_sign_internal(&[Mp], rnd)
392 }
393
394 fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
395 where
396 P: MlDsaParams,
397 {
398 let mu = message_representative(&self.tr, Mp);
403 self.raw_sign_mu(&mu, rnd)
404 }
405
406 fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
407 where
408 P: MlDsaParams,
409 {
410 let rhopp: B64 = H::default()
412 .absorb(&self.K)
413 .absorb(rnd)
414 .absorb(mu)
415 .squeeze_new();
416
417 for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
419 let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
420 let w = (&self.A_hat * &y.ntt()).ntt_inverse();
421 let w1 = w.high_bits::<P::TwoGamma2>();
422
423 let w1_tilde = P::encode_w1(&w1);
424 let c_tilde = H::default()
425 .absorb(mu)
426 .absorb(&w1_tilde)
427 .squeeze_new::<P::Lambda>();
428 let c = sample_in_ball(&c_tilde, P::TAU);
429 let c_hat = c.ntt();
430
431 let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
432 let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
433
434 let z = &y + &cs1;
435 let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
436
437 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
438 || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
439 {
440 continue;
441 }
442
443 let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
444 let minus_ct0 = -&ct0;
445 let w_cs2_ct0 = &(&w - &cs2) + &ct0;
446 let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
447
448 if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
449 continue;
450 }
451
452 let z = z.mod_plus_minus::<SpecQ>();
453 return Signature { c_tilde, z, h };
454 }
455
456 unreachable!("Rejection sampling failed to find a valid signature");
457 }
458
459 #[cfg(feature = "rand_core")]
467 pub fn sign_randomized<R: TryCryptoRng + ?Sized>(
468 &self,
469 M: &[u8],
470 ctx: &[u8],
471 rng: &mut R,
472 ) -> Result<Signature<P>, Error> {
473 if ctx.len() > 255 {
474 return Err(Error::new());
475 }
476
477 let mut rnd = B32::default();
478 rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
479
480 let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
481 Ok(self.sign_internal(Mp, &rnd))
482 }
483
484 #[cfg(feature = "rand_core")]
491 pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
492 &self,
493 mu: &B64,
494 rng: &mut R,
495 ) -> Result<Signature<P>, Error> {
496 let mut rnd = B32::default();
497 rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
498
499 Ok(self.raw_sign_mu(mu, &rnd))
500 }
501
502 pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
509 self.raw_sign_deterministic(&[M], ctx)
510 }
511
512 pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
516 let rnd = B32::default();
517 self.raw_sign_mu(mu, &rnd)
518 }
519
520 fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
521 if ctx.len() > 255 {
522 return Err(Error::new());
523 }
524
525 let rnd = B32::default();
526 let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
527 Ok(self.raw_sign_internal(Mp, &rnd))
528 }
529
530 pub fn encode(&self) -> EncodedSigningKey<P>
533 where
534 P: MlDsaParams,
535 {
536 let s1_enc = P::encode_s1(&self.s1);
537 let s2_enc = P::encode_s2(&self.s2);
538 let t0_enc = P::encode_t0(&self.t0);
539 P::concat_sk(
540 self.rho.clone(),
541 self.K.clone(),
542 self.tr.clone(),
543 s1_enc,
544 s2_enc,
545 t0_enc,
546 )
547 }
548
549 pub fn decode(enc: &EncodedSigningKey<P>) -> Self
552 where
553 P: MlDsaParams,
554 {
555 let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
556 Self::new(
557 rho.clone(),
558 K.clone(),
559 tr.clone(),
560 P::decode_s1(s1_enc),
561 P::decode_s2(s2_enc),
562 P::decode_t0(t0_enc),
563 None,
564 )
565 }
566
567 pub fn verifying_key(&self) -> VerifyingKey<P> {
579 let kp: &dyn signature::Keypair<VerifyingKey = VerifyingKey<P>> = self;
580
581 kp.verifying_key()
582 }
583}
584
585impl<P: MlDsaParams> Signer<Signature<P>> for SigningKey<P> {
589 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
590 self.try_multipart_sign(&[msg])
591 }
592}
593
594impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SigningKey<P> {
598 fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
599 self.raw_sign_deterministic(msg, &[])
600 }
601}
602
603impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
607 fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
608 &self,
609 f: F,
610 ) -> Result<Signature<P>, Error> {
611 let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
612 f(&mut digest)?;
613 let mu = H::pre_digest(digest).squeeze_new();
614
615 Ok(self.sign_mu_deterministic(&mu))
616 }
617}
618
619impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
622 type VerifyingKey = VerifyingKey<P>;
623
624 fn verifying_key(&self) -> Self::VerifyingKey {
629 let As1 = &self.A_hat * &self.s1_hat;
630 let t = &As1.ntt_inverse() + &self.s2;
631
632 let (t1, _) = t.power2round();
634
635 VerifyingKey::new(self.rho.clone(), t1, Some(self.A_hat.clone()), None)
636 }
637}
638
639#[cfg(feature = "rand_core")]
643impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
644 fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
645 &self,
646 rng: &mut R,
647 msg: &[u8],
648 ) -> Result<Signature<P>, Error> {
649 self.sign_randomized(msg, &[], rng)
650 }
651}
652
653#[cfg(feature = "rand_core")]
657impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningKey<P> {
658 fn try_sign_digest_with_rng<
659 R: TryCryptoRng + ?Sized,
660 F: Fn(&mut Shake256) -> Result<(), Error>,
661 >(
662 &self,
663 rng: &mut R,
664 f: F,
665 ) -> Result<Signature<P>, Error> {
666 let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
667 f(&mut digest)?;
668 let mu = H::pre_digest(digest).squeeze_new();
669
670 self.sign_mu_randomized(&mu, rng)
671 }
672}
673
674#[cfg(feature = "pkcs8")]
675impl<P> SignatureAlgorithmIdentifier for SigningKey<P>
676where
677 P: MlDsaParams,
678 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
679{
680 type Params = AnyRef<'static>;
681
682 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
683 Signature::<P>::ALGORITHM_IDENTIFIER;
684}
685
686#[cfg(feature = "pkcs8")]
687impl<P> TryFrom<PrivateKeyInfoRef<'_>> for SigningKey<P>
688where
689 P: MlDsaParams,
690 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
691{
692 type Error = pkcs8::Error;
693
694 fn try_from(private_key_info: pkcs8::PrivateKeyInfoRef<'_>) -> pkcs8::Result<Self> {
695 let keypair = KeyPair::try_from(private_key_info)?;
696
697 Ok(keypair.signing_key)
698 }
699}
700
701#[derive(Clone, Debug, PartialEq)]
703pub struct VerifyingKey<P: ParameterSet> {
704 rho: B32,
705 t1: Vector<P::K>,
706
707 A_hat: NttMatrix<P::K, P::L>,
709 t1_2d_hat: NttVector<P::K>,
710 tr: B64,
711}
712
713impl<P: MlDsaParams> VerifyingKey<P> {
714 fn new(
715 rho: B32,
716 t1: Vector<P::K>,
717 A_hat: Option<NttMatrix<P::K, P::L>>,
718 enc: Option<EncodedVerifyingKey<P>>,
719 ) -> Self {
720 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
721 let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
722
723 let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
724 let tr: B64 = H::default().absorb(&enc).squeeze_new();
725
726 Self {
727 rho,
728 t1,
729 A_hat,
730 t1_2d_hat,
731 tr,
732 }
733 }
734
735 pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
740 where
741 P: MlDsaParams,
742 {
743 self.raw_verify_internal(&[Mp], sigma)
744 }
745
746 fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
747 where
748 P: MlDsaParams,
749 {
750 let mu = message_representative(&self.tr, Mp);
752 self.raw_verify_mu(&mu, sigma)
753 }
754
755 fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
756 where
757 P: MlDsaParams,
758 {
759 let c = sample_in_ball(&sigma.c_tilde, P::TAU);
761
762 let z_hat = sigma.z.ntt();
763 let c_hat = c.ntt();
764 let Az_hat = &self.A_hat * &z_hat;
765 let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
766
767 let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
768 let w1p = sigma.h.use_hint(&wp_approx);
769
770 let w1p_tilde = P::encode_w1(&w1p);
771 let cp_tilde = H::default()
772 .absorb(mu)
773 .absorb(&w1p_tilde)
774 .squeeze_new::<P::Lambda>();
775
776 sigma.c_tilde == cp_tilde
777 }
778
779 pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
782 self.raw_verify_with_context(&[M], ctx, sigma)
783 }
784
785 pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
788 self.raw_verify_mu(mu, sigma)
789 }
790
791 fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
792 if ctx.len() > 255 {
793 return false;
794 }
795
796 let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
797 self.raw_verify_internal(Mp, sigma)
798 }
799
800 fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
801 let t1_enc = P::encode_t1(t1);
802 P::concat_vk(rho.clone(), t1_enc)
803 }
804
805 pub fn encode(&self) -> EncodedVerifyingKey<P> {
808 Self::encode_internal(&self.rho, &self.t1)
809 }
810
811 pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
814 let (rho, t1_enc) = P::split_vk(enc);
815 let t1 = P::decode_t1(t1_enc);
816 Self::new(rho.clone(), t1, None, Some(enc.clone()))
817 }
818}
819
820impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
821 fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
822 self.multipart_verify(&[msg], signature)
823 }
824}
825
826impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
827 fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
828 self.raw_verify_with_context(msg, &[], signature)
829 .then_some(())
830 .ok_or(Error::new())
831 }
832}
833
834impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
835 fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
836 &self,
837 f: F,
838 signature: &Signature<P>,
839 ) -> Result<(), Error> {
840 let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
841 f(&mut digest)?;
842 let mu = H::pre_digest(digest).squeeze_new();
843
844 self.raw_verify_mu(&mu, signature)
845 .then_some(())
846 .ok_or(Error::new())
847 }
848}
849
850#[cfg(feature = "pkcs8")]
851impl<P> SignatureAlgorithmIdentifier for VerifyingKey<P>
852where
853 P: MlDsaParams,
854 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
855{
856 type Params = AnyRef<'static>;
857
858 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
859 Signature::<P>::ALGORITHM_IDENTIFIER;
860}
861
862#[cfg(all(feature = "alloc", feature = "pkcs8"))]
863impl<P> EncodePublicKey for VerifyingKey<P>
864where
865 P: MlDsaParams,
866 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
867{
868 fn to_public_key_der(&self) -> spki::Result<der::Document> {
869 let public_key = self.encode();
870 let subject_public_key = BitStringRef::new(0, &public_key)?;
871
872 SubjectPublicKeyInfo {
873 algorithm: P::ALGORITHM_IDENTIFIER,
874 subject_public_key,
875 }
876 .try_into()
877 }
878}
879
880#[cfg(feature = "pkcs8")]
881impl<P> TryFrom<SubjectPublicKeyInfoRef<'_>> for VerifyingKey<P>
882where
883 P: MlDsaParams,
884 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
885{
886 type Error = spki::Error;
887
888 fn try_from(spki: SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
889 match spki.algorithm {
890 alg if alg == P::ALGORITHM_IDENTIFIER => {}
891 other => return Err(spki::Error::OidUnknown { oid: other.oid }),
892 }
893
894 Ok(Self::decode(
895 &EncodedVerifyingKey::<P>::try_from(
896 spki.subject_public_key
897 .as_bytes()
898 .ok_or_else(|| der::Tag::BitString.value_error().to_error())?,
899 )
900 .map_err(|_| pkcs8::Error::KeyMalformed)?,
901 ))
902 }
903}
904
905#[derive(Default, Clone, Debug, PartialEq)]
907pub struct MlDsa44;
908
909impl ParameterSet for MlDsa44 {
910 type K = U4;
911 type L = U4;
912 type Eta = U2;
913 type Gamma1 = Shleft<U1, U17>;
914 type Gamma2 = Quot<QMinus1, U88>;
915 type TwoGamma2 = Prod<U2, Self::Gamma2>;
916 type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
917 type Lambda = U32;
918 type Omega = U80;
919 const TAU: usize = 39;
920}
921
922#[cfg(feature = "pkcs8")]
923impl AssociatedAlgorithmIdentifier for MlDsa44 {
924 type Params = AnyRef<'static>;
925
926 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
927 oid: fips204::ID_ML_DSA_44,
928 parameters: None,
929 };
930}
931
932#[derive(Default, Clone, Debug, PartialEq)]
934pub struct MlDsa65;
935
936impl ParameterSet for MlDsa65 {
937 type K = U6;
938 type L = U5;
939 type Eta = U4;
940 type Gamma1 = Shleft<U1, U19>;
941 type Gamma2 = Quot<QMinus1, U32>;
942 type TwoGamma2 = Prod<U2, Self::Gamma2>;
943 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
944 type Lambda = U48;
945 type Omega = U55;
946 const TAU: usize = 49;
947}
948
949#[cfg(feature = "pkcs8")]
950impl AssociatedAlgorithmIdentifier for MlDsa65 {
951 type Params = AnyRef<'static>;
952
953 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
954 oid: fips204::ID_ML_DSA_65,
955 parameters: None,
956 };
957}
958
959#[derive(Default, Clone, Debug, PartialEq)]
961pub struct MlDsa87;
962
963impl ParameterSet for MlDsa87 {
964 type K = U8;
965 type L = U7;
966 type Eta = U2;
967 type Gamma1 = Shleft<U1, U19>;
968 type Gamma2 = Quot<QMinus1, U32>;
969 type TwoGamma2 = Prod<U2, Self::Gamma2>;
970 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
971 type Lambda = U64;
972 type Omega = U75;
973 const TAU: usize = 60;
974}
975
976#[cfg(feature = "pkcs8")]
977impl AssociatedAlgorithmIdentifier for MlDsa87 {
978 type Params = AnyRef<'static>;
979
980 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
981 oid: fips204::ID_ML_DSA_87,
982 parameters: None,
983 };
984}
985
986pub trait KeyGen: MlDsaParams {
988 type KeyPair: signature::Keypair;
990
991 #[cfg(feature = "rand_core")]
993 fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
994
995 fn from_seed(xi: &B32) -> Self::KeyPair;
999}
1000
1001impl<P> KeyGen for P
1002where
1003 P: MlDsaParams,
1004{
1005 type KeyPair = KeyPair<P>;
1006
1007 #[cfg(feature = "rand_core")]
1010 fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
1011 let mut xi = B32::default();
1012 rng.fill_bytes(&mut xi);
1013 Self::from_seed(&xi)
1014 }
1015
1016 fn from_seed(xi: &B32) -> KeyPair<P>
1021 where
1022 P: MlDsaParams,
1023 {
1024 let mut h = H::default()
1026 .absorb(xi)
1027 .absorb(&[P::K::U8])
1028 .absorb(&[P::L::U8]);
1029
1030 let rho: B32 = h.squeeze_new();
1031 let rhop: B64 = h.squeeze_new();
1032 let K: B32 = h.squeeze_new();
1033
1034 let A_hat = expand_a::<P::K, P::L>(&rho);
1036 let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
1037 let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
1038
1039 let As1_hat = &A_hat * &s1.ntt();
1041 let t = &As1_hat.ntt_inverse() + &s2;
1042
1043 let (t1, t0) = t.power2round();
1045
1046 let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
1047 let signing_key =
1048 SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
1049
1050 KeyPair {
1051 signing_key,
1052 verifying_key,
1053 #[cfg(feature = "pkcs8")]
1054 seed: xi.clone(),
1055 }
1056 }
1057}
1058
1059#[cfg(test)]
1060mod test {
1061 use super::*;
1062 use crate::param::*;
1063
1064 #[test]
1065 fn output_sizes() {
1066 assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
1071 assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
1072 assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
1073
1074 assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
1075 assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
1076 assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
1077
1078 assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
1079 assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
1080 assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
1081 }
1082
1083 fn encode_decode_round_trip_test<P>()
1084 where
1085 P: MlDsaParams + PartialEq,
1086 {
1087 let kp = P::from_seed(&Array::default());
1088 let sk = kp.signing_key;
1089 let vk = kp.verifying_key;
1090
1091 let vk_bytes = vk.encode();
1092 let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
1093 assert!(vk == vk2);
1094
1095 let sk_bytes = sk.encode();
1096 let sk2 = SigningKey::<P>::decode(&sk_bytes);
1097 assert!(sk == sk2);
1098
1099 let M = b"Hello world";
1100 let rnd = Array([0u8; 32]);
1101 let sig = sk.sign_internal(&[M], &rnd);
1102 let sig_bytes = sig.encode();
1103 let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
1104 assert!(sig == sig2);
1105 }
1106
1107 #[test]
1108 fn encode_decode_round_trip() {
1109 encode_decode_round_trip_test::<MlDsa44>();
1110 encode_decode_round_trip_test::<MlDsa65>();
1111 encode_decode_round_trip_test::<MlDsa87>();
1112 }
1113
1114 fn public_from_private_test<P>()
1115 where
1116 P: MlDsaParams + PartialEq,
1117 {
1118 let kp = P::from_seed(&Array::default());
1119 let sk = kp.signing_key;
1120 let vk = kp.verifying_key;
1121 let vk_derived = sk.verifying_key();
1122
1123 assert!(vk == vk_derived);
1124 }
1125
1126 #[test]
1127 fn public_from_private() {
1128 public_from_private_test::<MlDsa44>();
1129 public_from_private_test::<MlDsa65>();
1130 public_from_private_test::<MlDsa87>();
1131 }
1132
1133 fn sign_verify_round_trip_test<P>()
1134 where
1135 P: MlDsaParams,
1136 {
1137 let kp = P::from_seed(&Array::default());
1138 let sk = kp.signing_key;
1139 let vk = kp.verifying_key;
1140
1141 let M = b"Hello world";
1142 let rnd = Array([0u8; 32]);
1143 let sig = sk.sign_internal(&[M], &rnd);
1144
1145 assert!(vk.verify_internal(&[M], &sig));
1146 }
1147
1148 #[test]
1149 fn sign_verify_round_trip() {
1150 sign_verify_round_trip_test::<MlDsa44>();
1151 sign_verify_round_trip_test::<MlDsa65>();
1152 sign_verify_round_trip_test::<MlDsa87>();
1153 }
1154
1155 fn many_round_trip_test<P>()
1156 where
1157 P: MlDsaParams,
1158 {
1159 use rand::Rng;
1160
1161 const ITERATIONS: usize = 1000;
1162
1163 let mut rng = rand::rng();
1164 let mut seed = B32::default();
1165
1166 for _i in 0..ITERATIONS {
1167 let seed_data: &mut [u8] = seed.as_mut();
1168 rng.fill(seed_data);
1169
1170 let kp = P::from_seed(&seed);
1171 let sk = kp.signing_key;
1172 let vk = kp.verifying_key;
1173
1174 let M = b"Hello world";
1175 let rnd = Array([0u8; 32]);
1176 let sig = sk.sign_internal(&[M], &rnd);
1177
1178 let sig_enc = sig.encode();
1179 let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
1180
1181 assert_eq!(sig_dec, sig);
1182 assert!(vk.verify_internal(&[M], &sig_dec));
1183 }
1184 }
1185
1186 #[test]
1187 fn many_round_trip() {
1188 many_round_trip_test::<MlDsa44>();
1189 many_round_trip_test::<MlDsa65>();
1190 many_round_trip_test::<MlDsa87>();
1191 }
1192
1193 #[test]
1194 fn sign_mu_verify_mu_round_trip() {
1195 fn sign_mu_verify_mu<P>()
1196 where
1197 P: MlDsaParams,
1198 {
1199 let kp = P::from_seed(&Array::default());
1200 let sk = kp.signing_key;
1201 let vk = kp.verifying_key;
1202
1203 let M = b"Hello world";
1204 let rnd = Array([0u8; 32]);
1205 let mu = message_representative(&sk.tr, &[&[M]]);
1206 let sig = sk.raw_sign_mu(&mu, &rnd);
1207
1208 assert!(vk.raw_verify_mu(&mu, &sig));
1209 }
1210 sign_mu_verify_mu::<MlDsa44>();
1211 sign_mu_verify_mu::<MlDsa65>();
1212 sign_mu_verify_mu::<MlDsa87>();
1213 }
1214
1215 #[test]
1216 fn sign_mu_verify_internal_round_trip() {
1217 fn sign_mu_verify_internal<P>()
1218 where
1219 P: MlDsaParams,
1220 {
1221 let kp = P::from_seed(&Array::default());
1222 let sk = kp.signing_key;
1223 let vk = kp.verifying_key;
1224
1225 let M = b"Hello world";
1226 let rnd = Array([0u8; 32]);
1227 let mu = message_representative(&sk.tr, &[&[M]]);
1228 let sig = sk.raw_sign_mu(&mu, &rnd);
1229
1230 assert!(vk.verify_internal(&[M], &sig));
1231 }
1232 sign_mu_verify_internal::<MlDsa44>();
1233 sign_mu_verify_internal::<MlDsa65>();
1234 sign_mu_verify_internal::<MlDsa87>();
1235 }
1236
1237 #[test]
1238 fn sign_internal_verify_mu_round_trip() {
1239 fn sign_internal_verify_mu<P>()
1240 where
1241 P: MlDsaParams,
1242 {
1243 let kp = P::from_seed(&Array::default());
1244 let sk = kp.signing_key;
1245 let vk = kp.verifying_key;
1246
1247 let M = b"Hello world";
1248 let rnd = Array([0u8; 32]);
1249 let mu = message_representative(&sk.tr, &[&[M]]);
1250 let sig = sk.sign_internal(&[M], &rnd);
1251
1252 assert!(vk.raw_verify_mu(&mu, &sig));
1253 }
1254 sign_internal_verify_mu::<MlDsa44>();
1255 sign_internal_verify_mu::<MlDsa65>();
1256 sign_internal_verify_mu::<MlDsa87>();
1257 }
1258
1259 #[test]
1260 fn sign_digest_round_trip() {
1261 fn sign_digest<P>()
1262 where
1263 P: MlDsaParams,
1264 {
1265 let kp = P::from_seed(&Array::default());
1266 let sk = kp.signing_key;
1267 let vk = kp.verifying_key;
1268
1269 let M = b"Hello world";
1270 let sig = sk.sign_digest(|digest| digest.update(M));
1271 assert_eq!(sig, sk.sign(M));
1272
1273 vk.verify_digest(
1274 |digest| {
1275 digest.update(M);
1276 Ok(())
1277 },
1278 &sig,
1279 )
1280 .unwrap();
1281 }
1282 sign_digest::<MlDsa44>();
1283 sign_digest::<MlDsa65>();
1284 sign_digest::<MlDsa87>();
1285 }
1286
1287 #[test]
1288 #[cfg(feature = "rand_core")]
1289 fn sign_randomized_digest_round_trip() {
1290 fn sign_digest<P>()
1291 where
1292 P: MlDsaParams,
1293 {
1294 let kp = P::from_seed(&Array::default());
1295 let sk = kp.signing_key;
1296 let vk = kp.verifying_key;
1297
1298 let M = b"Hello world";
1299 let sig = sk.sign_digest_with_rng(&mut rand::rng(), |digest| digest.update(M));
1300
1301 vk.verify_digest(
1302 |digest| {
1303 digest.update(M);
1304 Ok(())
1305 },
1306 &sig,
1307 )
1308 .unwrap();
1309 }
1310 sign_digest::<MlDsa44>();
1311 sign_digest::<MlDsa65>();
1312 sign_digest::<MlDsa87>();
1313 }
1314
1315 #[test]
1316 fn from_seed_implementations_match() {
1317 fn assert_from_seed_equality<P>()
1318 where
1319 P: MlDsaParams,
1320 {
1321 let seed = Array([0u8; 32]);
1322 let kp1 = P::from_seed(&seed);
1323 let sk1 = SigningKey::<P>::from_seed(&seed);
1324 let vk1 = sk1.verifying_key();
1325 assert_eq!(kp1.signing_key, sk1);
1326 assert_eq!(kp1.verifying_key, vk1);
1327 }
1328 assert_from_seed_equality::<MlDsa44>();
1329 assert_from_seed_equality::<MlDsa65>();
1330 assert_from_seed_equality::<MlDsa87>();
1331 }
1332}