1use blake3;
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50
51mod serde_bytes {
53 use serde::{Deserialize, Deserializer, Serializer};
54
55 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
56 where
57 S: Serializer,
58 {
59 serializer.serialize_bytes(bytes)
60 }
61
62 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
63 where
64 D: Deserializer<'de>,
65 {
66 <Vec<u8>>::deserialize(deserializer)
67 }
68}
69
70mod serde_bytes_32 {
72 use serde::{Deserialize, Deserializer, Serializer};
73
74 pub fn serialize<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
75 where
76 S: Serializer,
77 {
78 serializer.serialize_bytes(bytes)
79 }
80
81 pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
82 where
83 D: Deserializer<'de>,
84 {
85 let vec = <Vec<u8>>::deserialize(deserializer)?;
86 vec.try_into()
87 .map_err(|_| serde::de::Error::custom("Expected 32 bytes"))
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum AccumulatorError {
94 ElementNotFound,
96 InvalidProof,
98 EmptyAccumulator,
100 SerializationError(String),
102 ElementExists,
104}
105
106impl std::fmt::Display for AccumulatorError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 Self::ElementNotFound => write!(f, "Element not found in accumulator"),
110 Self::InvalidProof => write!(f, "Invalid membership proof"),
111 Self::EmptyAccumulator => write!(f, "Accumulator is empty"),
112 Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
113 Self::ElementExists => write!(f, "Element already exists in accumulator"),
114 }
115 }
116}
117
118impl std::error::Error for AccumulatorError {}
119
120pub type AccumulatorResult<T> = Result<T, AccumulatorError>;
122
123pub const ACCUMULATOR_DIGEST_SIZE: usize = 32;
125
126pub fn hash_element(element: &[u8]) -> [u8; 32] {
128 blake3::hash(element).into()
129}
130
131#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
133pub struct AccumulatorDigest {
134 #[serde(with = "serde_bytes_32")]
135 digest: [u8; ACCUMULATOR_DIGEST_SIZE],
136}
137
138impl AccumulatorDigest {
139 pub fn from_bytes(bytes: [u8; ACCUMULATOR_DIGEST_SIZE]) -> Self {
141 Self { digest: bytes }
142 }
143
144 pub fn as_bytes(&self) -> &[u8; ACCUMULATOR_DIGEST_SIZE] {
146 &self.digest
147 }
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize)]
152pub struct MembershipProof {
153 element_hash: [u8; 32],
154 #[serde(with = "serde_bytes")]
155 witness: Vec<u8>,
156}
157
158impl MembershipProof {
159 pub fn to_bytes(&self) -> AccumulatorResult<Vec<u8>> {
161 crate::codec::encode(self).map_err(|e| AccumulatorError::SerializationError(e.to_string()))
162 }
163
164 pub fn from_bytes(bytes: &[u8]) -> AccumulatorResult<Self> {
166 crate::codec::decode(bytes).map_err(|e| AccumulatorError::SerializationError(e.to_string()))
167 }
168}
169
170#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct HashAccumulator {
177 elements: HashMap<[u8; 32], Vec<u8>>,
178 digest: [u8; 32],
179}
180
181impl HashAccumulator {
182 pub fn new() -> Self {
184 Self {
185 elements: HashMap::new(),
186 digest: [0u8; 32],
187 }
188 }
189
190 pub fn from_elements(elements: &[&[u8]]) -> Self {
192 let mut acc = Self::new();
193 for elem in elements {
194 acc.add(elem);
195 }
196 acc
197 }
198
199 fn update_digest(&mut self) {
201 if self.elements.is_empty() {
202 self.digest = [0u8; 32];
203 return;
204 }
205
206 let mut sorted_hashes: Vec<_> = self.elements.keys().collect();
208 sorted_hashes.sort();
209
210 let mut hasher = blake3::Hasher::new();
211 for hash in sorted_hashes {
212 hasher.update(hash);
213 }
214 self.digest = hasher.finalize().into();
215 }
216
217 pub fn add(&mut self, element: &[u8]) -> bool {
221 let hash = hash_element(element);
222 let newly_added = self.elements.insert(hash, element.to_vec()).is_none();
223 if newly_added {
224 self.update_digest();
225 }
226 newly_added
227 }
228
229 pub fn remove(&mut self, element: &[u8]) -> bool {
233 let hash = hash_element(element);
234 let was_present = self.elements.remove(&hash).is_some();
235 if was_present {
236 self.update_digest();
237 }
238 was_present
239 }
240
241 pub fn contains(&self, element: &[u8]) -> bool {
243 let hash = hash_element(element);
244 self.elements.contains_key(&hash)
245 }
246
247 pub fn prove(&self, element: &[u8]) -> AccumulatorResult<MembershipProof> {
253 let element_hash = hash_element(element);
254
255 if !self.elements.contains_key(&element_hash) {
256 return Err(AccumulatorError::ElementNotFound);
257 }
258
259 let mut witness_hashes: Vec<_> = self
261 .elements
262 .keys()
263 .filter(|&&h| h != element_hash)
264 .collect();
265 witness_hashes.sort();
266
267 let mut witness = Vec::new();
268 for hash in witness_hashes {
269 witness.extend_from_slice(hash);
270 }
271
272 Ok(MembershipProof {
273 element_hash,
274 witness,
275 })
276 }
277
278 pub fn verify(&self, element: &[u8], proof: &MembershipProof) -> bool {
280 let element_hash = hash_element(element);
281
282 if element_hash != proof.element_hash {
283 return false;
284 }
285
286 let mut all_hashes = vec![element_hash];
288 for chunk in proof.witness.chunks(32) {
289 if chunk.len() == 32 {
290 let hash: [u8; 32] = chunk.try_into().unwrap();
291 all_hashes.push(hash);
292 }
293 }
294 all_hashes.sort();
295
296 let mut hasher = blake3::Hasher::new();
297 for hash in all_hashes {
298 hasher.update(&hash);
299 }
300 let computed_digest: [u8; 32] = hasher.finalize().into();
301
302 computed_digest == self.digest
303 }
304
305 pub fn digest(&self) -> AccumulatorDigest {
307 AccumulatorDigest::from_bytes(self.digest)
308 }
309
310 pub fn len(&self) -> usize {
312 self.elements.len()
313 }
314
315 pub fn is_empty(&self) -> bool {
317 self.elements.is_empty()
318 }
319
320 pub fn elements(&self) -> Vec<Vec<u8>> {
322 self.elements.values().cloned().collect()
323 }
324
325 pub fn add_batch(&mut self, elements: &[&[u8]]) -> usize {
327 let mut added = 0;
328 for elem in elements {
329 let hash = hash_element(elem);
330 if self.elements.insert(hash, elem.to_vec()).is_none() {
331 added += 1;
332 }
333 }
334 if added > 0 {
335 self.update_digest();
336 }
337 added
338 }
339
340 pub fn remove_batch(&mut self, elements: &[&[u8]]) -> usize {
342 let mut removed = 0;
343 for elem in elements {
344 let hash = hash_element(elem);
345 if self.elements.remove(&hash).is_some() {
346 removed += 1;
347 }
348 }
349 if removed > 0 {
350 self.update_digest();
351 }
352 removed
353 }
354
355 pub fn to_bytes(&self) -> AccumulatorResult<Vec<u8>> {
357 crate::codec::encode(self).map_err(|e| AccumulatorError::SerializationError(e.to_string()))
358 }
359
360 pub fn from_bytes(bytes: &[u8]) -> AccumulatorResult<Self> {
362 crate::codec::decode(bytes).map_err(|e| AccumulatorError::SerializationError(e.to_string()))
363 }
364}
365
366impl Default for HashAccumulator {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371
372#[derive(Clone, Debug, Serialize, Deserialize)]
377pub struct CompactAccumulator {
378 digest: [u8; 32],
379 count: usize,
380}
381
382impl CompactAccumulator {
383 pub fn from_accumulator(acc: &HashAccumulator) -> Self {
385 Self {
386 digest: acc.digest,
387 count: acc.len(),
388 }
389 }
390
391 pub fn new(digest: [u8; 32], count: usize) -> Self {
393 Self { digest, count }
394 }
395
396 pub fn digest(&self) -> AccumulatorDigest {
398 AccumulatorDigest::from_bytes(self.digest)
399 }
400
401 pub fn count(&self) -> usize {
403 self.count
404 }
405
406 pub fn verify(&self, element: &[u8], proof: &MembershipProof) -> bool {
408 let element_hash = hash_element(element);
409
410 if element_hash != proof.element_hash {
411 return false;
412 }
413
414 let mut all_hashes = vec![element_hash];
416 for chunk in proof.witness.chunks(32) {
417 if chunk.len() == 32 {
418 let hash: [u8; 32] = chunk.try_into().unwrap();
419 all_hashes.push(hash);
420 }
421 }
422
423 if all_hashes.len() != self.count {
425 return false;
426 }
427
428 all_hashes.sort();
429
430 let mut hasher = blake3::Hasher::new();
431 for hash in all_hashes {
432 hasher.update(&hash);
433 }
434 let computed_digest: [u8; 32] = hasher.finalize().into();
435
436 computed_digest == self.digest
437 }
438}
439
440#[derive(Clone, Debug)]
445pub struct BloomAccumulator {
446 bits: Vec<bool>,
447 num_hashes: usize,
448 count: usize,
449}
450
451impl BloomAccumulator {
452 pub fn new(capacity: usize, false_positive_rate: f64) -> Self {
459 let bits_per_element = -1.44 * false_positive_rate.log2();
460 let num_bits = ((capacity as f64) * bits_per_element).ceil() as usize;
461 let num_hashes = (bits_per_element * std::f64::consts::LN_2).ceil() as usize;
462
463 Self {
464 bits: vec![false; num_bits],
465 num_hashes,
466 count: 0,
467 }
468 }
469
470 pub fn add(&mut self, element: &[u8]) {
472 let hash = blake3::hash(element);
473 let hash_bytes = hash.as_bytes();
474
475 for i in 0..self.num_hashes {
476 let index = self.hash_index(hash_bytes, i);
477 self.bits[index] = true;
478 }
479 self.count += 1;
480 }
481
482 pub fn might_contain(&self, element: &[u8]) -> bool {
484 let hash = blake3::hash(element);
485 let hash_bytes = hash.as_bytes();
486
487 for i in 0..self.num_hashes {
488 let index = self.hash_index(hash_bytes, i);
489 if !self.bits[index] {
490 return false;
491 }
492 }
493 true
494 }
495
496 fn hash_index(&self, hash: &[u8], i: usize) -> usize {
497 let offset = (i * 8) % hash.len();
498 let mut bytes = [0u8; 8];
499 for j in 0..8 {
500 bytes[j] = hash[(offset + j) % hash.len()];
501 }
502 let value = u64::from_le_bytes(bytes);
503 (value as usize) % self.bits.len()
504 }
505
506 pub fn count(&self) -> usize {
508 self.count
509 }
510
511 pub fn size(&self) -> usize {
513 self.bits.len()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_hash_accumulator_basic() {
523 let mut acc = HashAccumulator::new();
524 assert!(acc.is_empty());
525
526 let elem1 = b"element1";
527 let elem2 = b"element2";
528
529 assert!(acc.add(elem1));
530 assert!(acc.add(elem2));
531 assert_eq!(acc.len(), 2);
532
533 assert!(acc.contains(elem1));
534 assert!(acc.contains(elem2));
535 assert!(!acc.contains(b"element3"));
536 }
537
538 #[test]
539 fn test_hash_accumulator_proof() {
540 let mut acc = HashAccumulator::new();
541 let elem1 = b"peer_id_1";
542 let elem2 = b"peer_id_2";
543 let elem3 = b"peer_id_3";
544
545 acc.add(elem1);
546 acc.add(elem2);
547 acc.add(elem3);
548
549 let proof1 = acc.prove(elem1).unwrap();
550 assert!(acc.verify(elem1, &proof1));
551
552 let proof2 = acc.prove(elem2).unwrap();
553 assert!(acc.verify(elem2, &proof2));
554
555 assert!(!acc.verify(b"wrong", &proof1));
557 }
558
559 #[test]
560 fn test_hash_accumulator_remove() {
561 let mut acc = HashAccumulator::new();
562 let elem1 = b"element1";
563 let elem2 = b"element2";
564
565 acc.add(elem1);
566 acc.add(elem2);
567
568 assert!(acc.remove(elem1));
569 assert!(!acc.contains(elem1));
570 assert!(acc.contains(elem2));
571
572 assert!(!acc.remove(elem1));
574 }
575
576 #[test]
577 fn test_hash_accumulator_from_elements() {
578 let elements = vec![b"elem1".as_ref(), b"elem2".as_ref(), b"elem3".as_ref()];
579 let acc = HashAccumulator::from_elements(&elements);
580
581 assert_eq!(acc.len(), 3);
582 for elem in &elements {
583 assert!(acc.contains(elem));
584 }
585 }
586
587 #[test]
588 fn test_hash_accumulator_batch_operations() {
589 let mut acc = HashAccumulator::new();
590 let elements = vec![b"elem1".as_ref(), b"elem2".as_ref(), b"elem3".as_ref()];
591
592 let added = acc.add_batch(&elements);
593 assert_eq!(added, 3);
594 assert_eq!(acc.len(), 3);
595
596 let removed = acc.remove_batch(&elements[0..2]);
597 assert_eq!(removed, 2);
598 assert_eq!(acc.len(), 1);
599 }
600
601 #[test]
602 fn test_compact_accumulator() {
603 let mut acc = HashAccumulator::new();
604 acc.add(b"elem1");
605 acc.add(b"elem2");
606 acc.add(b"elem3");
607
608 let proof = acc.prove(b"elem1").unwrap();
609
610 let compact = CompactAccumulator::from_accumulator(&acc);
611 assert_eq!(compact.count(), 3);
612 assert!(compact.verify(b"elem1", &proof));
613 }
614
615 #[test]
616 fn test_bloom_accumulator() {
617 let mut bloom = BloomAccumulator::new(1000, 0.01);
618
619 let elements = vec![b"elem1", b"elem2", b"elem3"];
620 for elem in &elements {
621 bloom.add(*elem);
622 }
623
624 for elem in &elements {
625 assert!(bloom.might_contain(*elem));
626 }
627
628 let not_added = b"definitely_not_added_unique_12345";
631 let _ = bloom.might_contain(not_added);
633 }
634
635 #[test]
636 fn test_accumulator_serialization() {
637 let mut acc = HashAccumulator::new();
638 acc.add(b"elem1");
639 acc.add(b"elem2");
640
641 let bytes = acc.to_bytes().unwrap();
642 let restored = HashAccumulator::from_bytes(&bytes).unwrap();
643
644 assert_eq!(acc.digest(), restored.digest());
645 assert_eq!(acc.len(), restored.len());
646 assert!(restored.contains(b"elem1"));
647 assert!(restored.contains(b"elem2"));
648 }
649
650 #[test]
651 fn test_proof_serialization() {
652 let mut acc = HashAccumulator::new();
653 acc.add(b"elem1");
654 acc.add(b"elem2");
655
656 let proof = acc.prove(b"elem1").unwrap();
657 let bytes = proof.to_bytes().unwrap();
658 let restored = MembershipProof::from_bytes(&bytes).unwrap();
659
660 assert!(acc.verify(b"elem1", &restored));
661 }
662
663 #[test]
664 fn test_accumulator_digest_changes() {
665 let mut acc = HashAccumulator::new();
666 let digest1 = acc.digest();
667
668 acc.add(b"elem1");
669 let digest2 = acc.digest();
670 assert_ne!(digest1, digest2);
671
672 acc.add(b"elem2");
673 let digest3 = acc.digest();
674 assert_ne!(digest2, digest3);
675
676 acc.remove(b"elem1");
677 let digest4 = acc.digest();
678 assert_ne!(digest3, digest4);
679 }
680
681 #[test]
682 fn test_proof_not_found() {
683 let acc = HashAccumulator::new();
684 assert!(acc.prove(b"nonexistent").is_err());
685 }
686
687 #[test]
688 fn test_duplicate_add() {
689 let mut acc = HashAccumulator::new();
690 assert!(acc.add(b"elem1"));
691 assert!(!acc.add(b"elem1")); assert_eq!(acc.len(), 1);
693 }
694}