ml_dsa/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8#![warn(clippy::pedantic)] // Be pedantic by default
9#![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%`
10#![warn(clippy::as_conversions)] // Use proper conversions, not `as`
11#![allow(non_snake_case)] // Allow notation matching the spec
12#![allow(clippy::similar_names)] // Allow notation matching the spec
13#![allow(clippy::many_single_char_names)] // Allow notation matching the spec
14#![allow(clippy::clone_on_copy)] // Be explicit about moving data
15#![deny(missing_docs)] // Require all public interfaces to be documented
16
17//! # Quickstart
18//!
19//! ```
20//! # #[cfg(feature = "rand_core")]
21//! # {
22//! use ml_dsa::{MlDsa65, KeyGen, signature::{Keypair, Signer, Verifier}};
23//!
24//! let mut rng = rand::rng();
25//! let kp = MlDsa65::key_gen(&mut rng);
26//!
27//! let msg = b"Hello world";
28//! let sig = kp.signing_key().sign(msg);
29//!
30//! assert!(kp.verifying_key().verify(msg, &sig).is_ok());
31//! # }
32//! ```
33
34mod algebra;
35mod crypto;
36mod encode;
37mod hint;
38mod ntt;
39mod param;
40mod sampling;
41mod util;
42
43// TODO(RLB) Move module to an independent crate shared with ml_kem
44mod module_lattice;
45
46use core::convert::{AsRef, TryFrom, TryInto};
47use hybrid_array::{
48    Array,
49    typenum::{
50        Diff, Length, Prod, Quot, Shleft, U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64,
51        U75, U80, U88, Unsigned,
52    },
53};
54use signature::digest::Update;
55use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer};
56
57#[cfg(feature = "rand_core")]
58use signature::RandomizedDigestSigner;
59
60#[cfg(feature = "rand_core")]
61use rand_core::{CryptoRng, TryCryptoRng};
62
63use sha3::Shake256;
64#[cfg(feature = "zeroize")]
65use zeroize::{Zeroize, ZeroizeOnDrop};
66
67#[cfg(feature = "pkcs8")]
68use {
69    const_oid::db::fips204,
70    pkcs8::{
71        AlgorithmIdentifierRef, PrivateKeyInfoRef,
72        der::{self, AnyRef},
73        spki::{
74            self, AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier,
75            SubjectPublicKeyInfoRef,
76        },
77    },
78};
79
80#[cfg(all(feature = "alloc", feature = "pkcs8"))]
81use pkcs8::{
82    EncodePrivateKey, EncodePublicKey,
83    der::asn1::{BitString, BitStringRef, OctetStringRef},
84    spki::{SignatureBitStringEncoding, SubjectPublicKeyInfo},
85};
86
87use crate::algebra::{AlgebraExt, Elem, NttMatrix, NttVector, Truncate, Vector};
88use crate::crypto::H;
89use crate::hint::Hint;
90use crate::ntt::{Ntt, NttInverse};
91use crate::param::{ParameterSet, QMinus1, SamplingSize, SpecQ};
92use crate::sampling::{expand_a, expand_mask, expand_s, sample_in_ball};
93use crate::util::B64;
94use core::fmt;
95
96pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
97pub use crate::util::B32;
98pub use signature::{self, Error};
99
100/// An ML-DSA signature
101#[derive(Clone, PartialEq, Debug)]
102pub struct Signature<P: MlDsaParams> {
103    c_tilde: Array<u8, P::Lambda>,
104    z: Vector<P::L>,
105    h: Hint<P>,
106}
107
108impl<P: MlDsaParams> Signature<P> {
109    /// Encode the signature in a fixed-size byte array.
110    // Algorithm 26 sigEncode
111    pub fn encode(&self) -> EncodedSignature<P> {
112        let c_tilde = self.c_tilde.clone();
113        let z = P::encode_z(&self.z);
114        let h = self.h.bit_pack();
115        P::concat_sig(c_tilde, z, h)
116    }
117
118    /// Decode the signature from an appropriately sized byte array.
119    // Algorithm 27 sigDecode
120    pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
121        let (c_tilde, z, h) = P::split_sig(enc);
122
123        let c_tilde = c_tilde.clone();
124        let z = P::decode_z(z);
125        let h = Hint::bit_unpack(h)?;
126
127        if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
128            return None;
129        }
130
131        Some(Self { c_tilde, z, h })
132    }
133}
134
135impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
136    type Error = Error;
137
138    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
139        let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
140        Self::decode(&enc).ok_or(Error::new())
141    }
142}
143
144impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
145    type Error = Error;
146
147    fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
148        Ok(self.encode())
149    }
150}
151
152impl<P: MlDsaParams> signature::SignatureEncoding for Signature<P> {
153    type Repr = EncodedSignature<P>;
154}
155
156#[cfg(all(feature = "alloc", feature = "pkcs8"))]
157impl<P: MlDsaParams> SignatureBitStringEncoding for Signature<P> {
158    fn to_bitstring(&self) -> der::Result<BitString> {
159        BitString::new(0, self.encode().to_vec())
160    }
161}
162
163#[cfg(feature = "pkcs8")]
164impl<P> AssociatedAlgorithmIdentifier for Signature<P>
165where
166    P: MlDsaParams,
167    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
168{
169    type Params = AnyRef<'static>;
170
171    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER;
172}
173
174// This method takes a slice of slices so that we can accommodate the varying calculations (direct
175// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
176// having to allocate memory for components.
177fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
178    let mut h = H::default().absorb(tr);
179
180    for m in Mp.iter().copied().flatten() {
181        h = h.absorb(m);
182    }
183
184    h.squeeze_new()
185}
186
187/// An ML-DSA key pair
188pub struct KeyPair<P: MlDsaParams> {
189    /// The signing key of the key pair
190    signing_key: SigningKey<P>,
191
192    /// The verifying key of the key pair
193    verifying_key: VerifyingKey<P>,
194
195    /// The seed this signing key was derived from
196    #[cfg(feature = "pkcs8")]
197    seed: B32,
198}
199
200impl<P: MlDsaParams> KeyPair<P> {
201    /// The signing key of the key pair
202    pub fn signing_key(&self) -> &SigningKey<P> {
203        &self.signing_key
204    }
205
206    /// The verifying key of the key pair
207    pub fn verifying_key(&self) -> &VerifyingKey<P> {
208        &self.verifying_key
209    }
210}
211
212impl<P: MlDsaParams> AsRef<VerifyingKey<P>> for KeyPair<P> {
213    fn as_ref(&self) -> &VerifyingKey<P> {
214        &self.verifying_key
215    }
216}
217
218impl<P: MlDsaParams> fmt::Debug for KeyPair<P> {
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        f.debug_struct("KeyPair")
221            .field("verifying_key", &self.verifying_key)
222            .finish_non_exhaustive()
223    }
224}
225
226impl<P: MlDsaParams> signature::KeypairRef for KeyPair<P> {
227    type VerifyingKey = VerifyingKey<P>;
228}
229
230#[cfg(feature = "pkcs8")]
231impl<P> TryFrom<PrivateKeyInfoRef<'_>> for KeyPair<P>
232where
233    P: MlDsaParams,
234    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
235{
236    type Error = pkcs8::Error;
237
238    fn try_from(private_key_info: pkcs8::PrivateKeyInfoRef<'_>) -> pkcs8::Result<Self> {
239        match private_key_info.algorithm {
240            alg if alg == P::ALGORITHM_IDENTIFIER => {}
241            other => return Err(spki::Error::OidUnknown { oid: other.oid }.into()),
242        }
243
244        let seed = Array::try_from(private_key_info.private_key.as_bytes())
245            .map_err(|_| pkcs8::Error::KeyMalformed)?;
246        Ok(P::from_seed(&seed))
247    }
248}
249
250/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
251/// only supports signing with an empty context string.
252impl<P: MlDsaParams> Signer<Signature<P>> for KeyPair<P> {
253    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
254        self.try_multipart_sign(&[msg])
255    }
256}
257
258/// The `Signer` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA, and
259/// only supports signing with an empty context string.
260impl<P: MlDsaParams> MultipartSigner<Signature<P>> for KeyPair<P> {
261    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
262        self.signing_key.raw_sign_deterministic(msg, &[])
263    }
264}
265
266/// The `DigestSigner` implementation for `KeyPair` uses the optional deterministic variant of ML-DSA
267/// with a pre-computed μ, and only supports signing with an empty context string.
268impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for KeyPair<P> {
269    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
270        &self,
271        f: F,
272    ) -> Result<Signature<P>, Error> {
273        self.signing_key.try_sign_digest(&f)
274    }
275}
276
277#[cfg(feature = "pkcs8")]
278impl<P> SignatureAlgorithmIdentifier for KeyPair<P>
279where
280    P: MlDsaParams,
281    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
282{
283    type Params = AnyRef<'static>;
284
285    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
286        Signature::<P>::ALGORITHM_IDENTIFIER;
287}
288
289#[cfg(all(feature = "alloc", feature = "pkcs8"))]
290impl<P> EncodePrivateKey for KeyPair<P>
291where
292    P: MlDsaParams,
293    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
294{
295    fn to_pkcs8_der(&self) -> pkcs8::Result<der::SecretDocument> {
296        let pkcs8_key = pkcs8::PrivateKeyInfoRef::new(
297            P::ALGORITHM_IDENTIFIER,
298            OctetStringRef::new(&self.seed)?,
299        );
300        Ok(der::SecretDocument::encode_msg(&pkcs8_key)?)
301    }
302}
303
304/// An ML-DSA signing key
305#[derive(Clone, PartialEq)]
306pub struct SigningKey<P: MlDsaParams> {
307    rho: B32,
308    K: B32,
309    tr: B64,
310    s1: Vector<P::L>,
311    s2: Vector<P::K>,
312    t0: Vector<P::K>,
313
314    // Derived values
315    s1_hat: NttVector<P::L>,
316    s2_hat: NttVector<P::K>,
317    t0_hat: NttVector<P::K>,
318    A_hat: NttMatrix<P::K, P::L>,
319}
320
321impl<P: MlDsaParams> fmt::Debug for SigningKey<P> {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        f.debug_struct("SigningKey").finish_non_exhaustive()
324    }
325}
326
327#[cfg(feature = "zeroize")]
328impl<P: MlDsaParams> Drop for SigningKey<P> {
329    fn drop(&mut self) {
330        self.rho.zeroize();
331        self.K.zeroize();
332        self.tr.zeroize();
333        self.s1.zeroize();
334        self.s2.zeroize();
335        self.t0.zeroize();
336    }
337}
338
339#[cfg(feature = "zeroize")]
340impl<P: MlDsaParams> ZeroizeOnDrop for SigningKey<P> {}
341
342impl<P: MlDsaParams> SigningKey<P> {
343    fn new(
344        rho: B32,
345        K: B32,
346        tr: B64,
347        s1: Vector<P::L>,
348        s2: Vector<P::K>,
349        t0: Vector<P::K>,
350        A_hat: Option<NttMatrix<P::K, P::L>>,
351    ) -> Self {
352        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
353        let s1_hat = s1.ntt();
354        let s2_hat = s2.ntt();
355        let t0_hat = t0.ntt();
356
357        Self {
358            rho,
359            K,
360            tr,
361            s1,
362            s2,
363            t0,
364
365            s1_hat,
366            s2_hat,
367            t0_hat,
368            A_hat,
369        }
370    }
371
372    /// Deterministically generate a signing key from the specified seed.
373    ///
374    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204, but only returns a
375    /// signing key.
376    #[must_use]
377    pub fn from_seed(seed: &B32) -> Self {
378        let kp = P::from_seed(seed);
379        kp.signing_key
380    }
381
382    /// This method reflects the ML-DSA.Sign_internal algorithm from FIPS 204. It does not
383    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
384    /// and it does not separate the context string from the rest of the message.
385    // Algorithm 7 ML-DSA.Sign_internal
386    // TODO(RLB) Only expose based on a feature.  Tests need access, but normal code shouldn't.
387    pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
388    where
389        P: MlDsaParams,
390    {
391        self.raw_sign_internal(&[Mp], rnd)
392    }
393
394    fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
395    where
396        P: MlDsaParams,
397    {
398        // Compute the message representative
399        // XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing
400        // the concatenated M'.
401        // XXX(RLB) Should the API represent this as an input?
402        let mu = message_representative(&self.tr, Mp);
403        self.raw_sign_mu(&mu, rnd)
404    }
405
406    fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
407    where
408        P: MlDsaParams,
409    {
410        // Compute the private random seed
411        let rhopp: B64 = H::default()
412            .absorb(&self.K)
413            .absorb(rnd)
414            .absorb(mu)
415            .squeeze_new();
416
417        // Rejection sampling loop
418        for kappa in (0..u16::MAX).step_by(P::L::USIZE) {
419            let y = expand_mask::<P::L, P::Gamma1>(&rhopp, kappa);
420            let w = (&self.A_hat * &y.ntt()).ntt_inverse();
421            let w1 = w.high_bits::<P::TwoGamma2>();
422
423            let w1_tilde = P::encode_w1(&w1);
424            let c_tilde = H::default()
425                .absorb(mu)
426                .absorb(&w1_tilde)
427                .squeeze_new::<P::Lambda>();
428            let c = sample_in_ball(&c_tilde, P::TAU);
429            let c_hat = c.ntt();
430
431            let cs1 = (&c_hat * &self.s1_hat).ntt_inverse();
432            let cs2 = (&c_hat * &self.s2_hat).ntt_inverse();
433
434            let z = &y + &cs1;
435            let r0 = (&w - &cs2).low_bits::<P::TwoGamma2>();
436
437            if z.infinity_norm() >= P::GAMMA1_MINUS_BETA
438                || r0.infinity_norm() >= P::GAMMA2_MINUS_BETA
439            {
440                continue;
441            }
442
443            let ct0 = (&c_hat * &self.t0_hat).ntt_inverse();
444            let minus_ct0 = -&ct0;
445            let w_cs2_ct0 = &(&w - &cs2) + &ct0;
446            let h = Hint::<P>::new(&minus_ct0, &w_cs2_ct0);
447
448            if ct0.infinity_norm() >= P::Gamma2::U32 || h.hamming_weight() > P::Omega::USIZE {
449                continue;
450            }
451
452            let z = z.mod_plus_minus::<SpecQ>();
453            return Signature { c_tilde, z, h };
454        }
455
456        unreachable!("Rejection sampling failed to find a valid signature");
457    }
458
459    /// This method reflects the randomized ML-DSA.Sign algorithm.
460    ///
461    /// # Errors
462    ///
463    /// This method will return an opaque error if the context string is more than 255 bytes long,
464    /// or if it fails to get enough randomness.
465    // Algorithm 2 ML-DSA.Sign
466    #[cfg(feature = "rand_core")]
467    pub fn sign_randomized<R: TryCryptoRng + ?Sized>(
468        &self,
469        M: &[u8],
470        ctx: &[u8],
471        rng: &mut R,
472    ) -> Result<Signature<P>, Error> {
473        if ctx.len() > 255 {
474            return Err(Error::new());
475        }
476
477        let mut rnd = B32::default();
478        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
479
480        let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
481        Ok(self.sign_internal(Mp, &rnd))
482    }
483
484    /// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
485    ///
486    /// # Errors
487    ///
488    /// This method can return an opaque error if it fails to get enough randomness.
489    // Algorithm 2 ML-DSA.Sign (optional pre-computed μ variant)
490    #[cfg(feature = "rand_core")]
491    pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
492        &self,
493        mu: &B64,
494        rng: &mut R,
495    ) -> Result<Signature<P>, Error> {
496        let mut rnd = B32::default();
497        rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;
498
499        Ok(self.raw_sign_mu(mu, &rnd))
500    }
501
502    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
503    ///
504    /// # Errors
505    ///
506    /// This method will return an opaque error if the context string is more than 255 bytes long.
507    // Algorithm 2 ML-DSA.Sign (optional deterministic variant)
508    pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
509        self.raw_sign_deterministic(&[M], ctx)
510    }
511
512    /// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm with a
513    /// pre-computed μ.
514    // Algorithm 2 ML-DSA.Sign (optional deterministic and pre-computed μ variant)
515    pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
516        let rnd = B32::default();
517        self.raw_sign_mu(mu, &rnd)
518    }
519
520    fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
521        if ctx.len() > 255 {
522            return Err(Error::new());
523        }
524
525        let rnd = B32::default();
526        let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
527        Ok(self.raw_sign_internal(Mp, &rnd))
528    }
529
530    /// Encode the key in a fixed-size byte array.
531    // Algorithm 24 skEncode
532    pub fn encode(&self) -> EncodedSigningKey<P>
533    where
534        P: MlDsaParams,
535    {
536        let s1_enc = P::encode_s1(&self.s1);
537        let s2_enc = P::encode_s2(&self.s2);
538        let t0_enc = P::encode_t0(&self.t0);
539        P::concat_sk(
540            self.rho.clone(),
541            self.K.clone(),
542            self.tr.clone(),
543            s1_enc,
544            s2_enc,
545            t0_enc,
546        )
547    }
548
549    /// Decode the key from an appropriately sized byte array.
550    // Algorithm 25 skDecode
551    pub fn decode(enc: &EncodedSigningKey<P>) -> Self
552    where
553        P: MlDsaParams,
554    {
555        let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc);
556        Self::new(
557            rho.clone(),
558            K.clone(),
559            tr.clone(),
560            P::decode_s1(s1_enc),
561            P::decode_s2(s2_enc),
562            P::decode_t0(t0_enc),
563            None,
564        )
565    }
566
567    /// This auxiliary function derives a `VerifyingKey` from a bare
568    /// `SigningKey` (even in the absence of the original seed).
569    ///
570    /// This is a utility function that is useful when importing the private key
571    /// from an external source which does not export the seed and does not
572    /// provide the precomputed public key associated with the private key
573    /// itself.
574    ///
575    /// `SigningKey` implements `signature::Keypair`: this inherent method is
576    /// retained for convenience, so it is available for callers even when the
577    /// `signature::Keypair` trait is out-of-scope.
578    pub fn verifying_key(&self) -> VerifyingKey<P> {
579        let kp: &dyn signature::Keypair<VerifyingKey = VerifyingKey<P>> = self;
580
581        kp.verifying_key()
582    }
583}
584
585/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
586/// only supports signing with an empty context string.  If you would like to include a context
587/// string, use the [`SigningKey::sign_deterministic`] method.
588impl<P: MlDsaParams> Signer<Signature<P>> for SigningKey<P> {
589    fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
590        self.try_multipart_sign(&[msg])
591    }
592}
593
594/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA, and
595/// only supports signing with an empty context string. If you would like to include a context
596/// string, use the [`SigningKey::sign_deterministic`] method.
597impl<P: MlDsaParams> MultipartSigner<Signature<P>> for SigningKey<P> {
598    fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
599        self.raw_sign_deterministic(msg, &[])
600    }
601}
602
603/// The `Signer` implementation for `SigningKey` uses the optional deterministic variant of ML-DSA
604/// with a pre-computed µ, and only supports signing with an empty context string. If you would
605/// like to include a context string, use the [`SigningKey::sign_mu_deterministic`] method.
606impl<P: MlDsaParams> DigestSigner<Shake256, Signature<P>> for SigningKey<P> {
607    fn try_sign_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
608        &self,
609        f: F,
610    ) -> Result<Signature<P>, Error> {
611        let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
612        f(&mut digest)?;
613        let mu = H::pre_digest(digest).squeeze_new();
614
615        Ok(self.sign_mu_deterministic(&mu))
616    }
617}
618
619/// The `KeyPair` implementation for `SigningKey` allows to derive a `VerifyingKey` from
620/// a bare `SigningKey` (even in the absence of the original seed).
621impl<P: MlDsaParams> signature::Keypair for SigningKey<P> {
622    type VerifyingKey = VerifyingKey<P>;
623
624    /// This is a utility function that is useful when importing the private key
625    /// from an external source which does not export the seed and does not
626    /// provide the precomputed public key associated with the private key
627    /// itself.
628    fn verifying_key(&self) -> Self::VerifyingKey {
629        let As1 = &self.A_hat * &self.s1_hat;
630        let t = &As1.ntt_inverse() + &self.s2;
631
632        /* Discard t0 */
633        let (t1, _) = t.power2round();
634
635        VerifyingKey::new(self.rho.clone(), t1, Some(self.A_hat.clone()), None)
636    }
637}
638
639/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
640/// context string. If you would like to include a context string, use the
641/// [`SigningKey::sign_randomized`] method.
642#[cfg(feature = "rand_core")]
643impl<P: MlDsaParams> signature::RandomizedSigner<Signature<P>> for SigningKey<P> {
644    fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
645        &self,
646        rng: &mut R,
647        msg: &[u8],
648    ) -> Result<Signature<P>, Error> {
649        self.sign_randomized(msg, &[], rng)
650    }
651}
652
653/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty
654/// context string. If you would like to include a context string, use the
655/// [`SigningKey::sign_mu_randomized`] method.
656#[cfg(feature = "rand_core")]
657impl<P: MlDsaParams> RandomizedDigestSigner<Shake256, Signature<P>> for SigningKey<P> {
658    fn try_sign_digest_with_rng<
659        R: TryCryptoRng + ?Sized,
660        F: Fn(&mut Shake256) -> Result<(), Error>,
661    >(
662        &self,
663        rng: &mut R,
664        f: F,
665    ) -> Result<Signature<P>, Error> {
666        let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
667        f(&mut digest)?;
668        let mu = H::pre_digest(digest).squeeze_new();
669
670        self.sign_mu_randomized(&mu, rng)
671    }
672}
673
674#[cfg(feature = "pkcs8")]
675impl<P> SignatureAlgorithmIdentifier for SigningKey<P>
676where
677    P: MlDsaParams,
678    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
679{
680    type Params = AnyRef<'static>;
681
682    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
683        Signature::<P>::ALGORITHM_IDENTIFIER;
684}
685
686#[cfg(feature = "pkcs8")]
687impl<P> TryFrom<PrivateKeyInfoRef<'_>> for SigningKey<P>
688where
689    P: MlDsaParams,
690    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
691{
692    type Error = pkcs8::Error;
693
694    fn try_from(private_key_info: pkcs8::PrivateKeyInfoRef<'_>) -> pkcs8::Result<Self> {
695        let keypair = KeyPair::try_from(private_key_info)?;
696
697        Ok(keypair.signing_key)
698    }
699}
700
701/// An ML-DSA verification key
702#[derive(Clone, Debug, PartialEq)]
703pub struct VerifyingKey<P: ParameterSet> {
704    rho: B32,
705    t1: Vector<P::K>,
706
707    // Derived values
708    A_hat: NttMatrix<P::K, P::L>,
709    t1_2d_hat: NttVector<P::K>,
710    tr: B64,
711}
712
713impl<P: MlDsaParams> VerifyingKey<P> {
714    fn new(
715        rho: B32,
716        t1: Vector<P::K>,
717        A_hat: Option<NttMatrix<P::K, P::L>>,
718        enc: Option<EncodedVerifyingKey<P>>,
719    ) -> Self {
720        let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
721        let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));
722
723        let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
724        let tr: B64 = H::default().absorb(&enc).squeeze_new();
725
726        Self {
727            rho,
728            t1,
729            A_hat,
730            t1_2d_hat,
731            tr,
732        }
733    }
734
735    /// This algorithm reflects the ML-DSA.Verify_internal algorithm from FIPS 204.  It does not
736    /// include the domain separator that distinguishes between the normal and pre-hashed cases,
737    /// and it does not separate the context string from the rest of the message.
738    // Algorithm 8 ML-DSA.Verify_internal
739    pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
740    where
741        P: MlDsaParams,
742    {
743        self.raw_verify_internal(&[Mp], sigma)
744    }
745
746    fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
747    where
748        P: MlDsaParams,
749    {
750        // Compute the message representative
751        let mu = message_representative(&self.tr, Mp);
752        self.raw_verify_mu(&mu, sigma)
753    }
754
755    fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
756    where
757        P: MlDsaParams,
758    {
759        // Reconstruct w
760        let c = sample_in_ball(&sigma.c_tilde, P::TAU);
761
762        let z_hat = sigma.z.ntt();
763        let c_hat = c.ntt();
764        let Az_hat = &self.A_hat * &z_hat;
765        let ct1_2d_hat = &c_hat * &self.t1_2d_hat;
766
767        let wp_approx = (&Az_hat - &ct1_2d_hat).ntt_inverse();
768        let w1p = sigma.h.use_hint(&wp_approx);
769
770        let w1p_tilde = P::encode_w1(&w1p);
771        let cp_tilde = H::default()
772            .absorb(mu)
773            .absorb(&w1p_tilde)
774            .squeeze_new::<P::Lambda>();
775
776        sigma.c_tilde == cp_tilde
777    }
778
779    /// This algorithm reflects the ML-DSA.Verify algorithm from FIPS 204.
780    // Algorithm 3 ML-DSA.Verify
781    pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
782        self.raw_verify_with_context(&[M], ctx, sigma)
783    }
784
785    /// This algorithm reflects the ML-DSA.Verify algorithm with a pre-computed μ from FIPS 204.
786    // Algorithm 3 ML-DSA.Verify (optional pre-computed μ variant)
787    pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
788        self.raw_verify_mu(mu, sigma)
789    }
790
791    fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
792        if ctx.len() > 255 {
793            return false;
794        }
795
796        let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
797        self.raw_verify_internal(Mp, sigma)
798    }
799
800    fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
801        let t1_enc = P::encode_t1(t1);
802        P::concat_vk(rho.clone(), t1_enc)
803    }
804
805    /// Encode the key in a fixed-size byte array.
806    // Algorithm 22 pkEncode
807    pub fn encode(&self) -> EncodedVerifyingKey<P> {
808        Self::encode_internal(&self.rho, &self.t1)
809    }
810
811    /// Decode the key from an appropriately sized byte array.
812    // Algorithm 23 pkDecode
813    pub fn decode(enc: &EncodedVerifyingKey<P>) -> Self {
814        let (rho, t1_enc) = P::split_vk(enc);
815        let t1 = P::decode_t1(t1_enc);
816        Self::new(rho.clone(), t1, None, Some(enc.clone()))
817    }
818}
819
820impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
821    fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
822        self.multipart_verify(&[msg], signature)
823    }
824}
825
826impl<P: MlDsaParams> MultipartVerifier<Signature<P>> for VerifyingKey<P> {
827    fn multipart_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
828        self.raw_verify_with_context(msg, &[], signature)
829            .then_some(())
830            .ok_or(Error::new())
831    }
832}
833
834impl<P: MlDsaParams> DigestVerifier<Shake256, Signature<P>> for VerifyingKey<P> {
835    fn verify_digest<F: Fn(&mut Shake256) -> Result<(), Error>>(
836        &self,
837        f: F,
838        signature: &Signature<P>,
839    ) -> Result<(), Error> {
840        let mut digest = Shake256::default().chain(self.tr).chain([0, 0]);
841        f(&mut digest)?;
842        let mu = H::pre_digest(digest).squeeze_new();
843
844        self.raw_verify_mu(&mu, signature)
845            .then_some(())
846            .ok_or(Error::new())
847    }
848}
849
850#[cfg(feature = "pkcs8")]
851impl<P> SignatureAlgorithmIdentifier for VerifyingKey<P>
852where
853    P: MlDsaParams,
854    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
855{
856    type Params = AnyRef<'static>;
857
858    const SIGNATURE_ALGORITHM_IDENTIFIER: AlgorithmIdentifier<Self::Params> =
859        Signature::<P>::ALGORITHM_IDENTIFIER;
860}
861
862#[cfg(all(feature = "alloc", feature = "pkcs8"))]
863impl<P> EncodePublicKey for VerifyingKey<P>
864where
865    P: MlDsaParams,
866    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
867{
868    fn to_public_key_der(&self) -> spki::Result<der::Document> {
869        let public_key = self.encode();
870        let subject_public_key = BitStringRef::new(0, &public_key)?;
871
872        SubjectPublicKeyInfo {
873            algorithm: P::ALGORITHM_IDENTIFIER,
874            subject_public_key,
875        }
876        .try_into()
877    }
878}
879
880#[cfg(feature = "pkcs8")]
881impl<P> TryFrom<SubjectPublicKeyInfoRef<'_>> for VerifyingKey<P>
882where
883    P: MlDsaParams,
884    P: AssociatedAlgorithmIdentifier<Params = AnyRef<'static>>,
885{
886    type Error = spki::Error;
887
888    fn try_from(spki: SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
889        match spki.algorithm {
890            alg if alg == P::ALGORITHM_IDENTIFIER => {}
891            other => return Err(spki::Error::OidUnknown { oid: other.oid }),
892        }
893
894        Ok(Self::decode(
895            &EncodedVerifyingKey::<P>::try_from(
896                spki.subject_public_key
897                    .as_bytes()
898                    .ok_or_else(|| der::Tag::BitString.value_error().to_error())?,
899            )
900            .map_err(|_| pkcs8::Error::KeyMalformed)?,
901        ))
902    }
903}
904
905/// `MlDsa44` is the parameter set for security category 2.
906#[derive(Default, Clone, Debug, PartialEq)]
907pub struct MlDsa44;
908
909impl ParameterSet for MlDsa44 {
910    type K = U4;
911    type L = U4;
912    type Eta = U2;
913    type Gamma1 = Shleft<U1, U17>;
914    type Gamma2 = Quot<QMinus1, U88>;
915    type TwoGamma2 = Prod<U2, Self::Gamma2>;
916    type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
917    type Lambda = U32;
918    type Omega = U80;
919    const TAU: usize = 39;
920}
921
922#[cfg(feature = "pkcs8")]
923impl AssociatedAlgorithmIdentifier for MlDsa44 {
924    type Params = AnyRef<'static>;
925
926    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
927        oid: fips204::ID_ML_DSA_44,
928        parameters: None,
929    };
930}
931
932/// `MlDsa65` is the parameter set for security category 3.
933#[derive(Default, Clone, Debug, PartialEq)]
934pub struct MlDsa65;
935
936impl ParameterSet for MlDsa65 {
937    type K = U6;
938    type L = U5;
939    type Eta = U4;
940    type Gamma1 = Shleft<U1, U19>;
941    type Gamma2 = Quot<QMinus1, U32>;
942    type TwoGamma2 = Prod<U2, Self::Gamma2>;
943    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
944    type Lambda = U48;
945    type Omega = U55;
946    const TAU: usize = 49;
947}
948
949#[cfg(feature = "pkcs8")]
950impl AssociatedAlgorithmIdentifier for MlDsa65 {
951    type Params = AnyRef<'static>;
952
953    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
954        oid: fips204::ID_ML_DSA_65,
955        parameters: None,
956    };
957}
958
959/// `MlKem87` is the parameter set for security category 5.
960#[derive(Default, Clone, Debug, PartialEq)]
961pub struct MlDsa87;
962
963impl ParameterSet for MlDsa87 {
964    type K = U8;
965    type L = U7;
966    type Eta = U2;
967    type Gamma1 = Shleft<U1, U19>;
968    type Gamma2 = Quot<QMinus1, U32>;
969    type TwoGamma2 = Prod<U2, Self::Gamma2>;
970    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
971    type Lambda = U64;
972    type Omega = U75;
973    const TAU: usize = 60;
974}
975
976#[cfg(feature = "pkcs8")]
977impl AssociatedAlgorithmIdentifier for MlDsa87 {
978    type Params = AnyRef<'static>;
979
980    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = AlgorithmIdentifierRef {
981        oid: fips204::ID_ML_DSA_87,
982        parameters: None,
983    };
984}
985
986/// A parameter set that knows how to generate key pairs
987pub trait KeyGen: MlDsaParams {
988    /// The type that is returned by key generation
989    type KeyPair: signature::Keypair;
990
991    /// Generate a signing key pair from the specified RNG
992    #[cfg(feature = "rand_core")]
993    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;
994
995    /// Deterministically generate a signing key pair from the specified seed
996    ///
997    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
998    fn from_seed(xi: &B32) -> Self::KeyPair;
999}
1000
1001impl<P> KeyGen for P
1002where
1003    P: MlDsaParams,
1004{
1005    type KeyPair = KeyPair<P>;
1006
1007    /// Generate a signing key pair from the specified RNG
1008    // Algorithm 1 ML-DSA.KeyGen()
1009    #[cfg(feature = "rand_core")]
1010    fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
1011        let mut xi = B32::default();
1012        rng.fill_bytes(&mut xi);
1013        Self::from_seed(&xi)
1014    }
1015
1016    /// Deterministically generate a signing key pair from the specified seed
1017    ///
1018    /// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
1019    // Algorithm 6 ML-DSA.KeyGen_internal
1020    fn from_seed(xi: &B32) -> KeyPair<P>
1021    where
1022        P: MlDsaParams,
1023    {
1024        // Derive seeds
1025        let mut h = H::default()
1026            .absorb(xi)
1027            .absorb(&[P::K::U8])
1028            .absorb(&[P::L::U8]);
1029
1030        let rho: B32 = h.squeeze_new();
1031        let rhop: B64 = h.squeeze_new();
1032        let K: B32 = h.squeeze_new();
1033
1034        // Sample private key components
1035        let A_hat = expand_a::<P::K, P::L>(&rho);
1036        let s1 = expand_s::<P::L>(&rhop, P::Eta::ETA, 0);
1037        let s2 = expand_s::<P::K>(&rhop, P::Eta::ETA, P::L::USIZE);
1038
1039        // Compute derived values
1040        let As1_hat = &A_hat * &s1.ntt();
1041        let t = &As1_hat.ntt_inverse() + &s2;
1042
1043        // Compress and encode
1044        let (t1, t0) = t.power2round();
1045
1046        let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None);
1047        let signing_key =
1048            SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat));
1049
1050        KeyPair {
1051            signing_key,
1052            verifying_key,
1053            #[cfg(feature = "pkcs8")]
1054            seed: xi.clone(),
1055        }
1056    }
1057}
1058
1059#[cfg(test)]
1060mod test {
1061    use super::*;
1062    use crate::param::*;
1063
1064    #[test]
1065    fn output_sizes() {
1066        //           priv pub  sig
1067        // ML-DSA-44 2560 1312 2420
1068        // ML-DSA-65 4032 1952 3309
1069        // ML-DSA-87 4896 2592 4627
1070        assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
1071        assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
1072        assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
1073
1074        assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
1075        assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
1076        assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
1077
1078        assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
1079        assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
1080        assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
1081    }
1082
1083    fn encode_decode_round_trip_test<P>()
1084    where
1085        P: MlDsaParams + PartialEq,
1086    {
1087        let kp = P::from_seed(&Array::default());
1088        let sk = kp.signing_key;
1089        let vk = kp.verifying_key;
1090
1091        let vk_bytes = vk.encode();
1092        let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
1093        assert!(vk == vk2);
1094
1095        let sk_bytes = sk.encode();
1096        let sk2 = SigningKey::<P>::decode(&sk_bytes);
1097        assert!(sk == sk2);
1098
1099        let M = b"Hello world";
1100        let rnd = Array([0u8; 32]);
1101        let sig = sk.sign_internal(&[M], &rnd);
1102        let sig_bytes = sig.encode();
1103        let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
1104        assert!(sig == sig2);
1105    }
1106
1107    #[test]
1108    fn encode_decode_round_trip() {
1109        encode_decode_round_trip_test::<MlDsa44>();
1110        encode_decode_round_trip_test::<MlDsa65>();
1111        encode_decode_round_trip_test::<MlDsa87>();
1112    }
1113
1114    fn public_from_private_test<P>()
1115    where
1116        P: MlDsaParams + PartialEq,
1117    {
1118        let kp = P::from_seed(&Array::default());
1119        let sk = kp.signing_key;
1120        let vk = kp.verifying_key;
1121        let vk_derived = sk.verifying_key();
1122
1123        assert!(vk == vk_derived);
1124    }
1125
1126    #[test]
1127    fn public_from_private() {
1128        public_from_private_test::<MlDsa44>();
1129        public_from_private_test::<MlDsa65>();
1130        public_from_private_test::<MlDsa87>();
1131    }
1132
1133    fn sign_verify_round_trip_test<P>()
1134    where
1135        P: MlDsaParams,
1136    {
1137        let kp = P::from_seed(&Array::default());
1138        let sk = kp.signing_key;
1139        let vk = kp.verifying_key;
1140
1141        let M = b"Hello world";
1142        let rnd = Array([0u8; 32]);
1143        let sig = sk.sign_internal(&[M], &rnd);
1144
1145        assert!(vk.verify_internal(&[M], &sig));
1146    }
1147
1148    #[test]
1149    fn sign_verify_round_trip() {
1150        sign_verify_round_trip_test::<MlDsa44>();
1151        sign_verify_round_trip_test::<MlDsa65>();
1152        sign_verify_round_trip_test::<MlDsa87>();
1153    }
1154
1155    fn many_round_trip_test<P>()
1156    where
1157        P: MlDsaParams,
1158    {
1159        use rand::Rng;
1160
1161        const ITERATIONS: usize = 1000;
1162
1163        let mut rng = rand::rng();
1164        let mut seed = B32::default();
1165
1166        for _i in 0..ITERATIONS {
1167            let seed_data: &mut [u8] = seed.as_mut();
1168            rng.fill(seed_data);
1169
1170            let kp = P::from_seed(&seed);
1171            let sk = kp.signing_key;
1172            let vk = kp.verifying_key;
1173
1174            let M = b"Hello world";
1175            let rnd = Array([0u8; 32]);
1176            let sig = sk.sign_internal(&[M], &rnd);
1177
1178            let sig_enc = sig.encode();
1179            let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();
1180
1181            assert_eq!(sig_dec, sig);
1182            assert!(vk.verify_internal(&[M], &sig_dec));
1183        }
1184    }
1185
1186    #[test]
1187    fn many_round_trip() {
1188        many_round_trip_test::<MlDsa44>();
1189        many_round_trip_test::<MlDsa65>();
1190        many_round_trip_test::<MlDsa87>();
1191    }
1192
1193    #[test]
1194    fn sign_mu_verify_mu_round_trip() {
1195        fn sign_mu_verify_mu<P>()
1196        where
1197            P: MlDsaParams,
1198        {
1199            let kp = P::from_seed(&Array::default());
1200            let sk = kp.signing_key;
1201            let vk = kp.verifying_key;
1202
1203            let M = b"Hello world";
1204            let rnd = Array([0u8; 32]);
1205            let mu = message_representative(&sk.tr, &[&[M]]);
1206            let sig = sk.raw_sign_mu(&mu, &rnd);
1207
1208            assert!(vk.raw_verify_mu(&mu, &sig));
1209        }
1210        sign_mu_verify_mu::<MlDsa44>();
1211        sign_mu_verify_mu::<MlDsa65>();
1212        sign_mu_verify_mu::<MlDsa87>();
1213    }
1214
1215    #[test]
1216    fn sign_mu_verify_internal_round_trip() {
1217        fn sign_mu_verify_internal<P>()
1218        where
1219            P: MlDsaParams,
1220        {
1221            let kp = P::from_seed(&Array::default());
1222            let sk = kp.signing_key;
1223            let vk = kp.verifying_key;
1224
1225            let M = b"Hello world";
1226            let rnd = Array([0u8; 32]);
1227            let mu = message_representative(&sk.tr, &[&[M]]);
1228            let sig = sk.raw_sign_mu(&mu, &rnd);
1229
1230            assert!(vk.verify_internal(&[M], &sig));
1231        }
1232        sign_mu_verify_internal::<MlDsa44>();
1233        sign_mu_verify_internal::<MlDsa65>();
1234        sign_mu_verify_internal::<MlDsa87>();
1235    }
1236
1237    #[test]
1238    fn sign_internal_verify_mu_round_trip() {
1239        fn sign_internal_verify_mu<P>()
1240        where
1241            P: MlDsaParams,
1242        {
1243            let kp = P::from_seed(&Array::default());
1244            let sk = kp.signing_key;
1245            let vk = kp.verifying_key;
1246
1247            let M = b"Hello world";
1248            let rnd = Array([0u8; 32]);
1249            let mu = message_representative(&sk.tr, &[&[M]]);
1250            let sig = sk.sign_internal(&[M], &rnd);
1251
1252            assert!(vk.raw_verify_mu(&mu, &sig));
1253        }
1254        sign_internal_verify_mu::<MlDsa44>();
1255        sign_internal_verify_mu::<MlDsa65>();
1256        sign_internal_verify_mu::<MlDsa87>();
1257    }
1258
1259    #[test]
1260    fn sign_digest_round_trip() {
1261        fn sign_digest<P>()
1262        where
1263            P: MlDsaParams,
1264        {
1265            let kp = P::from_seed(&Array::default());
1266            let sk = kp.signing_key;
1267            let vk = kp.verifying_key;
1268
1269            let M = b"Hello world";
1270            let sig = sk.sign_digest(|digest| digest.update(M));
1271            assert_eq!(sig, sk.sign(M));
1272
1273            vk.verify_digest(
1274                |digest| {
1275                    digest.update(M);
1276                    Ok(())
1277                },
1278                &sig,
1279            )
1280            .unwrap();
1281        }
1282        sign_digest::<MlDsa44>();
1283        sign_digest::<MlDsa65>();
1284        sign_digest::<MlDsa87>();
1285    }
1286
1287    #[test]
1288    #[cfg(feature = "rand_core")]
1289    fn sign_randomized_digest_round_trip() {
1290        fn sign_digest<P>()
1291        where
1292            P: MlDsaParams,
1293        {
1294            let kp = P::from_seed(&Array::default());
1295            let sk = kp.signing_key;
1296            let vk = kp.verifying_key;
1297
1298            let M = b"Hello world";
1299            let sig = sk.sign_digest_with_rng(&mut rand::rng(), |digest| digest.update(M));
1300
1301            vk.verify_digest(
1302                |digest| {
1303                    digest.update(M);
1304                    Ok(())
1305                },
1306                &sig,
1307            )
1308            .unwrap();
1309        }
1310        sign_digest::<MlDsa44>();
1311        sign_digest::<MlDsa65>();
1312        sign_digest::<MlDsa87>();
1313    }
1314
1315    #[test]
1316    fn from_seed_implementations_match() {
1317        fn assert_from_seed_equality<P>()
1318        where
1319            P: MlDsaParams,
1320        {
1321            let seed = Array([0u8; 32]);
1322            let kp1 = P::from_seed(&seed);
1323            let sk1 = SigningKey::<P>::from_seed(&seed);
1324            let vk1 = sk1.verifying_key();
1325            assert_eq!(kp1.signing_key, sk1);
1326            assert_eq!(kp1.verifying_key, vk1);
1327        }
1328        assert_from_seed_equality::<MlDsa44>();
1329        assert_from_seed_equality::<MlDsa65>();
1330        assert_from_seed_equality::<MlDsa87>();
1331    }
1332}