1use super::{
12 super::{variant::Variant, Error},
13 hash_with_namespace,
14};
15use bytes::{Buf, BufMut};
16use commonware_codec::{Error as CodecError, FixedSize, Read, ReadExt, Write};
17use commonware_math::algebra::Additive;
18use commonware_parallel::Strategy;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct PublicKey<V: Variant>(V::Public);
26
27impl<V: Variant> PublicKey<V> {
28 pub fn zero() -> Self {
30 Self(V::Public::zero())
31 }
32
33 pub(crate) const fn inner(&self) -> &V::Public {
35 &self.0
36 }
37
38 pub(crate) fn add(&mut self, other: &V::Public) {
40 self.0 += other;
41 }
42}
43
44impl<V: Variant> Write for PublicKey<V> {
45 fn write(&self, writer: &mut impl BufMut) {
46 self.0.write(writer);
47 }
48}
49
50impl<V: Variant> Read for PublicKey<V> {
51 type Cfg = ();
52
53 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
54 Ok(Self(V::Public::read(reader)?))
55 }
56}
57
58impl<V: Variant> FixedSize for PublicKey<V> {
59 const SIZE: usize = V::Public::SIZE;
60}
61
62#[cfg(feature = "arbitrary")]
63impl<V: Variant> arbitrary::Arbitrary<'_> for PublicKey<V>
64where
65 V::Public: for<'a> arbitrary::Arbitrary<'a>,
66{
67 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
68 Ok(Self(V::Public::arbitrary(u)?))
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub struct Signature<V: Variant>(V::Signature);
78
79impl<V: Variant> Signature<V> {
80 pub fn zero() -> Self {
82 Self(V::Signature::zero())
83 }
84
85 pub(crate) const fn inner(&self) -> &V::Signature {
87 &self.0
88 }
89
90 pub(crate) fn add(&mut self, other: &V::Signature) {
92 self.0 += other;
93 }
94}
95
96impl<V: Variant> Write for Signature<V> {
97 fn write(&self, writer: &mut impl BufMut) {
98 self.0.write(writer);
99 }
100}
101
102impl<V: Variant> Read for Signature<V> {
103 type Cfg = ();
104
105 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
106 Ok(Self(V::Signature::read(reader)?))
107 }
108}
109
110impl<V: Variant> FixedSize for Signature<V> {
111 const SIZE: usize = V::Signature::SIZE;
112}
113
114#[cfg(feature = "arbitrary")]
115impl<V: Variant> arbitrary::Arbitrary<'_> for Signature<V>
116where
117 V::Signature: for<'a> arbitrary::Arbitrary<'a>,
118{
119 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
120 Ok(Self(V::Signature::arbitrary(u)?))
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Hash)]
129pub struct Message<V: Variant>(V::Signature);
130
131impl<V: Variant> Message<V> {
132 pub fn zero() -> Self {
134 Self(V::Signature::zero())
135 }
136
137 pub(crate) const fn inner(&self) -> &V::Signature {
139 &self.0
140 }
141
142 pub(crate) fn add(&mut self, other: &V::Signature) {
144 self.0 += other;
145 }
146
147 pub(crate) fn combine(&mut self, other: &Self) {
149 self.0 += &other.0;
150 }
151}
152
153impl<V: Variant> Write for Message<V> {
154 fn write(&self, writer: &mut impl BufMut) {
155 self.0.write(writer);
156 }
157}
158
159impl<V: Variant> Read for Message<V> {
160 type Cfg = ();
161
162 fn read_cfg(reader: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
163 Ok(Self(V::Signature::read(reader)?))
164 }
165}
166
167impl<V: Variant> FixedSize for Message<V> {
168 const SIZE: usize = V::Signature::SIZE;
169}
170
171#[cfg(feature = "arbitrary")]
172impl<V: Variant> arbitrary::Arbitrary<'_> for Message<V>
173where
174 V::Signature: for<'a> arbitrary::Arbitrary<'a>,
175{
176 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
177 Ok(Self(V::Signature::arbitrary(u)?))
178 }
179}
180
181pub fn combine_public_keys<'a, V, I>(public_keys: I) -> PublicKey<V>
190where
191 V: Variant,
192 I: IntoIterator<Item = &'a V::Public>,
193 V::Public: 'a,
194{
195 let mut p = PublicKey::zero();
196 for pk in public_keys {
197 p.add(pk);
198 }
199 p
200}
201
202pub fn combine_signatures<'a, V, I>(signatures: I) -> Signature<V>
210where
211 V: Variant,
212 I: IntoIterator<Item = &'a V::Signature>,
213 V::Signature: 'a,
214{
215 let mut s = Signature::zero();
216 for sig in signatures {
217 s.add(sig);
218 }
219 s
220}
221
222pub fn combine_messages<'a, V, I>(messages: I, strategy: &impl Strategy) -> Message<V>
228where
229 V: Variant,
230 I: IntoIterator<Item = &'a (&'a [u8], &'a [u8])> + Send,
231 I::IntoIter: Send,
232{
233 strategy.fold(
234 messages,
235 Message::zero,
236 |mut sum, (namespace, msg)| {
237 let hm = hash_with_namespace::<V>(V::MESSAGE, namespace, msg);
238 sum.add(&hm);
239 sum
240 },
241 |mut a, b| {
242 a.combine(&b);
243 a
244 },
245 )
246}
247
248pub fn verify_same_message<V: Variant>(
262 public: &PublicKey<V>,
263 namespace: &[u8],
264 message: &[u8],
265 signature: &Signature<V>,
266) -> Result<(), Error> {
267 let hm = hash_with_namespace::<V>(V::MESSAGE, namespace, message);
268
269 V::verify(public.inner(), &hm, signature.inner())
271}
272
273pub fn verify_same_signer<V: Variant>(
285 public: &V::Public,
286 message: &Message<V>,
287 signature: &Signature<V>,
288) -> Result<(), Error> {
289 V::verify(public, message.inner(), signature.inner())
290}
291
292#[cfg(test)]
293mod tests {
294 use super::{
295 super::{aggregate, keypair, sign_message},
296 *,
297 };
298 use crate::bls12381::primitives::{
299 group::{G1_MESSAGE, G2_MESSAGE},
300 variant::{MinPk, MinSig},
301 Error,
302 };
303 use blst::BLST_ERROR;
304 use commonware_codec::Encode;
305 use commonware_parallel::{Rayon, Sequential};
306 use commonware_utils::{test_rng, union_unique, NZUsize};
307
308 fn blst_aggregate_verify_same_message<'a, V, I>(
309 public: I,
310 message: &[u8],
311 signature: &Signature<V>,
312 ) -> Result<(), BLST_ERROR>
313 where
314 V: Variant,
315 I: IntoIterator<Item = &'a V::Public>,
316 V::Public: 'a,
317 {
318 match V::MESSAGE {
319 G1_MESSAGE => {
320 let public = public
321 .into_iter()
322 .map(|pk| blst::min_sig::PublicKey::from_bytes(&pk.encode()).unwrap())
323 .collect::<Vec<_>>();
324 let public = public.iter().collect::<Vec<_>>();
325 let signature =
326 blst::min_sig::Signature::from_bytes(&signature.inner().encode()).unwrap();
327 match signature.fast_aggregate_verify(true, message, V::MESSAGE, &public) {
328 BLST_ERROR::BLST_SUCCESS => Ok(()),
329 e => Err(e),
330 }
331 }
332 G2_MESSAGE => {
333 let public = public
334 .into_iter()
335 .map(|pk| blst::min_pk::PublicKey::from_bytes(&pk.encode()).unwrap())
336 .collect::<Vec<_>>();
337 let public = public.iter().collect::<Vec<_>>();
338 let signature =
339 blst::min_pk::Signature::from_bytes(&signature.inner().encode()).unwrap();
340 match signature.fast_aggregate_verify(true, message, V::MESSAGE, &public) {
341 BLST_ERROR::BLST_SUCCESS => Ok(()),
342 e => Err(e),
343 }
344 }
345 _ => panic!("Unsupported Variant"),
346 }
347 }
348
349 fn aggregate_verify_same_message_correct<V: Variant>() {
350 let mut rng = test_rng();
351 let (private1, public1) = keypair::<_, V>(&mut rng);
352 let (private2, public2) = keypair::<_, V>(&mut rng);
353 let (private3, public3) = keypair::<_, V>(&mut rng);
354 let namespace = b"test";
355 let message = b"message";
356 let sig1 = sign_message::<V>(&private1, namespace, message);
357 let sig2 = sign_message::<V>(&private2, namespace, message);
358 let sig3 = sign_message::<V>(&private3, namespace, message);
359 let pks = vec![public1, public2, public3];
360 let signatures = vec![sig1, sig2, sig3];
361
362 let aggregate_pk = aggregate::combine_public_keys::<V, _>(&pks);
363 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
364
365 verify_same_message::<V>(&aggregate_pk, namespace, message, &aggregate_sig)
366 .expect("Aggregated signature should be valid");
367
368 let payload = union_unique(namespace, message);
369 blst_aggregate_verify_same_message::<V, _>(&pks, &payload, &aggregate_sig)
370 .expect("Aggregated signature should be valid");
371 }
372
373 #[test]
374 fn test_aggregate_verify_same_message() {
375 aggregate_verify_same_message_correct::<MinPk>();
376 aggregate_verify_same_message_correct::<MinSig>();
377 }
378
379 fn aggregate_verify_same_message_wrong_public_keys<V: Variant>() {
380 let mut rng = test_rng();
381 let (private1, public1) = keypair::<_, V>(&mut rng);
382 let (private2, public2) = keypair::<_, V>(&mut rng);
383 let (private3, _) = keypair::<_, V>(&mut rng);
384 let namespace = b"test";
385 let message = b"message";
386 let sig1 = sign_message::<V>(&private1, namespace, message);
387 let sig2 = sign_message::<V>(&private2, namespace, message);
388 let sig3 = sign_message::<V>(&private3, namespace, message);
389 let signatures = vec![sig1, sig2, sig3];
390
391 let (_, public4) = keypair::<_, V>(&mut rng);
392 let wrong_pks = vec![public1, public2, public4];
393 let wrong_aggregate_pk = aggregate::combine_public_keys::<V, _>(&wrong_pks);
394 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
395 let result =
396 verify_same_message::<V>(&wrong_aggregate_pk, namespace, message, &aggregate_sig);
397 assert!(matches!(result, Err(Error::InvalidSignature)));
398 }
399
400 #[test]
401 fn test_aggregate_verify_same_message_wrong_public_keys() {
402 aggregate_verify_same_message_wrong_public_keys::<MinPk>();
403 aggregate_verify_same_message_wrong_public_keys::<MinSig>();
404 }
405
406 fn aggregate_verify_same_message_wrong_public_key_count<V: Variant>() {
407 let mut rng = test_rng();
408 let (private1, public1) = keypair::<_, V>(&mut rng);
409 let (private2, public2) = keypair::<_, V>(&mut rng);
410 let (private3, _) = keypair::<_, V>(&mut rng);
411 let namespace = b"test";
412 let message = b"message";
413 let sig1 = sign_message::<V>(&private1, namespace, message);
414 let sig2 = sign_message::<V>(&private2, namespace, message);
415 let sig3 = sign_message::<V>(&private3, namespace, message);
416 let signatures = vec![sig1, sig2, sig3];
417
418 let wrong_pks = vec![public1, public2];
419 let wrong_aggregate_pk = aggregate::combine_public_keys::<V, _>(&wrong_pks);
420 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
421 let result =
422 verify_same_message::<V>(&wrong_aggregate_pk, namespace, message, &aggregate_sig);
423 assert!(matches!(result, Err(Error::InvalidSignature)));
424 }
425
426 #[test]
427 fn test_aggregate_verify_same_message_wrong_public_key_count() {
428 aggregate_verify_same_message_wrong_public_key_count::<MinPk>();
429 aggregate_verify_same_message_wrong_public_key_count::<MinSig>();
430 }
431
432 fn blst_aggregate_verify_same_signer<'a, V, I>(
433 public: &V::Public,
434 msgs: I,
435 signature: &Signature<V>,
436 ) -> Result<(), BLST_ERROR>
437 where
438 V: Variant,
439 I: IntoIterator<Item = &'a [u8]>,
440 {
441 match V::MESSAGE {
442 G1_MESSAGE => {
443 let public = blst::min_sig::PublicKey::from_bytes(&public.encode()).unwrap();
444 let msgs = msgs.into_iter().collect::<Vec<_>>();
445 let pks = vec![&public; msgs.len()];
446 let signature =
447 blst::min_sig::Signature::from_bytes(&signature.inner().encode()).unwrap();
448 match signature.aggregate_verify(true, &msgs, V::MESSAGE, &pks, true) {
449 BLST_ERROR::BLST_SUCCESS => Ok(()),
450 e => Err(e),
451 }
452 }
453 G2_MESSAGE => {
454 let public = blst::min_pk::PublicKey::from_bytes(&public.encode()).unwrap();
455 let msgs = msgs.into_iter().collect::<Vec<_>>();
456 let pks = vec![&public; msgs.len()];
457 let signature =
458 blst::min_pk::Signature::from_bytes(&signature.inner().encode()).unwrap();
459 match signature.aggregate_verify(true, &msgs, V::MESSAGE, &pks, true) {
460 BLST_ERROR::BLST_SUCCESS => Ok(()),
461 e => Err(e),
462 }
463 }
464 _ => panic!("Unsupported Variant"),
465 }
466 }
467
468 fn aggregate_verify_same_signer_correct<V: Variant>() {
469 let (private, public) = keypair::<_, V>(&mut test_rng());
470 let namespace = b"test";
471 let messages: Vec<(&[u8], &[u8])> = vec![
472 (namespace, b"Message 1"),
473 (namespace, b"Message 2"),
474 (namespace, b"Message 3"),
475 ];
476 let signatures: Vec<_> = messages
477 .iter()
478 .map(|(namespace, msg)| sign_message::<V>(&private, namespace, msg))
479 .collect();
480
481 let aggregate_sig = aggregate::combine_signatures::<V, _>(&signatures);
482
483 let combined_msg = aggregate::combine_messages::<V, _>(&messages, &Sequential);
484 aggregate::verify_same_signer::<V>(&public, &combined_msg, &aggregate_sig)
485 .expect("Aggregated signature should be valid");
486
487 let parallel = Rayon::new(NZUsize!(4)).unwrap();
488 let combined_msg_parallel = aggregate::combine_messages::<V, _>(&messages, ¶llel);
489 aggregate::verify_same_signer::<V>(&public, &combined_msg_parallel, &aggregate_sig)
490 .expect("Aggregated signature should be valid with parallelism");
491
492 let payload_msgs: Vec<_> = messages
493 .iter()
494 .map(|(ns, msg)| union_unique(ns, msg))
495 .collect();
496 let payload_refs: Vec<&[u8]> = payload_msgs.iter().map(|p| p.as_ref()).collect();
497 blst_aggregate_verify_same_signer::<V, _>(&public, payload_refs, &aggregate_sig)
498 .expect("blst should also accept aggregated signature");
499 }
500
501 #[test]
502 fn test_aggregate_verify_same_signer_correct() {
503 aggregate_verify_same_signer_correct::<MinPk>();
504 aggregate_verify_same_signer_correct::<MinSig>();
505 }
506
507 #[cfg(feature = "arbitrary")]
508 mod conformance {
509 use super::*;
510 use commonware_codec::conformance::CodecConformance;
511
512 commonware_conformance::conformance_tests! {
513 CodecConformance<PublicKey<MinSig>>,
514 CodecConformance<Message<MinSig>>,
515 CodecConformance<Signature<MinSig>>,
516 }
517 }
518}