Skip to main content

ml_dsa/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_cfg))]
8#![allow(non_snake_case)] // Allow notation matching the spec
9#![allow(clippy::similar_names)] // Allow notation matching the spec
10#![allow(clippy::many_single_char_names)] // Allow notation matching the spec
11#![allow(clippy::clone_on_copy)] // Be explicit about moving data
12
13//! # Usage
14//!
15//! The following types provide the core functionality of this crate, and are all generic around the
16//! [`MlDsaParams`] trait which defines the security level and is one of [`MlDsa44`], [`MlDsa65`],
17//! or [`MlDsa87`]  (with `MlDsa65` recommended as providing the best balance of security and
18//! performance):
19//!
20//! - [`SigningKey`]: secret key capable of generating signatures. Implements the [`KeyInit`],
21//!   [`KeyExport`], [`Keypair`], and [`Signer`] traits, as well as [`Generate`] when the
22//!   `rand_core` feature of this crate is enabled.
23//! - [`VerifyingKey`]: public key associated with a given [`SigningKey`]. Implements the
24//!   [`KeyInit`], [`KeyExport`], and [`Verifier`] traits.
25//! - [`Signature`]: ML-DSA signature generated by a `SigningKey`, and verifiable by a
26//!   `VerifyingKey`. Implements the [`SignatureEncoding`] trait.
27//!
28#![cfg_attr(feature = "getrandom", doc = "```")]
29#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
30//! # fn main() -> Result<(), signature::Error> {
31//! // NOTE: requires the `getrandom` feature is enabled
32//! use ml_dsa::{MlDsa65, Generate, Keypair, SigningKey, Signer, Verifier};
33//!
34//! let sk = SigningKey::<MlDsa65>::generate();
35//!
36//! let msg = b"Hello world";
37//! let sig = sk.sign(msg);
38//!
39//! sk.verifying_key().verify(msg, &sig)?;
40//! # Ok(()) }
41//! ```
42
43#[cfg(feature = "alloc")]
44extern crate alloc;
45
46#[cfg(feature = "pkcs8")]
47pub mod pkcs8;
48
49mod algebra;
50mod crypto;
51mod encode;
52mod hint;
53mod ntt;
54mod param;
55mod sampling;
56mod signing;
57mod verifying;
58
59pub use crate::{
60    param::{EncodedSignature, EncodedVerifyingKey, ExpandedSigningKeyBytes, MlDsaParams},
61    signing::{ExpandedSigningKey, SigningKey},
62    verifying::VerifyingKey,
63};
64pub use common::{self, KeyExport, KeyInit, KeySizeUser};
65pub use signature::{self, Error, Keypair, SignatureEncoding, Signer, Verifier};
66
67#[cfg(feature = "rand_core")]
68pub use common::Generate;
69
70use crate::{
71    algebra::{AlgebraExt, Vector},
72    crypto::H,
73    hint::Hint,
74    param::{ParameterSet, QMinus1},
75};
76use core::convert::{TryFrom, TryInto};
77use hybrid_array::{
78    Array,
79    sizes::{U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64, U75, U80, U88},
80    typenum::{Diff, Length, Prod, Quot, Shleft},
81};
82use module_lattice::{MaybeBox, Truncate};
83use shake::Shake256;
84
85/// A 32-byte array, defined here for brevity because it is used several times
86pub type B32 = Array<u8, U32>;
87
88/// A 64-byte array, defined here for brevity because it is used several times
89pub(crate) type B64 = Array<u8, U64>;
90
91/// ML-DSA seeds are signing (private) keys, which are consistently 32-bytes across all security
92/// levels, and are the preferred serialization for representing such keys.
93pub type Seed = B32;
94
95/// An ML-DSA signature
96#[derive(Clone, Debug, PartialEq)]
97pub struct Signature<P: MlDsaParams> {
98    c_tilde: Array<u8, P::Lambda>,
99    z: MaybeBox<Vector<P::L>>,
100    h: Hint<P>,
101}
102
103impl<P: MlDsaParams> Signature<P> {
104    /// Encode the signature in a fixed-size byte array.
105    // Algorithm 26 sigEncode
106    pub fn encode(&self) -> EncodedSignature<P> {
107        let c_tilde = self.c_tilde.clone();
108        let z = P::encode_z(&self.z);
109        let h = self.h.bit_pack();
110        P::concat_sig(c_tilde, z, h)
111    }
112
113    /// Decode the signature from an appropriately sized byte array.
114    // Algorithm 27 sigDecode
115    pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
116        let (c_tilde, z, h) = P::split_sig(enc);
117
118        let c_tilde = c_tilde.clone();
119        let z = MaybeBox::new(P::decode_z(z));
120        let h = Hint::bit_unpack(h)?;
121
122        if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
123            return None;
124        }
125
126        Some(Self { c_tilde, z, h })
127    }
128}
129
130impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
131    type Error = Error;
132
133    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
134        let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
135        Self::decode(&enc).ok_or(Error::new())
136    }
137}
138
139impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
140    type Error = Error;
141
142    fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
143        Ok(self.encode())
144    }
145}
146
147impl<P: MlDsaParams> SignatureEncoding for Signature<P> {
148    type Repr = EncodedSignature<P>;
149}
150
151impl<P: MlDsaParams> core::hash::Hash for Signature<P> {
152    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
153        self.encode().hash(state);
154    }
155}
156
157struct MuBuilder(H);
158
159impl MuBuilder {
160    fn new(tr: &[u8], ctx: &[u8]) -> Self {
161        let mut h = H::default();
162        h = h.absorb(tr);
163        h = h.absorb(&[0]);
164        h = h.absorb(&[Truncate::truncate(ctx.len())]);
165        h = h.absorb(ctx);
166
167        Self(h)
168    }
169
170    fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
171        let mut h = H::default().absorb(tr);
172
173        for m in Mp {
174            h = h.absorb(m);
175        }
176
177        h.squeeze_new()
178    }
179
180    fn message(mut self, M: &[&[u8]]) -> B64 {
181        for m in M {
182            self.0 = self.0.absorb(m);
183        }
184
185        self.0.squeeze_new()
186    }
187
188    fn finish(mut self) -> B64 {
189        self.0.squeeze_new()
190    }
191}
192
193impl AsMut<Shake256> for MuBuilder {
194    fn as_mut(&mut self) -> &mut Shake256 {
195        self.0.updatable()
196    }
197}
198
199/// `MlDsa44` is the parameter set for security category 2, providing the equivalent of 128-bit
200/// symmetric security.
201#[derive(Clone, Copy, Debug, Default, PartialEq)]
202pub struct MlDsa44;
203
204impl ParameterSet for MlDsa44 {
205    type K = U4;
206    type L = U4;
207    type Eta = U2;
208    type Gamma1 = Shleft<U1, U17>;
209    type Gamma2 = Quot<QMinus1, U88>;
210    type TwoGamma2 = Prod<U2, Self::Gamma2>;
211    type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
212    type Lambda = U32;
213    type Omega = U80;
214    const TAU: usize = 39;
215}
216
217/// `MlDsa65` is the parameter set for security category 3, providing the equivalent of 192-bit
218/// symmetric security, and is the recommended parameter set.
219///
220/// This set provides the best balance between performance and security.
221#[derive(Clone, Copy, Debug, Default, PartialEq)]
222pub struct MlDsa65;
223
224impl ParameterSet for MlDsa65 {
225    type K = U6;
226    type L = U5;
227    type Eta = U4;
228    type Gamma1 = Shleft<U1, U19>;
229    type Gamma2 = Quot<QMinus1, U32>;
230    type TwoGamma2 = Prod<U2, Self::Gamma2>;
231    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
232    type Lambda = U48;
233    type Omega = U55;
234    const TAU: usize = 49;
235}
236
237/// `MlDsa87` is the parameter set for security category 5, providing the equivalent of 256-bit
238/// symmetric security.
239#[derive(Clone, Copy, Debug, Default, PartialEq)]
240pub struct MlDsa87;
241
242impl ParameterSet for MlDsa87 {
243    type K = U8;
244    type L = U7;
245    type Eta = U2;
246    type Gamma1 = Shleft<U1, U19>;
247    type Gamma2 = Quot<QMinus1, U32>;
248    type TwoGamma2 = Prod<U2, Self::Gamma2>;
249    type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
250    type Lambda = U64;
251    type Omega = U75;
252    const TAU: usize = 60;
253}
254
255#[cfg(test)]
256mod test {
257    use super::*;
258    use crate::param::*;
259    use hybrid_array::typenum::Unsigned;
260    use signature::Keypair;
261
262    #[test]
263    fn output_sizes() {
264        //           priv pub  sig
265        // ML-DSA-44 2560 1312 2420
266        // ML-DSA-65 4032 1952 3309
267        // ML-DSA-87 4896 2592 4627
268        assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
269        assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
270        assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
271
272        assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
273        assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
274        assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
275
276        assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
277        assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
278        assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
279    }
280
281    fn encode_decode_round_trip_test<P>()
282    where
283        P: MlDsaParams + PartialEq,
284    {
285        let seed = Array::default();
286        let ssk = SigningKey::from_seed(&seed);
287        assert_eq!(ssk.to_seed(), seed);
288
289        let esk = ssk.expanded_key();
290        let vk = ssk.verifying_key();
291
292        let vk_bytes = vk.encode();
293        let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
294        assert!(vk == vk2);
295
296        #[allow(deprecated)]
297        {
298            let sk_bytes = esk.to_expanded();
299            let sk2 = ExpandedSigningKey::<P>::from_expanded(&sk_bytes);
300            assert!(esk == &sk2);
301
302            let M = b"Hello world";
303            let rnd = Array([0u8; 32]);
304            let sig = esk.sign_internal(&[M], &rnd);
305            let sig_bytes = sig.encode();
306            let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
307            assert!(sig == sig2);
308        }
309    }
310
311    #[test]
312    fn encode_decode_round_trip() {
313        encode_decode_round_trip_test::<MlDsa44>();
314        encode_decode_round_trip_test::<MlDsa65>();
315        encode_decode_round_trip_test::<MlDsa87>();
316    }
317
318    fn public_from_private_test<P>()
319    where
320        P: MlDsaParams + PartialEq,
321    {
322        let ssk = SigningKey::<P>::from_seed(&Array::default());
323        let esk = ssk.expanded_key();
324        let vk = ssk.verifying_key();
325        let vk_derived = esk.verifying_key();
326
327        assert!(vk == vk_derived);
328    }
329
330    #[test]
331    fn public_from_private() {
332        public_from_private_test::<MlDsa44>();
333        public_from_private_test::<MlDsa65>();
334        public_from_private_test::<MlDsa87>();
335    }
336
337    fn sign_verify_round_trip_test<P>()
338    where
339        P: MlDsaParams,
340    {
341        let ssk = SigningKey::<P>::from_seed(&Array::default());
342        let esk = ssk.expanded_key();
343        let vk = ssk.verifying_key();
344
345        let M = b"Hello world";
346        let rnd = Array([0u8; 32]);
347        let sig = esk.sign_internal(&[M], &rnd);
348
349        assert!(vk.verify_internal(M, &sig));
350    }
351
352    #[test]
353    fn sign_verify_round_trip() {
354        sign_verify_round_trip_test::<MlDsa44>();
355        sign_verify_round_trip_test::<MlDsa65>();
356        sign_verify_round_trip_test::<MlDsa87>();
357    }
358
359    #[test]
360    fn sign_mu_verify_mu_round_trip() {
361        fn sign_mu_verify_mu<P>()
362        where
363            P: MlDsaParams,
364        {
365            let ssk = SigningKey::<P>::from_seed(&Array::default());
366            let esk = ssk.expanded_key();
367            let vk = ssk.verifying_key();
368
369            let M = b"Hello world";
370            let rnd = Array([0u8; 32]);
371            let mu = MuBuilder::internal(&esk.tr, &[M]);
372            let sig = esk.raw_sign_mu(&mu, &rnd);
373
374            assert!(vk.raw_verify_mu(&mu, &sig));
375        }
376        sign_mu_verify_mu::<MlDsa44>();
377        sign_mu_verify_mu::<MlDsa65>();
378        sign_mu_verify_mu::<MlDsa87>();
379    }
380
381    #[test]
382    fn sign_mu_verify_internal_round_trip() {
383        fn sign_mu_verify_internal<P>()
384        where
385            P: MlDsaParams,
386        {
387            let ssk = SigningKey::<P>::from_seed(&Array::default());
388            let esk = ssk.expanded_key();
389            let vk = ssk.verifying_key();
390
391            let M = b"Hello world";
392            let rnd = Array([0u8; 32]);
393            let mu = MuBuilder::internal(&esk.tr, &[M]);
394            let sig = esk.raw_sign_mu(&mu, &rnd);
395
396            assert!(vk.verify_internal(M, &sig));
397        }
398        sign_mu_verify_internal::<MlDsa44>();
399        sign_mu_verify_internal::<MlDsa65>();
400        sign_mu_verify_internal::<MlDsa87>();
401    }
402
403    #[test]
404    fn sign_internal_verify_mu_round_trip() {
405        fn sign_internal_verify_mu<P>()
406        where
407            P: MlDsaParams,
408        {
409            let ssk = SigningKey::<P>::from_seed(&Array::default());
410            let esk = ssk.expanded_key();
411            let vk = ssk.verifying_key();
412
413            let M = b"Hello world";
414            let rnd = Array([0u8; 32]);
415            let mu = MuBuilder::internal(&esk.tr, &[M]);
416            let sig = esk.sign_internal(&[M], &rnd);
417
418            assert!(vk.raw_verify_mu(&mu, &sig));
419        }
420        sign_internal_verify_mu::<MlDsa44>();
421        sign_internal_verify_mu::<MlDsa65>();
422        sign_internal_verify_mu::<MlDsa87>();
423    }
424
425    #[test]
426    fn from_seed_implementations_match() {
427        fn assert_from_seed_equality<P>()
428        where
429            P: MlDsaParams,
430        {
431            let seed = Seed::default();
432            let ssk = SigningKey::<P>::from_seed(&seed);
433            let sk1 = ExpandedSigningKey::<P>::from_seed(&seed);
434            assert_eq!(ssk.expanded_key(), &sk1);
435        }
436        assert_from_seed_equality::<MlDsa44>();
437        assert_from_seed_equality::<MlDsa65>();
438        assert_from_seed_equality::<MlDsa87>();
439    }
440
441    #[test]
442    fn to_seed_returns_correct_seed() {
443        fn test_to_seed<P: MlDsaParams>() {
444            let seed = Array([
445                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
446                24, 25, 26, 27, 28, 29, 30, 31, 32,
447            ]);
448            let kp = SigningKey::<P>::from_seed(&seed);
449            assert_eq!(kp.to_seed(), seed);
450        }
451        test_to_seed::<MlDsa44>();
452        test_to_seed::<MlDsa65>();
453        test_to_seed::<MlDsa87>();
454    }
455
456    #[test]
457    fn verification_rejects_invalid_signature() {
458        fn test_invalid_sig<P: MlDsaParams>() {
459            let kp = SigningKey::<P>::from_seed(&Array::default());
460            let vk = kp.verifying_key();
461
462            let msg = b"Hello world";
463            let rnd = Array([0u8; 32]);
464            let mut sig = kp.expanded_key().sign_internal(&[msg], &rnd);
465            sig.c_tilde[0] ^= 0xFF;
466
467            assert!(!vk.verify_with_context(msg, &[], &sig));
468        }
469        test_invalid_sig::<MlDsa44>();
470        test_invalid_sig::<MlDsa65>();
471        test_invalid_sig::<MlDsa87>();
472    }
473
474    #[test]
475    fn verification_rejects_wrong_message() {
476        fn test_wrong_msg<P: MlDsaParams>() {
477            let kp = SigningKey::<P>::from_seed(&Array::default());
478            let vk = kp.verifying_key();
479
480            let msg1 = b"Hello world";
481            let msg2 = b"Wrong message";
482            let rnd = Array([0u8; 32]);
483            let sig = kp.expanded_key().sign_internal(&[msg1], &rnd);
484
485            assert!(!vk.verify_with_context(msg2, &[], &sig));
486        }
487        test_wrong_msg::<MlDsa44>();
488        test_wrong_msg::<MlDsa65>();
489        test_wrong_msg::<MlDsa87>();
490    }
491
492    #[test]
493    fn context_length_validation() {
494        fn test_ctx_length<P: MlDsaParams>() {
495            let ssk = SigningKey::<P>::from_seed(&Array::default());
496            let sk = ssk.expanded_key();
497            let vk = ssk.verifying_key();
498
499            let msg = b"Hello world";
500            let long_ctx = [0u8; 256];
501            let short_ctx = [0u8; 255];
502
503            assert!(sk.sign_deterministic(msg, &long_ctx).is_err());
504
505            let sig = sk.sign_deterministic(msg, &short_ctx).unwrap();
506            assert!(!vk.verify_with_context(msg, &long_ctx, &sig));
507            assert!(vk.verify_with_context(msg, &short_ctx, &sig));
508        }
509        test_ctx_length::<MlDsa44>();
510        test_ctx_length::<MlDsa65>();
511        test_ctx_length::<MlDsa87>();
512    }
513
514    #[test]
515    fn derived_verifying_key_validates_signatures() {
516        fn test_derived_vk<P: MlDsaParams>() {
517            let seed = Array([42u8; 32]);
518            let ssk = SigningKey::<P>::from_seed(&seed);
519            let sk = ssk.expanded_key();
520            let derived_vk = sk.verifying_key();
521
522            let msg = b"Test message for derived key";
523            let rnd = Array([0u8; 32]);
524            let sig = sk.sign_internal(&[msg], &rnd);
525
526            assert!(derived_vk.verify_internal(msg, &sig));
527            assert_eq!(derived_vk.encode(), ssk.verifying_key().encode());
528        }
529        test_derived_vk::<MlDsa44>();
530        test_derived_vk::<MlDsa65>();
531        test_derived_vk::<MlDsa87>();
532    }
533
534    #[test]
535    #[cfg(feature = "alloc")]
536    fn debug_implementations() {
537        extern crate alloc;
538        use core::fmt::Write;
539
540        fn test_debug<P: MlDsaParams>() {
541            let kp = SigningKey::<P>::from_seed(&Array::default());
542
543            let mut kp_debug = alloc::string::String::new();
544            write!(&mut kp_debug, "{:?}", kp).unwrap();
545            assert!(kp_debug.contains("SigningKey"));
546
547            let mut sk_debug = alloc::string::String::new();
548            write!(&mut sk_debug, "{:?}", kp.expanded_key()).unwrap();
549            assert!(sk_debug.contains("ExpandedSigningKey"));
550        }
551        test_debug::<MlDsa44>();
552        test_debug::<MlDsa65>();
553        test_debug::<MlDsa87>();
554    }
555}