1use super::{
12 super::{variant::Variant, Error},
13 hash_with_namespace,
14};
15#[cfg(not(feature = "std"))]
16use alloc::vec::Vec;
17use bytes::{Buf, BufMut};
18use commonware_codec::{Error as CodecError, FixedSize, Read, ReadExt, Write};
19use commonware_math::algebra::Additive;
20use commonware_parallel::Strategy;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub struct PublicKey<V: Variant>(V::Public);
28
29impl<V: Variant> PublicKey<V> {
30 pub fn zero() -> Self {
32 Self(V::Public::zero())
33 }
34
35 pub(crate) const fn inner(&self) -> &V::Public {
37 &self.0
38 }
39
40 pub(crate) fn add(&mut self, other: &V::Public) {
42 self.0 += other;
43 }
44}
45
46impl<V: Variant> Write for PublicKey<V> {
47 fn write(&self, writer: &mut impl BufMut) {
48 self.0.write(writer);
49 }
50}
51
52impl<V: Variant> Read for PublicKey<V> {
53 type Cfg = ();
54
55 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
56 Ok(Self(V::Public::read(reader)?))
57 }
58}
59
60impl<V: Variant> FixedSize for PublicKey<V> {
61 const SIZE: usize = V::Public::SIZE;
62}
63
64#[cfg(feature = "arbitrary")]
65impl<V: Variant> arbitrary::Arbitrary<'_> for PublicKey<V>
66where
67 V::Public: for<'a> arbitrary::Arbitrary<'a>,
68{
69 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
70 Ok(Self(V::Public::arbitrary(u)?))
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Hash)]
79pub struct Signature<V: Variant>(V::Signature);
80
81impl<V: Variant> Signature<V> {
82 pub fn zero() -> Self {
84 Self(V::Signature::zero())
85 }
86
87 pub(crate) const fn inner(&self) -> &V::Signature {
89 &self.0
90 }
91
92 pub(crate) fn add(&mut self, other: &V::Signature) {
94 self.0 += other;
95 }
96}
97
98impl<V: Variant> Write for Signature<V> {
99 fn write(&self, writer: &mut impl BufMut) {
100 self.0.write(writer);
101 }
102}
103
104impl<V: Variant> Read for Signature<V> {
105 type Cfg = ();
106
107 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
108 Ok(Self(V::Signature::read(reader)?))
109 }
110}
111
112impl<V: Variant> FixedSize for Signature<V> {
113 const SIZE: usize = V::Signature::SIZE;
114}
115
116#[cfg(feature = "arbitrary")]
117impl<V: Variant> arbitrary::Arbitrary<'_> for Signature<V>
118where
119 V::Signature: for<'a> arbitrary::Arbitrary<'a>,
120{
121 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
122 Ok(Self(V::Signature::arbitrary(u)?))
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Hash)]
131pub struct Message<V: Variant>(V::Signature);
132
133impl<V: Variant> Message<V> {
134 pub fn zero() -> Self {
136 Self(V::Signature::zero())
137 }
138
139 pub(crate) const fn inner(&self) -> &V::Signature {
141 &self.0
142 }
143
144 pub(crate) fn add(&mut self, other: &V::Signature) {
146 self.0 += other;
147 }
148
149 pub(crate) fn combine(&mut self, other: &Self) {
151 self.0 += &other.0;
152 }
153}
154
155impl<V: Variant> Write for Message<V> {
156 fn write(&self, writer: &mut impl BufMut) {
157 self.0.write(writer);
158 }
159}
160
161impl<V: Variant> Read for Message<V> {
162 type Cfg = ();
163
164 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
165 Ok(Self(V::Signature::read(reader)?))
166 }
167}
168
169impl<V: Variant> FixedSize for Message<V> {
170 const SIZE: usize = V::Signature::SIZE;
171}
172
173#[cfg(feature = "arbitrary")]
174impl<V: Variant> arbitrary::Arbitrary<'_> for Message<V>
175where
176 V::Signature: for<'a> arbitrary::Arbitrary<'a>,
177{
178 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
179 Ok(Self(V::Signature::arbitrary(u)?))
180 }
181}
182
183pub fn combine_public_keys<'a, V, I>(public_keys: I) -> PublicKey<V>
192where
193 V: Variant,
194 I: IntoIterator<Item = &'a V::Public>,
195 V::Public: 'a,
196{
197 let mut p = PublicKey::zero();
198 for pk in public_keys {
199 p.add(pk);
200 }
201 p
202}
203
204pub fn combine_signatures<'a, V, I>(signatures: I) -> Signature<V>
212where
213 V: Variant,
214 I: IntoIterator<Item = &'a V::Signature>,
215 V::Signature: 'a,
216{
217 let mut s = Signature::zero();
218 for sig in signatures {
219 s.add(sig);
220 }
221 s
222}
223
224pub fn combine_messages<'a, V, I>(messages: I, strategy: &impl Strategy) -> Message<V>
230where
231 V: Variant,
232 I: IntoIterator<Item = &'a (&'a [u8], &'a [u8])> + Send,
233 I::IntoIter: Send,
234{
235 strategy.fold(
236 messages,
237 Message::zero,
238 |mut sum, (namespace, msg)| {
239 let hm = hash_with_namespace::<V>(V::MESSAGE, namespace, msg);
240 sum.add(&hm);
241 sum
242 },
243 |mut a, b| {
244 a.combine(&b);
245 a
246 },
247 )
248}
249
250pub fn verify_same_message<V: Variant>(
264 public: &PublicKey<V>,
265 namespace: &[u8],
266 message: &[u8],
267 signature: &Signature<V>,
268) -> Result<(), Error> {
269 let hm = hash_with_namespace::<V>(V::MESSAGE, namespace, message);
270
271 V::verify(public.inner(), &hm, signature.inner())
273}
274
275pub fn verify_same_signer<V: Variant>(
287 public: &V::Public,
288 message: &Message<V>,
289 signature: &Signature<V>,
290) -> Result<(), Error> {
291 V::verify(public, message.inner(), signature.inner())
292}
293
294#[cfg(test)]
295mod tests {
296 use super::{
297 super::{aggregate, keypair, sign_message},
298 *,
299 };
300 use crate::bls12381::primitives::{
301 group::{G1_MESSAGE, G2_MESSAGE},
302 variant::{MinPk, MinSig},
303 Error,
304 };
305 use blst::BLST_ERROR;
306 use commonware_codec::Encode;
307 use commonware_parallel::{Rayon, Sequential};
308 use commonware_utils::{test_rng, union_unique, NZUsize};
309
310 fn blst_aggregate_verify_same_message<'a, V, I>(
311 public: I,
312 message: &[u8],
313 signature: &Signature<V>,
314 ) -> Result<(), BLST_ERROR>
315 where
316 V: Variant,
317 I: IntoIterator<Item = &'a V::Public>,
318 V::Public: 'a,
319 {
320 match V::MESSAGE {
321 G1_MESSAGE => {
322 let public = public
323 .into_iter()
324 .map(|pk| blst::min_sig::PublicKey::from_bytes(&pk.encode()).unwrap())
325 .collect::<Vec<_>>();
326 let public = public.iter().collect::<Vec<_>>();
327 let signature =
328 blst::min_sig::Signature::from_bytes(&signature.inner().encode()).unwrap();
329 match signature.fast_aggregate_verify(true, message, V::MESSAGE, &public) {
330 BLST_ERROR::BLST_SUCCESS => Ok(()),
331 e => Err(e),
332 }
333 }
334 G2_MESSAGE => {
335 let public = public
336 .into_iter()
337 .map(|pk| blst::min_pk::PublicKey::from_bytes(&pk.encode()).unwrap())
338 .collect::<Vec<_>>();
339 let public = public.iter().collect::<Vec<_>>();
340 let signature =
341 blst::min_pk::Signature::from_bytes(&signature.inner().encode()).unwrap();
342 match signature.fast_aggregate_verify(true, message, V::MESSAGE, &public) {
343 BLST_ERROR::BLST_SUCCESS => Ok(()),
344 e => Err(e),
345 }
346 }
347 _ => panic!("Unsupported Variant"),
348 }
349 }
350
351 fn aggregate_verify_same_message_correct<V: Variant>() {
352 let mut rng = test_rng();
353 let (private1, public1) = keypair::<_, V>(&mut rng);
354 let (private2, public2) = keypair::<_, V>(&mut rng);
355 let (private3, public3) = keypair::<_, V>(&mut rng);
356 let namespace = b"test";
357 let message = b"message";
358 let sig1 = sign_message::<V>(&private1, namespace, message);
359 let sig2 = sign_message::<V>(&private2, namespace, message);
360 let sig3 = sign_message::<V>(&private3, namespace, message);
361 let pks = vec![public1, public2, public3];
362 let signatures = vec![sig1, sig2, sig3];
363
364 let aggregate_pk = aggregate::combine_public_keys::<V, _>(&pks);
365 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
366
367 verify_same_message::<V>(&aggregate_pk, namespace, message, &aggregate_sig)
368 .expect("Aggregated signature should be valid");
369
370 let payload = union_unique(namespace, message);
371 blst_aggregate_verify_same_message::<V, _>(&pks, &payload, &aggregate_sig)
372 .expect("Aggregated signature should be valid");
373 }
374
375 #[test]
376 fn test_aggregate_verify_same_message() {
377 aggregate_verify_same_message_correct::<MinPk>();
378 aggregate_verify_same_message_correct::<MinSig>();
379 }
380
381 fn aggregate_verify_same_message_wrong_public_keys<V: Variant>() {
382 let mut rng = test_rng();
383 let (private1, public1) = keypair::<_, V>(&mut rng);
384 let (private2, public2) = keypair::<_, V>(&mut rng);
385 let (private3, _) = keypair::<_, V>(&mut rng);
386 let namespace = b"test";
387 let message = b"message";
388 let sig1 = sign_message::<V>(&private1, namespace, message);
389 let sig2 = sign_message::<V>(&private2, namespace, message);
390 let sig3 = sign_message::<V>(&private3, namespace, message);
391 let signatures = vec![sig1, sig2, sig3];
392
393 let (_, public4) = keypair::<_, V>(&mut rng);
394 let wrong_pks = vec![public1, public2, public4];
395 let wrong_aggregate_pk = aggregate::combine_public_keys::<V, _>(&wrong_pks);
396 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
397 let result =
398 verify_same_message::<V>(&wrong_aggregate_pk, namespace, message, &aggregate_sig);
399 assert!(matches!(result, Err(Error::InvalidSignature)));
400 }
401
402 #[test]
403 fn test_aggregate_verify_same_message_wrong_public_keys() {
404 aggregate_verify_same_message_wrong_public_keys::<MinPk>();
405 aggregate_verify_same_message_wrong_public_keys::<MinSig>();
406 }
407
408 fn aggregate_verify_same_message_wrong_public_key_count<V: Variant>() {
409 let mut rng = test_rng();
410 let (private1, public1) = keypair::<_, V>(&mut rng);
411 let (private2, public2) = keypair::<_, V>(&mut rng);
412 let (private3, _) = keypair::<_, V>(&mut rng);
413 let namespace = b"test";
414 let message = b"message";
415 let sig1 = sign_message::<V>(&private1, namespace, message);
416 let sig2 = sign_message::<V>(&private2, namespace, message);
417 let sig3 = sign_message::<V>(&private3, namespace, message);
418 let signatures = vec![sig1, sig2, sig3];
419
420 let wrong_pks = vec![public1, public2];
421 let wrong_aggregate_pk = aggregate::combine_public_keys::<V, _>(&wrong_pks);
422 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
423 let result =
424 verify_same_message::<V>(&wrong_aggregate_pk, namespace, message, &aggregate_sig);
425 assert!(matches!(result, Err(Error::InvalidSignature)));
426 }
427
428 #[test]
429 fn test_aggregate_verify_same_message_wrong_public_key_count() {
430 aggregate_verify_same_message_wrong_public_key_count::<MinPk>();
431 aggregate_verify_same_message_wrong_public_key_count::<MinSig>();
432 }
433
434 fn blst_aggregate_verify_same_signer<'a, V, I>(
435 public: &V::Public,
436 msgs: I,
437 signature: &Signature<V>,
438 ) -> Result<(), BLST_ERROR>
439 where
440 V: Variant,
441 I: IntoIterator<Item = &'a [u8]>,
442 {
443 match V::MESSAGE {
444 G1_MESSAGE => {
445 let public = blst::min_sig::PublicKey::from_bytes(&public.encode()).unwrap();
446 let msgs = msgs.into_iter().collect::<Vec<_>>();
447 let pks = vec![&public; msgs.len()];
448 let signature =
449 blst::min_sig::Signature::from_bytes(&signature.inner().encode()).unwrap();
450 match signature.aggregate_verify(true, &msgs, V::MESSAGE, &pks, true) {
451 BLST_ERROR::BLST_SUCCESS => Ok(()),
452 e => Err(e),
453 }
454 }
455 G2_MESSAGE => {
456 let public = blst::min_pk::PublicKey::from_bytes(&public.encode()).unwrap();
457 let msgs = msgs.into_iter().collect::<Vec<_>>();
458 let pks = vec![&public; msgs.len()];
459 let signature =
460 blst::min_pk::Signature::from_bytes(&signature.inner().encode()).unwrap();
461 match signature.aggregate_verify(true, &msgs, V::MESSAGE, &pks, true) {
462 BLST_ERROR::BLST_SUCCESS => Ok(()),
463 e => Err(e),
464 }
465 }
466 _ => panic!("Unsupported Variant"),
467 }
468 }
469
470 fn aggregate_verify_same_signer_correct<V: Variant>() {
471 let (private, public) = keypair::<_, V>(&mut test_rng());
472 let namespace = b"test";
473 let messages: Vec<(&[u8], &[u8])> = vec![
474 (namespace, b"Message 1"),
475 (namespace, b"Message 2"),
476 (namespace, b"Message 3"),
477 ];
478 let signatures: Vec<_> = messages
479 .iter()
480 .map(|(namespace, msg)| sign_message::<V>(&private, namespace, msg))
481 .collect();
482
483 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
484
485 let combined_msg = aggregate::combine_messages::<V, _>(&messages, &Sequential);
486 aggregate::verify_same_signer::<V>(&public, &combined_msg, &aggregate_sig)
487 .expect("Aggregated signature should be valid");
488
489 let parallel = Rayon::new(NZUsize!(4)).unwrap();
490 let combined_msg_parallel = aggregate::combine_messages::<V, _>(&messages, ¶llel);
491 aggregate::verify_same_signer::<V>(&public, &combined_msg_parallel, &aggregate_sig)
492 .expect("Aggregated signature should be valid with parallelism");
493
494 let payload_msgs: Vec<_> = messages
495 .iter()
496 .map(|(ns, msg)| union_unique(ns, msg))
497 .collect();
498 let payload_refs: Vec<&[u8]> = payload_msgs.iter().map(|p| p.as_ref()).collect();
499 blst_aggregate_verify_same_signer::<V, _>(&public, payload_refs, &aggregate_sig)
500 .expect("blst should also accept aggregated signature");
501 }
502
503 #[test]
504 fn test_aggregate_verify_same_signer_correct() {
505 aggregate_verify_same_signer_correct::<MinPk>();
506 aggregate_verify_same_signer_correct::<MinSig>();
507 }
508
509 #[cfg(feature = "arbitrary")]
510 mod conformance {
511 use super::*;
512 use commonware_codec::conformance::CodecConformance;
513
514 commonware_conformance::conformance_tests! {
515 CodecConformance<PublicKey<MinSig>>,
516 CodecConformance<Message<MinSig>>,
517 CodecConformance<Signature<MinSig>>,
518 }
519 }
520}