1#![no_std]
2#![allow(clippy::too_many_arguments)]
4
5pub use fn_dsa_comm::{
67 DOMAIN_NONE,
68 DomainContext,
69 FN_DSA_LOGN_512,
70 FN_DSA_LOGN_1024,
71 HASH_ID_RAW,
72 HASH_ID_SHA3_256,
73 HASH_ID_SHA3_384,
74 HASH_ID_SHA3_512,
75 HASH_ID_SHA256,
76 HASH_ID_SHA384,
77 HASH_ID_SHA512,
78 HASH_ID_SHA512_256,
79 HASH_ID_SHAKE128,
80 HASH_ID_SHAKE256,
81 HashIdentifier,
82 signature_size,
83 vrfy_key_size,
84};
85use fn_dsa_comm::{
86 codec,
87 hash_to_point,
88 mq,
89 shake,
90};
91
92pub trait VerifyingKey: Sized {
94 fn decode(src: &[u8]) -> Option<Self>;
102
103 fn verify(&self, sig: &[u8], ctx: &DomainContext, id: &HashIdentifier, hv: &[u8]) -> bool;
117}
118
119macro_rules! vrfy_key_impl {
120 ($typename:ident, $logn_min:expr, $logn_max:expr) => {
121 #[doc = concat!("Signature verifier for degrees (`logn`) ",
122 stringify!($logn_min), " to ", stringify!($logn_max), " only.")]
123 #[derive(Copy, Clone, Debug)]
124 pub struct $typename {
125 logn: u32,
126 h: [u16; 1 << ($logn_max)],
127 hashed_key: [u8; 64],
128
129 #[cfg(all(
130 not(feature = "no_avx2"),
131 any(target_arch = "x86_64", target_arch = "x86")
132 ))]
133 use_avx2: bool,
134 }
135
136 impl VerifyingKey for $typename {
137 fn decode(src: &[u8]) -> Option<Self> {
138 let mut h = [0u16; 1 << ($logn_max)];
139 let mut hashed_key = [0u8; 64];
140 let mut sh = shake::SHAKE256::new();
141 sh.inject(src);
142 sh.flip();
143 sh.extract(&mut hashed_key);
144
145 #[cfg(all(
146 not(feature = "no_avx2"),
147 any(target_arch = "x86_64", target_arch = "x86")
148 ))]
149 {
150 if fn_dsa_comm::has_avx2() {
151 unsafe {
152 let logn = decode_avx2_inner($logn_min, $logn_max, &mut h[..], src)?;
153 return Some(Self {
154 logn,
155 h,
156 hashed_key,
157 use_avx2: true,
158 });
159 }
160 }
161 }
162
163 let logn = decode_inner($logn_min, $logn_max, &mut h[..], src)?;
164 Some(Self {
165 logn,
166 h,
167 hashed_key,
168 #[cfg(all(
169 not(feature = "no_avx2"),
170 any(target_arch = "x86_64", target_arch = "x86")
171 ))]
172 use_avx2: false,
173 })
174 }
175
176 fn verify(
177 &self,
178 sig: &[u8],
179 ctx: &DomainContext,
180 id: &HashIdentifier,
181 hv: &[u8],
182 ) -> bool {
183 let logn = self.logn;
184 let n = 1usize << logn;
185 let mut tmp_i16 = [0i16; 1 << ($logn_max)];
186 let mut tmp_u16 = [0u16; 2 << ($logn_max)];
187
188 #[cfg(all(
189 not(feature = "no_avx2"),
190 any(target_arch = "x86_64", target_arch = "x86")
191 ))]
192 if self.use_avx2 {
193 unsafe {
194 return verify_avx2_inner(
195 logn,
196 &self.h[..n],
197 &self.hashed_key,
198 sig,
199 ctx,
200 id,
201 hv,
202 &mut tmp_i16[..n],
203 &mut tmp_u16[..(2 * n)],
204 );
205 }
206 }
207
208 verify_inner(
209 logn,
210 &self.h[..n],
211 &self.hashed_key,
212 sig,
213 ctx,
214 id,
215 hv,
216 &mut tmp_i16[..n],
217 &mut tmp_u16[..(2 * n)],
218 )
219 }
220 }
221 };
222}
223
224vrfy_key_impl!(VerifyingKeyStandard, 9, 10);
226
227vrfy_key_impl!(VerifyingKey512, 9, 9);
229
230vrfy_key_impl!(VerifyingKey1024, 10, 10);
232
233vrfy_key_impl!(VerifyingKeyWeak, 2, 8);
236
237fn decode_inner(logn_min: u32, logn_max: u32, h: &mut [u16], src: &[u8]) -> Option<u32> {
245 if src.is_empty() {
246 return None;
247 }
248 let head = src[0];
249 if (head & 0xF0) != 0x00 {
250 return None;
251 }
252 let logn = (head & 0x0F) as u32;
253 if logn < logn_min || logn > logn_max {
254 return None;
255 }
256 if src.len() != vrfy_key_size(logn) {
257 return None;
258 }
259 let n = 1usize << logn;
260 let _ = codec::modq_decode(&src[1..], &mut h[..n])?;
261 mq::mqpoly_ext_to_int(logn, h);
262 mq::mqpoly_int_to_NTT(logn, h);
263 Some(logn)
264}
265
266fn verify_inner(
267 logn: u32,
268 h: &[u16],
269 hashed_key: &[u8],
270 sig: &[u8],
271 ctx: &DomainContext,
272 id: &HashIdentifier,
273 hv: &[u8],
274 tmp_i16: &mut [i16],
275 tmp_u16: &mut [u16],
276) -> bool {
277 let n = 1usize << logn;
280 let s2i = &mut tmp_i16[..n];
281 let (t1, tmp_u16) = tmp_u16.split_at_mut(n);
282 let (t2, _) = tmp_u16.split_at_mut(n);
283
284 if sig.len() != signature_size(logn) {
286 return false;
287 }
288 let head = sig[0];
289 if head != (0x30 + logn) as u8 {
290 return false;
291 }
292 if !codec::comp_decode(&sig[41..], s2i) {
293 return false;
294 }
295
296 let norm2 = mq::signed_poly_sqnorm(logn, &*s2i);
300
301 hash_to_point(&sig[1..41], hashed_key, ctx, id, hv, t1);
303 mq::mqpoly_ext_to_int(logn, t1);
304
305 mq::mqpoly_signed_to_ext(logn, &*s2i, t2);
307 mq::mqpoly_ext_to_int(logn, t2);
308 mq::mqpoly_int_to_NTT(logn, t2);
309
310 mq::mqpoly_mul_ntt(logn, t2, h);
312 mq::mqpoly_NTT_to_int(logn, t2);
313 mq::mqpoly_sub_int(logn, t1, t2);
314 mq::mqpoly_int_to_ext(logn, t1);
315
316 let norm1 = mq::mqpoly_sqnorm(logn, &*t1);
318
319 norm1 < norm2.wrapping_neg() && (norm1 + norm2) <= mq::SQBETA[logn as usize]
322}
323
324#[cfg(all(
326 not(feature = "no_avx2"),
327 any(target_arch = "x86_64", target_arch = "x86")
328))]
329#[target_feature(enable = "avx2")]
330unsafe fn decode_avx2_inner(
331 logn_min: u32,
332 logn_max: u32,
333 h: &mut [u16],
334 src: &[u8],
335) -> Option<u32> {
336 use fn_dsa_comm::mq_avx2;
337
338 if src.is_empty() {
339 return None;
340 }
341 let head = src[0];
342 if (head & 0xF0) != 0x00 {
343 return None;
344 }
345 let logn = (head & 0x0F) as u32;
346 if logn < logn_min || logn > logn_max {
347 return None;
348 }
349 if src.len() != vrfy_key_size(logn) {
350 return None;
351 }
352 let n = 1usize << logn;
353 let _ = codec::modq_decode(&src[1..], &mut h[..n])?;
354 unsafe {
355 mq_avx2::mqpoly_ext_to_int(logn, h);
356 mq_avx2::mqpoly_int_to_NTT(logn, h);
357 }
358 Some(logn)
359}
360
361#[cfg(all(
363 not(feature = "no_avx2"),
364 any(target_arch = "x86_64", target_arch = "x86")
365))]
366#[target_feature(enable = "avx2")]
367unsafe fn verify_avx2_inner(
368 logn: u32,
369 h: &[u16],
370 hashed_key: &[u8],
371 sig: &[u8],
372 ctx: &DomainContext,
373 id: &HashIdentifier,
374 hv: &[u8],
375 tmp_i16: &mut [i16],
376 tmp_u16: &mut [u16],
377) -> bool {
378 use fn_dsa_comm::mq_avx2;
379
380 let n = 1usize << logn;
383 let s2i = &mut tmp_i16[..n];
384 let (t1, tmp_u16) = tmp_u16.split_at_mut(n);
385 let (t2, _) = tmp_u16.split_at_mut(n);
386
387 if sig.len() != signature_size(logn) {
389 return false;
390 }
391 let head = sig[0];
392 if head != (0x30 + logn) as u8 {
393 return false;
394 }
395 if !codec::comp_decode(&sig[41..], s2i) {
396 return false;
397 }
398
399 let norm2 = unsafe { mq_avx2::signed_poly_sqnorm(logn, &*s2i) };
403
404 hash_to_point(&sig[1..41], hashed_key, ctx, id, hv, t1);
406 let norm1 = unsafe {
407 mq_avx2::mqpoly_ext_to_int(logn, t1);
408
409 mq_avx2::mqpoly_signed_to_ext(logn, &*s2i, t2);
411 mq_avx2::mqpoly_ext_to_int(logn, t2);
412 mq_avx2::mqpoly_int_to_NTT(logn, t2);
413
414 mq_avx2::mqpoly_mul_ntt(logn, t2, h);
416 mq_avx2::mqpoly_NTT_to_int(logn, t2);
417 mq_avx2::mqpoly_sub_int(logn, t1, t2);
418 mq_avx2::mqpoly_int_to_ext(logn, t1);
419
420 mq_avx2::mqpoly_sqnorm(logn, &*t1)
422 };
423
424 norm1 < norm2.wrapping_neg() && (norm1 + norm2) <= mq_avx2::SQBETA[logn as usize]
427}
428
429#[cfg(test)]
430mod tests {
431
432 use fn_dsa_comm::shake::{
433 SHA3_256,
434 SHAKE256,
435 };
436 use fn_dsa_comm::{
437 Infallible,
438 TryCryptoRng,
439 TryRng,
440 sign_key_size,
441 };
442 use fn_dsa_kgen::{
443 KeyPairGenerator,
444 KeyPairGenerator512,
445 KeyPairGenerator1024,
446 };
447 use fn_dsa_sign::{
448 SigningKey,
449 SigningKey512,
450 SigningKey1024,
451 };
452
453 use super::*;
454
455 struct ShakeRng(SHAKE256);
457
458 impl ShakeRng {
459 fn from_seed(seed: &[u8]) -> Self {
460 let mut sh = SHAKE256::new();
461 sh.inject(seed);
462 sh.flip();
463 Self(sh)
464 }
465 }
466
467 impl TryRng for ShakeRng {
468 type Error = Infallible;
469
470 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
471 let mut buf = [0u8; 4];
472 self.0.extract(&mut buf);
473 Ok(u32::from_le_bytes(buf))
474 }
475
476 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
477 let mut buf = [0u8; 8];
478 self.0.extract(&mut buf);
479 Ok(u64::from_le_bytes(buf))
480 }
481
482 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
483 self.0.extract(dest);
484 Ok(())
485 }
486 }
487
488 impl TryCryptoRng for ShakeRng {}
489
490 fn sha3_256_digest_32(msg: &[u8]) -> [u8; 32] {
491 let mut sh = SHA3_256::new();
492 sh.update(msg);
493 sh.digest()
494 }
495
496 fn verify_roundtrip_logn(logn: u32) {
500 let hv = sha3_256_digest_32(b"fn-dsa-vrfy kat");
501
502 let mut rng = ShakeRng::from_seed(&[0x55u8, logn as u8, 0xAA]);
503
504 let mut sk_buf = [0u8; sign_key_size(10)];
505 let mut vk_buf = [0u8; vrfy_key_size(10)];
506 let mut sig_storage = [0u8; signature_size(10)];
507 let sk_sl = &mut sk_buf[..sign_key_size(logn)];
508 let vk_sl = &mut vk_buf[..vrfy_key_size(logn)];
509 let sig = &mut sig_storage[..signature_size(logn)];
510
511 match logn {
512 9 => {
513 let mut kg = KeyPairGenerator512::default();
514 kg.keygen(logn, &mut rng, sk_sl, vk_sl);
515 let mut sk = SigningKey512::decode(sk_sl).unwrap();
516 sk.sign(&mut rng, &DOMAIN_NONE, &HASH_ID_SHA3_256, &hv, sig);
517 }
518 10 => {
519 let mut kg = KeyPairGenerator1024::default();
520 kg.keygen(logn, &mut rng, sk_sl, vk_sl);
521 let mut sk = SigningKey1024::decode(sk_sl).unwrap();
522 sk.sign(&mut rng, &DOMAIN_NONE, &HASH_ID_SHA3_256, &hv, sig);
523 }
524 _ => unreachable!(),
525 }
526
527 let vk = VerifyingKeyStandard::decode(vk_sl).unwrap();
528 assert!(vk.verify(sig, &DOMAIN_NONE, &HASH_ID_SHA3_256, &hv));
529
530 let mut hv_bad = hv;
531 hv_bad[0] ^= 0x01;
532 assert!(!vk.verify(sig, &DOMAIN_NONE, &HASH_ID_SHA3_256, &hv_bad));
533
534 let mut sig_bad = [0u8; signature_size(10)];
535 sig_bad[..sig.len()].copy_from_slice(sig);
536 sig_bad[50] ^= 0x01;
537 assert!(!vk.verify(&sig_bad[..sig.len()], &DOMAIN_NONE, &HASH_ID_SHA3_256, &hv));
538
539 let n = 1usize << logn;
540 let mut tmp_i16 = [0i16; 1 << 10];
541 let mut tmp_u16 = [0u16; 2 << 10];
542 assert!(verify_inner(
543 logn,
544 &vk.h[..n],
545 &vk.hashed_key,
546 sig,
547 &DOMAIN_NONE,
548 &HASH_ID_SHA3_256,
549 &hv,
550 &mut tmp_i16[..n],
551 &mut tmp_u16[..(2 * n)],
552 ));
553 assert!(!verify_inner(
554 logn,
555 &vk.h[..n],
556 &vk.hashed_key,
557 &sig_bad[..sig.len()],
558 &DOMAIN_NONE,
559 &HASH_ID_SHA3_256,
560 &hv,
561 &mut tmp_i16[..n],
562 &mut tmp_u16[..(2 * n)],
563 ));
564
565 #[cfg(all(
566 not(feature = "no_avx2"),
567 any(target_arch = "x86_64", target_arch = "x86")
568 ))]
569 if fn_dsa_comm::has_avx2() {
570 unsafe {
571 assert!(verify_avx2_inner(
572 logn,
573 &vk.h[..n],
574 &vk.hashed_key,
575 sig,
576 &DOMAIN_NONE,
577 &HASH_ID_SHA3_256,
578 &hv,
579 &mut tmp_i16[..n],
580 &mut tmp_u16[..(2 * n)],
581 ));
582 }
583 }
584 }
585
586 #[test]
587 fn verify_kat_512() {
588 verify_roundtrip_logn(9);
589 }
590
591 #[test]
592 fn verify_kat_1024() {
593 verify_roundtrip_logn(10);
594 }
595}