ml_dsa/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_cfg))]
8#![warn(clippy::pedantic)] // Be pedantic by default
9#![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%`
10#![warn(clippy::as_conversions)] // Use proper conversions, not `as`
11#![allow(non_snake_case)] // Allow notation matching the spec
12#![allow(clippy::similar_names)] // Allow notation matching the spec
13#![allow(clippy::many_single_char_names)] // Allow notation matching the spec
14#![allow(clippy::clone_on_copy)] // Be explicit about moving data
15#![deny(missing_docs)] // Require all public interfaces to be documented
16#![warn(unreachable_pub)] // Prevent unexpected interface changes
17
18//! # Quickstart
19//!
20//! ```
21//! # #[cfg(feature = "rand_core")]
22//! # {
23//! use ml_dsa::{MlDsa65, KeyGen, signature::{Keypair, Signer, Verifier}};
24//!
25//! let mut rng = rand::rng();
26//! let kp = MlDsa65::key_gen(&mut rng);
27//!
28//! let msg = b"Hello world";
29//! let sig = kp.signing_key().sign(msg);
30//!
31//! assert!(kp.verifying_key().verify(msg, &sig).is_ok());
32//! # }
33//! ```
34
35mod algebra;
36mod crypto;
37mod encode;
38mod hint;
39mod ntt;
40mod param;
41mod pkcs8;
42mod sampling;
43mod util;
44
45// TODO(RLB) Move module to an independent crate shared with ml_kem
46mod module_lattice;
47
48use core::convert::{AsRef, TryFrom, TryInto};
49use hybrid_array::{
50    Array,
51    typenum::{
52        Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
53        U75, U80, U88, Unsigned,
54    },
55};
56use sha3::Shake256;
57use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
58
59#[cfg(feature = "rand_core")]
60use {
61    rand_core::{CryptoRng, TryCryptoRng},
62    signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner},
63};
64
65#[cfg(feature = "zeroize")]
66use zeroize::{Zeroize, ZeroizeOnDrop};
67
68use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
69use crate::crypto::H;
70use crate::hint::Hint;
71use crate::ntt::{Ntt, NttInverse};
72use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
73use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
74use crate::util::B64;
75use core::fmt;
76
77pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
78pub use crate::util::B32;
79pub use signature::{self, Error};
80
81/// ML-DSA seeds are signing (private) keys, which are consistently 32-bytes across all security
82/// levels, and are the preferred serialization for representing such keys.
83pub type Seed = B32;
84
85/// An ML-DSA signature
86#[derive(Clone, PartialEq, Debug)]
87pub struct Signature<P: MlDsaParams> {
88    c_tilde: Array<u8, P::Lambda>,
89    z: Vector<P::L>,
90    h: Hint<P>,
91}
92
93impl<P: MlDsaParams> Signature<P> {
94    /// Encode the signature in a fixed-size byte array.
95    // Algorithm 26 sigEncode
96    pub fn encode(&self) -> EncodedSignature<P> {
97        let c_tilde = self.c_tilde.clone();
98        let z = P::encode_z(&self.z);
99        let h = self.h.bit_pack();
100        P::concat_sig(c_tilde, z, h)
101    }
102
103    /// Decode the signature from an appropriately sized byte array.
104    // Algorithm 27 sigDecode
105    pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
106        let (c_tilde, z, h) = P::split_sig(enc);
107
108        let c_tilde = c_tilde.clone();
109        let z = P::decode_z(z);
110        let h = Hint::bit_unpack(h)?;
111
112        if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
113            return None;
114        }
115
116        Some(Self { c_tilde, z, h })
117    }
118}
119
120impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
121    type Error = Error;
122
123    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
124        let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
125        Self::decode(&enc).ok_or(Error::new())
126    }
127}
128
129impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
130    type Error = Error;
131
132    fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
133        Ok(self.encode())
134    }
135}
136
137impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
138    type Repr = EncodedSignature<P>;
139}
140
141struct MuBuilder(H);
142
143impl MuBuilder {
144    fn new(tr: &[u8], ctx: &[u8]) -> Self {
145        let mut h = H::default();
146        h = h.absorb(tr);
147        h = h.absorb(&[0]);
148        h = h.absorb(&[Truncate::truncate(ctx.len())]);
149        h = h.absorb(ctx);
150
151        Self(h)
152    }
153
154    fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
155        let mut h = H::default().absorb(tr);
156
157        for m in Mp {
158            h = h.absorb(m);
159        }
160
161        h.squeeze_new()
162    }
163
164    fn message(mut self, M: &[&[u8]]) -> B64 {
165        for m in M {
166            self.0 = self.0.absorb(m);
167        }
168
169        self.0.squeeze_new()
170    }
171
172    fn finish(mut self) -> B64 {
173        self.0.squeeze_new()
174    }
175}
176
177impl AsMut<Shake256> for MuBuilder {
178    fn as_mut(&mut self) -> &mut Shake256 {
179        self.0.updatable()
180    }
181}
182
183/// An ML-DSA key pair
184pub struct KeyPair<P: MlDsaParams> {
185    /// The signing key of the key pair
186    signing_key: SigningKey<P>,
187
188    /// The verifying key of the key pair
189    verifying_key: VerifyingKey<P>,
190
191    /// The seed this signing key was derived from
192    seed: B32,
193}
194
195impl<P: MlDsaParams> KeyPair<P> {
196    /// The signing key of the key pair
197    pub fn signing_key(&self) -> &SigningKey<P> {
198        &self.signing_key
199    }
200
201    /// The verifying key of the key pair
202    pub fn verifying_key(&self) -> &VerifyingKey<P> {
203        &self.verifying_key
204    }
205
206    /// Serialize the [`Seed`] value: 32-bytes which can be used to reconstruct the
207    /// [`KeyPair`].
208    ///
209    /// # ⚠️ Warning!
210    ///
211    /// This value is key material. Please treat it with care.
212    #[inline]
213    pub fn to_seed(&self) -> Seed {
214        self.seed
215    }
216}
217
218impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
219    fn as_ref(&self) -> &VerifyingKey<P> {
220        &self.verifying_key
221    }
222}
223
224impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("KeyPair")
227            .field("verifying_key", &self.verifying_key)
228            .finish_non_exhaustive()
229    }
230}
231
232impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
233    type VerifyingKey = VerifyingKey<P>;
234}
235
236/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
237/// only supports signing with an empty context string.
238impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
239    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
240        self.try_multipart_sign(&[msg])
241    }
242}
243
244/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
245/// only supports signing with an empty context string.
246impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
247    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
248        self.signing_key.raw_sign_deterministic(msg, &[])
249    }
250}
251
252/// The `DigestSigner` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA
253/// with a pre-computed μ, and only supports signing with an empty context string.
254impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
255    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
256        &self,
257        f: F,
258    ) -> Result<Signature<P>, Error> {
259        self.signing_key.try_sign_digest(&f)
260    }
261}
262
263/// An ML-DSA signing key
264#[derive(Clone, PartialEq)]
265pub struct SigningKey<P: MlDsaParams> {
266    rho: B32,
267    K: B32,
268    tr: B64,
269    s1: Vector<P::L>,
270    s2: Vector<P::K>,
271    t0: Vector<P::K>,
272
273    // Derived values
274    s1_hat: NttVector<P::L>,
275    s2_hat: NttVector<P::K>,
276    t0_hat: NttVector<P::K>,
277    A_hat: NttMatrix<P::K, P::L>,
278}
279
280impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        f.debug_struct("SigningKey").finish_non_exhaustive()
283    }
284}
285
286#[cfg(feature = "zeroize")]
287impl<P: MlDsaParams> Drop for SigningKey<P> {
288    fn drop(&mut self) {
289        self.rho.zeroize();
290        self.K.zeroize();
291        self.tr.zeroize();
292        self.s1.zeroize();
293        self.s2.zeroize();
294        self.t0.zeroize();
295    }
296}
297
298#[cfg(feature = "zeroize")]
299impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
300
301impl<P: MlDsaParams> SigningKey<P> {
302    fn new(
303        rho: B32,
304        K: B32,
305        tr: B64,
306        s1: Vector<P::L>,
307        s2: Vector<P::K>,
308        t0: Vector<P::K>,
309        A_hat: Option<NttMatrix<P::K, P::L>>,
310    ) -> Self {
311        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
312        let s1_hat = s1.ntt();
313        let s2_hat = s2.ntt();
314        let t0_hat = t0.ntt();
315
316        Self {
317            rho,
318            K,
319            tr,
320            s1,
321            s2,
322            t0,
323
324            s1_hat,
325            s2_hat,
326            t0_hat,
327            A_hat,
328        }
329    }
330
331    /// Deterministically generate a signing key from the specified seed.
332    ///
333    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204, but only returns a
334    /// signing key.
335    #[must_use]
336    pub fn from_seed(seed: &Seed) -> Self {
337        let kp = P::from_seed(seed);
338        kp.signing_key
339    }
340
341    /// This method reflects the ML-DSA.Sign_internal algorithm from FIPS 204. It does not
342    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
343    /// and it does not separate the context string from the rest of the message.
344    // Algorithm 7 ML-DSA.Sign_internal
345    // TODO(RLB) Only expose based on a feature.  Tests need access, but normal code shouldn't.
346    pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
347    where
348        P: MlDsaParams,
349    {
350        let mu = MuBuilder::internal(&self.tr, Mp);
351        self.raw_sign_mu(&mu, rnd)
352    }
353
354    fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
355    where
356        P: MlDsaParams,
357    {
358        // Compute the private random seed
359        let rhopp: B64 = H::default()
360            .absorb(&self.K)
361            .absorb(rnd)
362            .absorb(mu)
363            .squeeze_new();
364
365        // Rejection sampling loop
366        for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
367            let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
368            let w = (&self.A_hat * &y.ntt()).ntt_inverse();
369            let w1 = w.high_bits::<P::TwoGamma2>();
370
371            let w1_tilde = P::encode_w1(&w1);
372            let c_tilde = H::default()
373                .absorb(mu)
374                .absorb(&w1_tilde)
375                .squeeze_new::<P::Lambda>();
376            let c = sample_in_ball(&c_tilde, P::TAU);
377            let c_hat = c.ntt();
378
379            let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
380            let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
381
382            let z = &y + &cs1;
383            let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
384
385            if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
386                || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
387            {
388                continue;
389            }
390
391            let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
392            let minus_ct0 = -&ct0;
393            let w_cs2_ct0 = &(&w - &cs2) + &ct0;
394            let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
395
396            if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
397                continue;
398            }
399
400            let z = z.mod_plus_minus::<SpecQ>();
401            return Signature { c_tilde, z, h };
402        }
403
404        unreachable!("Rejection sampling failed to find a valid signature");
405    }
406
407    /// This method reflects the randomized ML-DSA.Sign algorithm.
408    ///
409    /// # Errors
410    ///
411    /// This method will return an opaque error if the context string is more than 255 bytes long,
412    /// or if it fails to get enough randomness.
413    // Algorithm 2 ML-DSA.Sign
414    #[cfg(feature = "rand_core")]
415    pub fn sign_randomized<R: TryCryptoRng + ?Sized>(
416        &self,
417        M: &[u8],
418        ctx: &[u8],
419        rng: &mut R,
420    ) -> Result<Signature<P>, Error> {
421        self.raw_sign_randomized(&[M], ctx, rng)
422    }
423
424    #[cfg(feature = "rand_core")]
425    fn raw_sign_randomized<R: TryCryptoRng + ?Sized>(
426        &self,
427        Mp: &[&[u8]],
428        ctx: &[u8],
429        rng: &mut R,
430    ) -> Result<Signature<P>, Error> {
431        if ctx.len() > 255 {
432            return Err(Error::new());
433        }
434
435        let mut rnd = B32::default();
436        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
437
438        let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
439        Ok(self.raw_sign_mu(&mu, &rnd))
440    }
441
442    /// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
443    ///
444    /// # Errors
445    ///
446    /// This method can return an opaque error if it fails to get enough randomness.
447    // Algorithm 2 ML-DSA.Sign (optional pre-computed μ variant)
448    #[cfg(feature = "rand_core")]
449    pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
450        &self,
451        mu: &B64,
452        rng: &mut R,
453    ) -> Result<Signature<P>, Error> {
454        let mut rnd = B32::default();
455        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
456
457        Ok(self.raw_sign_mu(mu, &rnd))
458    }
459
460    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
461    ///
462    /// # Errors
463    ///
464    /// This method will return an opaque error if the context string is more than 255 bytes long.
465    // Algorithm 2 ML-DSA.Sign (optional deterministic variant)
466    pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
467        self.raw_sign_deterministic(&[M], ctx)
468    }
469
470    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm with a
471    /// pre-computed μ.
472    // Algorithm 2 ML-DSA.Sign (optional deterministic and pre-computed μ variant)
473    pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
474        let rnd = B32::default();
475        self.raw_sign_mu(mu, &rnd)
476    }
477
478    fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
479        if ctx.len() > 255 {
480            return Err(Error::new());
481        }
482
483        let mu = MuBuilder::new(&self.tr, ctx).message(Mp);
484        Ok(self.sign_mu_deterministic(&mu))
485    }
486
487    /// Encode the key in a fixed-size byte array.
488    // Algorithm 24 skEncode
489    pub fn encode(&self) -> EncodedSigningKey<P>
490    where
491        P: MlDsaParams,
492    {
493        let s1_enc = P::encode_s1(&self.s1);
494        let s2_enc = P::encode_s2(&self.s2);
495        let t0_enc = P::encode_t0(&self.t0);
496        P::concat_sk(
497            self.rho.clone(),
498            self.K.clone(),
499            self.tr.clone(),
500            s1_enc,
501            s2_enc,
502            t0_enc,
503        )
504    }
505
506    /// Decode the key from an appropriately sized byte array.
507    // Algorithm 25 skDecode
508    pub fn decode(enc: &EncodedSigningKey<P>) -> Self
509    where
510        P: MlDsaParams,
511    {
512        let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
513        Self::new(
514            rho.clone(),
515            K.clone(),
516            tr.clone(),
517            P::decode_s1(s1_enc),
518            P::decode_s2(s2_enc),
519            P::decode_t0(t0_enc),
520            None,
521        )
522    }
523
524    /// This auxiliary function derives a `VerifyingKey` from a bare
525    /// `SigningKey` (even in the absence of the original seed).
526    ///
527    /// This is a utility function that is useful when importing the private key
528    /// from an external source which does not export the seed and does not
529    /// provide the precomputed public key associated with the private key
530    /// itself.
531    ///
532    /// `SigningKey` implements `signature::Keypair`: this inherent method is
533    /// retained for convenience, so it is available for callers even when the
534    /// `signature::Keypair` trait is out-of-scope.
535    pub fn verifying_key(&self) -> VerifyingKey<P> {
536        let kp: &dyn signature::Keypair<VerifyingKey = VerifyingKey<P>> = self;
537
538        kp.verifying_key()
539    }
540}
541
542/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
543/// only supports signing with an empty context string.  If you would like to include a context
544/// string, use the [`SigningKey::sign_deterministic`] method.
545impl<P: MlDsaParams> Signer<Signature<P>> for SigningKey<P> {
546    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
547        self.try_multipart_sign(&[msg])
548    }
549}
550
551/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
552/// only supports signing with an empty context string. If you would like to include a context
553/// string, use the [`SigningKey::sign_deterministic`] method.
554impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SigningKey<P> {
555    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
556        self.raw_sign_deterministic(msg, &[])
557    }
558}
559
560/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA
561/// with a pre-computed µ, and only supports signing with an empty context string. If you would
562/// like to include a context string, use the [`SigningKey::sign_mu_deterministic`] method.
563impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
564    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
565        &self,
566        f: F,
567    ) -> Result<Signature<P>, Error> {
568        let mut mu = MuBuilder::new(&self.tr, &[]);
569        f(mu.as_mut())?;
570        let mu = mu.finish();
571
572        Ok(self.sign_mu_deterministic(&mu))
573    }
574}
575
576/// The `KeyPair` implementation for `SigningKey` allows to derive a `VerifyingKey` from
577/// a bare `SigningKey` (even in the absence of the original seed).
578impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
579    type VerifyingKey = VerifyingKey<P>;
580
581    /// This is a utility function that is useful when importing the private key
582    /// from an external source which does not export the seed and does not
583    /// provide the precomputed public key associated with the private key
584    /// itself.
585    fn verifying_key(&self) -> Self::VerifyingKey {
586        let As1 = &self.A_hat * &self.s1_hat;
587        let t = &As1.ntt_inverse() + &self.s2;
588
589        /* Discard t0 */
590        let (t1, _) = t.power2round();
591
592        VerifyingKey::new(self.rho.clone(), t1, Some(self.A_hat.clone()), None)
593    }
594}
595
596/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
597/// context string. If you would like to include a context string, use the
598/// [`SigningKey::sign_randomized`] method.
599#[cfg(feature = "rand_core")]
600impl<P: MlDsaParams> RandomizedSigner<Signature<P>> for SigningKey<P> {
601    fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
602        &self,
603        rng: &mut R,
604        msg: &[u8],
605    ) -> Result<Signature<P>, Error> {
606        self.try_multipart_sign_with_rng(rng, &[msg])
607    }
608}
609
610/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
611/// context string. If you would like to include a context string, use the
612/// [`SigningKey::sign_randomized`] method.
613#[cfg(feature = "rand_core")]
614impl<P: MlDsaParams> RandomizedMultipartSigner<Signature<P>> for SigningKey<P> {
615    fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
616        &self,
617        rng: &mut R,
618        msg: &[&[u8]],
619    ) -> Result<Signature<P>, Error> {
620        self.raw_sign_randomized(msg, &[], rng)
621    }
622}
623
624/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
625/// context string. If you would like to include a context string, use the
626/// [`SigningKey::sign_mu_randomized`] method.
627#[cfg(feature = "rand_core")]
628impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningKey<P> {
629    fn try_sign_digest_with_rng<
630        R: TryCryptoRng + ?Sized,
631        F: Fn(&mut Shake256) -> Result<(), Error>,
632    >(
633        &self,
634        rng: &mut R,
635        f: F,
636    ) -> Result<Signature<P>, Error> {
637        let mut mu = MuBuilder::new(&self.tr, &[]);
638        f(mu.as_mut())?;
639        let mu = mu.finish();
640
641        self.sign_mu_randomized(&mu, rng)
642    }
643}
644
645/// An ML-DSA verification key
646#[derive(Clone, Debug, PartialEq)]
647pub struct VerifyingKey<P: ParameterSet> {
648    rho: B32,
649    t1: Vector<P::K>,
650
651    // Derived values
652    A_hat: NttMatrix<P::K, P::L>,
653    t1_2d_hat: NttVector<P::K>,
654    tr: B64,
655}
656
657impl<P: MlDsaParams> VerifyingKey<P> {
658    fn new(
659        rho: B32,
660        t1: Vector<P::K>,
661        A_hat: Option<NttMatrix<P::K, P::L>>,
662        enc: Option<EncodedVerifyingKey<P>>,
663    ) -> Self {
664        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
665        let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
666
667        let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
668        let tr: B64 = H::default().absorb(&enc).squeeze_new();
669
670        Self {
671            rho,
672            t1,
673            A_hat,
674            t1_2d_hat,
675            tr,
676        }
677    }
678
679    /// Computes µ according to FIPS 204 for use in ML-DSA.Sign and ML-DSA.Verify.
680    ///
681    /// # Errors
682    ///
683    /// Returns [`Error`] if the given `Mp` returns one.
684    pub fn compute_mu<F: FnOnce(&mut Shake256) -> Result<(), Error>>(
685        &self,
686        Mp: F,
687        ctx: &[u8],
688    ) -> Result<B64, Error> {
689        let mut mu = MuBuilder::new(&self.tr, ctx);
690        Mp(mu.as_mut())?;
691        Ok(mu.finish())
692    }
693
694    /// This algorithm reflects the ML-DSA.Verify_internal algorithm from FIPS 204.  It does not
695    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
696    /// and it does not separate the context string from the rest of the message.
697    // Algorithm 8 ML-DSA.Verify_internal
698    pub fn verify_internal(&self, M: &[u8], sigma: &Signature<P>) -> bool
699    where
700        P: MlDsaParams,
701    {
702        let mu = MuBuilder::internal(&self.tr, &[M]);
703        self.raw_verify_mu(&mu, sigma)
704    }
705
706    fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
707    where
708        P: MlDsaParams,
709    {
710        // Reconstruct w
711        let c = sample_in_ball(&sigma.c_tilde, P::TAU);
712
713        let z_hat = sigma.z.ntt();
714        let c_hat = c.ntt();
715        let Az_hat = &self.A_hat * &z_hat;
716        let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
717
718        let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
719        let w1p = sigma.h.use_hint(&wp_approx);
720
721        let w1p_tilde = P::encode_w1(&w1p);
722        let cp_tilde = H::default()
723            .absorb(mu)
724            .absorb(&w1p_tilde)
725            .squeeze_new::<P::Lambda>();
726
727        sigma.c_tilde == cp_tilde
728    }
729
730    /// This algorithm reflects the ML-DSA.Verify algorithm from FIPS 204.
731    // Algorithm 3 ML-DSA.Verify
732    pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
733        self.raw_verify_with_context(&[M], ctx, sigma)
734    }
735
736    /// This algorithm reflects the ML-DSA.Verify algorithm with a pre-computed μ from FIPS 204.
737    // Algorithm 3 ML-DSA.Verify (optional pre-computed μ variant)
738    pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
739        self.raw_verify_mu(mu, sigma)
740    }
741
742    fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
743        if ctx.len() > 255 {
744            return false;
745        }
746
747        let mu = MuBuilder::new(&self.tr, ctx).message(M);
748        self.verify_mu(&mu, sigma)
749    }
750
751    fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
752        let t1_enc = P::encode_t1(t1);
753        P::concat_vk(rho.clone(), t1_enc)
754    }
755
756    /// Encode the key in a fixed-size byte array.
757    // Algorithm 22 pkEncode
758    pub fn encode(&self) -> EncodedVerifyingKey<P> {
759        Self::encode_internal(&self.rho, &self.t1)
760    }
761
762    /// Decode the key from an appropriately sized byte array.
763    // Algorithm 23 pkDecode
764    pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
765        let (rho, t1_enc) = P::split_vk(enc);
766        let t1 = P::decode_t1(t1_enc);
767        Self::new(rho.clone(), t1, None, Some(enc.clone()))
768    }
769}
770
771impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
772    fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
773        self.multipart_verify(&[msg], signature)
774    }
775}
776
777impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
778    fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
779        self.raw_verify_with_context(msg, &[], signature)
780            .then_some(())
781            .ok_or(Error::new())
782    }
783}
784
785impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
786    fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
787        &self,
788        f: F,
789        signature: &Signature<P>,
790    ) -> Result<(), Error> {
791        let mut mu = MuBuilder::new(&self.tr, &[]);
792        f(mu.as_mut())?;
793        let mu = mu.finish();
794
795        self.raw_verify_mu(&mu, signature)
796            .then_some(())
797            .ok_or(Error::new())
798    }
799}
800
801/// `MlDsa44` is the parameter set for security category 2.
802#[derive(Default, Clone, Debug, PartialEq)]
803pub struct MlDsa44;
804
805impl ParameterSet for MlDsa44 {
806    type K = U4;
807    type L = U4;
808    type Eta = U2;
809    type Gamma1 = Shleft<U1, U17>;
810    type Gamma2 = Quot<QMinus1, U88>;
811    type TwoGamma2 = Prod<U2, Self::Gamma2>;
812    type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
813    type Lambda = U32;
814    type Omega = U80;
815    const TAU: usize = 39;
816}
817
818/// `MlDsa65` is the parameter set for security category 3.
819#[derive(Default, Clone, Debug, PartialEq)]
820pub struct MlDsa65;
821
822impl ParameterSet for MlDsa65 {
823    type K = U6;
824    type L = U5;
825    type Eta = U4;
826    type Gamma1 = Shleft<U1, U19>;
827    type Gamma2 = Quot<QMinus1, U32>;
828    type TwoGamma2 = Prod<U2, Self::Gamma2>;
829    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
830    type Lambda = U48;
831    type Omega = U55;
832    const TAU: usize = 49;
833}
834
835/// `MlKem87` is the parameter set for security category 5.
836#[derive(Default, Clone, Debug, PartialEq)]
837pub struct MlDsa87;
838
839impl ParameterSet for MlDsa87 {
840    type K = U8;
841    type L = U7;
842    type Eta = U2;
843    type Gamma1 = Shleft<U1, U19>;
844    type Gamma2 = Quot<QMinus1, U32>;
845    type TwoGamma2 = Prod<U2, Self::Gamma2>;
846    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
847    type Lambda = U64;
848    type Omega = U75;
849    const TAU: usize = 60;
850}
851
852/// A parameter set that knows how to generate key pairs
853pub trait KeyGen: MlDsaParams {
854    /// The type that is returned by key generation
855    type KeyPair: signature::Keypair;
856
857    /// Generate a signing key pair from the specified RNG
858    #[cfg(feature = "rand_core")]
859    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
860
861    /// Deterministically generate a signing key pair from the specified seed
862    ///
863    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
864    fn from_seed(xi: &B32) -> Self::KeyPair;
865}
866
867impl<P> KeyGen for P
868where
869    P: MlDsaParams,
870{
871    type KeyPair = KeyPair<P>;
872
873    /// Generate a signing key pair from the specified RNG
874    // Algorithm 1 ML-DSA.KeyGen()
875    #[cfg(feature = "rand_core")]
876    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
877        let mut xi = B32::default();
878        rng.fill_bytes(&mut xi);
879        Self::from_seed(&xi)
880    }
881
882    /// Deterministically generate a signing key pair from the specified seed
883    ///
884    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
885    // Algorithm 6 ML-DSA.KeyGen_internal
886    fn from_seed(xi: &Seed) -> KeyPair<P>
887    where
888        P: MlDsaParams,
889    {
890        // Derive seeds
891        let mut h = H::default()
892            .absorb(xi)
893            .absorb(&[P::K::U8])
894            .absorb(&[P::L::U8]);
895
896        let rho: B32 = h.squeeze_new();
897        let rhop: B64 = h.squeeze_new();
898        let K: B32 = h.squeeze_new();
899
900        // Sample private key components
901        let A_hat = expand_a::<P::K, P::L>(&rho);
902        let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
903        let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
904
905        // Compute derived values
906        let As1_hat = &A_hat * &s1.ntt();
907        let t = &As1_hat.ntt_inverse() + &s2;
908
909        // Compress and encode
910        let (t1, t0) = t.power2round();
911
912        let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
913        let signing_key =
914            SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
915
916        KeyPair {
917            signing_key,
918            verifying_key,
919            seed: xi.clone(),
920        }
921    }
922}
923
924#[cfg(test)]
925mod test {
926    use super::*;
927    use crate::param::*;
928    use signature::digest::Update;
929
930    #[test]
931    fn output_sizes() {
932        //           priv pub  sig
933        // ML-DSA-44 2560 1312 2420
934        // ML-DSA-65 4032 1952 3309
935        // ML-DSA-87 4896 2592 4627
936        assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
937        assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
938        assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
939
940        assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
941        assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
942        assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
943
944        assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
945        assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
946        assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
947    }
948
949    fn encode_decode_round_trip_test<P>()
950    where
951        P: MlDsaParams + PartialEq,
952    {
953        let seed = Array::default();
954        let kp = P::from_seed(&seed);
955        assert_eq!(kp.to_seed(), seed);
956
957        let sk = kp.signing_key;
958        let vk = kp.verifying_key;
959
960        let vk_bytes = vk.encode();
961        let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
962        assert!(vk == vk2);
963
964        let sk_bytes = sk.encode();
965        let sk2 = SigningKey::<P>::decode(&sk_bytes);
966        assert!(sk == sk2);
967
968        let M = b"Hello world";
969        let rnd = Array([0u8; 32]);
970        let sig = sk.sign_internal(&[M], &rnd);
971        let sig_bytes = sig.encode();
972        let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
973        assert!(sig == sig2);
974    }
975
976    #[test]
977    fn encode_decode_round_trip() {
978        encode_decode_round_trip_test::<MlDsa44>();
979        encode_decode_round_trip_test::<MlDsa65>();
980        encode_decode_round_trip_test::<MlDsa87>();
981    }
982
983    fn public_from_private_test<P>()
984    where
985        P: MlDsaParams + PartialEq,
986    {
987        let kp = P::from_seed(&Array::default());
988        let sk = kp.signing_key;
989        let vk = kp.verifying_key;
990        let vk_derived = sk.verifying_key();
991
992        assert!(vk == vk_derived);
993    }
994
995    #[test]
996    fn public_from_private() {
997        public_from_private_test::<MlDsa44>();
998        public_from_private_test::<MlDsa65>();
999        public_from_private_test::<MlDsa87>();
1000    }
1001
1002    fn sign_verify_round_trip_test<P>()
1003    where
1004        P: MlDsaParams,
1005    {
1006        let kp = P::from_seed(&Array::default());
1007        let sk = kp.signing_key;
1008        let vk = kp.verifying_key;
1009
1010        let M = b"Hello world";
1011        let rnd = Array([0u8; 32]);
1012        let sig = sk.sign_internal(&[M], &rnd);
1013
1014        assert!(vk.verify_internal(M, &sig));
1015    }
1016
1017    #[test]
1018    fn sign_verify_round_trip() {
1019        sign_verify_round_trip_test::<MlDsa44>();
1020        sign_verify_round_trip_test::<MlDsa65>();
1021        sign_verify_round_trip_test::<MlDsa87>();
1022    }
1023
1024    fn many_round_trip_test<P>()
1025    where
1026        P: MlDsaParams,
1027    {
1028        use rand::Rng;
1029
1030        const ITERATIONS: usize = 1000;
1031
1032        let mut rng = rand::rng();
1033        let mut seed = B32::default();
1034
1035        for _i in 0..ITERATIONS {
1036            let seed_data: &mut [u8] = seed.as_mut();
1037            rng.fill(seed_data);
1038
1039            let kp = P::from_seed(&seed);
1040            let sk = kp.signing_key;
1041            let vk = kp.verifying_key;
1042
1043            let M = b"Hello world";
1044            let rnd = Array([0u8; 32]);
1045            let sig = sk.sign_internal(&[M], &rnd);
1046
1047            let sig_enc = sig.encode();
1048            let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
1049
1050            assert_eq!(sig_dec, sig);
1051            assert!(vk.verify_internal(M, &sig_dec));
1052        }
1053    }
1054
1055    #[test]
1056    fn many_round_trip() {
1057        many_round_trip_test::<MlDsa44>();
1058        many_round_trip_test::<MlDsa65>();
1059        many_round_trip_test::<MlDsa87>();
1060    }
1061
1062    #[test]
1063    fn sign_mu_verify_mu_round_trip() {
1064        fn sign_mu_verify_mu<P>()
1065        where
1066            P: MlDsaParams,
1067        {
1068            let kp = P::from_seed(&Array::default());
1069            let sk = kp.signing_key;
1070            let vk = kp.verifying_key;
1071
1072            let M = b"Hello world";
1073            let rnd = Array([0u8; 32]);
1074            let mu = MuBuilder::internal(&sk.tr, &[M]);
1075            let sig = sk.raw_sign_mu(&mu, &rnd);
1076
1077            assert!(vk.raw_verify_mu(&mu, &sig));
1078        }
1079        sign_mu_verify_mu::<MlDsa44>();
1080        sign_mu_verify_mu::<MlDsa65>();
1081        sign_mu_verify_mu::<MlDsa87>();
1082    }
1083
1084    #[test]
1085    fn sign_mu_verify_internal_round_trip() {
1086        fn sign_mu_verify_internal<P>()
1087        where
1088            P: MlDsaParams,
1089        {
1090            let kp = P::from_seed(&Array::default());
1091            let sk = kp.signing_key;
1092            let vk = kp.verifying_key;
1093
1094            let M = b"Hello world";
1095            let rnd = Array([0u8; 32]);
1096            let mu = MuBuilder::internal(&sk.tr, &[M]);
1097            let sig = sk.raw_sign_mu(&mu, &rnd);
1098
1099            assert!(vk.verify_internal(M, &sig));
1100        }
1101        sign_mu_verify_internal::<MlDsa44>();
1102        sign_mu_verify_internal::<MlDsa65>();
1103        sign_mu_verify_internal::<MlDsa87>();
1104    }
1105
1106    #[test]
1107    fn sign_internal_verify_mu_round_trip() {
1108        fn sign_internal_verify_mu<P>()
1109        where
1110            P: MlDsaParams,
1111        {
1112            let kp = P::from_seed(&Array::default());
1113            let sk = kp.signing_key;
1114            let vk = kp.verifying_key;
1115
1116            let M = b"Hello world";
1117            let rnd = Array([0u8; 32]);
1118            let mu = MuBuilder::internal(&sk.tr, &[M]);
1119            let sig = sk.sign_internal(&[M], &rnd);
1120
1121            assert!(vk.raw_verify_mu(&mu, &sig));
1122        }
1123        sign_internal_verify_mu::<MlDsa44>();
1124        sign_internal_verify_mu::<MlDsa65>();
1125        sign_internal_verify_mu::<MlDsa87>();
1126    }
1127
1128    #[test]
1129    fn sign_digest_round_trip() {
1130        fn sign_digest<P>()
1131        where
1132            P: MlDsaParams,
1133        {
1134            let kp = P::from_seed(&Array::default());
1135            let sk = kp.signing_key;
1136            let vk = kp.verifying_key;
1137
1138            let M = b"Hello world";
1139            let sig = sk.sign_digest(|digest| digest.update(M));
1140            assert_eq!(sig, sk.sign(M));
1141
1142            vk.verify_digest(
1143                |digest| {
1144                    digest.update(M);
1145                    Ok(())
1146                },
1147                &sig,
1148            )
1149            .unwrap();
1150        }
1151        sign_digest::<MlDsa44>();
1152        sign_digest::<MlDsa65>();
1153        sign_digest::<MlDsa87>();
1154    }
1155
1156    #[test]
1157    #[cfg(feature = "rand_core")]
1158    fn sign_randomized_digest_round_trip() {
1159        fn sign_digest<P>()
1160        where
1161            P: MlDsaParams,
1162        {
1163            let kp = P::from_seed(&Array::default());
1164            let sk = kp.signing_key;
1165            let vk = kp.verifying_key;
1166
1167            let M = b"Hello world";
1168            let sig = sk.sign_digest_with_rng(&mut rand::rng(), |digest| digest.update(M));
1169
1170            vk.verify_digest(
1171                |digest| {
1172                    digest.update(M);
1173                    Ok(())
1174                },
1175                &sig,
1176            )
1177            .unwrap();
1178        }
1179        sign_digest::<MlDsa44>();
1180        sign_digest::<MlDsa65>();
1181        sign_digest::<MlDsa87>();
1182    }
1183
1184    #[test]
1185    fn from_seed_implementations_match() {
1186        fn assert_from_seed_equality<P>()
1187        where
1188            P: MlDsaParams,
1189        {
1190            let seed = Seed::default();
1191            let kp1 = P::from_seed(&seed);
1192            let sk1 = SigningKey::<P>::from_seed(&seed);
1193            let vk1 = sk1.verifying_key();
1194            assert_eq!(kp1.signing_key, sk1);
1195            assert_eq!(kp1.verifying_key, vk1);
1196        }
1197        assert_from_seed_equality::<MlDsa44>();
1198        assert_from_seed_equality::<MlDsa65>();
1199        assert_from_seed_equality::<MlDsa87>();
1200    }
1201}