1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
5 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
6)]
7#![cfg_attr(docsrs, feature(doc_cfg))]
8#![allow(non_snake_case)] #![allow(clippy::similar_names)] #![allow(clippy::many_single_char_names)] #![allow(clippy::clone_on_copy)] #![cfg_attr(feature = "getrandom", doc = "```")]
29#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
30#[cfg(feature = "alloc")]
44extern crate alloc;
45
46#[cfg(feature = "pkcs8")]
47pub mod pkcs8;
48
49mod algebra;
50mod crypto;
51mod encode;
52mod hint;
53mod ntt;
54mod param;
55mod sampling;
56mod signing;
57mod verifying;
58
59pub use crate::{
60 param::{EncodedSignature, EncodedVerifyingKey, ExpandedSigningKeyBytes, MlDsaParams},
61 signing::{ExpandedSigningKey, SigningKey},
62 verifying::VerifyingKey,
63};
64pub use common::{self, KeyExport, KeyInit, KeySizeUser};
65pub use signature::{self, Error, Keypair, SignatureEncoding, Signer, Verifier};
66
67#[cfg(feature = "rand_core")]
68pub use common::Generate;
69
70use crate::{
71 algebra::{AlgebraExt, Vector},
72 crypto::H,
73 hint::Hint,
74 param::{ParameterSet, QMinus1},
75};
76use core::convert::{TryFrom, TryInto};
77use hybrid_array::{
78 Array,
79 sizes::{U1, U2, U4, U5, U6, U7, U8, U17, U19, U32, U48, U55, U64, U75, U80, U88},
80 typenum::{Diff, Length, Prod, Quot, Shleft},
81};
82use module_lattice::{MaybeBox, Truncate};
83use shake::Shake256;
84
85pub type B32 = Array<u8, U32>;
87
88pub(crate) type B64 = Array<u8, U64>;
90
91pub type Seed = B32;
94
95#[derive(Clone, Debug, PartialEq)]
97pub struct Signature<P: MlDsaParams> {
98 c_tilde: Array<u8, P::Lambda>,
99 z: MaybeBox<Vector<P::L>>,
100 h: Hint<P>,
101}
102
103impl<P: MlDsaParams> Signature<P> {
104 pub fn encode(&self) -> EncodedSignature<P> {
107 let c_tilde = self.c_tilde.clone();
108 let z = P::encode_z(&self.z);
109 let h = self.h.bit_pack();
110 P::concat_sig(c_tilde, z, h)
111 }
112
113 pub fn decode(enc: &EncodedSignature<P>) -> Option<Self> {
116 let (c_tilde, z, h) = P::split_sig(enc);
117
118 let c_tilde = c_tilde.clone();
119 let z = MaybeBox::new(P::decode_z(z));
120 let h = Hint::bit_unpack(h)?;
121
122 if z.infinity_norm() >= P::GAMMA1_MINUS_BETA {
123 return None;
124 }
125
126 Some(Self { c_tilde, z, h })
127 }
128}
129
130impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature<P> {
131 type Error = Error;
132
133 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
134 let enc = EncodedSignature::<P>::try_from(value).map_err(|_| Error::new())?;
135 Self::decode(&enc).ok_or(Error::new())
136 }
137}
138
139impl<P: MlDsaParams> TryInto<EncodedSignature<P>> for Signature<P> {
140 type Error = Error;
141
142 fn try_into(self) -> Result<EncodedSignature<P>, Self::Error> {
143 Ok(self.encode())
144 }
145}
146
147impl<P: MlDsaParams> SignatureEncoding for Signature<P> {
148 type Repr = EncodedSignature<P>;
149}
150
151impl<P: MlDsaParams> core::hash::Hash for Signature<P> {
152 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
153 self.encode().hash(state);
154 }
155}
156
157struct MuBuilder(H);
158
159impl MuBuilder {
160 fn new(tr: &[u8], ctx: &[u8]) -> Self {
161 let mut h = H::default();
162 h = h.absorb(tr);
163 h = h.absorb(&[0]);
164 h = h.absorb(&[Truncate::truncate(ctx.len())]);
165 h = h.absorb(ctx);
166
167 Self(h)
168 }
169
170 fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 {
171 let mut h = H::default().absorb(tr);
172
173 for m in Mp {
174 h = h.absorb(m);
175 }
176
177 h.squeeze_new()
178 }
179
180 fn message(mut self, M: &[&[u8]]) -> B64 {
181 for m in M {
182 self.0 = self.0.absorb(m);
183 }
184
185 self.0.squeeze_new()
186 }
187
188 fn finish(mut self) -> B64 {
189 self.0.squeeze_new()
190 }
191}
192
193impl AsMut<Shake256> for MuBuilder {
194 fn as_mut(&mut self) -> &mut Shake256 {
195 self.0.updatable()
196 }
197}
198
199#[derive(Clone, Copy, Debug, Default, PartialEq)]
202pub struct MlDsa44;
203
204impl ParameterSet for MlDsa44 {
205 type K = U4;
206 type L = U4;
207 type Eta = U2;
208 type Gamma1 = Shleft<U1, U17>;
209 type Gamma2 = Quot<QMinus1, U88>;
210 type TwoGamma2 = Prod<U2, Self::Gamma2>;
211 type W1Bits = Length<Diff<Quot<U88, U2>, U1>>;
212 type Lambda = U32;
213 type Omega = U80;
214 const TAU: usize = 39;
215}
216
217#[derive(Clone, Copy, Debug, Default, PartialEq)]
222pub struct MlDsa65;
223
224impl ParameterSet for MlDsa65 {
225 type K = U6;
226 type L = U5;
227 type Eta = U4;
228 type Gamma1 = Shleft<U1, U19>;
229 type Gamma2 = Quot<QMinus1, U32>;
230 type TwoGamma2 = Prod<U2, Self::Gamma2>;
231 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
232 type Lambda = U48;
233 type Omega = U55;
234 const TAU: usize = 49;
235}
236
237#[derive(Clone, Copy, Debug, Default, PartialEq)]
240pub struct MlDsa87;
241
242impl ParameterSet for MlDsa87 {
243 type K = U8;
244 type L = U7;
245 type Eta = U2;
246 type Gamma1 = Shleft<U1, U19>;
247 type Gamma2 = Quot<QMinus1, U32>;
248 type TwoGamma2 = Prod<U2, Self::Gamma2>;
249 type W1Bits = Length<Diff<Quot<U32, U2>, U1>>;
250 type Lambda = U64;
251 type Omega = U75;
252 const TAU: usize = 60;
253}
254
255#[cfg(test)]
256mod test {
257 use super::*;
258 use crate::param::*;
259 use hybrid_array::typenum::Unsigned;
260 use signature::Keypair;
261
262 #[test]
263 fn output_sizes() {
264 assert_eq!(SigningKeySize::<MlDsa44>::USIZE, 2560);
269 assert_eq!(VerifyingKeySize::<MlDsa44>::USIZE, 1312);
270 assert_eq!(SignatureSize::<MlDsa44>::USIZE, 2420);
271
272 assert_eq!(SigningKeySize::<MlDsa65>::USIZE, 4032);
273 assert_eq!(VerifyingKeySize::<MlDsa65>::USIZE, 1952);
274 assert_eq!(SignatureSize::<MlDsa65>::USIZE, 3309);
275
276 assert_eq!(SigningKeySize::<MlDsa87>::USIZE, 4896);
277 assert_eq!(VerifyingKeySize::<MlDsa87>::USIZE, 2592);
278 assert_eq!(SignatureSize::<MlDsa87>::USIZE, 4627);
279 }
280
281 fn encode_decode_round_trip_test<P>()
282 where
283 P: MlDsaParams + PartialEq,
284 {
285 let seed = Array::default();
286 let ssk = SigningKey::from_seed(&seed);
287 assert_eq!(ssk.to_seed(), seed);
288
289 let esk = ssk.expanded_key();
290 let vk = ssk.verifying_key();
291
292 let vk_bytes = vk.encode();
293 let vk2 = VerifyingKey::<P>::decode(&vk_bytes);
294 assert!(vk == vk2);
295
296 #[allow(deprecated)]
297 {
298 let sk_bytes = esk.to_expanded();
299 let sk2 = ExpandedSigningKey::<P>::from_expanded(&sk_bytes);
300 assert!(esk == &sk2);
301
302 let M = b"Hello world";
303 let rnd = Array([0u8; 32]);
304 let sig = esk.sign_internal(&[M], &rnd);
305 let sig_bytes = sig.encode();
306 let sig2 = Signature::<P>::decode(&sig_bytes).unwrap();
307 assert!(sig == sig2);
308 }
309 }
310
311 #[test]
312 fn encode_decode_round_trip() {
313 encode_decode_round_trip_test::<MlDsa44>();
314 encode_decode_round_trip_test::<MlDsa65>();
315 encode_decode_round_trip_test::<MlDsa87>();
316 }
317
318 fn public_from_private_test<P>()
319 where
320 P: MlDsaParams + PartialEq,
321 {
322 let ssk = SigningKey::<P>::from_seed(&Array::default());
323 let esk = ssk.expanded_key();
324 let vk = ssk.verifying_key();
325 let vk_derived = esk.verifying_key();
326
327 assert!(vk == vk_derived);
328 }
329
330 #[test]
331 fn public_from_private() {
332 public_from_private_test::<MlDsa44>();
333 public_from_private_test::<MlDsa65>();
334 public_from_private_test::<MlDsa87>();
335 }
336
337 fn sign_verify_round_trip_test<P>()
338 where
339 P: MlDsaParams,
340 {
341 let ssk = SigningKey::<P>::from_seed(&Array::default());
342 let esk = ssk.expanded_key();
343 let vk = ssk.verifying_key();
344
345 let M = b"Hello world";
346 let rnd = Array([0u8; 32]);
347 let sig = esk.sign_internal(&[M], &rnd);
348
349 assert!(vk.verify_internal(M, &sig));
350 }
351
352 #[test]
353 fn sign_verify_round_trip() {
354 sign_verify_round_trip_test::<MlDsa44>();
355 sign_verify_round_trip_test::<MlDsa65>();
356 sign_verify_round_trip_test::<MlDsa87>();
357 }
358
359 #[test]
360 fn sign_mu_verify_mu_round_trip() {
361 fn sign_mu_verify_mu<P>()
362 where
363 P: MlDsaParams,
364 {
365 let ssk = SigningKey::<P>::from_seed(&Array::default());
366 let esk = ssk.expanded_key();
367 let vk = ssk.verifying_key();
368
369 let M = b"Hello world";
370 let rnd = Array([0u8; 32]);
371 let mu = MuBuilder::internal(&esk.tr, &[M]);
372 let sig = esk.raw_sign_mu(&mu, &rnd);
373
374 assert!(vk.raw_verify_mu(&mu, &sig));
375 }
376 sign_mu_verify_mu::<MlDsa44>();
377 sign_mu_verify_mu::<MlDsa65>();
378 sign_mu_verify_mu::<MlDsa87>();
379 }
380
381 #[test]
382 fn sign_mu_verify_internal_round_trip() {
383 fn sign_mu_verify_internal<P>()
384 where
385 P: MlDsaParams,
386 {
387 let ssk = SigningKey::<P>::from_seed(&Array::default());
388 let esk = ssk.expanded_key();
389 let vk = ssk.verifying_key();
390
391 let M = b"Hello world";
392 let rnd = Array([0u8; 32]);
393 let mu = MuBuilder::internal(&esk.tr, &[M]);
394 let sig = esk.raw_sign_mu(&mu, &rnd);
395
396 assert!(vk.verify_internal(M, &sig));
397 }
398 sign_mu_verify_internal::<MlDsa44>();
399 sign_mu_verify_internal::<MlDsa65>();
400 sign_mu_verify_internal::<MlDsa87>();
401 }
402
403 #[test]
404 fn sign_internal_verify_mu_round_trip() {
405 fn sign_internal_verify_mu<P>()
406 where
407 P: MlDsaParams,
408 {
409 let ssk = SigningKey::<P>::from_seed(&Array::default());
410 let esk = ssk.expanded_key();
411 let vk = ssk.verifying_key();
412
413 let M = b"Hello world";
414 let rnd = Array([0u8; 32]);
415 let mu = MuBuilder::internal(&esk.tr, &[M]);
416 let sig = esk.sign_internal(&[M], &rnd);
417
418 assert!(vk.raw_verify_mu(&mu, &sig));
419 }
420 sign_internal_verify_mu::<MlDsa44>();
421 sign_internal_verify_mu::<MlDsa65>();
422 sign_internal_verify_mu::<MlDsa87>();
423 }
424
425 #[test]
426 fn from_seed_implementations_match() {
427 fn assert_from_seed_equality<P>()
428 where
429 P: MlDsaParams,
430 {
431 let seed = Seed::default();
432 let ssk = SigningKey::<P>::from_seed(&seed);
433 let sk1 = ExpandedSigningKey::<P>::from_seed(&seed);
434 assert_eq!(ssk.expanded_key(), &sk1);
435 }
436 assert_from_seed_equality::<MlDsa44>();
437 assert_from_seed_equality::<MlDsa65>();
438 assert_from_seed_equality::<MlDsa87>();
439 }
440
441 #[test]
442 fn to_seed_returns_correct_seed() {
443 fn test_to_seed<P: MlDsaParams>() {
444 let seed = Array([
445 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
446 24, 25, 26, 27, 28, 29, 30, 31, 32,
447 ]);
448 let kp = SigningKey::<P>::from_seed(&seed);
449 assert_eq!(kp.to_seed(), seed);
450 }
451 test_to_seed::<MlDsa44>();
452 test_to_seed::<MlDsa65>();
453 test_to_seed::<MlDsa87>();
454 }
455
456 #[test]
457 fn verification_rejects_invalid_signature() {
458 fn test_invalid_sig<P: MlDsaParams>() {
459 let kp = SigningKey::<P>::from_seed(&Array::default());
460 let vk = kp.verifying_key();
461
462 let msg = b"Hello world";
463 let rnd = Array([0u8; 32]);
464 let mut sig = kp.expanded_key().sign_internal(&[msg], &rnd);
465 sig.c_tilde[0] ^= 0xFF;
466
467 assert!(!vk.verify_with_context(msg, &[], &sig));
468 }
469 test_invalid_sig::<MlDsa44>();
470 test_invalid_sig::<MlDsa65>();
471 test_invalid_sig::<MlDsa87>();
472 }
473
474 #[test]
475 fn verification_rejects_wrong_message() {
476 fn test_wrong_msg<P: MlDsaParams>() {
477 let kp = SigningKey::<P>::from_seed(&Array::default());
478 let vk = kp.verifying_key();
479
480 let msg1 = b"Hello world";
481 let msg2 = b"Wrong message";
482 let rnd = Array([0u8; 32]);
483 let sig = kp.expanded_key().sign_internal(&[msg1], &rnd);
484
485 assert!(!vk.verify_with_context(msg2, &[], &sig));
486 }
487 test_wrong_msg::<MlDsa44>();
488 test_wrong_msg::<MlDsa65>();
489 test_wrong_msg::<MlDsa87>();
490 }
491
492 #[test]
493 fn context_length_validation() {
494 fn test_ctx_length<P: MlDsaParams>() {
495 let ssk = SigningKey::<P>::from_seed(&Array::default());
496 let sk = ssk.expanded_key();
497 let vk = ssk.verifying_key();
498
499 let msg = b"Hello world";
500 let long_ctx = [0u8; 256];
501 let short_ctx = [0u8; 255];
502
503 assert!(sk.sign_deterministic(msg, &long_ctx).is_err());
504
505 let sig = sk.sign_deterministic(msg, &short_ctx).unwrap();
506 assert!(!vk.verify_with_context(msg, &long_ctx, &sig));
507 assert!(vk.verify_with_context(msg, &short_ctx, &sig));
508 }
509 test_ctx_length::<MlDsa44>();
510 test_ctx_length::<MlDsa65>();
511 test_ctx_length::<MlDsa87>();
512 }
513
514 #[test]
515 fn derived_verifying_key_validates_signatures() {
516 fn test_derived_vk<P: MlDsaParams>() {
517 let seed = Array([42u8; 32]);
518 let ssk = SigningKey::<P>::from_seed(&seed);
519 let sk = ssk.expanded_key();
520 let derived_vk = sk.verifying_key();
521
522 let msg = b"Test message for derived key";
523 let rnd = Array([0u8; 32]);
524 let sig = sk.sign_internal(&[msg], &rnd);
525
526 assert!(derived_vk.verify_internal(msg, &sig));
527 assert_eq!(derived_vk.encode(), ssk.verifying_key().encode());
528 }
529 test_derived_vk::<MlDsa44>();
530 test_derived_vk::<MlDsa65>();
531 test_derived_vk::<MlDsa87>();
532 }
533
534 #[test]
535 #[cfg(feature = "alloc")]
536 fn debug_implementations() {
537 extern crate alloc;
538 use core::fmt::Write;
539
540 fn test_debug<P: MlDsaParams>() {
541 let kp = SigningKey::<P>::from_seed(&Array::default());
542
543 let mut kp_debug = alloc::string::String::new();
544 write!(&mut kp_debug, "{:?}", kp).unwrap();
545 assert!(kp_debug.contains("SigningKey"));
546
547 let mut sk_debug = alloc::string::String::new();
548 write!(&mut sk_debug, "{:?}", kp.expanded_key()).unwrap();
549 assert!(sk_debug.contains("ExpandedSigningKey"));
550 }
551 test_debug::<MlDsa44>();
552 test_debug::<MlDsa65>();
553 test_debug::<MlDsa87>();
554 }
555}