1use crate::{PublicKey, SignatureBytes, SigningError};
41use serde::{Deserialize, Serialize};
42use thiserror::Error;
43
44pub type Signature = SignatureBytes;
46
47#[derive(Clone, Debug, Serialize, Deserialize)]
49pub struct AggregateSignature {
50 #[serde(with = "serde_pubkey_vec")]
52 public_keys: Vec<PublicKey>,
53 #[serde(with = "serde_signature_vec")]
55 signatures: Vec<Signature>,
56 #[serde(with = "serde_bytes")]
58 message_hash: Vec<u8>,
59}
60
61#[derive(Default)]
63pub struct SignatureAggregator {
64 entries: Vec<(PublicKey, Signature)>,
66}
67
68#[derive(Debug, Error)]
70pub enum AggregateError {
71 #[error("No signatures provided")]
73 NoSignatures,
74
75 #[error("Duplicate public key in aggregate")]
77 DuplicatePublicKey,
78
79 #[error("Signature verification failed")]
81 VerificationFailed,
82
83 #[error("Invalid signature: {0}")]
85 InvalidSignature(#[from] SigningError),
86}
87
88pub type AggregateResult<T> = Result<T, AggregateError>;
89
90mod serde_bytes {
92 use serde::{Deserialize, Deserializer, Serializer};
93
94 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
95 where
96 S: Serializer,
97 {
98 serializer.serialize_bytes(bytes)
99 }
100
101 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
102 where
103 D: Deserializer<'de>,
104 {
105 <Vec<u8>>::deserialize(deserializer)
106 }
107}
108
109mod serde_pubkey_vec {
111 use serde::{Deserialize, Deserializer, Serialize, Serializer};
112
113 pub fn serialize<S>(keys: &[[u8; 32]], serializer: S) -> Result<S::Ok, S::Error>
114 where
115 S: Serializer,
116 {
117 let bytes_vec: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
118 bytes_vec.serialize(serializer)
119 }
120
121 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<[u8; 32]>, D::Error>
122 where
123 D: Deserializer<'de>,
124 {
125 let vec_of_vecs: Vec<Vec<u8>> = Vec::deserialize(deserializer)?;
126 vec_of_vecs
127 .into_iter()
128 .map(|v| {
129 if v.len() != 32 {
130 return Err(serde::de::Error::custom("Expected 32 bytes"));
131 }
132 let mut arr = [0u8; 32];
133 arr.copy_from_slice(&v);
134 Ok(arr)
135 })
136 .collect()
137 }
138}
139
140mod serde_signature_vec {
142 use serde::{Deserialize, Deserializer, Serialize, Serializer};
143
144 pub fn serialize<S>(sigs: &[[u8; 64]], serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: Serializer,
147 {
148 let bytes_vec: Vec<&[u8]> = sigs.iter().map(|s| s.as_slice()).collect();
149 bytes_vec.serialize(serializer)
150 }
151
152 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<[u8; 64]>, D::Error>
153 where
154 D: Deserializer<'de>,
155 {
156 let vec_of_vecs: Vec<Vec<u8>> = Vec::deserialize(deserializer)?;
157 vec_of_vecs
158 .into_iter()
159 .map(|v| {
160 if v.len() != 64 {
161 return Err(serde::de::Error::custom("Expected 64 bytes"));
162 }
163 let mut arr = [0u8; 64];
164 arr.copy_from_slice(&v);
165 Ok(arr)
166 })
167 .collect()
168 }
169}
170
171impl AggregateSignature {
172 pub fn new(
179 public_keys: Vec<PublicKey>,
180 signatures: Vec<Signature>,
181 message: &[u8],
182 ) -> AggregateResult<Self> {
183 if public_keys.is_empty() || signatures.is_empty() {
184 return Err(AggregateError::NoSignatures);
185 }
186
187 if public_keys.len() != signatures.len() {
188 return Err(AggregateError::VerificationFailed);
189 }
190
191 for i in 0..public_keys.len() {
193 for j in (i + 1)..public_keys.len() {
194 if public_keys[i] == public_keys[j] {
195 return Err(AggregateError::DuplicatePublicKey);
196 }
197 }
198 }
199
200 let message_hash = blake3::hash(message).as_bytes().to_vec();
201
202 Ok(Self {
203 public_keys,
204 signatures,
205 message_hash,
206 })
207 }
208
209 pub fn verify(&self, message: &[u8]) -> AggregateResult<()> {
217 let expected_hash = blake3::hash(message);
219 if expected_hash.as_bytes() != self.message_hash.as_slice() {
220 return Err(AggregateError::VerificationFailed);
221 }
222
223 for (public_key, signature) in self.public_keys.iter().zip(self.signatures.iter()) {
225 crate::verify(public_key, message, signature)?;
226 }
227
228 Ok(())
229 }
230
231 pub fn count(&self) -> usize {
233 self.signatures.len()
234 }
235
236 pub fn public_keys(&self) -> &[PublicKey] {
238 &self.public_keys
239 }
240
241 pub fn signatures(&self) -> &[Signature] {
243 &self.signatures
244 }
245
246 pub fn contains_signer(&self, public_key: &PublicKey) -> bool {
248 self.public_keys.contains(public_key)
249 }
250}
251
252impl SignatureAggregator {
253 pub fn new() -> Self {
255 Self {
256 entries: Vec::new(),
257 }
258 }
259
260 pub fn add_signature(&mut self, public_key: &PublicKey, signature: &Signature) {
266 self.entries.push((*public_key, *signature));
267 }
268
269 pub fn add_signatures(&mut self, entries: &[(PublicKey, Signature)]) {
271 self.entries.extend_from_slice(entries);
272 }
273
274 pub fn len(&self) -> usize {
276 self.entries.len()
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.entries.is_empty()
282 }
283
284 pub fn finalize(self, message: &[u8]) -> AggregateResult<AggregateSignature> {
292 if self.entries.is_empty() {
293 return Err(AggregateError::NoSignatures);
294 }
295
296 let (public_keys, signatures): (Vec<_>, Vec<_>) = self.entries.into_iter().unzip();
297
298 AggregateSignature::new(public_keys, signatures, message)
299 }
300
301 pub fn clear(&mut self) {
303 self.entries.clear();
304 }
305}
306
307pub fn verify_batch(
319 public_keys: &[PublicKey],
320 signatures: &[Signature],
321 message: &[u8],
322) -> AggregateResult<()> {
323 let aggregate = AggregateSignature::new(public_keys.to_vec(), signatures.to_vec(), message)?;
324 aggregate.verify(message)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::KeyPair;
331
332 #[test]
333 fn test_aggregate_signature_basic() {
334 let message = b"test message";
335
336 let keypair1 = KeyPair::generate();
337 let keypair2 = KeyPair::generate();
338 let keypair3 = KeyPair::generate();
339
340 let sig1 = keypair1.sign(message);
341 let sig2 = keypair2.sign(message);
342 let sig3 = keypair3.sign(message);
343
344 let aggregate = AggregateSignature::new(
345 vec![
346 keypair1.public_key(),
347 keypair2.public_key(),
348 keypair3.public_key(),
349 ],
350 vec![sig1, sig2, sig3],
351 message,
352 )
353 .unwrap();
354
355 assert!(aggregate.verify(message).is_ok());
356 assert_eq!(aggregate.count(), 3);
357 }
358
359 #[test]
360 fn test_signature_aggregator() {
361 let message = b"bandwidth proof";
362
363 let keypair1 = KeyPair::generate();
364 let keypair2 = KeyPair::generate();
365
366 let sig1 = keypair1.sign(message);
367 let sig2 = keypair2.sign(message);
368
369 let mut aggregator = SignatureAggregator::new();
370 aggregator.add_signature(&keypair1.public_key(), &sig1);
371 aggregator.add_signature(&keypair2.public_key(), &sig2);
372
373 assert_eq!(aggregator.len(), 2);
374
375 let aggregate = aggregator.finalize(message).unwrap();
376 assert!(aggregate.verify(message).is_ok());
377 }
378
379 #[test]
380 fn test_wrong_message_fails() {
381 let message1 = b"message 1";
382 let message2 = b"message 2";
383
384 let keypair = KeyPair::generate();
385 let sig = keypair.sign(message1);
386
387 let aggregate =
388 AggregateSignature::new(vec![keypair.public_key()], vec![sig], message1).unwrap();
389
390 assert!(aggregate.verify(message2).is_err());
392 }
393
394 #[test]
395 fn test_duplicate_public_key_rejected() {
396 let message = b"test";
397
398 let keypair = KeyPair::generate();
399 let sig1 = keypair.sign(message);
400 let sig2 = keypair.sign(message);
401
402 let result = AggregateSignature::new(
403 vec![keypair.public_key(), keypair.public_key()],
404 vec![sig1, sig2],
405 message,
406 );
407
408 assert!(matches!(result, Err(AggregateError::DuplicatePublicKey)));
409 }
410
411 #[test]
412 fn test_empty_aggregate_rejected() {
413 let result = AggregateSignature::new(vec![], vec![], b"test");
414 assert!(matches!(result, Err(AggregateError::NoSignatures)));
415 }
416
417 #[test]
418 fn test_mismatched_lengths_rejected() {
419 let keypair1 = KeyPair::generate();
420 let keypair2 = KeyPair::generate();
421 let sig = keypair1.sign(b"test");
422
423 let result = AggregateSignature::new(
424 vec![keypair1.public_key(), keypair2.public_key()],
425 vec![sig],
426 b"test",
427 );
428
429 assert!(matches!(result, Err(AggregateError::VerificationFailed)));
430 }
431
432 #[test]
433 fn test_contains_signer() {
434 let message = b"test";
435
436 let keypair1 = KeyPair::generate();
437 let keypair2 = KeyPair::generate();
438 let keypair3 = KeyPair::generate();
439
440 let sig1 = keypair1.sign(message);
441 let sig2 = keypair2.sign(message);
442
443 let aggregate = AggregateSignature::new(
444 vec![keypair1.public_key(), keypair2.public_key()],
445 vec![sig1, sig2],
446 message,
447 )
448 .unwrap();
449
450 assert!(aggregate.contains_signer(&keypair1.public_key()));
451 assert!(aggregate.contains_signer(&keypair2.public_key()));
452 assert!(!aggregate.contains_signer(&keypair3.public_key()));
453 }
454
455 #[test]
456 fn test_verify_batch() {
457 let message = b"batch test";
458
459 let keypair1 = KeyPair::generate();
460 let keypair2 = KeyPair::generate();
461
462 let sig1 = keypair1.sign(message);
463 let sig2 = keypair2.sign(message);
464
465 let result = verify_batch(
466 &[keypair1.public_key(), keypair2.public_key()],
467 &[sig1, sig2],
468 message,
469 );
470
471 assert!(result.is_ok());
472 }
473
474 #[test]
475 fn test_aggregator_clear() {
476 let mut aggregator = SignatureAggregator::new();
477
478 let keypair = KeyPair::generate();
479 let sig = keypair.sign(b"test");
480
481 aggregator.add_signature(&keypair.public_key(), &sig);
482 assert_eq!(aggregator.len(), 1);
483
484 aggregator.clear();
485 assert_eq!(aggregator.len(), 0);
486 assert!(aggregator.is_empty());
487 }
488
489 #[test]
490 fn test_serialization() {
491 let message = b"serialize test";
492
493 let keypair = KeyPair::generate();
494 let sig = keypair.sign(message);
495
496 let aggregate =
497 AggregateSignature::new(vec![keypair.public_key()], vec![sig], message).unwrap();
498
499 let serialized = crate::codec::encode(&aggregate).unwrap();
501 let deserialized: AggregateSignature = crate::codec::decode(&serialized).unwrap();
502
503 assert!(deserialized.verify(message).is_ok());
504 assert_eq!(aggregate.count(), deserialized.count());
505 }
506}