1use std::{
2 alloc::{self, handle_alloc_error, Layout},
3 cmp::Ordering,
4 marker::Send,
5 mem,
6 ptr::NonNull,
7};
8
9use light_hasher::{bigint::bigint_to_be_bytes_array, HasherError};
10use num_bigint::{BigUint, ToBigUint};
11use num_traits::{FromBytes, ToPrimitive};
12use thiserror::Error;
13
14pub mod zero_copy;
15
16pub const ITERATIONS: usize = 20;
17
18#[derive(Debug, Error, PartialEq)]
19pub enum HashSetError {
20 #[error("The hash set is full, cannot add any new elements")]
21 Full,
22 #[error("The provided element is already in the hash set")]
23 ElementAlreadyExists,
24 #[error("The provided element doesn't exist in the hash set")]
25 ElementDoesNotExist,
26 #[error("Could not convert the index from/to usize")]
27 UsizeConv,
28 #[error("Integer overflow")]
29 IntegerOverflow,
30 #[error("Invalid buffer size, expected {0}, got {1}")]
31 BufferSize(usize, usize),
32 #[error("HasherError: big integer conversion error")]
33 Hasher(#[from] HasherError),
34}
35
36impl From<HashSetError> for u32 {
37 fn from(e: HashSetError) -> u32 {
38 match e {
39 HashSetError::Full => 9001,
40 HashSetError::ElementAlreadyExists => 9002,
41 HashSetError::ElementDoesNotExist => 9003,
42 HashSetError::UsizeConv => 9004,
43 HashSetError::IntegerOverflow => 9005,
44 HashSetError::BufferSize(_, _) => 9006,
45 HashSetError::Hasher(e) => e.into(),
46 }
47 }
48}
49
50#[cfg(feature = "solana")]
51impl From<HashSetError> for solana_program_error::ProgramError {
52 fn from(e: HashSetError) -> Self {
53 solana_program_error::ProgramError::Custom(e.into())
54 }
55}
56
57#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
58pub struct HashSetCell {
59 pub value: [u8; 32],
60 pub sequence_number: Option<usize>,
61}
62
63unsafe impl Send for HashSet {}
64
65impl HashSetCell {
66 pub fn value_bytes(&self) -> [u8; 32] {
68 self.value
69 }
70
71 pub fn value_biguint(&self) -> BigUint {
73 BigUint::from_bytes_be(self.value.as_slice())
74 }
75
76 pub fn sequence_number(&self) -> Option<usize> {
78 self.sequence_number
79 }
80
81 pub fn is_marked(&self) -> bool {
83 self.sequence_number.is_some()
84 }
85
86 pub fn is_valid(&self, current_sequence_number: usize) -> bool {
98 match self.sequence_number {
99 Some(sequence_number) => match sequence_number.cmp(¤t_sequence_number) {
100 Ordering::Less | Ordering::Equal => false,
101 Ordering::Greater => true,
102 },
103 None => true,
104 }
105 }
106}
107
108#[derive(Debug)]
109pub struct HashSet {
110 capacity: usize,
112 pub sequence_threshold: usize,
116
117 buckets: NonNull<Option<HashSetCell>>,
120}
121
122unsafe impl Send for HashSetCell {}
123
124impl HashSet {
125 pub fn non_dyn_fields_size() -> usize {
127 mem::size_of::<usize>()
129 + mem::size_of::<usize>()
131 }
132
133 pub fn size_in_account(capacity_values: usize) -> usize {
135 let dyn_fields_size = Self::non_dyn_fields_size();
136
137 let buckets_size_unaligned = mem::size_of::<Option<HashSetCell>>() * capacity_values;
138 let buckets_size = buckets_size_unaligned + mem::align_of::<usize>()
140 - (buckets_size_unaligned % mem::align_of::<usize>());
141
142 dyn_fields_size + buckets_size
143 }
144
145 pub fn new(capacity_values: usize, sequence_threshold: usize) -> Result<Self, HashSetError> {
147 let layout = Layout::array::<Option<HashSetCell>>(capacity_values).unwrap();
149 let values_ptr = unsafe { alloc::alloc(layout) as *mut Option<HashSetCell> };
150 if values_ptr.is_null() {
151 handle_alloc_error(layout);
152 }
153 let values = NonNull::new(values_ptr).unwrap();
154 for i in 0..capacity_values {
155 unsafe {
156 std::ptr::write(values_ptr.add(i), None);
157 }
158 }
159
160 Ok(HashSet {
161 sequence_threshold,
162 capacity: capacity_values,
163 buckets: values,
164 })
165 }
166
167 pub unsafe fn from_bytes_copy(bytes: &mut [u8]) -> Result<Self, HashSetError> {
181 if bytes.len() < Self::non_dyn_fields_size() {
182 return Err(HashSetError::BufferSize(
183 Self::non_dyn_fields_size(),
184 bytes.len(),
185 ));
186 }
187
188 let capacity = usize::from_le_bytes(bytes[0..8].try_into().unwrap());
189 let sequence_threshold = usize::from_le_bytes(bytes[8..16].try_into().unwrap());
190 let expected_size = Self::size_in_account(capacity);
191 if bytes.len() != expected_size {
192 return Err(HashSetError::BufferSize(expected_size, bytes.len()));
193 }
194
195 let buckets_layout = Layout::array::<Option<HashSetCell>>(capacity).unwrap();
196 let buckets_dst_ptr = unsafe { alloc::alloc(buckets_layout) as *mut Option<HashSetCell> };
199 if buckets_dst_ptr.is_null() {
200 handle_alloc_error(buckets_layout);
201 }
202 let buckets = NonNull::new(buckets_dst_ptr).unwrap();
203 for i in 0..capacity {
204 std::ptr::write(buckets_dst_ptr.add(i), None);
205 }
206
207 let offset = Self::non_dyn_fields_size() + mem::size_of::<usize>();
208 let buckets_src_ptr = bytes.as_ptr().add(offset) as *const Option<HashSetCell>;
209 std::ptr::copy(buckets_src_ptr, buckets_dst_ptr, capacity);
210
211 Ok(Self {
212 capacity,
213 sequence_threshold,
214 buckets,
215 })
216 }
217
218 fn probe_index(&self, value: &BigUint, iteration: usize) -> usize {
219 let iteration = iteration + self.capacity / 10;
221 let probe_index = (value
222 + iteration.to_biguint().unwrap() * iteration.to_biguint().unwrap())
223 % self.capacity.to_biguint().unwrap();
224 probe_index.to_usize().unwrap()
225 }
226
227 pub fn get_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
230 if index >= self.capacity {
231 return None;
232 }
233 let bucket = unsafe { &*self.buckets.as_ptr().add(index) };
234 Some(bucket)
235 }
236
237 pub fn get_bucket_mut(&mut self, index: usize) -> Option<&mut Option<HashSetCell>> {
240 if index >= self.capacity {
241 return None;
242 }
243 let bucket = unsafe { &mut *self.buckets.as_ptr().add(index) };
244 Some(bucket)
245 }
246
247 pub fn get_unmarked_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
250 let bucket = self.get_bucket(index);
251 let is_unmarked = match bucket {
252 Some(Some(bucket)) => !bucket.is_marked(),
253 Some(None) => false,
254 None => false,
255 };
256 if is_unmarked {
257 bucket
258 } else {
259 None
260 }
261 }
262
263 pub fn get_capacity(&self) -> usize {
264 self.capacity
265 }
266
267 fn insert_into_occupied_cell(
268 &mut self,
269 value_index: usize,
270 value: &BigUint,
271 current_sequence_number: usize,
272 ) -> Result<bool, HashSetError> {
273 let bucket = self.get_bucket_mut(value_index).unwrap();
275
276 match bucket {
277 Some(bucket) => {
279 if let Some(element_sequence_number) = bucket.sequence_number {
284 if current_sequence_number >= element_sequence_number {
285 *bucket = HashSetCell {
286 value: bigint_to_be_bytes_array(value)?,
287 sequence_number: None,
288 };
289 return Ok(true);
290 }
291 }
292 if &BigUint::from_be_bytes(bucket.value.as_slice()) == value {
295 return Err(HashSetError::ElementAlreadyExists);
296 }
297 }
298 None => unreachable!(),
302 }
303 Ok(false)
304 }
305
306 pub fn insert(
314 &mut self,
315 value: &BigUint,
316 current_sequence_number: usize,
317 ) -> Result<usize, HashSetError> {
318 let index_bucket = self.find_element_iter(value, current_sequence_number, 0, ITERATIONS)?;
319 let (index, is_new) = match index_bucket {
320 Some(index) => index,
321 None => {
322 return Err(HashSetError::Full);
323 }
324 };
325
326 match is_new {
327 false => {
329 if self.insert_into_occupied_cell(index, value, current_sequence_number)? {
330 return Ok(index);
331 }
332 }
333 true => {
334 let bucket = self.get_bucket_mut(index).unwrap();
336
337 *bucket = Some(HashSetCell {
338 value: bigint_to_be_bytes_array(value)?,
339 sequence_number: None,
340 });
341 return Ok(index);
342 }
343 }
344 Err(HashSetError::Full)
345 }
346
347 pub fn find_element_index(
352 &self,
353 value: &BigUint,
354 current_sequence_number: Option<usize>,
355 ) -> Result<Option<usize>, HashSetError> {
356 for i in 0..ITERATIONS {
357 let probe_index = self.probe_index(value, i);
358 let bucket = self.get_bucket(probe_index).unwrap();
360 match bucket {
361 Some(bucket) => {
362 if &bucket.value_biguint() == value {
363 match current_sequence_number {
364 Some(current_sequence_number) => {
367 if bucket.is_valid(current_sequence_number) {
368 return Ok(Some(probe_index));
369 }
370 continue;
371 }
372 None => return Ok(Some(probe_index)),
373 }
374 }
375 continue;
376 }
377 None => {
380 return Ok(None);
381 }
382 }
383 }
384
385 Ok(None)
386 }
387
388 pub fn find_element(
389 &self,
390 value: &BigUint,
391 current_sequence_number: Option<usize>,
392 ) -> Result<Option<(&HashSetCell, usize)>, HashSetError> {
393 let index = self.find_element_index(value, current_sequence_number)?;
394 match index {
395 Some(index) => {
396 let bucket = self.get_bucket(index).unwrap();
397 match bucket {
398 Some(bucket) => Ok(Some((bucket, index))),
399 None => Ok(None),
400 }
401 }
402 None => Ok(None),
403 }
404 }
405
406 pub fn find_element_mut(
407 &mut self,
408 value: &BigUint,
409 current_sequence_number: Option<usize>,
410 ) -> Result<Option<(&mut HashSetCell, usize)>, HashSetError> {
411 let index = self.find_element_index(value, current_sequence_number)?;
412 match index {
413 Some(index) => {
414 let bucket = self.get_bucket_mut(index).unwrap();
415 match bucket {
416 Some(bucket) => Ok(Some((bucket, index))),
417 None => Ok(None),
418 }
419 }
420 None => Ok(None),
421 }
422 }
423
424 pub fn find_element_iter(
430 &mut self,
431 value: &BigUint,
432 current_sequence_number: usize,
433 start_iter: usize,
434 num_iterations: usize,
435 ) -> Result<Option<(usize, bool)>, HashSetError> {
436 let mut first_free_element: Option<(usize, bool)> = None;
437 for i in start_iter..start_iter + num_iterations {
438 let probe_index = self.probe_index(value, i);
439 let bucket = self.get_bucket(probe_index).unwrap();
440
441 match bucket {
442 Some(bucket) => {
443 let is_valid = bucket.is_valid(current_sequence_number);
444 if first_free_element.is_none() && !is_valid {
445 first_free_element = Some((probe_index, false));
446 }
447 if is_valid && &bucket.value_biguint() == value {
448 return Err(HashSetError::ElementAlreadyExists);
449 } else {
450 continue;
451 }
452 }
453 None => {
454 if first_free_element.is_none() {
457 first_free_element = Some((probe_index, true));
458 }
459 break;
463 }
464 }
465 }
466 Ok(first_free_element)
467 }
468
469 pub fn first(
471 &self,
472 current_sequence_number: usize,
473 ) -> Result<Option<&HashSetCell>, HashSetError> {
474 for i in 0..self.capacity {
475 let bucket = self.get_bucket(i).unwrap();
477 if let Some(bucket) = bucket {
478 if bucket.is_valid(current_sequence_number) {
479 return Ok(Some(bucket));
480 }
481 }
482 }
483
484 Ok(None)
485 }
486
487 pub fn first_no_seq(&self) -> Result<Option<(HashSetCell, u16)>, HashSetError> {
489 for i in 0..self.capacity {
490 let bucket = self.get_bucket(i).unwrap();
492
493 if let Some(bucket) = bucket {
494 if bucket.sequence_number.is_none() {
495 return Ok(Some((*bucket, i as u16)));
496 }
497 }
498 }
499
500 Ok(None)
501 }
502
503 pub fn contains(
505 &self,
506 value: &BigUint,
507 sequence_number: Option<usize>,
508 ) -> Result<bool, HashSetError> {
509 let element = self.find_element(value, sequence_number)?;
510 Ok(element.is_some())
511 }
512
513 pub fn mark_with_sequence_number(
515 &mut self,
516 index: usize,
517 sequence_number: usize,
518 ) -> Result<(), HashSetError> {
519 let sequence_threshold = self.sequence_threshold;
520 let element = self
521 .get_bucket_mut(index)
522 .ok_or(HashSetError::ElementDoesNotExist)?;
523
524 match element {
525 Some(element) => {
526 element.sequence_number = Some(sequence_number + sequence_threshold);
527 Ok(())
528 }
529 None => Err(HashSetError::ElementDoesNotExist),
530 }
531 }
532
533 pub fn iter(&self) -> HashSetIterator {
535 HashSetIterator {
536 hash_set: self,
537 current: 0,
538 }
539 }
540}
541
542impl Drop for HashSet {
543 fn drop(&mut self) {
544 unsafe {
547 let layout = Layout::array::<Option<HashSetCell>>(self.capacity).unwrap();
548 alloc::dealloc(self.buckets.as_ptr() as *mut u8, layout);
549 }
550 }
551}
552
553impl PartialEq for HashSet {
554 fn eq(&self, other: &Self) -> bool {
555 self.capacity.eq(&other.capacity)
556 && self.sequence_threshold.eq(&other.sequence_threshold)
557 && self.iter().eq(other.iter())
558 }
559}
560
561pub struct HashSetIterator<'a> {
562 hash_set: &'a HashSet,
563 current: usize,
564}
565
566impl<'a> Iterator for HashSetIterator<'a> {
567 type Item = (usize, &'a HashSetCell);
568
569 fn next(&mut self) -> Option<Self::Item> {
570 while self.current < self.hash_set.get_capacity() {
571 let element_index = self.current;
572 self.current += 1;
573
574 if let Some(Some(cur_element)) = self.hash_set.get_bucket(element_index) {
575 return Some((element_index, cur_element));
576 }
577 }
578 None
579 }
580}
581
582#[cfg(test)]
583mod test {
584 use ark_bn254::Fr;
585 use ark_ff::UniformRand;
586 use rand::{thread_rng, Rng};
587
588 use super::*;
589 use crate::zero_copy::HashSetZeroCopy;
590
591 #[test]
592 fn test_is_valid() {
593 let mut rng = thread_rng();
594
595 let cell = HashSetCell {
596 value: [0u8; 32],
597 sequence_number: None,
598 };
599 assert!(cell.is_valid(0));
601 for _ in 0..100 {
602 let seq: usize = rng.gen();
603 assert!(cell.is_valid(seq));
604 }
605
606 let cell = HashSetCell {
607 value: [0u8; 32],
608 sequence_number: Some(2400),
609 };
610 for i in 0..2400 {
612 assert!(cell.is_valid(i));
613 }
614 for i in 2400..10000 {
615 assert!(!cell.is_valid(i));
616 }
617 }
618
619 #[test]
622 fn test_hash_set_manual() {
623 let mut hs = HashSet::new(256, 4).unwrap();
624
625 let element_1_1 = 1.to_biguint().unwrap();
628 let index_1_1 = hs.insert(&element_1_1, 0).unwrap();
629 hs.mark_with_sequence_number(index_1_1, 1).unwrap();
630
631 assert!(hs.contains(&element_1_1, Some(1)).unwrap());
633 assert!(matches!(
636 hs.insert(&element_1_1, 1),
637 Err(HashSetError::ElementAlreadyExists)
638 ));
639
640 let element_2_3 = 3.to_biguint().unwrap();
644 let element_2_6 = 6.to_biguint().unwrap();
645 let element_2_8 = 8.to_biguint().unwrap();
646 let element_2_9 = 9.to_biguint().unwrap();
647 let index_2_3 = hs.insert(&element_2_3, 1).unwrap();
648 let index_2_6 = hs.insert(&element_2_6, 1).unwrap();
649 let index_2_8 = hs.insert(&element_2_8, 1).unwrap();
650 let index_2_9 = hs.insert(&element_2_9, 1).unwrap();
651 assert!(hs.contains(&element_2_3, Some(2)).unwrap());
652 assert!(hs.contains(&element_2_6, Some(2)).unwrap());
653 assert!(hs.contains(&element_2_8, Some(2)).unwrap());
654 assert!(hs.contains(&element_2_9, Some(2)).unwrap());
655 hs.mark_with_sequence_number(index_2_3, 2).unwrap();
656 hs.mark_with_sequence_number(index_2_6, 2).unwrap();
657 hs.mark_with_sequence_number(index_2_8, 2).unwrap();
658 hs.mark_with_sequence_number(index_2_9, 2).unwrap();
659 assert!(matches!(
660 hs.insert(&element_2_3, 2),
661 Err(HashSetError::ElementAlreadyExists)
662 ));
663 assert!(matches!(
664 hs.insert(&element_2_6, 2),
665 Err(HashSetError::ElementAlreadyExists)
666 ));
667 assert!(matches!(
668 hs.insert(&element_2_8, 2),
669 Err(HashSetError::ElementAlreadyExists)
670 ));
671 assert!(matches!(
672 hs.insert(&element_2_9, 2),
673 Err(HashSetError::ElementAlreadyExists)
674 ));
675
676 let element_3_11 = 11.to_biguint().unwrap();
677 let element_3_13 = 13.to_biguint().unwrap();
678 let element_3_21 = 21.to_biguint().unwrap();
679 let element_3_29 = 29.to_biguint().unwrap();
680 let index_3_11 = hs.insert(&element_3_11, 2).unwrap();
681 let index_3_13 = hs.insert(&element_3_13, 2).unwrap();
682 let index_3_21 = hs.insert(&element_3_21, 2).unwrap();
683 let index_3_29 = hs.insert(&element_3_29, 2).unwrap();
684 assert!(hs.contains(&element_3_11, Some(3)).unwrap());
685 assert!(hs.contains(&element_3_13, Some(3)).unwrap());
686 assert!(hs.contains(&element_3_21, Some(3)).unwrap());
687 assert!(hs.contains(&element_3_29, Some(3)).unwrap());
688 hs.mark_with_sequence_number(index_3_11, 3).unwrap();
689 hs.mark_with_sequence_number(index_3_13, 3).unwrap();
690 hs.mark_with_sequence_number(index_3_21, 3).unwrap();
691 hs.mark_with_sequence_number(index_3_29, 3).unwrap();
692 assert!(matches!(
693 hs.insert(&element_3_11, 3),
694 Err(HashSetError::ElementAlreadyExists)
695 ));
696 assert!(matches!(
697 hs.insert(&element_3_13, 3),
698 Err(HashSetError::ElementAlreadyExists)
699 ));
700 assert!(matches!(
701 hs.insert(&element_3_21, 3),
702 Err(HashSetError::ElementAlreadyExists)
703 ));
704 assert!(matches!(
705 hs.insert(&element_3_29, 3),
706 Err(HashSetError::ElementAlreadyExists)
707 ));
708
709 let element_4_93 = 93.to_biguint().unwrap();
710 let element_4_65 = 64.to_biguint().unwrap();
711 let element_4_72 = 72.to_biguint().unwrap();
712 let element_4_15 = 15.to_biguint().unwrap();
713 let index_4_93 = hs.insert(&element_4_93, 3).unwrap();
714 let index_4_65 = hs.insert(&element_4_65, 3).unwrap();
715 let index_4_72 = hs.insert(&element_4_72, 3).unwrap();
716 let index_4_15 = hs.insert(&element_4_15, 3).unwrap();
717 assert!(hs.contains(&element_4_93, Some(4)).unwrap());
718 assert!(hs.contains(&element_4_65, Some(4)).unwrap());
719 assert!(hs.contains(&element_4_72, Some(4)).unwrap());
720 assert!(hs.contains(&element_4_15, Some(4)).unwrap());
721 hs.mark_with_sequence_number(index_4_93, 4).unwrap();
722 hs.mark_with_sequence_number(index_4_65, 4).unwrap();
723 hs.mark_with_sequence_number(index_4_72, 4).unwrap();
724 hs.mark_with_sequence_number(index_4_15, 4).unwrap();
725
726 assert!(matches!(
733 hs.insert(&element_1_1, 4),
734 Err(HashSetError::ElementAlreadyExists)
735 ));
736 assert!(matches!(
737 hs.insert(&element_2_3, 5),
738 Err(HashSetError::ElementAlreadyExists)
739 ));
740 assert!(matches!(
741 hs.insert(&element_2_6, 5),
742 Err(HashSetError::ElementAlreadyExists)
743 ));
744 assert!(matches!(
745 hs.insert(&element_2_8, 5),
746 Err(HashSetError::ElementAlreadyExists)
747 ));
748 assert!(matches!(
749 hs.insert(&element_2_9, 5),
750 Err(HashSetError::ElementAlreadyExists)
751 ));
752 hs.insert(&element_1_1, 5).unwrap();
753 hs.insert(&element_2_3, 6).unwrap();
754 hs.insert(&element_2_6, 6).unwrap();
755 hs.insert(&element_2_8, 6).unwrap();
756 hs.insert(&element_2_9, 6).unwrap();
757 }
758
759 #[test]
761 fn test_hash_set_random() {
762 let mut hs = HashSet::new(6857, 2400).unwrap();
763
764 assert_eq!(hs.first(0).unwrap(), None);
766 let mut rng = thread_rng();
767 let mut seq = 0;
768 let nullifiers: [BigUint; 10000] =
769 std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
770 for nf_chunk in nullifiers.chunks(2400) {
771 for nullifier in nf_chunk.iter() {
772 assert!(!hs.contains(nullifier, Some(seq)).unwrap());
773 let index = hs.insert(nullifier, seq).unwrap();
774 assert!(hs.contains(nullifier, Some(seq)).unwrap());
775
776 let nullifier_bytes = bigint_to_be_bytes_array(nullifier).unwrap();
777
778 let element = *hs.find_element(nullifier, Some(seq)).unwrap().unwrap().0;
779 assert_eq!(
780 element,
781 HashSetCell {
782 value: bigint_to_be_bytes_array(nullifier).unwrap(),
783 sequence_number: None,
784 }
785 );
786 assert_eq!(element.value_bytes(), nullifier_bytes);
787 assert_eq!(&element.value_biguint(), nullifier);
788 assert_eq!(element.sequence_number(), None);
789 assert!(!element.is_marked());
790 assert!(element.is_valid(seq));
791
792 hs.mark_with_sequence_number(index, seq).unwrap();
793 let element = *hs.find_element(nullifier, Some(seq)).unwrap().unwrap().0;
794
795 assert_eq!(
796 element,
797 HashSetCell {
798 value: nullifier_bytes,
799 sequence_number: Some(2400 + seq)
800 }
801 );
802 assert_eq!(element.value_bytes(), nullifier_bytes);
803 assert_eq!(&element.value_biguint(), nullifier);
804 assert_eq!(element.sequence_number(), Some(2400 + seq));
805 assert!(element.is_marked());
806 assert!(element.is_valid(seq));
807
808 assert!(matches!(
811 hs.insert(nullifier, seq + 2399),
812 Err(HashSetError::ElementAlreadyExists),
813 ));
814 seq += 1;
815 }
816 seq += 2400;
817 }
818 }
819
820 fn hash_set_from_bytes_copy<
821 const CAPACITY: usize,
822 const SEQUENCE_THRESHOLD: usize,
823 const OPERATIONS: usize,
824 >() {
825 let mut hs_1 = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
826
827 let mut rng = thread_rng();
828
829 let mut bytes = vec![0u8; HashSet::size_in_account(CAPACITY)];
831 rng.fill(bytes.as_mut_slice());
832
833 {
835 let mut hs_2 = unsafe {
836 HashSetZeroCopy::from_bytes_zero_copy_init(&mut bytes, CAPACITY, SEQUENCE_THRESHOLD)
837 .unwrap()
838 };
839
840 for seq in 0..OPERATIONS {
841 let value = BigUint::from(Fr::rand(&mut rng));
842 hs_1.insert(&value, seq).unwrap();
843 hs_2.insert(&value, seq).unwrap();
844 }
845
846 assert_eq!(hs_1, *hs_2);
847 }
848
849 {
851 let hs_2 = unsafe { HashSet::from_bytes_copy(&mut bytes).unwrap() };
852 assert_eq!(hs_1, hs_2);
853 }
854 }
855
856 #[test]
857 fn test_hash_set_from_bytes_copy_6857_2400_3600() {
858 hash_set_from_bytes_copy::<6857, 2400, 3600>()
859 }
860
861 #[test]
862 fn test_hash_set_from_bytes_copy_9601_2400_5000() {
863 hash_set_from_bytes_copy::<9601, 2400, 5000>()
864 }
865
866 fn hash_set_full<const CAPACITY: usize, const SEQUENCE_THRESHOLD: usize>() {
867 for _ in 0..100 {
868 let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
869
870 let mut rng = rand::thread_rng();
871
872 for i in 0..CAPACITY {
875 let value = BigUint::from(Fr::rand(&mut rng));
876 match hs.insert(&value, 0) {
877 Ok(index) => hs.mark_with_sequence_number(index, 0).unwrap(),
878 Err(e) => {
879 assert!(matches!(e, HashSetError::Full));
880 println!("initial insertions: {i}: failed, stopping");
881 break;
882 }
883 }
884 }
885
886 for i in 0..1000 {
890 let value = BigUint::from(Fr::rand(&mut rng));
891 let res = hs.insert(&value, 0);
892 if res.is_err() {
893 assert!(matches!(res, Err(HashSetError::Full)));
894 } else {
895 println!("secondary insertions: {i}: apparent success with value: {value:?}");
896 }
897 }
898
899 for i in 0..1000 {
902 let value = BigUint::from(Fr::rand(&mut rng));
903 let sequence_number = rng.gen_range(0..hs.sequence_threshold);
906 let res = hs.insert(&value, sequence_number);
907 if res.is_err() {
908 assert!(matches!(res, Err(HashSetError::Full)));
909 } else {
910 println!("tertiary insertions: {i}: surprising success with value: {value:?}");
911 }
912 }
913
914 for i in 0..CAPACITY {
917 let value = BigUint::from(Fr::rand(&mut rng));
918 if let Err(e) = hs.insert(&value, SEQUENCE_THRESHOLD + i) {
919 assert!(matches!(e, HashSetError::Full));
920 println!("insertions after fillup: {i}: failed, stopping");
921 break;
922 }
923 }
924 }
925 }
926
927 #[test]
928 fn test_hash_set_full_6857_2400() {
929 hash_set_full::<6857, 2400>()
930 }
931
932 #[test]
933 fn test_hash_set_full_9601_2400() {
934 hash_set_full::<9601, 2400>()
935 }
936
937 #[test]
938 fn test_hash_set_element_does_not_exist() {
939 let mut hs = HashSet::new(4800, 2400).unwrap();
940
941 let mut rng = thread_rng();
942
943 for _ in 0..1000 {
944 let index = rng.gen_range(0..4800);
945
946 let res = hs.mark_with_sequence_number(index, 0);
948 assert!(matches!(res, Err(HashSetError::ElementDoesNotExist)));
949 }
950
951 for _ in 0..1000 {
952 let value = BigUint::from(Fr::rand(&mut rng));
955 let index = hs.insert(&value, 0).unwrap();
956 hs.mark_with_sequence_number(index, 1).unwrap();
957 }
958 }
959
960 #[test]
961 fn test_hash_set_iter_manual() {
962 let mut hs = HashSet::new(6857, 2400).unwrap();
963
964 let nullifier_1 = 945635_u32.to_biguint().unwrap();
965 let nullifier_2 = 3546656654734254353455_u128.to_biguint().unwrap();
966 let nullifier_3 = 543543656564_u64.to_biguint().unwrap();
967 let nullifier_4 = 43_u8.to_biguint().unwrap();
968 let nullifier_5 = 0_u8.to_biguint().unwrap();
969 let nullifier_6 = 65423_u32.to_biguint().unwrap();
970 let nullifier_7 = 745654665_u32.to_biguint().unwrap();
971 let nullifier_8 = 97664353453465354645645465_u128.to_biguint().unwrap();
972 let nullifier_9 = 453565465464565635475_u128.to_biguint().unwrap();
973 let nullifier_10 = 543645654645_u64.to_biguint().unwrap();
974
975 hs.insert(&nullifier_1, 0).unwrap();
976 hs.insert(&nullifier_2, 0).unwrap();
977 hs.insert(&nullifier_3, 0).unwrap();
978 hs.insert(&nullifier_4, 0).unwrap();
979 hs.insert(&nullifier_5, 0).unwrap();
980 hs.insert(&nullifier_6, 0).unwrap();
981 hs.insert(&nullifier_7, 0).unwrap();
982 hs.insert(&nullifier_8, 0).unwrap();
983 hs.insert(&nullifier_9, 0).unwrap();
984 hs.insert(&nullifier_10, 0).unwrap();
985
986 let inserted_nullifiers = hs
987 .iter()
988 .map(|(_, nullifier)| nullifier.value_biguint())
989 .collect::<Vec<_>>();
990 assert_eq!(inserted_nullifiers.len(), 10);
991 assert_eq!(inserted_nullifiers[0], nullifier_7);
992 assert_eq!(inserted_nullifiers[1], nullifier_3);
993 assert_eq!(inserted_nullifiers[2], nullifier_10);
994 assert_eq!(inserted_nullifiers[3], nullifier_1);
995 assert_eq!(inserted_nullifiers[4], nullifier_8);
996 assert_eq!(inserted_nullifiers[5], nullifier_5);
997 assert_eq!(inserted_nullifiers[6], nullifier_4);
998 assert_eq!(inserted_nullifiers[7], nullifier_2);
999 assert_eq!(inserted_nullifiers[8], nullifier_9);
1000 assert_eq!(inserted_nullifiers[9], nullifier_6);
1001 }
1002
1003 fn hash_set_iter_random<
1004 const INSERTIONS: usize,
1005 const CAPACITY: usize,
1006 const SEQUENCE_THRESHOLD: usize,
1007 >() {
1008 let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
1009 let mut rng = thread_rng();
1010
1011 let nullifiers: [BigUint; INSERTIONS] =
1012 std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
1013
1014 for nullifier in nullifiers.iter() {
1015 hs.insert(nullifier, 0).unwrap();
1016 }
1017
1018 let mut sorted_nullifiers = nullifiers.iter().collect::<Vec<_>>();
1019 let mut inserted_nullifiers = hs
1020 .iter()
1021 .map(|(_, nullifier)| nullifier.value_biguint())
1022 .collect::<Vec<_>>();
1023 sorted_nullifiers.sort();
1024 inserted_nullifiers.sort();
1025
1026 let inserted_nullifiers = inserted_nullifiers.iter().collect::<Vec<&BigUint>>();
1027 assert_eq!(inserted_nullifiers.len(), INSERTIONS);
1028 assert_eq!(sorted_nullifiers.as_slice(), inserted_nullifiers.as_slice());
1029 }
1030
1031 #[test]
1032 fn test_hash_set_iter_random_6857_2400() {
1033 hash_set_iter_random::<3500, 6857, 2400>()
1034 }
1035
1036 #[test]
1037 fn test_hash_set_iter_random_9601_2400() {
1038 hash_set_iter_random::<5000, 9601, 2400>()
1039 }
1040
1041 #[test]
1042 fn test_hash_set_get_bucket() {
1043 let mut hs = HashSet::new(6857, 2400).unwrap();
1044
1045 for i in 0..3600 {
1046 let bn_i = i.to_biguint().unwrap();
1047 hs.insert(&bn_i, i).unwrap();
1048 }
1049 let mut unused_indices = vec![true; 6857];
1050 for i in 0..3600 {
1051 let bn_i = i.to_biguint().unwrap();
1052 let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
1053 let element = hs.get_bucket(i).unwrap().unwrap();
1054 assert_eq!(element.value_biguint(), bn_i);
1055 unused_indices[i] = false;
1056 }
1057 for i in unused_indices.iter().enumerate() {
1059 if *i.1 {
1060 assert!(hs.get_bucket(i.0).unwrap().is_none());
1061 }
1062 }
1063 for i in 6857..10_000 {
1065 assert!(hs.get_bucket(i).is_none());
1066 }
1067 }
1068
1069 #[test]
1070 fn test_hash_set_get_bucket_mut() {
1071 let mut hs = HashSet::new(6857, 2400).unwrap();
1072
1073 for i in 0..3600 {
1074 let bn_i = i.to_biguint().unwrap();
1075 hs.insert(&bn_i, i).unwrap();
1076 }
1077 let mut unused_indices = vec![false; 6857];
1078
1079 for i in 0..3600 {
1080 let bn_i = i.to_biguint().unwrap();
1081 let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
1082
1083 let element = hs.get_bucket_mut(i).unwrap();
1084 assert_eq!(element.unwrap().value_biguint(), bn_i);
1085 unused_indices[i] = true;
1086
1087 *element = Some(HashSetCell {
1089 value: [0_u8; 32],
1090 sequence_number: None,
1091 });
1092 }
1093
1094 for (i, is_used) in unused_indices.iter().enumerate() {
1095 if *is_used {
1096 let element = hs.get_bucket_mut(i).unwrap().unwrap();
1097 assert_eq!(element.value_bytes(), [0_u8; 32]);
1098 }
1099 }
1100 for (i, is_used) in unused_indices.iter().enumerate() {
1102 if !*is_used {
1103 assert!(hs.get_bucket_mut(i).unwrap().is_none());
1104 }
1105 }
1106 for i in 6857..10_000 {
1108 assert!(hs.get_bucket_mut(i).is_none());
1109 }
1110 }
1111
1112 #[test]
1113 fn test_hash_set_get_unmarked_bucket() {
1114 let mut hs = HashSet::new(6857, 2400).unwrap();
1115
1116 (0..3600).for_each(|i| {
1119 let bn_i = i.to_biguint().unwrap();
1120 hs.insert(&bn_i, i).unwrap();
1121 });
1122
1123 for i in 0..3600 {
1124 let i = hs
1125 .find_element_index(&i.to_biguint().unwrap(), None)
1126 .unwrap()
1127 .unwrap();
1128 let element = hs.get_unmarked_bucket(i);
1129 assert!(element.is_some());
1130 }
1131
1132 for i in 0..3600 {
1134 let index = hs
1135 .find_element_index(&i.to_biguint().unwrap(), None)
1136 .unwrap()
1137 .unwrap();
1138 hs.mark_with_sequence_number(index, i).unwrap();
1139 }
1140
1141 for i in 0..3600 {
1142 let i = hs
1143 .find_element_index(&i.to_biguint().unwrap(), None)
1144 .unwrap()
1145 .unwrap();
1146 let element = hs.get_unmarked_bucket(i);
1147 assert!(element.is_none());
1148 }
1149 }
1150
1151 #[test]
1152 fn test_hash_set_first_no_seq() {
1153 let mut hs = HashSet::new(6857, 2400).unwrap();
1154
1155 for i in 0..3600 {
1158 let bn_i = i.to_biguint().unwrap();
1159 hs.insert(&bn_i, i).unwrap();
1160
1161 let element = hs.first_no_seq().unwrap().unwrap();
1162 assert_eq!(element.0.value_biguint(), 0.to_biguint().unwrap());
1163 }
1164 }
1165}