Skip to main content

ml_dsa/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_cfg))]
8#![warn(clippy::pedantic)] // Be pedantic by default
9#![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%`
10#![warn(clippy::as_conversions)] // Use proper conversions, not `as`
11#![allow(non_snake_case)] // Allow notation matching the spec
12#![allow(clippy::similar_names)] // Allow notation matching the spec
13#![allow(clippy::many_single_char_names)] // Allow notation matching the spec
14#![allow(clippy::clone_on_copy)] // Be explicit about moving data
15#![deny(missing_docs)] // Require all public interfaces to be documented
16#![warn(unreachable_pub)] // Prevent unexpected interface changes
17
18//! # Quickstart
19//!
20//! ```
21//! # #[cfg(feature = "rand_core")]
22//! # {
23//! use ml_dsa::{
24//!     signature::{Keypair, Signer, Verifier},
25//!     MlDsa65, KeyGen,
26//! };
27//! use getrandom::{SysRng, rand_core::UnwrapErr};
28//!
29//! let mut rng = UnwrapErr(SysRng);
30//! let kp = MlDsa65::key_gen(&mut rng);
31//!
32//! let msg = b"Hello world";
33//! let sig = kp.signing_key().sign(msg);
34//!
35//! assert!(kp.verifying_key().verify(msg, &sig).is_ok());
36//! # }
37//! ```
38
39mod algebra;
40mod crypto;
41mod encode;
42mod hint;
43mod ntt;
44mod param;
45mod pkcs8;
46mod sampling;
47
48pub use crate::param::{EncodedSignature, EncodedVerifyingKey, ExpandedSigningKey, MlDsaParams};
49pub use signature::{self, Error};
50
51use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Vector};
52use crate::crypto::H;
53use crate::hint::Hint;
54use crate::ntt::{Ntt, NttInverse};
55use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
56use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
57use core::convert::{AsRef, TryFrom, TryInto};
58use core::fmt;
59use hybrid_array::{
60    Array,
61    typenum::{
62        Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
63        U75, U80, U88, Unsigned,
64    },
65};
66use module_lattice::Truncate;
67use sha3::Shake256;
68use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
69
70#[cfg(feature = "rand_core")]
71use {
72    rand_core::{CryptoRng, TryCryptoRng},
73    signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner},
74};
75
76#[cfg(feature = "zeroize")]
77use zeroize::{Zeroize, ZeroizeOnDrop};
78
79/// A 32-byte array, defined here for brevity because it is used several times
80pub type B32 = Array<u8, U32>;
81
82/// A 64-byte array, defined here for brevity because it is used several times
83pub(crate) type B64 = Array<u8, U64>;
84
85/// ML-DSA seeds are signing (private) keys, which are consistently 32-bytes across all security
86/// levels, and are the preferred serialization for representing such keys.
87pub type Seed = B32;
88
89/// An ML-DSA signature
90#[derive(Clone, PartialEq, Debug)]
91pub struct Signature<P: MlDsaParams> {
92    c_tilde: Array<u8, P::Lambda>,
93    z: Vector<P::L>,
94    h: Hint<P>,
95}
96
97impl<P: MlDsaParams> Signature<P> {
98    /// Encode the signature in a fixed-size byte array.
99    // Algorithm 26 sigEncode
100    pub fn encode(&self) -> EncodedSignature<P> {
101        let c_tilde = self.c_tilde.clone();
102        let z = P::encode_z(&self.z);
103        let h = self.h.bit_pack();
104        P::concat_sig(c_tilde, z, h)
105    }
106
107    /// Decode the signature from an appropriately sized byte array.
108    // Algorithm 27 sigDecode
109    pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
110        let (c_tilde, z, h) = P::split_sig(enc);
111
112        let c_tilde = c_tilde.clone();
113        let z = P::decode_z(z);
114        let h = Hint::bit_unpack(h)?;
115
116        if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
117            return None;
118        }
119
120        Some(Self { c_tilde, z, h })
121    }
122}
123
124impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
125    type Error = Error;
126
127    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
128        let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
129        Self::decode(&enc).ok_or(Error::new())
130    }
131}
132
133impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
134    type Error = Error;
135
136    fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
137        Ok(self.encode())
138    }
139}
140
141impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
142    type Repr = EncodedSignature<P>;
143}
144
145struct MuBuilder(H);
146
147impl MuBuilder {
148    fn new(tr: &[u8], ctx: &[u8]) -> Self {
149        let mut h = H::default();
150        h = h.absorb(tr);
151        h = h.absorb(&[0]);
152        h = h.absorb(&[Truncate::truncate(ctx.len())]);
153        h = h.absorb(ctx);
154
155        Self(h)
156    }
157
158    fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
159        let mut h = H::default().absorb(tr);
160
161        for m in Mp {
162            h = h.absorb(m);
163        }
164
165        h.squeeze_new()
166    }
167
168    fn message(mut self, M: &[&[u8]]) -> B64 {
169        for m in M {
170            self.0 = self.0.absorb(m);
171        }
172
173        self.0.squeeze_new()
174    }
175
176    fn finish(mut self) -> B64 {
177        self.0.squeeze_new()
178    }
179}
180
181impl AsMut<Shake256> for MuBuilder {
182    fn as_mut(&mut self) -> &mut Shake256 {
183        self.0.updatable()
184    }
185}
186
187/// An ML-DSA key pair
188pub struct KeyPair<P: MlDsaParams> {
189    /// The signing key of the key pair
190    signing_key: SigningKey<P>,
191
192    /// The verifying key of the key pair
193    verifying_key: VerifyingKey<P>,
194
195    /// The seed this signing key was derived from
196    seed: B32,
197}
198
199impl<P: MlDsaParams> KeyPair<P> {
200    /// The signing key of the key pair
201    pub fn signing_key(&self) -> &SigningKey<P> {
202        &self.signing_key
203    }
204
205    /// The verifying key of the key pair
206    pub fn verifying_key(&self) -> &VerifyingKey<P> {
207        &self.verifying_key
208    }
209
210    /// Serialize the [`Seed`] value: 32-bytes which can be used to reconstruct the
211    /// [`KeyPair`].
212    ///
213    /// # ⚠️ Warning!
214    ///
215    /// This value is key material. Please treat it with care.
216    #[inline]
217    pub fn to_seed(&self) -> Seed {
218        self.seed
219    }
220}
221
222impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
223    fn as_ref(&self) -> &VerifyingKey<P> {
224        &self.verifying_key
225    }
226}
227
228impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        f.debug_struct("KeyPair")
231            .field("verifying_key", &self.verifying_key)
232            .finish_non_exhaustive()
233    }
234}
235
236impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
237    type VerifyingKey = VerifyingKey<P>;
238}
239
240/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
241/// only supports signing with an empty context string.
242impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
243    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
244        self.try_multipart_sign(&[msg])
245    }
246}
247
248/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
249/// only supports signing with an empty context string.
250impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
251    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
252        self.signing_key.raw_sign_deterministic(msg, &[])
253    }
254}
255
256/// The `DigestSigner` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA
257/// with a pre-computed μ, and only supports signing with an empty context string.
258impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
259    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
260        &self,
261        f: F,
262    ) -> Result<Signature<P>, Error> {
263        self.signing_key.try_sign_digest(&f)
264    }
265}
266
267/// An ML-DSA signing key
268#[derive(Clone, PartialEq)]
269pub struct SigningKey<P: MlDsaParams> {
270    rho: B32,
271    K: B32,
272    tr: B64,
273    s1: Vector<P::L>,
274    s2: Vector<P::K>,
275    t0: Vector<P::K>,
276
277    // Derived values
278    s1_hat: NttVector<P::L>,
279    s2_hat: NttVector<P::K>,
280    t0_hat: NttVector<P::K>,
281    A_hat: NttMatrix<P::K, P::L>,
282}
283
284impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        f.debug_struct("SigningKey").finish_non_exhaustive()
287    }
288}
289
290#[cfg(feature = "zeroize")]
291impl<P: MlDsaParams> Drop for SigningKey<P> {
292    fn drop(&mut self) {
293        self.rho.zeroize();
294        self.K.zeroize();
295        self.tr.zeroize();
296        self.s1.zeroize();
297        self.s2.zeroize();
298        self.t0.zeroize();
299    }
300}
301
302#[cfg(feature = "zeroize")]
303impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
304
305impl<P: MlDsaParams> SigningKey<P> {
306    fn new(
307        rho: B32,
308        K: B32,
309        tr: B64,
310        s1: Vector<P::L>,
311        s2: Vector<P::K>,
312        t0: Vector<P::K>,
313        A_hat: Option<NttMatrix<P::K, P::L>>,
314    ) -> Self {
315        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
316        let s1_hat = s1.ntt();
317        let s2_hat = s2.ntt();
318        let t0_hat = t0.ntt();
319
320        Self {
321            rho,
322            K,
323            tr,
324            s1,
325            s2,
326            t0,
327
328            s1_hat,
329            s2_hat,
330            t0_hat,
331            A_hat,
332        }
333    }
334
335    /// Deterministically generate a signing key from the specified seed.
336    ///
337    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204, but only returns a
338    /// signing key.
339    #[must_use]
340    pub fn from_seed(seed: &Seed) -> Self {
341        let kp = P::from_seed(seed);
342        kp.signing_key
343    }
344
345    /// This method reflects the ML-DSA.Sign_internal algorithm from FIPS 204. It does not
346    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
347    /// and it does not separate the context string from the rest of the message.
348    // Algorithm 7 ML-DSA.Sign_internal
349    // TODO(RLB) Only expose based on a feature.  Tests need access, but normal code shouldn't.
350    pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
351    where
352        P: MlDsaParams,
353    {
354        let mu = MuBuilder::internal(&self.tr, Mp);
355        self.raw_sign_mu(&mu, rnd)
356    }
357
358    fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
359    where
360        P: MlDsaParams,
361    {
362        // Compute the private random seed
363        let rhopp: B64 = H::default()
364            .absorb(&self.K)
365            .absorb(rnd)
366            .absorb(mu)
367            .squeeze_new();
368
369        // Rejection sampling loop
370        for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
371            let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
372            let w = (&self.A_hat * &y.ntt()).ntt_inverse();
373            let w1 = w.high_bits::<P::TwoGamma2>();
374
375            let w1_tilde = P::encode_w1(&w1);
376            let c_tilde = H::default()
377                .absorb(mu)
378                .absorb(&w1_tilde)
379                .squeeze_new::<P::Lambda>();
380            let c = sample_in_ball(&c_tilde, P::TAU);
381            let c_hat = c.ntt();
382
383            let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
384            let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
385
386            let z = &y + &cs1;
387            let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
388
389            if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
390                || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
391            {
392                continue;
393            }
394
395            let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
396            let minus_ct0 = -&ct0;
397            let w_cs2_ct0 = &(&w - &cs2) + &ct0;
398            let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
399
400            if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
401                continue;
402            }
403
404            let z = z.mod_plus_minus::<SpecQ>();
405            return Signature { c_tilde, z, h };
406        }
407
408        unreachable!("Rejection sampling failed to find a valid signature");
409    }
410
411    /// This method reflects the randomized ML-DSA.Sign algorithm.
412    ///
413    /// # Errors
414    ///
415    /// This method will return an opaque error if the context string is more than 255 bytes long,
416    /// or if it fails to get enough randomness.
417    // Algorithm 2 ML-DSA.Sign
418    #[cfg(feature = "rand_core")]
419    pub fn sign_randomized<R: TryCryptoRng + ?Sized>(
420        &self,
421        M: &[u8],
422        ctx: &[u8],
423        rng: &mut R,
424    ) -> Result<Signature<P>, Error> {
425        self.raw_sign_randomized(&[M], ctx, rng)
426    }
427
428    #[cfg(feature = "rand_core")]
429    fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
430        &self,
431        Mp: &[&[u8]],
432        ctx: &[u8],
433        rng: &mut R,
434    ) -> Result<Signature<P>, Error> {
435        if ctx.len() > 255 {
436            return Err(Error::new());
437        }
438
439        let mut rnd = B32::default();
440        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
441
442        let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
443        Ok(self.raw_sign_mu(&mu, &rnd))
444    }
445
446    /// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
447    ///
448    /// # Errors
449    ///
450    /// This method can return an opaque error if it fails to get enough randomness.
451    // Algorithm 2 ML-DSA.Sign (optional pre-computed μ variant)
452    #[cfg(feature = "rand_core")]
453    pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
454        &self,
455        mu: &B64,
456        rng: &mut R,
457    ) -> Result<Signature<P>, Error> {
458        let mut rnd = B32::default();
459        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
460
461        Ok(self.raw_sign_mu(mu, &rnd))
462    }
463
464    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
465    ///
466    /// # Errors
467    ///
468    /// This method will return an opaque error if the context string is more than 255 bytes long.
469    // Algorithm 2 ML-DSA.Sign (optional deterministic variant)
470    pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
471        self.raw_sign_deterministic(&[M], ctx)
472    }
473
474    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm with a
475    /// pre-computed μ.
476    // Algorithm 2 ML-DSA.Sign (optional deterministic and pre-computed μ variant)
477    pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
478        let rnd = B32::default();
479        self.raw_sign_mu(mu, &rnd)
480    }
481
482    fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
483        if ctx.len() > 255 {
484            return Err(Error::new());
485        }
486
487        let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
488        Ok(self.sign_mu_deterministic(&mu))
489    }
490
491    /// This auxiliary function derives a `VerifyingKey` from a bare
492    /// `SigningKey` (even in the absence of the original seed).
493    ///
494    /// This is a utility function that is useful when importing the private key
495    /// from an external source which does not export the seed and does not
496    /// provide the precomputed public key associated with the private key
497    /// itself.
498    ///
499    /// `SigningKey` implements `signature::Keypair`: this inherent method is
500    /// retained for convenience, so it is available for callers even when the
501    /// `signature::Keypair` trait is out-of-scope.
502    pub fn verifying_key(&self) -> VerifyingKey<P> {
503        let kp: &dyn signature::Keypair<VerifyingKey = VerifyingKey<P>> = self;
504        kp.verifying_key()
505    }
506
507    /// DEPRECATED: decode the key from an appropriately sized byte array.
508    ///
509    /// Note that this form is deprecated in practice; prefer to use [`SigningKey::from_seed`].
510    ///
511    /// <div class="warning">
512    /// <b>Panics</b>
513    ///
514    /// This API does not validate expanded signing keys and can potentially panic if keys are
515    /// malformed or maliciously generated.
516    ///
517    /// To avoid panics, use [`SigningKey::from_seed`] instead.
518    /// </div>
519    // Algorithm 25 skDecode
520    #[deprecated(since = "0.1.0", note = "use `SigningKey::from_seed` instead")]
521    pub fn from_expanded(enc: &ExpandedSigningKey<P>) -> Self
522    where
523        P: MlDsaParams,
524    {
525        let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
526        Self::new(
527            rho.clone(),
528            K.clone(),
529            tr.clone(),
530            P::decode_s1(s1_enc),
531            P::decode_s2(s2_enc),
532            P::decode_t0(t0_enc),
533            None,
534        )
535    }
536
537    /// DEPRECATED: encode the key in a fixed-size byte array.
538    ///
539    /// Note that this form is deprecated in practice; prefer to use [`KeyPair::to_seed`].
540    // Algorithm 24 skEncode
541    #[deprecated(since = "0.1.0", note = "use `KeyPair::to_seed` instead")]
542    pub fn to_expanded(&self) -> ExpandedSigningKey<P>
543    where
544        P: MlDsaParams,
545    {
546        let s1_enc = P::encode_s1(&self.s1);
547        let s2_enc = P::encode_s2(&self.s2);
548        let t0_enc = P::encode_t0(&self.t0);
549        P::concat_sk(
550            self.rho.clone(),
551            self.K.clone(),
552            self.tr.clone(),
553            s1_enc,
554            s2_enc,
555            t0_enc,
556        )
557    }
558}
559
560/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
561/// only supports signing with an empty context string.  If you would like to include a context
562/// string, use the [`SigningKey::sign_deterministic`] method.
563impl<P: MlDsaParams> Signer<Signature<P>> for SigningKey<P> {
564    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
565        self.try_multipart_sign(&[msg])
566    }
567}
568
569/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
570/// only supports signing with an empty context string. If you would like to include a context
571/// string, use the [`SigningKey::sign_deterministic`] method.
572impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SigningKey<P> {
573    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
574        self.raw_sign_deterministic(msg, &[])
575    }
576}
577
578/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA
579/// with a pre-computed µ, and only supports signing with an empty context string. If you would
580/// like to include a context string, use the [`SigningKey::sign_mu_deterministic`] method.
581impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
582    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
583        &self,
584        f: F,
585    ) -> Result<Signature<P>, Error> {
586        let mut mu = MuBuilder::new(&self.tr, &[]);
587        f(mu.as_mut())?;
588        let mu = mu.finish();
589
590        Ok(self.sign_mu_deterministic(&mu))
591    }
592}
593
594/// The `KeyPair` implementation for `SigningKey` allows to derive a `VerifyingKey` from
595/// a bare `SigningKey` (even in the absence of the original seed).
596impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
597    type VerifyingKey = VerifyingKey<P>;
598
599    /// This is a utility function that is useful when importing the private key
600    /// from an external source which does not export the seed and does not
601    /// provide the precomputed public key associated with the private key
602    /// itself.
603    fn verifying_key(&self) -> Self::VerifyingKey {
604        let As1 = &self.A_hat * &self.s1_hat;
605        let t = &As1.ntt_inverse() + &self.s2;
606
607        /* Discard t0 */
608        let (t1, _) = t.power2round();
609
610        VerifyingKey::new(self.rho.clone(), t1, Some(self.A_hat.clone()), None)
611    }
612}
613
614/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
615/// context string. If you would like to include a context string, use the
616/// [`SigningKey::sign_randomized`] method.
617#[cfg(feature = "rand_core")]
618impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
619    fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
620        &self,
621        rng: &mut R,
622        msg: &[u8],
623    ) -> Result<Signature<P>, Error> {
624        self.try_multipart_sign_with_rng(rng, &[msg])
625    }
626}
627
628/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
629/// context string. If you would like to include a context string, use the
630/// [`SigningKey::sign_randomized`] method.
631#[cfg(feature = "rand_core")]
632impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
633    fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
634        &self,
635        rng: &mut R,
636        msg: &[&[u8]],
637    ) -> Result<Signature<P>, Error> {
638        self.raw_sign_randomized(msg, &[], rng)
639    }
640}
641
642/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
643/// context string. If you would like to include a context string, use the
644/// [`SigningKey::sign_mu_randomized`] method.
645#[cfg(feature = "rand_core")]
646impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningKey<P> {
647    fn try_sign_digest_with_rng<
648        R: TryCryptoRng + ?Sized,
649        F: Fn(&mut Shake256) -> Result<(), Error>,
650    >(
651        &self,
652        rng: &mut R,
653        f: F,
654    ) -> Result<Signature<P>, Error> {
655        let mut mu = MuBuilder::new(&self.tr, &[]);
656        f(mu.as_mut())?;
657        let mu = mu.finish();
658
659        self.sign_mu_randomized(&mu, rng)
660    }
661}
662
663/// An ML-DSA verification key
664#[derive(Clone, Debug, PartialEq)]
665pub struct VerifyingKey<P: ParameterSet> {
666    rho: B32,
667    t1: Vector<P::K>,
668
669    // Derived values
670    A_hat: NttMatrix<P::K, P::L>,
671    t1_2d_hat: NttVector<P::K>,
672    tr: B64,
673}
674
675impl<P: MlDsaParams> VerifyingKey<P> {
676    fn new(
677        rho: B32,
678        t1: Vector<P::K>,
679        A_hat: Option<NttMatrix<P::K, P::L>>,
680        enc: Option<EncodedVerifyingKey<P>>,
681    ) -> Self {
682        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
683        let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
684
685        let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
686        let tr: B64 = H::default().absorb(&enc).squeeze_new();
687
688        Self {
689            rho,
690            t1,
691            A_hat,
692            t1_2d_hat,
693            tr,
694        }
695    }
696
697    /// Computes µ according to FIPS 204 for use in ML-DSA.Sign and ML-DSA.Verify.
698    ///
699    /// # Errors
700    ///
701    /// Returns [`Error`] if the given `Mp` returns one.
702    pub fn compute_mu<F: FnOnce(&mut Shake256) -> Result<(), Error>>(
703        &self,
704        Mp: F,
705        ctx: &[u8],
706    ) -> Result<B64, Error> {
707        let mut mu = MuBuilder::new(&self.tr, ctx);
708        Mp(mu.as_mut())?;
709        Ok(mu.finish())
710    }
711
712    /// This algorithm reflects the ML-DSA.Verify_internal algorithm from FIPS 204.  It does not
713    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
714    /// and it does not separate the context string from the rest of the message.
715    // Algorithm 8 ML-DSA.Verify_internal
716    pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
717    where
718        P: MlDsaParams,
719    {
720        let mu = MuBuilder::internal(&self.tr, &[M]);
721        self.raw_verify_mu(&mu, sigma)
722    }
723
724    fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
725    where
726        P: MlDsaParams,
727    {
728        // Reconstruct w
729        let c = sample_in_ball(&sigma.c_tilde, P::TAU);
730
731        let z_hat = sigma.z.ntt();
732        let c_hat = c.ntt();
733        let Az_hat = &self.A_hat * &z_hat;
734        let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
735
736        let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
737        let w1p = sigma.h.use_hint(&wp_approx);
738
739        let w1p_tilde = P::encode_w1(&w1p);
740        let cp_tilde = H::default()
741            .absorb(mu)
742            .absorb(&w1p_tilde)
743            .squeeze_new::<P::Lambda>();
744
745        sigma.c_tilde == cp_tilde
746    }
747
748    /// This algorithm reflects the ML-DSA.Verify algorithm from FIPS 204.
749    // Algorithm 3 ML-DSA.Verify
750    pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
751        self.raw_verify_with_context(&[M], ctx, sigma)
752    }
753
754    /// This algorithm reflects the ML-DSA.Verify algorithm with a pre-computed μ from FIPS 204.
755    // Algorithm 3 ML-DSA.Verify (optional pre-computed μ variant)
756    pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
757        self.raw_verify_mu(mu, sigma)
758    }
759
760    fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
761        if ctx.len() > 255 {
762            return false;
763        }
764
765        let mu = MuBuilder::new(&self.tr, ctx).message(M);
766        self.verify_mu(&mu, sigma)
767    }
768
769    fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
770        let t1_enc = P::encode_t1(t1);
771        P::concat_vk(rho.clone(), t1_enc)
772    }
773
774    /// Encode the key in a fixed-size byte array.
775    // Algorithm 22 pkEncode
776    pub fn encode(&self) -> EncodedVerifyingKey<P> {
777        Self::encode_internal(&self.rho, &self.t1)
778    }
779
780    /// Decode the key from an appropriately sized byte array.
781    // Algorithm 23 pkDecode
782    pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
783        let (rho, t1_enc) = P::split_vk(enc);
784        let t1 = P::decode_t1(t1_enc);
785        Self::new(rho.clone(), t1, None, Some(enc.clone()))
786    }
787}
788
789impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
790    fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
791        self.multipart_verify(&[msg], signature)
792    }
793}
794
795impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
796    fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
797        self.raw_verify_with_context(msg, &[], signature)
798            .then_some(())
799            .ok_or(Error::new())
800    }
801}
802
803impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
804    fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
805        &self,
806        f: F,
807        signature: &Signature<P>,
808    ) -> Result<(), Error> {
809        let mut mu = MuBuilder::new(&self.tr, &[]);
810        f(mu.as_mut())?;
811        let mu = mu.finish();
812
813        self.raw_verify_mu(&mu, signature)
814            .then_some(())
815            .ok_or(Error::new())
816    }
817}
818
819/// `MlDsa44` is the parameter set for security category 2.
820#[derive(Default, Clone, Debug, PartialEq)]
821pub struct MlDsa44;
822
823impl ParameterSet for MlDsa44 {
824    type K = U4;
825    type L = U4;
826    type Eta = U2;
827    type Gamma1 = Shleft<U1, U17>;
828    type Gamma2 = Quot<QMinus1, U88>;
829    type TwoGamma2 = Prod<U2, Self::Gamma2>;
830    type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
831    type Lambda = U32;
832    type Omega = U80;
833    const TAU: usize = 39;
834}
835
836/// `MlDsa65` is the parameter set for security category 3.
837#[derive(Default, Clone, Debug, PartialEq)]
838pub struct MlDsa65;
839
840impl ParameterSet for MlDsa65 {
841    type K = U6;
842    type L = U5;
843    type Eta = U4;
844    type Gamma1 = Shleft<U1, U19>;
845    type Gamma2 = Quot<QMinus1, U32>;
846    type TwoGamma2 = Prod<U2, Self::Gamma2>;
847    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
848    type Lambda = U48;
849    type Omega = U55;
850    const TAU: usize = 49;
851}
852
853/// `MlDsa87` is the parameter set for security category 5.
854#[derive(Default, Clone, Debug, PartialEq)]
855pub struct MlDsa87;
856
857impl ParameterSet for MlDsa87 {
858    type K = U8;
859    type L = U7;
860    type Eta = U2;
861    type Gamma1 = Shleft<U1, U19>;
862    type Gamma2 = Quot<QMinus1, U32>;
863    type TwoGamma2 = Prod<U2, Self::Gamma2>;
864    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
865    type Lambda = U64;
866    type Omega = U75;
867    const TAU: usize = 60;
868}
869
870/// A parameter set that knows how to generate key pairs
871pub trait KeyGen: MlDsaParams {
872    /// The type that is returned by key generation
873    type KeyPair: signature::Keypair;
874
875    /// Generate a signing key pair from the specified RNG
876    #[cfg(feature = "rand_core")]
877    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
878
879    /// Deterministically generate a signing key pair from the specified seed
880    ///
881    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
882    fn from_seed(xi: &B32) -> Self::KeyPair;
883}
884
885impl<P> KeyGen for P
886where
887    P: MlDsaParams,
888{
889    type KeyPair = KeyPair<P>;
890
891    /// Generate a signing key pair from the specified RNG
892    // Algorithm 1 ML-DSA.KeyGen()
893    #[cfg(feature = "rand_core")]
894    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
895        let mut xi = B32::default();
896        rng.fill_bytes(&mut xi);
897        Self::from_seed(&xi)
898    }
899
900    /// Deterministically generate a signing key pair from the specified seed
901    ///
902    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
903    // Algorithm 6 ML-DSA.KeyGen_internal
904    fn from_seed(xi: &Seed) -> KeyPair<P>
905    where
906        P: MlDsaParams,
907    {
908        // Derive seeds
909        let mut h = H::default()
910            .absorb(xi)
911            .absorb(&[P::K::U8])
912            .absorb(&[P::L::U8]);
913
914        let rho: B32 = h.squeeze_new();
915        let rhop: B64 = h.squeeze_new();
916        let K: B32 = h.squeeze_new();
917
918        // Sample private key components
919        let A_hat = expand_a::<P::K, P::L>(&rho);
920        let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
921        let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
922
923        // Compute derived values
924        let As1_hat = &A_hat * &s1.ntt();
925        let t = &As1_hat.ntt_inverse() + &s2;
926
927        // Compress and encode
928        let (t1, t0) = t.power2round();
929
930        let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
931        let signing_key =
932            SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
933
934        KeyPair {
935            signing_key,
936            verifying_key,
937            seed: xi.clone(),
938        }
939    }
940}
941
942#[cfg(test)]
943mod test {
944    use super::*;
945    use crate::param::*;
946
947    #[test]
948    fn output_sizes() {
949        //           priv pub  sig
950        // ML-DSA-44 2560 1312 2420
951        // ML-DSA-65 4032 1952 3309
952        // ML-DSA-87 4896 2592 4627
953        assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
954        assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
955        assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
956
957        assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
958        assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
959        assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
960
961        assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
962        assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
963        assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
964    }
965
966    fn encode_decode_round_trip_test<P>()
967    where
968        P: MlDsaParams + PartialEq,
969    {
970        let seed = Array::default();
971        let kp = P::from_seed(&seed);
972        assert_eq!(kp.to_seed(), seed);
973
974        let sk = kp.signing_key;
975        let vk = kp.verifying_key;
976
977        let vk_bytes = vk.encode();
978        let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
979        assert!(vk == vk2);
980
981        #[allow(deprecated)]
982        {
983            let sk_bytes = sk.to_expanded();
984            let sk2 = SigningKey::<P>::from_expanded(&sk_bytes);
985            assert!(sk == sk2);
986
987            let M = b"Hello world";
988            let rnd = Array([0u8; 32]);
989            let sig = sk.sign_internal(&[M], &rnd);
990            let sig_bytes = sig.encode();
991            let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
992            assert!(sig == sig2);
993        }
994    }
995
996    #[test]
997    fn encode_decode_round_trip() {
998        encode_decode_round_trip_test::<MlDsa44>();
999        encode_decode_round_trip_test::<MlDsa65>();
1000        encode_decode_round_trip_test::<MlDsa87>();
1001    }
1002
1003    fn public_from_private_test<P>()
1004    where
1005        P: MlDsaParams + PartialEq,
1006    {
1007        let kp = P::from_seed(&Array::default());
1008        let sk = kp.signing_key;
1009        let vk = kp.verifying_key;
1010        let vk_derived = sk.verifying_key();
1011
1012        assert!(vk == vk_derived);
1013    }
1014
1015    #[test]
1016    fn public_from_private() {
1017        public_from_private_test::<MlDsa44>();
1018        public_from_private_test::<MlDsa65>();
1019        public_from_private_test::<MlDsa87>();
1020    }
1021
1022    fn sign_verify_round_trip_test<P>()
1023    where
1024        P: MlDsaParams,
1025    {
1026        let kp = P::from_seed(&Array::default());
1027        let sk = kp.signing_key;
1028        let vk = kp.verifying_key;
1029
1030        let M = b"Hello world";
1031        let rnd = Array([0u8; 32]);
1032        let sig = sk.sign_internal(&[M], &rnd);
1033
1034        assert!(vk.verify_internal(M, &sig));
1035    }
1036
1037    #[test]
1038    fn sign_verify_round_trip() {
1039        sign_verify_round_trip_test::<MlDsa44>();
1040        sign_verify_round_trip_test::<MlDsa65>();
1041        sign_verify_round_trip_test::<MlDsa87>();
1042    }
1043
1044    #[test]
1045    fn sign_mu_verify_mu_round_trip() {
1046        fn sign_mu_verify_mu<P>()
1047        where
1048            P: MlDsaParams,
1049        {
1050            let kp = P::from_seed(&Array::default());
1051            let sk = kp.signing_key;
1052            let vk = kp.verifying_key;
1053
1054            let M = b"Hello world";
1055            let rnd = Array([0u8; 32]);
1056            let mu = MuBuilder::internal(&sk.tr, &[M]);
1057            let sig = sk.raw_sign_mu(&mu, &rnd);
1058
1059            assert!(vk.raw_verify_mu(&mu, &sig));
1060        }
1061        sign_mu_verify_mu::<MlDsa44>();
1062        sign_mu_verify_mu::<MlDsa65>();
1063        sign_mu_verify_mu::<MlDsa87>();
1064    }
1065
1066    #[test]
1067    fn sign_mu_verify_internal_round_trip() {
1068        fn sign_mu_verify_internal<P>()
1069        where
1070            P: MlDsaParams,
1071        {
1072            let kp = P::from_seed(&Array::default());
1073            let sk = kp.signing_key;
1074            let vk = kp.verifying_key;
1075
1076            let M = b"Hello world";
1077            let rnd = Array([0u8; 32]);
1078            let mu = MuBuilder::internal(&sk.tr, &[M]);
1079            let sig = sk.raw_sign_mu(&mu, &rnd);
1080
1081            assert!(vk.verify_internal(M, &sig));
1082        }
1083        sign_mu_verify_internal::<MlDsa44>();
1084        sign_mu_verify_internal::<MlDsa65>();
1085        sign_mu_verify_internal::<MlDsa87>();
1086    }
1087
1088    #[test]
1089    fn sign_internal_verify_mu_round_trip() {
1090        fn sign_internal_verify_mu<P>()
1091        where
1092            P: MlDsaParams,
1093        {
1094            let kp = P::from_seed(&Array::default());
1095            let sk = kp.signing_key;
1096            let vk = kp.verifying_key;
1097
1098            let M = b"Hello world";
1099            let rnd = Array([0u8; 32]);
1100            let mu = MuBuilder::internal(&sk.tr, &[M]);
1101            let sig = sk.sign_internal(&[M], &rnd);
1102
1103            assert!(vk.raw_verify_mu(&mu, &sig));
1104        }
1105        sign_internal_verify_mu::<MlDsa44>();
1106        sign_internal_verify_mu::<MlDsa65>();
1107        sign_internal_verify_mu::<MlDsa87>();
1108    }
1109
1110    #[test]
1111    fn from_seed_implementations_match() {
1112        fn assert_from_seed_equality<P>()
1113        where
1114            P: MlDsaParams,
1115        {
1116            let seed = Seed::default();
1117            let kp1 = P::from_seed(&seed);
1118            let sk1 = SigningKey::<P>::from_seed(&seed);
1119            let vk1 = sk1.verifying_key();
1120            assert_eq!(kp1.signing_key, sk1);
1121            assert_eq!(kp1.verifying_key, vk1);
1122        }
1123        assert_from_seed_equality::<MlDsa44>();
1124        assert_from_seed_equality::<MlDsa65>();
1125        assert_from_seed_equality::<MlDsa87>();
1126    }
1127
1128    #[test]
1129    fn to_seed_returns_correct_seed() {
1130        fn test_to_seed<P: MlDsaParams>() {
1131            let seed = Array([
1132                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
1133                24, 25, 26, 27, 28, 29, 30, 31, 32,
1134            ]);
1135            let kp = P::from_seed(&seed);
1136            assert_eq!(kp.to_seed(), seed);
1137        }
1138        test_to_seed::<MlDsa44>();
1139        test_to_seed::<MlDsa65>();
1140        test_to_seed::<MlDsa87>();
1141    }
1142
1143    #[test]
1144    fn verification_rejects_invalid_signature() {
1145        fn test_invalid_sig<P: MlDsaParams>() {
1146            let kp = P::from_seed(&Array::default());
1147            let vk = kp.verifying_key();
1148
1149            let msg = b"Hello world";
1150            let rnd = Array([0u8; 32]);
1151            let mut sig = kp.signing_key().sign_internal(&[msg], &rnd);
1152            sig.c_tilde[0] ^= 0xFF;
1153
1154            assert!(!vk.verify_with_context(msg, &[], &sig));
1155        }
1156        test_invalid_sig::<MlDsa44>();
1157        test_invalid_sig::<MlDsa65>();
1158        test_invalid_sig::<MlDsa87>();
1159    }
1160
1161    #[test]
1162    fn verification_rejects_wrong_message() {
1163        fn test_wrong_msg<P: MlDsaParams>() {
1164            let kp = P::from_seed(&Array::default());
1165            let vk = kp.verifying_key();
1166
1167            let msg1 = b"Hello world";
1168            let msg2 = b"Wrong message";
1169            let rnd = Array([0u8; 32]);
1170            let sig = kp.signing_key().sign_internal(&[msg1], &rnd);
1171
1172            assert!(!vk.verify_with_context(msg2, &[], &sig));
1173        }
1174        test_wrong_msg::<MlDsa44>();
1175        test_wrong_msg::<MlDsa65>();
1176        test_wrong_msg::<MlDsa87>();
1177    }
1178
1179    #[test]
1180    fn context_length_validation() {
1181        fn test_ctx_length<P: MlDsaParams>() {
1182            let kp = P::from_seed(&Array::default());
1183            let sk = kp.signing_key();
1184            let vk = kp.verifying_key();
1185
1186            let msg = b"Hello world";
1187            let long_ctx = [0u8; 256];
1188            let short_ctx = [0u8; 255];
1189
1190            assert!(sk.sign_deterministic(msg, &long_ctx).is_err());
1191
1192            let sig = sk.sign_deterministic(msg, &short_ctx).unwrap();
1193            assert!(!vk.verify_with_context(msg, &long_ctx, &sig));
1194            assert!(vk.verify_with_context(msg, &short_ctx, &sig));
1195        }
1196        test_ctx_length::<MlDsa44>();
1197        test_ctx_length::<MlDsa65>();
1198        test_ctx_length::<MlDsa87>();
1199    }
1200
1201    #[test]
1202    fn derived_verifying_key_validates_signatures() {
1203        fn test_derived_vk<P: MlDsaParams>() {
1204            let seed = Array([42u8; 32]);
1205            let kp = P::from_seed(&seed);
1206            let sk = kp.signing_key();
1207            let derived_vk = sk.verifying_key();
1208
1209            let msg = b"Test message for derived key";
1210            let rnd = Array([0u8; 32]);
1211            let sig = sk.sign_internal(&[msg], &rnd);
1212
1213            assert!(derived_vk.verify_internal(msg, &sig));
1214            assert_eq!(derived_vk.encode(), kp.verifying_key().encode());
1215        }
1216        test_derived_vk::<MlDsa44>();
1217        test_derived_vk::<MlDsa65>();
1218        test_derived_vk::<MlDsa87>();
1219    }
1220
1221    #[test]
1222    #[cfg(feature = "alloc")]
1223    fn debug_implementations() {
1224        extern crate alloc;
1225        use core::fmt::Write;
1226
1227        fn test_debug<P: MlDsaParams>() {
1228            let kp = P::from_seed(&Array::default());
1229
1230            let mut kp_debug = alloc::string::String::new();
1231            write!(&mut kp_debug, "{:?}", kp).unwrap();
1232            assert!(kp_debug.contains("KeyPair"));
1233
1234            let mut sk_debug = alloc::string::String::new();
1235            write!(&mut sk_debug, "{:?}", kp.signing_key()).unwrap();
1236            assert!(sk_debug.contains("SigningKey"));
1237        }
1238        test_debug::<MlDsa44>();
1239        test_debug::<MlDsa65>();
1240        test_debug::<MlDsa87>();
1241    }
1242}