ml_dsa/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8#![warn(clippy::pedantic)] // Be pedantic by default
9#![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%`
10#![warn(clippy::as_conversions)] // Use proper conversions, not `as`
11#![allow(non_snake_case)] // Allow notation matching the spec
12#![allow(clippy::similar_names)] // Allow notation matching the spec
13#![allow(clippy::many_single_char_names)] // Allow notation matching the spec
14#![allow(clippy::clone_on_copy)] // Be explicit about moving data
15#![deny(missing_docs)] // Require all public interfaces to be documented
16
17//! # Quickstart
18//!
19//! ```
20//! use ml_dsa::{MlDsa65, KeyGen, signature::{Keypair, Signer, Verifier}};
21//!
22//! let mut rng = rand::thread_rng();
23//! let kp = MlDsa65::key_gen(&mut rng);
24//!
25//! let msg = b"Hello world";
26//! let sig = kp.signing_key().sign(msg);
27//!
28//! assert!(kp.verifying_key().verify(msg, &sig).is_ok());
29//! ```
30
31mod algebra;
32mod crypto;
33mod encode;
34mod hint;
35mod ntt;
36mod param;
37mod sampling;
38mod util;
39
40// TODO(RLB) Move module to an independent crate shared with ml_kem
41mod module_lattice;
42
43use core::convert::{AsRef, TryFrom, TryInto};
44use hybrid_array::{
45    Array,
46    typenum::{
47        Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
48        U75, U80, U88, Unsigned,
49    },
50};
51
52#[cfg(feature = "rand_core")]
53use rand_core::{CryptoRng, CryptoRngCore, RngCore};
54
55#[cfg(feature = "zeroize")]
56use zeroize::{Zeroize, ZeroizeOnDrop};
57
58#[cfg(feature = "pkcs8")]
59use pkcs8::{
60    AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfo,
61    der::{self, AnyRef},
62    spki::{
63        self, AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier,
64        SubjectPublicKeyInfoRef,
65    },
66};
67
68#[cfg(all(feature = "alloc", feature = "pkcs8"))]
69use pkcs8::{
70    EncodePrivateKey, EncodePublicKey,
71    der::asn1::{BitString, BitStringRef},
72    spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
73};
74
75use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
76use crate::crypto::H;
77use crate::hint::Hint;
78use crate::ntt::{Ntt, NttInverse};
79use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
80use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
81use crate::util::B64;
82use core::fmt;
83
84pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
85pub use crate::util::B32;
86pub use signature::{self, Error};
87
88/// An ML-DSA signature
89#[derive(Clone, PartialEq, Debug)]
90pub struct Signature<P: MlDsaParams> {
91    c_tilde: Array<u8, P::Lambda>,
92    z: Vector<P::L>,
93    h: Hint<P>,
94}
95
96impl<P: MlDsaParams> Signature<P> {
97    /// Encode the signature in a fixed-size byte array.
98    // Algorithm 26 sigEncode
99    pub fn encode(&self) -> EncodedSignature<P> {
100        let c_tilde = self.c_tilde.clone();
101        let z = P::encode_z(&self.z);
102        let h = self.h.bit_pack();
103        P::concat_sig(c_tilde, z, h)
104    }
105
106    /// Decode the signature from an appropriately sized byte array.
107    // Algorithm 27 sigDecode
108    pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
109        let (c_tilde, z, h) = P::split_sig(enc);
110
111        let c_tilde = c_tilde.clone();
112        let z = P::decode_z(z);
113        let h = Hint::bit_unpack(h)?;
114
115        if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
116            return None;
117        }
118
119        Some(Self { c_tilde, z, h })
120    }
121}
122
123impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
124    type Error = Error;
125
126    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
127        let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
128        Self::decode(&enc).ok_or(Error::new())
129    }
130}
131
132impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
133    type Error = Error;
134
135    fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
136        Ok(self.encode())
137    }
138}
139
140impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
141    type Repr = EncodedSignature<P>;
142}
143
144#[cfg(feature = "alloc")]
145impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
146    fn to_bitstring(&self) -> der::Result<BitString> {
147        BitString::new(0, self.encode().to_vec())
148    }
149}
150
151#[cfg(feature = "pkcs8")]
152impl<P> AssociatedAlgorithmIdentifier for Signature<P>
153where
154    P: MlDsaParams,
155    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
156{
157    type Params = AnyRef<'static>;
158
159    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
160}
161
162// This method takes a slice of slices so that we can accommodate the varying calculations (direct
163// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
164// having to allocate memory for components.
165fn message_representative(tr: &[u8], Mp: &[&[u8]]) -> B64 {
166    let mut h = H::default().absorb(tr);
167
168    for m in Mp {
169        h = h.absorb(m);
170    }
171
172    h.squeeze_new()
173}
174
175/// An ML-DSA key pair
176pub struct KeyPair<P: MlDsaParams> {
177    /// The signing key of the key pair
178    signing_key: SigningKey<P>,
179
180    /// The verifying key of the key pair
181    verifying_key: VerifyingKey<P>,
182
183    /// The seed this signing key was derived from
184    seed: B32,
185}
186
187impl<P: MlDsaParams> KeyPair<P> {
188    /// The signing key of the key pair
189    pub fn signing_key(&self) -> &SigningKey<P> {
190        &self.signing_key
191    }
192
193    /// The verifying key of the key pair
194    pub fn verifying_key(&self) -> &VerifyingKey<P> {
195        &self.verifying_key
196    }
197}
198
199impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
200    fn as_ref(&self) -> &VerifyingKey<P> {
201        &self.verifying_key
202    }
203}
204
205impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        f.debug_struct("KeyPair")
208            .field("verifying_key", &self.verifying_key)
209            .finish_non_exhaustive()
210    }
211}
212
213impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
214    type VerifyingKey = VerifyingKey<P>;
215}
216
217#[cfg(feature = "pkcs8")]
218impl<P> TryFrom<PrivateKeyInfo<'_>> for KeyPair<P>
219where
220    P: MlDsaParams,
221    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
222{
223    type Error = pkcs8::Error;
224
225    fn try_from(private_key_info: pkcs8::PrivateKeyInfo<'_>) -> pkcs8::Result<Self> {
226        match private_key_info.algorithm {
227            alg if alg == P::ALGORITHM_IDENTIFIER => {}
228            other => return Err(spki::Error::OidUnknown { oid: other.oid }.into()),
229        }
230
231        let seed = Array::try_from(private_key_info.private_key)
232            .map_err(|_| pkcs8::Error::KeyMalformed)?;
233        Ok(P::key_gen_internal(&seed))
234    }
235}
236
237/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
238/// only supports signing with an empty context string.
239impl<P: MlDsaParams> signature::Signer<Signature<P>> for KeyPair<P> {
240    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
241        self.signing_key.sign_deterministic(msg, &[])
242    }
243}
244
245#[cfg(feature = "pkcs8")]
246impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
247where
248    P: MlDsaParams,
249    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
250{
251    type Params = AnyRef<'static>;
252
253    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
254        Signature::<P>::ALGORITHM_IDENTIFIER;
255}
256
257#[cfg(all(feature = "alloc", feature = "pkcs8"))]
258impl<P> EncodePrivateKey for KeyPair<P>
259where
260    P: MlDsaParams,
261    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
262{
263    fn to_pkcs8_der(&self) -> pkcs8::Result<der::SecretDocument> {
264        let pkcs8_key = pkcs8::PrivateKeyInfo::new(P::ALGORITHM_IDENTIFIER, &self.seed);
265        Ok(der::SecretDocument::encode_msg(&pkcs8_key)?)
266    }
267}
268
269/// An ML-DSA signing key
270#[derive(Clone, PartialEq)]
271pub struct SigningKey<P: MlDsaParams> {
272    rho: B32,
273    K: B32,
274    tr: B64,
275    s1: Vector<P::L>,
276    s2: Vector<P::K>,
277    t0: Vector<P::K>,
278
279    // Derived values
280    s1_hat: NttVector<P::L>,
281    s2_hat: NttVector<P::K>,
282    t0_hat: NttVector<P::K>,
283    A_hat: NttMatrix<P::K, P::L>,
284}
285
286impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        f.debug_struct("SigningKey").finish_non_exhaustive()
289    }
290}
291
292#[cfg(feature = "zeroize")]
293impl<P: MlDsaParams> Drop for SigningKey<P> {
294    fn drop(&mut self) {
295        self.rho.zeroize();
296        self.K.zeroize();
297        self.tr.zeroize();
298        self.s1.zeroize();
299        self.s2.zeroize();
300        self.t0.zeroize();
301    }
302}
303
304#[cfg(feature = "zeroize")]
305impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
306
307impl<P: MlDsaParams> SigningKey<P> {
308    fn new(
309        rho: B32,
310        K: B32,
311        tr: B64,
312        s1: Vector<P::L>,
313        s2: Vector<P::K>,
314        t0: Vector<P::K>,
315        A_hat: Option<NttMatrix<P::K, P::L>>,
316    ) -> Self {
317        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
318        let s1_hat = s1.ntt();
319        let s2_hat = s2.ntt();
320        let t0_hat = t0.ntt();
321
322        Self {
323            rho,
324            K,
325            tr,
326            s1,
327            s2,
328            t0,
329
330            s1_hat,
331            s2_hat,
332            t0_hat,
333            A_hat,
334        }
335    }
336
337    /// This method reflects the ML-DSA.Sign_internal algorithm from FIPS 204. It does not
338    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
339    /// and it does not separate the context string from the rest of the message.
340    // Algorithm 7 ML-DSA.Sign_internal
341    // TODO(RLB) Only expose based on a feature.  Tests need access, but normal code shouldn't.
342    pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
343    where
344        P: MlDsaParams,
345    {
346        // Compute the message representative
347        // XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
348        // the concatenated M'.
349        // XXX(RLB) Should the API represent this as an input?
350        let mu = message_representative(&self.tr, Mp);
351
352        // Compute the private random seed
353        let rhopp: B64 = H::default()
354            .absorb(&self.K)
355            .absorb(rnd)
356            .absorb(&mu)
357            .squeeze_new();
358
359        // Rejection sampling loop
360        for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
361            let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
362            let w = (&self.A_hat * &y.ntt()).ntt_inverse();
363            let w1 = w.high_bits::<P::TwoGamma2>();
364
365            let w1_tilde = P::encode_w1(&w1);
366            let c_tilde = H::default()
367                .absorb(&mu)
368                .absorb(&w1_tilde)
369                .squeeze_new::<P::Lambda>();
370            let c = sample_in_ball(&c_tilde, P::TAU);
371            let c_hat = c.ntt();
372
373            let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
374            let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
375
376            let z = &y + &cs1;
377            let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
378
379            if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
380                || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
381            {
382                continue;
383            }
384
385            let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
386            let minus_ct0 = -&ct0;
387            let w_cs2_ct0 = &(&w - &cs2) + &ct0;
388            let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
389
390            if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
391                continue;
392            }
393
394            let z = z.mod_plus_minus::<SpecQ>();
395            return Signature { c_tilde, z, h };
396        }
397
398        unreachable!("Rejection sampling failed to find a valid signature");
399    }
400
401    /// This method reflects the randomized ML-DSA.Sign algorithm.
402    ///
403    /// # Errors
404    ///
405    /// This method will return an opaque error if the context string is more than 255 bytes long,
406    /// or if it fails to get enough randomness.
407    // Algorithm 2 ML-DSA.Sign
408    #[cfg(feature = "rand_core")]
409    pub fn sign_randomized<R: RngCore + CryptoRng + ?Sized>(
410        &self,
411        M: &[u8],
412        ctx: &[u8],
413        rng: &mut R,
414    ) -> Result<Signature<P>, Error> {
415        if ctx.len() > 255 {
416            return Err(Error::new());
417        }
418
419        let mut rnd = B32::default();
420        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
421
422        let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
423        Ok(self.sign_internal(Mp, &rnd))
424    }
425
426    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
427    ///
428    /// # Errors
429    ///
430    /// This method will return an opaque error if the context string is more than 255 bytes long.
431    // Algorithm 2 ML-DSA.Sign (optional deterministic variant)
432    pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
433        if ctx.len() > 255 {
434            return Err(Error::new());
435        }
436
437        let rnd = B32::default();
438        let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
439        Ok(self.sign_internal(Mp, &rnd))
440    }
441
442    /// Encode the key in a fixed-size byte array.
443    // Algorithm 24 skEncode
444    pub fn encode(&self) -> EncodedSigningKey<P>
445    where
446        P: MlDsaParams,
447    {
448        let s1_enc = P::encode_s1(&self.s1);
449        let s2_enc = P::encode_s2(&self.s2);
450        let t0_enc = P::encode_t0(&self.t0);
451        P::concat_sk(
452            self.rho.clone(),
453            self.K.clone(),
454            self.tr.clone(),
455            s1_enc,
456            s2_enc,
457            t0_enc,
458        )
459    }
460
461    /// Decode the key from an appropriately sized byte array.
462    // Algorithm 25 skDecode
463    pub fn decode(enc: &EncodedSigningKey<P>) -> Self
464    where
465        P: MlDsaParams,
466    {
467        let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
468        Self::new(
469            rho.clone(),
470            K.clone(),
471            tr.clone(),
472            P::decode_s1(s1_enc),
473            P::decode_s2(s2_enc),
474            P::decode_t0(t0_enc),
475            None,
476        )
477    }
478}
479
480/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
481/// only supports signing with an empty context string.  If you would like to include a context
482/// string, use the [`SigningKey::sign_deterministic`] method.
483impl<P: MlDsaParams> signature::Signer<Signature<P>> for SigningKey<P> {
484    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
485        self.sign_deterministic(msg, &[])
486    }
487}
488
489/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
490/// context string. If you would like to include a context string, use the [`SigningKey::sign`]
491/// method.
492#[cfg(feature = "rand_core")]
493impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
494    fn try_sign_with_rng(
495        &self,
496        rng: &mut impl CryptoRngCore,
497        msg: &[u8],
498    ) -> Result<Signature<P>, Error> {
499        self.sign_randomized(msg, &[], rng)
500    }
501}
502
503#[cfg(feature = "pkcs8")]
504impl<P> SignatureAlgorithmIdentifier for SigningKey<P>
505where
506    P: MlDsaParams,
507    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
508{
509    type Params = AnyRef<'static>;
510
511    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
512        Signature::<P>::ALGORITHM_IDENTIFIER;
513}
514
515#[cfg(feature = "pkcs8")]
516impl<P> TryFrom<PrivateKeyInfo<'_>> for SigningKey<P>
517where
518    P: MlDsaParams,
519    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
520{
521    type Error = pkcs8::Error;
522
523    fn try_from(private_key_info: pkcs8::PrivateKeyInfo<'_>) -> pkcs8::Result<Self> {
524        let keypair = KeyPair::try_from(private_key_info)?;
525
526        Ok(keypair.signing_key)
527    }
528}
529
530/// An ML-DSA verification key
531#[derive(Clone, Debug, PartialEq)]
532pub struct VerifyingKey<P: ParameterSet> {
533    rho: B32,
534    t1: Vector<P::K>,
535
536    // Derived values
537    A_hat: NttMatrix<P::K, P::L>,
538    t1_2d_hat: NttVector<P::K>,
539    tr: B64,
540}
541
542impl<P: MlDsaParams> VerifyingKey<P> {
543    fn new(
544        rho: B32,
545        t1: Vector<P::K>,
546        A_hat: Option<NttMatrix<P::K, P::L>>,
547        enc: Option<EncodedVerifyingKey<P>>,
548    ) -> Self {
549        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
550        let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
551
552        let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
553        let tr: B64 = H::default().absorb(&enc).squeeze_new();
554
555        Self {
556            rho,
557            t1,
558            A_hat,
559            t1_2d_hat,
560            tr,
561        }
562    }
563
564    /// This algorithm reflects the ML-DSA.Verify_internal algorithm from FIPS 204.  It does not
565    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
566    /// and it does not separate the context string from the rest of the message.
567    // Algorithm 8 ML-DSA.Verify_internal
568    pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
569    where
570        P: MlDsaParams,
571    {
572        // Compute the message representative
573        let mu = message_representative(&self.tr, Mp);
574
575        // Reconstruct w
576        let c = sample_in_ball(&sigma.c_tilde, P::TAU);
577
578        let z_hat = sigma.z.ntt();
579        let c_hat = c.ntt();
580        let Az_hat = &self.A_hat * &z_hat;
581        let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
582
583        let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
584        let w1p = sigma.h.use_hint(&wp_approx);
585
586        let w1p_tilde = P::encode_w1(&w1p);
587        let cp_tilde = H::default()
588            .absorb(&mu)
589            .absorb(&w1p_tilde)
590            .squeeze_new::<P::Lambda>();
591
592        sigma.c_tilde == cp_tilde
593    }
594
595    /// This algorithm reflect the ML-DSA.Verify algorithm from FIPS 204.
596    // Algorithm 3 ML-DSA.Verify
597    pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
598        if ctx.len() > 255 {
599            return false;
600        }
601
602        let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
603        self.verify_internal(Mp, sigma)
604    }
605
606    fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
607        let t1_enc = P::encode_t1(t1);
608        P::concat_vk(rho.clone(), t1_enc)
609    }
610
611    /// Encode the key in a fixed-size byte array.
612    // Algorithm 22 pkEncode
613    pub fn encode(&self) -> EncodedVerifyingKey<P> {
614        Self::encode_internal(&self.rho, &self.t1)
615    }
616
617    /// Decode the key from an appropriately sized byte array.
618    // Algorithm 23 pkDecode
619    pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
620        let (rho, t1_enc) = P::split_vk(enc);
621        let t1 = P::decode_t1(t1_enc);
622        Self::new(rho.clone(), t1, None, Some(enc.clone()))
623    }
624}
625
626impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
627    fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
628        self.verify_with_context(msg, &[], signature)
629            .then_some(())
630            .ok_or(Error::new())
631    }
632}
633
634#[cfg(feature = "pkcs8")]
635impl<P> SignatureAlgorithmIdentifier for VerifyingKey<P>
636where
637    P: MlDsaParams,
638    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
639{
640    type Params = AnyRef<'static>;
641
642    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
643        Signature::<P>::ALGORITHM_IDENTIFIER;
644}
645
646#[cfg(feature = "alloc")]
647impl<P> EncodePublicKey for VerifyingKey<P>
648where
649    P: MlDsaParams,
650    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
651{
652    fn to_public_key_der(&self) -> spki::Result<der::Document> {
653        let public_key = self.encode();
654        let subject_public_key = BitStringRef::new(0, &public_key)?;
655
656        SubjectPublicKeyInfo {
657            algorithm: P::ALGORITHM_IDENTIFIER,
658            subject_public_key,
659        }
660        .try_into()
661    }
662}
663
664#[cfg(feature = "pkcs8")]
665impl<P> TryFrom<SubjectPublicKeyInfoRef<'_>> for VerifyingKey<P>
666where
667    P: MlDsaParams,
668    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
669{
670    type Error = spki::Error;
671
672    fn try_from(spki: SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
673        match spki.algorithm {
674            alg if alg == P::ALGORITHM_IDENTIFIER => {}
675            other => return Err(spki::Error::OidUnknown { oid: other.oid }),
676        }
677
678        Ok(Self::decode(
679            &EncodedVerifyingKey::<P>::try_from(
680                spki.subject_public_key
681                    .as_bytes()
682                    .ok_or_else(|| der::Tag::BitString.value_error())?,
683            )
684            .map_err(|_| pkcs8::Error::KeyMalformed)?,
685        ))
686    }
687}
688
689/// `MlDsa44` is the parameter set for security category 2.
690#[derive(Default, Clone, Debug, PartialEq)]
691pub struct MlDsa44;
692
693impl ParameterSet for MlDsa44 {
694    type K = U4;
695    type L = U4;
696    type Eta = U2;
697    type Gamma1 = Shleft<U1, U17>;
698    type Gamma2 = Quot<QMinus1, U88>;
699    type TwoGamma2 = Prod<U2, Self::Gamma2>;
700    type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
701    type Lambda = U32;
702    type Omega = U80;
703    const TAU: usize = 39;
704}
705
706#[cfg(feature = "pkcs8")]
707impl AssociatedAlgorithmIdentifier for MlDsa44 {
708    type Params = AnyRef<'static>;
709
710    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
711        oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17"),
712        parameters: None,
713    };
714}
715
716/// `MlDsa65` is the parameter set for security category 3.
717#[derive(Default, Clone, Debug, PartialEq)]
718pub struct MlDsa65;
719
720impl ParameterSet for MlDsa65 {
721    type K = U6;
722    type L = U5;
723    type Eta = U4;
724    type Gamma1 = Shleft<U1, U19>;
725    type Gamma2 = Quot<QMinus1, U32>;
726    type TwoGamma2 = Prod<U2, Self::Gamma2>;
727    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
728    type Lambda = U48;
729    type Omega = U55;
730    const TAU: usize = 49;
731}
732
733#[cfg(feature = "pkcs8")]
734impl AssociatedAlgorithmIdentifier for MlDsa65 {
735    type Params = AnyRef<'static>;
736
737    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
738        oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18"),
739        parameters: None,
740    };
741}
742
743/// `MlKem87` is the parameter set for security category 5.
744#[derive(Default, Clone, Debug, PartialEq)]
745pub struct MlDsa87;
746
747impl ParameterSet for MlDsa87 {
748    type K = U8;
749    type L = U7;
750    type Eta = U2;
751    type Gamma1 = Shleft<U1, U19>;
752    type Gamma2 = Quot<QMinus1, U32>;
753    type TwoGamma2 = Prod<U2, Self::Gamma2>;
754    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
755    type Lambda = U64;
756    type Omega = U75;
757    const TAU: usize = 60;
758}
759
760#[cfg(feature = "pkcs8")]
761impl AssociatedAlgorithmIdentifier for MlDsa87 {
762    type Params = AnyRef<'static>;
763
764    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
765        oid: ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19"),
766        parameters: None,
767    };
768}
769
770/// A parameter set that knows how to generate key pairs
771pub trait KeyGen: MlDsaParams {
772    /// The type that is returned by key generation
773    type KeyPair: signature::Keypair;
774
775    /// Generate a signing key pair from the specified RNG
776    #[cfg(feature = "rand_core")]
777    fn key_gen<R: RngCore + CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
778
779    /// Deterministically generate a signing key pair from the specified seed
780    // TODO(RLB): Only expose this based on a feature.
781    fn key_gen_internal(xi: &B32) -> Self::KeyPair;
782}
783
784impl<P> KeyGen for P
785where
786    P: MlDsaParams,
787{
788    type KeyPair = KeyPair<P>;
789
790    /// Generate a signing key pair from the specified RNG
791    // Algorithm 1 ML-DSA.KeyGen()
792    #[cfg(feature = "rand_core")]
793    fn key_gen<R: RngCore + CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
794        let mut xi = B32::default();
795        rng.fill_bytes(&mut xi);
796        Self::key_gen_internal(&xi)
797    }
798
799    /// Deterministically generate a signing key pair from the specified seed
800    // Algorithm 6 ML-DSA.KeyGen_internal
801    fn key_gen_internal(xi: &B32) -> KeyPair<P>
802    where
803        P: MlDsaParams,
804    {
805        // Derive seeds
806        let mut h = H::default()
807            .absorb(xi)
808            .absorb(&[P::K::U8])
809            .absorb(&[P::L::U8]);
810
811        let rho: B32 = h.squeeze_new();
812        let rhop: B64 = h.squeeze_new();
813        let K: B32 = h.squeeze_new();
814
815        // Sample private key components
816        let A_hat = expand_a::<P::K, P::L>(&rho);
817        let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
818        let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
819
820        // Compute derived values
821        let As1_hat = &A_hat * &s1.ntt();
822        let t = &As1_hat.ntt_inverse() + &s2;
823
824        // Compress and encode
825        let (t1, t0) = t.power2round();
826
827        let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
828        let signing_key =
829            SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
830
831        KeyPair {
832            signing_key,
833            verifying_key,
834            seed: xi.clone(),
835        }
836    }
837}
838
839#[cfg(test)]
840mod test {
841    use super::*;
842    use crate::param::*;
843
844    #[test]
845    fn output_sizes() {
846        //           priv pub  sig
847        // ML-DSA-44 2560 1312 2420
848        // ML-DSA-65 4032 1952 3309
849        // ML-DSA-87 4896 2592 4627
850        assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
851        assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
852        assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
853
854        assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
855        assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
856        assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
857
858        assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
859        assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
860        assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
861    }
862
863    fn encode_decode_round_trip_test<P>()
864    where
865        P: MlDsaParams + PartialEq,
866    {
867        let kp = P::key_gen_internal(&Default::default());
868        let sk = kp.signing_key;
869        let vk = kp.verifying_key;
870
871        let vk_bytes = vk.encode();
872        let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
873        assert!(vk == vk2);
874
875        let sk_bytes = sk.encode();
876        let sk2 = SigningKey::<P>::decode(&sk_bytes);
877        assert!(sk == sk2);
878
879        let M = b"Hello world";
880        let rnd = Array([0u8; 32]);
881        let sig = sk.sign_internal(&[M], &rnd);
882        let sig_bytes = sig.encode();
883        let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
884        assert!(sig == sig2);
885    }
886
887    #[test]
888    fn encode_decode_round_trip() {
889        encode_decode_round_trip_test::<MlDsa44>();
890        encode_decode_round_trip_test::<MlDsa65>();
891        encode_decode_round_trip_test::<MlDsa87>();
892    }
893
894    fn sign_verify_round_trip_test<P>()
895    where
896        P: MlDsaParams,
897    {
898        let kp = P::key_gen_internal(&Default::default());
899        let sk = kp.signing_key;
900        let vk = kp.verifying_key;
901
902        let M = b"Hello world";
903        let rnd = Array([0u8; 32]);
904        let sig = sk.sign_internal(&[M], &rnd);
905
906        assert!(vk.verify_internal(&[M], &sig));
907    }
908
909    #[test]
910    fn sign_verify_round_trip() {
911        sign_verify_round_trip_test::<MlDsa44>();
912        sign_verify_round_trip_test::<MlDsa65>();
913        sign_verify_round_trip_test::<MlDsa87>();
914    }
915
916    fn many_round_trip_test<P>()
917    where
918        P: MlDsaParams,
919    {
920        use rand::Rng;
921
922        const ITERATIONS: usize = 1000;
923
924        let mut rng = rand::rngs::OsRng;
925        let mut seed = B32::default();
926
927        for _i in 0..ITERATIONS {
928            let seed_data: &mut [u8] = seed.as_mut();
929            rng.fill(seed_data);
930
931            let kp = P::key_gen_internal(&seed);
932            let sk = kp.signing_key;
933            let vk = kp.verifying_key;
934
935            let M = b"Hello world";
936            let rnd = Array([0u8; 32]);
937            let sig = sk.sign_internal(&[M], &rnd);
938
939            let sig_enc = sig.encode();
940            let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
941
942            assert_eq!(sig_dec, sig);
943            assert!(vk.verify_internal(&[M], &sig_dec));
944        }
945    }
946
947    #[test]
948    fn many_round_trip() {
949        many_round_trip_test::<MlDsa44>();
950        many_round_trip_test::<MlDsa65>();
951        many_round_trip_test::<MlDsa87>();
952    }
953}