1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8#![warn(clippy::pedantic)] #![warn(clippy::integer_division_remainder_used)] #![warn(clippy::as_conversions)] #![allow(non_snake_case)] #![allow(clippy::similar_names)] #![allow(clippy::many_single_char_names)] #![allow(clippy::clone_on_copy)] #![deny(missing_docs)] mod algebra;
32mod crypto;
33mod encode;
34mod hint;
35mod ntt;
36mod param;
37mod sampling;
38mod util;
39
40mod module_lattice;
42
43use core::convert::{AsRef, TryFrom, TryInto};
44use hybrid_array::{
45 Array,
46 typenum::{
47 Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
48 U75, U80, U88, Unsigned,
49 },
50};
51
52#[cfg(feature = "rand_core")]
53use rand_core::{CryptoRng, CryptoRngCore, RngCore};
54
55#[cfg(feature = "zeroize")]
56use zeroize::{Zeroize, ZeroizeOnDrop};
57
58#[cfg(feature = "pkcs8")]
59use pkcs8::{
60 AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfo,
61 der::{self, AnyRef},
62 spki::{
63 self, AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier,
64 SubjectPublicKeyInfoRef,
65 },
66};
67
68#[cfg(all(feature = "alloc", feature = "pkcs8"))]
69use pkcs8::{
70 EncodePrivateKey, EncodePublicKey,
71 der::asn1::{BitString, BitStringRef},
72 spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
73};
74
75use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
76use crate::crypto::H;
77use crate::hint::Hint;
78use crate::ntt::{Ntt, NttInverse};
79use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
80use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
81use crate::util::B64;
82use core::fmt;
83
84pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
85pub use crate::util::B32;
86pub use signature::{self, Error};
87
88#[derive(Clone, PartialEq, Debug)]
90pub struct Signature<P: MlDsaParams> {
91 c_tilde: Array<u8, P::Lambda>,
92 z: Vector<P::L>,
93 h: Hint<P>,
94}
95
96impl<P: MlDsaParams> Signature<P> {
97 pub fn encode(&self) -> EncodedSignature<P> {
100 let c_tilde = self.c_tilde.clone();
101 let z = P::encode_z(&self.z);
102 let h = self.h.bit_pack();
103 P::concat_sig(c_tilde, z, h)
104 }
105
106 pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
109 let (c_tilde, z, h) = P::split_sig(enc);
110
111 let c_tilde = c_tilde.clone();
112 let z = P::decode_z(z);
113 let h = Hint::bit_unpack(h)?;
114
115 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
116 return None;
117 }
118
119 Some(Self { c_tilde, z, h })
120 }
121}
122
123impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
124 type Error = Error;
125
126 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
127 let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
128 Self::decode(&enc).ok_or(Error::new())
129 }
130}
131
132impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
133 type Error = Error;
134
135 fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
136 Ok(self.encode())
137 }
138}
139
140impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
141 type Repr = EncodedSignature<P>;
142}
143
144#[cfg(feature = "alloc")]
145impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
146 fn to_bitstring(&self) -> der::Result<BitString> {
147 BitString::new(0, self.encode().to_vec())
148 }
149}
150
151#[cfg(feature = "pkcs8")]
152impl<P> AssociatedAlgorithmIdentifier for Signature<P>
153where
154 P: MlDsaParams,
155 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
156{
157 type Params = AnyRef<'static>;
158
159 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
160}
161
162fn message_representative(tr: &[u8], Mp: &[&[u8]]) -> B64 {
166 let mut h = H::default().absorb(tr);
167
168 for m in Mp {
169 h = h.absorb(m);
170 }
171
172 h.squeeze_new()
173}
174
175pub struct KeyPair<P: MlDsaParams> {
177 signing_key: SigningKey<P>,
179
180 verifying_key: VerifyingKey<P>,
182
183 seed: B32,
185}
186
187impl<P: MlDsaParams> KeyPair<P> {
188 pub fn signing_key(&self) -> &SigningKey<P> {
190 &self.signing_key
191 }
192
193 pub fn verifying_key(&self) -> &VerifyingKey<P> {
195 &self.verifying_key
196 }
197}
198
199impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
200 fn as_ref(&self) -> &VerifyingKey<P> {
201 &self.verifying_key
202 }
203}
204
205impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 f.debug_struct("KeyPair")
208 .field("verifying_key", &self.verifying_key)
209 .finish_non_exhaustive()
210 }
211}
212
213impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
214 type VerifyingKey = VerifyingKey<P>;
215}
216
217#[cfg(feature = "pkcs8")]
218impl<P> TryFrom<PrivateKeyInfo<'_>> for KeyPair<P>
219where
220 P: MlDsaParams,
221 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
222{
223 type Error = pkcs8::Error;
224
225 fn try_from(private_key_info: pkcs8::PrivateKeyInfo<'_>) -> pkcs8::Result<Self> {
226 match private_key_info.algorithm {
227 alg if alg == P::ALGORITHM_IDENTIFIER => {}
228 other => return Err(spki::Error::OidUnknown { oid: other.oid }.into()),
229 }
230
231 let seed = Array::try_from(private_key_info.private_key)
232 .map_err(|_| pkcs8::Error::KeyMalformed)?;
233 Ok(P::key_gen_internal(&seed))
234 }
235}
236
237impl<P: MlDsaParams> signature::Signer<Signature<P>> for KeyPair<P> {
240 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
241 self.signing_key.sign_deterministic(msg, &[])
242 }
243}
244
245#[cfg(feature = "pkcs8")]
246impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
247where
248 P: MlDsaParams,
249 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
250{
251 type Params = AnyRef<'static>;
252
253 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
254 Signature::<P>::ALGORITHM_IDENTIFIER;
255}
256
257#[cfg(all(feature = "alloc", feature = "pkcs8"))]
258impl<P> EncodePrivateKey for KeyPair<P>
259where
260 P: MlDsaParams,
261 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
262{
263 fn to_pkcs8_der(&self) -> pkcs8::Result<der::SecretDocument> {
264 let pkcs8_key = pkcs8::PrivateKeyInfo::new(P::ALGORITHM_IDENTIFIER, &self.seed);
265 Ok(der::SecretDocument::encode_msg(&pkcs8_key)?)
266 }
267}
268
269#[derive(Clone, PartialEq)]
271pub struct SigningKey<P: MlDsaParams> {
272 rho: B32,
273 K: B32,
274 tr: B64,
275 s1: Vector<P::L>,
276 s2: Vector<P::K>,
277 t0: Vector<P::K>,
278
279 s1_hat: NttVector<P::L>,
281 s2_hat: NttVector<P::K>,
282 t0_hat: NttVector<P::K>,
283 A_hat: NttMatrix<P::K, P::L>,
284}
285
286impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_struct("SigningKey").finish_non_exhaustive()
289 }
290}
291
292#[cfg(feature = "zeroize")]
293impl<P: MlDsaParams> Drop for SigningKey<P> {
294 fn drop(&mut self) {
295 self.rho.zeroize();
296 self.K.zeroize();
297 self.tr.zeroize();
298 self.s1.zeroize();
299 self.s2.zeroize();
300 self.t0.zeroize();
301 }
302}
303
304#[cfg(feature = "zeroize")]
305impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
306
307impl<P: MlDsaParams> SigningKey<P> {
308 fn new(
309 rho: B32,
310 K: B32,
311 tr: B64,
312 s1: Vector<P::L>,
313 s2: Vector<P::K>,
314 t0: Vector<P::K>,
315 A_hat: Option<NttMatrix<P::K, P::L>>,
316 ) -> Self {
317 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
318 let s1_hat = s1.ntt();
319 let s2_hat = s2.ntt();
320 let t0_hat = t0.ntt();
321
322 Self {
323 rho,
324 K,
325 tr,
326 s1,
327 s2,
328 t0,
329
330 s1_hat,
331 s2_hat,
332 t0_hat,
333 A_hat,
334 }
335 }
336
337 pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
343 where
344 P: MlDsaParams,
345 {
346 let mu = message_representative(&self.tr, Mp);
351
352 let rhopp: B64 = H::default()
354 .absorb(&self.K)
355 .absorb(rnd)
356 .absorb(&mu)
357 .squeeze_new();
358
359 for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
361 let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
362 let w = (&self.A_hat * &y.ntt()).ntt_inverse();
363 let w1 = w.high_bits::<P::TwoGamma2>();
364
365 let w1_tilde = P::encode_w1(&w1);
366 let c_tilde = H::default()
367 .absorb(&mu)
368 .absorb(&w1_tilde)
369 .squeeze_new::<P::Lambda>();
370 let c = sample_in_ball(&c_tilde, P::TAU);
371 let c_hat = c.ntt();
372
373 let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
374 let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
375
376 let z = &y + &cs1;
377 let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
378
379 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
380 || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
381 {
382 continue;
383 }
384
385 let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
386 let minus_ct0 = -&ct0;
387 let w_cs2_ct0 = &(&w - &cs2) + &ct0;
388 let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
389
390 if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
391 continue;
392 }
393
394 let z = z.mod_plus_minus::<SpecQ>();
395 return Signature { c_tilde, z, h };
396 }
397
398 unreachable!("Rejection sampling failed to find a valid signature");
399 }
400
401 #[cfg(feature = "rand_core")]
409 pub fn sign_randomized<R: RngCore + CryptoRng + ?Sized>(
410 &self,
411 M: &[u8],
412 ctx: &[u8],
413 rng: &mut R,
414 ) -> Result<Signature<P>, Error> {
415 if ctx.len() > 255 {
416 return Err(Error::new());
417 }
418
419 let mut rnd = B32::default();
420 rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
421
422 let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
423 Ok(self.sign_internal(Mp, &rnd))
424 }
425
426 pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
433 if ctx.len() > 255 {
434 return Err(Error::new());
435 }
436
437 let rnd = B32::default();
438 let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
439 Ok(self.sign_internal(Mp, &rnd))
440 }
441
442 pub fn encode(&self) -> EncodedSigningKey<P>
445 where
446 P: MlDsaParams,
447 {
448 let s1_enc = P::encode_s1(&self.s1);
449 let s2_enc = P::encode_s2(&self.s2);
450 let t0_enc = P::encode_t0(&self.t0);
451 P::concat_sk(
452 self.rho.clone(),
453 self.K.clone(),
454 self.tr.clone(),
455 s1_enc,
456 s2_enc,
457 t0_enc,
458 )
459 }
460
461 pub fn decode(enc: &EncodedSigningKey<P>) -> Self
464 where
465 P: MlDsaParams,
466 {
467 let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
468 Self::new(
469 rho.clone(),
470 K.clone(),
471 tr.clone(),
472 P::decode_s1(s1_enc),
473 P::decode_s2(s2_enc),
474 P::decode_t0(t0_enc),
475 None,
476 )
477 }
478}
479
480impl<P: MlDsaParams> signature::Signer<Signature<P>> for SigningKey<P> {
484 fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
485 self.sign_deterministic(msg, &[])
486 }
487}
488
489#[cfg(feature = "rand_core")]
493impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
494 fn try_sign_with_rng(
495 &self,
496 rng: &mut impl CryptoRngCore,
497 msg: &[u8],
498 ) -> Result<Signature<P>, Error> {
499 self.sign_randomized(msg, &[], rng)
500 }
501}
502
503#[cfg(feature = "pkcs8")]
504impl<P> SignatureAlgorithmIdentifier for SigningKey<P>
505where
506 P: MlDsaParams,
507 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
508{
509 type Params = AnyRef<'static>;
510
511 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
512 Signature::<P>::ALGORITHM_IDENTIFIER;
513}
514
515#[cfg(feature = "pkcs8")]
516impl<P> TryFrom<PrivateKeyInfo<'_>> for SigningKey<P>
517where
518 P: MlDsaParams,
519 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
520{
521 type Error = pkcs8::Error;
522
523 fn try_from(private_key_info: pkcs8::PrivateKeyInfo<'_>) -> pkcs8::Result<Self> {
524 let keypair = KeyPair::try_from(private_key_info)?;
525
526 Ok(keypair.signing_key)
527 }
528}
529
530#[derive(Clone, Debug, PartialEq)]
532pub struct VerifyingKey<P: ParameterSet> {
533 rho: B32,
534 t1: Vector<P::K>,
535
536 A_hat: NttMatrix<P::K, P::L>,
538 t1_2d_hat: NttVector<P::K>,
539 tr: B64,
540}
541
542impl<P: MlDsaParams> VerifyingKey<P> {
543 fn new(
544 rho: B32,
545 t1: Vector<P::K>,
546 A_hat: Option<NttMatrix<P::K, P::L>>,
547 enc: Option<EncodedVerifyingKey<P>>,
548 ) -> Self {
549 let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
550 let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
551
552 let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
553 let tr: B64 = H::default().absorb(&enc).squeeze_new();
554
555 Self {
556 rho,
557 t1,
558 A_hat,
559 t1_2d_hat,
560 tr,
561 }
562 }
563
564 pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
569 where
570 P: MlDsaParams,
571 {
572 let mu = message_representative(&self.tr, Mp);
574
575 let c = sample_in_ball(&sigma.c_tilde, P::TAU);
577
578 let z_hat = sigma.z.ntt();
579 let c_hat = c.ntt();
580 let Az_hat = &self.A_hat * &z_hat;
581 let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
582
583 let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
584 let w1p = sigma.h.use_hint(&wp_approx);
585
586 let w1p_tilde = P::encode_w1(&w1p);
587 let cp_tilde = H::default()
588 .absorb(&mu)
589 .absorb(&w1p_tilde)
590 .squeeze_new::<P::Lambda>();
591
592 sigma.c_tilde == cp_tilde
593 }
594
595 pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
598 if ctx.len() > 255 {
599 return false;
600 }
601
602 let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
603 self.verify_internal(Mp, sigma)
604 }
605
606 fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
607 let t1_enc = P::encode_t1(t1);
608 P::concat_vk(rho.clone(), t1_enc)
609 }
610
611 pub fn encode(&self) -> EncodedVerifyingKey<P> {
614 Self::encode_internal(&self.rho, &self.t1)
615 }
616
617 pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
620 let (rho, t1_enc) = P::split_vk(enc);
621 let t1 = P::decode_t1(t1_enc);
622 Self::new(rho.clone(), t1, None, Some(enc.clone()))
623 }
624}
625
626impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
627 fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
628 self.verify_with_context(msg, &[], signature)
629 .then_some(())
630 .ok_or(Error::new())
631 }
632}
633
634#[cfg(feature = "pkcs8")]
635impl<P> SignatureAlgorithmIdentifier for VerifyingKey<P>
636where
637 P: MlDsaParams,
638 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
639{
640 type Params = AnyRef<'static>;
641
642 const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
643 Signature::<P>::ALGORITHM_IDENTIFIER;
644}
645
646#[cfg(feature = "alloc")]
647impl<P> EncodePublicKey for VerifyingKey<P>
648where
649 P: MlDsaParams,
650 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
651{
652 fn to_public_key_der(&self) -> spki::Result<der::Document> {
653 let public_key = self.encode();
654 let subject_public_key = BitStringRef::new(0, &public_key)?;
655
656 SubjectPublicKeyInfo {
657 algorithm: P::ALGORITHM_IDENTIFIER,
658 subject_public_key,
659 }
660 .try_into()
661 }
662}
663
664#[cfg(feature = "pkcs8")]
665impl<P> TryFrom<SubjectPublicKeyInfoRef<'_>> for VerifyingKey<P>
666where
667 P: MlDsaParams,
668 P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
669{
670 type Error = spki::Error;
671
672 fn try_from(spki: SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
673 match spki.algorithm {
674 alg if alg == P::ALGORITHM_IDENTIFIER => {}
675 other => return Err(spki::Error::OidUnknown { oid: other.oid }),
676 }
677
678 Ok(Self::decode(
679 &EncodedVerifyingKey::<P>::try_from(
680 spki.subject_public_key
681 .as_bytes()
682 .ok_or_else(|| der::Tag::BitString.value_error())?,
683 )
684 .map_err(|_| pkcs8::Error::KeyMalformed)?,
685 ))
686 }
687}
688
689#[derive(Default, Clone, Debug, PartialEq)]
691pub struct MlDsa44;
692
693impl ParameterSet for MlDsa44 {
694 type K = U4;
695 type L = U4;
696 type Eta = U2;
697 type Gamma1 = Shleft<U1, U17>;
698 type Gamma2 = Quot<QMinus1, U88>;
699 type TwoGamma2 = Prod<U2, Self::Gamma2>;
700 type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
701 type Lambda = U32;
702 type Omega = U80;
703 const TAU: usize = 39;
704}
705
706#[cfg(feature = "pkcs8")]
707impl AssociatedAlgorithmIdentifier for MlDsa44 {
708 type Params = AnyRef<'static>;
709
710 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
711 oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17"),
712 parameters: None,
713 };
714}
715
716#[derive(Default, Clone, Debug, PartialEq)]
718pub struct MlDsa65;
719
720impl ParameterSet for MlDsa65 {
721 type K = U6;
722 type L = U5;
723 type Eta = U4;
724 type Gamma1 = Shleft<U1, U19>;
725 type Gamma2 = Quot<QMinus1, U32>;
726 type TwoGamma2 = Prod<U2, Self::Gamma2>;
727 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
728 type Lambda = U48;
729 type Omega = U55;
730 const TAU: usize = 49;
731}
732
733#[cfg(feature = "pkcs8")]
734impl AssociatedAlgorithmIdentifier for MlDsa65 {
735 type Params = AnyRef<'static>;
736
737 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
738 oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18"),
739 parameters: None,
740 };
741}
742
743#[derive(Default, Clone, Debug, PartialEq)]
745pub struct MlDsa87;
746
747impl ParameterSet for MlDsa87 {
748 type K = U8;
749 type L = U7;
750 type Eta = U2;
751 type Gamma1 = Shleft<U1, U19>;
752 type Gamma2 = Quot<QMinus1, U32>;
753 type TwoGamma2 = Prod<U2, Self::Gamma2>;
754 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
755 type Lambda = U64;
756 type Omega = U75;
757 const TAU: usize = 60;
758}
759
760#[cfg(feature = "pkcs8")]
761impl AssociatedAlgorithmIdentifier for MlDsa87 {
762 type Params = AnyRef<'static>;
763
764 const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
765 oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19"),
766 parameters: None,
767 };
768}
769
770pub trait KeyGen: MlDsaParams {
772 type KeyPair: signature::Keypair;
774
775 #[cfg(feature = "rand_core")]
777 fn key_gen<R: RngCore + CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
778
779 fn key_gen_internal(xi: &B32) -> Self::KeyPair;
782}
783
784impl<P> KeyGen for P
785where
786 P: MlDsaParams,
787{
788 type KeyPair = KeyPair<P>;
789
790 #[cfg(feature = "rand_core")]
793 fn key_gen<R: RngCore + CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
794 let mut xi = B32::default();
795 rng.fill_bytes(&mut xi);
796 Self::key_gen_internal(&xi)
797 }
798
799 fn key_gen_internal(xi: &B32) -> KeyPair<P>
802 where
803 P: MlDsaParams,
804 {
805 let mut h = H::default()
807 .absorb(xi)
808 .absorb(&[P::K::U8])
809 .absorb(&[P::L::U8]);
810
811 let rho: B32 = h.squeeze_new();
812 let rhop: B64 = h.squeeze_new();
813 let K: B32 = h.squeeze_new();
814
815 let A_hat = expand_a::<P::K, P::L>(&rho);
817 let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
818 let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
819
820 let As1_hat = &A_hat * &s1.ntt();
822 let t = &As1_hat.ntt_inverse() + &s2;
823
824 let (t1, t0) = t.power2round();
826
827 let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
828 let signing_key =
829 SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
830
831 KeyPair {
832 signing_key,
833 verifying_key,
834 seed: xi.clone(),
835 }
836 }
837}
838
839#[cfg(test)]
840mod test {
841 use super::*;
842 use crate::param::*;
843
844 #[test]
845 fn output_sizes() {
846 assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
851 assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
852 assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
853
854 assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
855 assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
856 assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
857
858 assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
859 assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
860 assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
861 }
862
863 fn encode_decode_round_trip_test<P>()
864 where
865 P: MlDsaParams + PartialEq,
866 {
867 let kp = P::key_gen_internal(&Default::default());
868 let sk = kp.signing_key;
869 let vk = kp.verifying_key;
870
871 let vk_bytes = vk.encode();
872 let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
873 assert!(vk == vk2);
874
875 let sk_bytes = sk.encode();
876 let sk2 = SigningKey::<P>::decode(&sk_bytes);
877 assert!(sk == sk2);
878
879 let M = b"Hello world";
880 let rnd = Array([0u8; 32]);
881 let sig = sk.sign_internal(&[M], &rnd);
882 let sig_bytes = sig.encode();
883 let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
884 assert!(sig == sig2);
885 }
886
887 #[test]
888 fn encode_decode_round_trip() {
889 encode_decode_round_trip_test::<MlDsa44>();
890 encode_decode_round_trip_test::<MlDsa65>();
891 encode_decode_round_trip_test::<MlDsa87>();
892 }
893
894 fn sign_verify_round_trip_test<P>()
895 where
896 P: MlDsaParams,
897 {
898 let kp = P::key_gen_internal(&Default::default());
899 let sk = kp.signing_key;
900 let vk = kp.verifying_key;
901
902 let M = b"Hello world";
903 let rnd = Array([0u8; 32]);
904 let sig = sk.sign_internal(&[M], &rnd);
905
906 assert!(vk.verify_internal(&[M], &sig));
907 }
908
909 #[test]
910 fn sign_verify_round_trip() {
911 sign_verify_round_trip_test::<MlDsa44>();
912 sign_verify_round_trip_test::<MlDsa65>();
913 sign_verify_round_trip_test::<MlDsa87>();
914 }
915
916 fn many_round_trip_test<P>()
917 where
918 P: MlDsaParams,
919 {
920 use rand::Rng;
921
922 const ITERATIONS: usize = 1000;
923
924 let mut rng = rand::rngs::OsRng;
925 let mut seed = B32::default();
926
927 for _i in 0..ITERATIONS {
928 let seed_data: &mut [u8] = seed.as_mut();
929 rng.fill(seed_data);
930
931 let kp = P::key_gen_internal(&seed);
932 let sk = kp.signing_key;
933 let vk = kp.verifying_key;
934
935 let M = b"Hello world";
936 let rnd = Array([0u8; 32]);
937 let sig = sk.sign_internal(&[M], &rnd);
938
939 let sig_enc = sig.encode();
940 let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
941
942 assert_eq!(sig_dec, sig);
943 assert!(vk.verify_internal(&[M], &sig_dec));
944 }
945 }
946
947 #[test]
948 fn many_round_trip() {
949 many_round_trip_test::<MlDsa44>();
950 many_round_trip_test::<MlDsa65>();
951 many_round_trip_test::<MlDsa87>();
952 }
953}