light_hash_set/
lib.rs

1use std::{
2    alloc::{self, handle_alloc_error, Layout},
3    cmp::Ordering,
4    marker::Send,
5    mem,
6    ptr::NonNull,
7};
8
9use light_hasher::{bigint::bigint_to_be_bytes_array, HasherError};
10use num_bigint::{BigUint, ToBigUint};
11use num_traits::{FromBytes, ToPrimitive};
12use thiserror::Error;
13
14pub mod zero_copy;
15
16pub const ITERATIONS: usize = 20;
17
18#[derive(Debug, Error, PartialEq)]
19pub enum HashSetError {
20    #[error("The hash set is full, cannot add any new elements")]
21    Full,
22    #[error("The provided element is already in the hash set")]
23    ElementAlreadyExists,
24    #[error("The provided element doesn't exist in the hash set")]
25    ElementDoesNotExist,
26    #[error("Could not convert the index from/to usize")]
27    UsizeConv,
28    #[error("Integer overflow")]
29    IntegerOverflow,
30    #[error("Invalid buffer size, expected {0}, got {1}")]
31    BufferSize(usize, usize),
32    #[error("HasherError: big integer conversion error")]
33    Hasher(#[from] HasherError),
34}
35
36impl From<HashSetError> for u32 {
37    fn from(e: HashSetError) -> u32 {
38        match e {
39            HashSetError::Full => 9001,
40            HashSetError::ElementAlreadyExists => 9002,
41            HashSetError::ElementDoesNotExist => 9003,
42            HashSetError::UsizeConv => 9004,
43            HashSetError::IntegerOverflow => 9005,
44            HashSetError::BufferSize(_, _) => 9006,
45            HashSetError::Hasher(e) => e.into(),
46        }
47    }
48}
49
50#[cfg(feature = "solana")]
51impl From<HashSetError> for solana_program_error::ProgramError {
52    fn from(e: HashSetError) -> Self {
53        solana_program_error::ProgramError::Custom(e.into())
54    }
55}
56
57#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
58pub struct HashSetCell {
59    pub value: [u8; 32],
60    pub sequence_number: Option<usize>,
61}
62
63unsafe impl Send for HashSet {}
64
65impl HashSetCell {
66    /// Returns the value as a byte array.
67    pub fn value_bytes(&self) -> [u8; 32] {
68        self.value
69    }
70
71    /// Returns the value as a big number.
72    pub fn value_biguint(&self) -> BigUint {
73        BigUint::from_bytes_be(self.value.as_slice())
74    }
75
76    /// Returns the associated sequence number.
77    pub fn sequence_number(&self) -> Option<usize> {
78        self.sequence_number
79    }
80
81    /// Checks whether the value is marked with a sequence number.
82    pub fn is_marked(&self) -> bool {
83        self.sequence_number.is_some()
84    }
85
86    /// Checks whether the value is valid according to the provided
87    /// `current_sequence_number` (which usually should be a sequence number
88    /// associated with the Merkle tree).
89    ///
90    /// The value is valid if:
91    ///
92    /// * It was not annotated with sequence number.
93    /// * Its sequence number is lower than the provided `sequence_number`.
94    ///
95    /// The value is invalid if it's lower or equal to the provided
96    /// `sequence_number`.
97    pub fn is_valid(&self, current_sequence_number: usize) -> bool {
98        match self.sequence_number {
99            Some(sequence_number) => match sequence_number.cmp(&current_sequence_number) {
100                Ordering::Less | Ordering::Equal => false,
101                Ordering::Greater => true,
102            },
103            None => true,
104        }
105    }
106}
107
108#[derive(Debug)]
109pub struct HashSet {
110    /// Capacity of the buckets.
111    capacity: usize,
112    /// Difference of sequence numbers, after which the given element can be
113    /// replaced by an another one (with a sequence number higher than the
114    /// threshold).
115    pub sequence_threshold: usize,
116
117    /// An array of buckets. It has a size equal to the expected number of
118    /// elements.
119    buckets: NonNull<Option<HashSetCell>>,
120}
121
122unsafe impl Send for HashSetCell {}
123
124impl HashSet {
125    /// Size of the struct **without** dynamically sized fields.
126    pub fn non_dyn_fields_size() -> usize {
127        // capacity
128        mem::size_of::<usize>()
129        // sequence_threshold
130        + mem::size_of::<usize>()
131    }
132
133    /// Size which needs to be allocated on Solana account to fit the hash set.
134    pub fn size_in_account(capacity_values: usize) -> usize {
135        let dyn_fields_size = Self::non_dyn_fields_size();
136
137        let buckets_size_unaligned = mem::size_of::<Option<HashSetCell>>() * capacity_values;
138        // Make sure that alignment of `values` matches the alignment of `usize`.
139        let buckets_size = buckets_size_unaligned + mem::align_of::<usize>()
140            - (buckets_size_unaligned % mem::align_of::<usize>());
141
142        dyn_fields_size + buckets_size
143    }
144
145    // Create a new hash set with the given capacity
146    pub fn new(capacity_values: usize, sequence_threshold: usize) -> Result<Self, HashSetError> {
147        // SAFETY: It's just a regular allocation.
148        let layout = Layout::array::<Option<HashSetCell>>(capacity_values).unwrap();
149        let values_ptr = unsafe { alloc::alloc(layout) as *mut Option<HashSetCell> };
150        if values_ptr.is_null() {
151            handle_alloc_error(layout);
152        }
153        let values = NonNull::new(values_ptr).unwrap();
154        for i in 0..capacity_values {
155            unsafe {
156                std::ptr::write(values_ptr.add(i), None);
157            }
158        }
159
160        Ok(HashSet {
161            sequence_threshold,
162            capacity: capacity_values,
163            buckets: values,
164        })
165    }
166
167    /// Creates a copy of `HashSet` from the given byte slice.
168    ///
169    /// # Purpose
170    ///
171    /// This method is meant to be used mostly in the SDK code, to convert
172    /// fetched Solana accounts to actual hash sets. Creating a copy is the
173    /// safest way of conversion in async Rust.
174    ///
175    /// # Safety
176    ///
177    /// This is highly unsafe. Ensuring the alignment and that the slice
178    /// provides actual actual data of the hash set is the caller's
179    /// responsibility.
180    pub unsafe fn from_bytes_copy(bytes: &mut [u8]) -> Result<Self, HashSetError> {
181        if bytes.len() < Self::non_dyn_fields_size() {
182            return Err(HashSetError::BufferSize(
183                Self::non_dyn_fields_size(),
184                bytes.len(),
185            ));
186        }
187
188        let capacity = usize::from_le_bytes(bytes[0..8].try_into().unwrap());
189        let sequence_threshold = usize::from_le_bytes(bytes[8..16].try_into().unwrap());
190        let expected_size = Self::size_in_account(capacity);
191        if bytes.len() != expected_size {
192            return Err(HashSetError::BufferSize(expected_size, bytes.len()));
193        }
194
195        let buckets_layout = Layout::array::<Option<HashSetCell>>(capacity).unwrap();
196        // SAFETY: `I` is always a signed integer. Creating a layout for an
197        // array of integers of any size won't cause any panic.
198        let buckets_dst_ptr = unsafe { alloc::alloc(buckets_layout) as *mut Option<HashSetCell> };
199        if buckets_dst_ptr.is_null() {
200            handle_alloc_error(buckets_layout);
201        }
202        let buckets = NonNull::new(buckets_dst_ptr).unwrap();
203        for i in 0..capacity {
204            std::ptr::write(buckets_dst_ptr.add(i), None);
205        }
206
207        let offset = Self::non_dyn_fields_size() + mem::size_of::<usize>();
208        let buckets_src_ptr = bytes.as_ptr().add(offset) as *const Option<HashSetCell>;
209        std::ptr::copy(buckets_src_ptr, buckets_dst_ptr, capacity);
210
211        Ok(Self {
212            capacity,
213            sequence_threshold,
214            buckets,
215        })
216    }
217
218    fn probe_index(&self, value: &BigUint, iteration: usize) -> usize {
219        // Increase stepsize over the capacity of the hash set.
220        let iteration = iteration + self.capacity / 10;
221        let probe_index = (value
222            + iteration.to_biguint().unwrap() * iteration.to_biguint().unwrap())
223            % self.capacity.to_biguint().unwrap();
224        probe_index.to_usize().unwrap()
225    }
226
227    /// Returns a reference to a bucket under the given `index`. Does not check
228    /// the validity.
229    pub fn get_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
230        if index >= self.capacity {
231            return None;
232        }
233        let bucket = unsafe { &*self.buckets.as_ptr().add(index) };
234        Some(bucket)
235    }
236
237    /// Returns a mutable reference to a bucket under the given `index`. Does
238    /// not check the validity.
239    pub fn get_bucket_mut(&mut self, index: usize) -> Option<&mut Option<HashSetCell>> {
240        if index >= self.capacity {
241            return None;
242        }
243        let bucket = unsafe { &mut *self.buckets.as_ptr().add(index) };
244        Some(bucket)
245    }
246
247    /// Returns a reference to an unmarked bucket under the given index. If the
248    /// bucket is marked, returns `None`.
249    pub fn get_unmarked_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
250        let bucket = self.get_bucket(index);
251        let is_unmarked = match bucket {
252            Some(Some(bucket)) => !bucket.is_marked(),
253            Some(None) => false,
254            None => false,
255        };
256        if is_unmarked {
257            bucket
258        } else {
259            None
260        }
261    }
262
263    pub fn get_capacity(&self) -> usize {
264        self.capacity
265    }
266
267    fn insert_into_occupied_cell(
268        &mut self,
269        value_index: usize,
270        value: &BigUint,
271        current_sequence_number: usize,
272    ) -> Result<bool, HashSetError> {
273        // PANICS: We trust the bounds of `value_index` here.
274        let bucket = self.get_bucket_mut(value_index).unwrap();
275
276        match bucket {
277            // The cell in the value array is already taken.
278            Some(bucket) => {
279                // We can overwrite that cell only if the element
280                // is expired - when the difference between its
281                // sequence number and provided sequence number is
282                // greater than the threshold.
283                if let Some(element_sequence_number) = bucket.sequence_number {
284                    if current_sequence_number >= element_sequence_number {
285                        *bucket = HashSetCell {
286                            value: bigint_to_be_bytes_array(value)?,
287                            sequence_number: None,
288                        };
289                        return Ok(true);
290                    }
291                }
292                // Otherwise, we need to prevent having multiple valid
293                // elements with the same value.
294                if &BigUint::from_be_bytes(bucket.value.as_slice()) == value {
295                    return Err(HashSetError::ElementAlreadyExists);
296                }
297            }
298            // Panics: If there is a hash set cell pointing to a `None` value,
299            // it means we really screwed up in the implementation...
300            // That should never happen.
301            None => unreachable!(),
302        }
303        Ok(false)
304    }
305
306    /// Inserts a value into the hash set, with `self.capacity_values` attempts.
307    ///
308    /// Every attempt uses quadratic probing to find an empty cell or a cell
309    /// which can be overwritten.
310    ///
311    /// `current sequence_number` is used to check whether existing values can
312    /// be overwritten.
313    pub fn insert(
314        &mut self,
315        value: &BigUint,
316        current_sequence_number: usize,
317    ) -> Result<usize, HashSetError> {
318        let index_bucket = self.find_element_iter(value, current_sequence_number, 0, ITERATIONS)?;
319        let (index, is_new) = match index_bucket {
320            Some(index) => index,
321            None => {
322                return Err(HashSetError::Full);
323            }
324        };
325
326        match is_new {
327            // The visited hash set cell points to a value in the array.
328            false => {
329                if self.insert_into_occupied_cell(index, value, current_sequence_number)? {
330                    return Ok(index);
331                }
332            }
333            true => {
334                // PANICS: We trust the bounds of `index`.
335                let bucket = self.get_bucket_mut(index).unwrap();
336
337                *bucket = Some(HashSetCell {
338                    value: bigint_to_be_bytes_array(value)?,
339                    sequence_number: None,
340                });
341                return Ok(index);
342            }
343        }
344        Err(HashSetError::Full)
345    }
346
347    /// Finds an index of the provided `value` inside `buckets`.
348    ///
349    /// Uses the optional `current_sequence_number` arguments for checking the
350    /// validity of the element.
351    pub fn find_element_index(
352        &self,
353        value: &BigUint,
354        current_sequence_number: Option<usize>,
355    ) -> Result<Option<usize>, HashSetError> {
356        for i in 0..ITERATIONS {
357            let probe_index = self.probe_index(value, i);
358            // PANICS: `probe_index()` ensures the bounds.
359            let bucket = self.get_bucket(probe_index).unwrap();
360            match bucket {
361                Some(bucket) => {
362                    if &bucket.value_biguint() == value {
363                        match current_sequence_number {
364                            // If the caller provided `current_sequence_number`,
365                            // check the validity of the bucket.
366                            Some(current_sequence_number) => {
367                                if bucket.is_valid(current_sequence_number) {
368                                    return Ok(Some(probe_index));
369                                }
370                                continue;
371                            }
372                            None => return Ok(Some(probe_index)),
373                        }
374                    }
375                    continue;
376                }
377                // If we found an empty bucket, it means that there is no
378                // chance of our element existing in the hash set.
379                None => {
380                    return Ok(None);
381                }
382            }
383        }
384
385        Ok(None)
386    }
387
388    pub fn find_element(
389        &self,
390        value: &BigUint,
391        current_sequence_number: Option<usize>,
392    ) -> Result<Option<(&HashSetCell, usize)>, HashSetError> {
393        let index = self.find_element_index(value, current_sequence_number)?;
394        match index {
395            Some(index) => {
396                let bucket = self.get_bucket(index).unwrap();
397                match bucket {
398                    Some(bucket) => Ok(Some((bucket, index))),
399                    None => Ok(None),
400                }
401            }
402            None => Ok(None),
403        }
404    }
405
406    pub fn find_element_mut(
407        &mut self,
408        value: &BigUint,
409        current_sequence_number: Option<usize>,
410    ) -> Result<Option<(&mut HashSetCell, usize)>, HashSetError> {
411        let index = self.find_element_index(value, current_sequence_number)?;
412        match index {
413            Some(index) => {
414                let bucket = self.get_bucket_mut(index).unwrap();
415                match bucket {
416                    Some(bucket) => Ok(Some((bucket, index))),
417                    None => Ok(None),
418                }
419            }
420            None => Ok(None),
421        }
422    }
423
424    /// find_element_iter iterates over a fixed range of elements
425    /// in the hash set.
426    /// We always have to iterate over the whole range
427    /// to make sure that the value is not in the hash-set.
428    /// Returns the position of the first free value.
429    pub fn find_element_iter(
430        &mut self,
431        value: &BigUint,
432        current_sequence_number: usize,
433        start_iter: usize,
434        num_iterations: usize,
435    ) -> Result<Option<(usize, bool)>, HashSetError> {
436        let mut first_free_element: Option<(usize, bool)> = None;
437        for i in start_iter..start_iter + num_iterations {
438            let probe_index = self.probe_index(value, i);
439            let bucket = self.get_bucket(probe_index).unwrap();
440
441            match bucket {
442                Some(bucket) => {
443                    let is_valid = bucket.is_valid(current_sequence_number);
444                    if first_free_element.is_none() && !is_valid {
445                        first_free_element = Some((probe_index, false));
446                    }
447                    if is_valid && &bucket.value_biguint() == value {
448                        return Err(HashSetError::ElementAlreadyExists);
449                    } else {
450                        continue;
451                    }
452                }
453                None => {
454                    // A previous bucket could have been freed already even
455                    // though the whole hash set has not been used yet.
456                    if first_free_element.is_none() {
457                        first_free_element = Some((probe_index, true));
458                    }
459                    // Since we encountered an empty bucket we know for sure
460                    // that the element is not in a bucket with higher probe
461                    // index.
462                    break;
463                }
464            }
465        }
466        Ok(first_free_element)
467    }
468
469    /// Returns a first available element.
470    pub fn first(
471        &self,
472        current_sequence_number: usize,
473    ) -> Result<Option<&HashSetCell>, HashSetError> {
474        for i in 0..self.capacity {
475            // PANICS: The loop ensures the bounds.
476            let bucket = self.get_bucket(i).unwrap();
477            if let Some(bucket) = bucket {
478                if bucket.is_valid(current_sequence_number) {
479                    return Ok(Some(bucket));
480                }
481            }
482        }
483
484        Ok(None)
485    }
486
487    /// Returns a first available element that does not have a sequence number.
488    pub fn first_no_seq(&self) -> Result<Option<(HashSetCell, u16)>, HashSetError> {
489        for i in 0..self.capacity {
490            // PANICS: The loop ensures the bounds.
491            let bucket = self.get_bucket(i).unwrap();
492
493            if let Some(bucket) = bucket {
494                if bucket.sequence_number.is_none() {
495                    return Ok(Some((*bucket, i as u16)));
496                }
497            }
498        }
499
500        Ok(None)
501    }
502
503    /// Checks if the hash set contains a value.
504    pub fn contains(
505        &self,
506        value: &BigUint,
507        sequence_number: Option<usize>,
508    ) -> Result<bool, HashSetError> {
509        let element = self.find_element(value, sequence_number)?;
510        Ok(element.is_some())
511    }
512
513    /// Marks the given element with a given sequence number.
514    pub fn mark_with_sequence_number(
515        &mut self,
516        index: usize,
517        sequence_number: usize,
518    ) -> Result<(), HashSetError> {
519        let sequence_threshold = self.sequence_threshold;
520        let element = self
521            .get_bucket_mut(index)
522            .ok_or(HashSetError::ElementDoesNotExist)?;
523
524        match element {
525            Some(element) => {
526                element.sequence_number = Some(sequence_number + sequence_threshold);
527                Ok(())
528            }
529            None => Err(HashSetError::ElementDoesNotExist),
530        }
531    }
532
533    /// Returns an iterator over elements.
534    pub fn iter(&self) -> HashSetIterator {
535        HashSetIterator {
536            hash_set: self,
537            current: 0,
538        }
539    }
540}
541
542impl Drop for HashSet {
543    fn drop(&mut self) {
544        // SAFETY: As long as `next_value_index`, `capacity_indices` and
545        // `capacity_values` are correct, this deallocaion is safe.
546        unsafe {
547            let layout = Layout::array::<Option<HashSetCell>>(self.capacity).unwrap();
548            alloc::dealloc(self.buckets.as_ptr() as *mut u8, layout);
549        }
550    }
551}
552
553impl PartialEq for HashSet {
554    fn eq(&self, other: &Self) -> bool {
555        self.capacity.eq(&other.capacity)
556            && self.sequence_threshold.eq(&other.sequence_threshold)
557            && self.iter().eq(other.iter())
558    }
559}
560
561pub struct HashSetIterator<'a> {
562    hash_set: &'a HashSet,
563    current: usize,
564}
565
566impl<'a> Iterator for HashSetIterator<'a> {
567    type Item = (usize, &'a HashSetCell);
568
569    fn next(&mut self) -> Option<Self::Item> {
570        while self.current < self.hash_set.get_capacity() {
571            let element_index = self.current;
572            self.current += 1;
573
574            if let Some(Some(cur_element)) = self.hash_set.get_bucket(element_index) {
575                return Some((element_index, cur_element));
576            }
577        }
578        None
579    }
580}
581
582#[cfg(test)]
583mod test {
584    use ark_bn254::Fr;
585    use ark_ff::UniformRand;
586    use rand::{thread_rng, Rng};
587
588    use super::*;
589    use crate::zero_copy::HashSetZeroCopy;
590
591    #[test]
592    fn test_is_valid() {
593        let mut rng = thread_rng();
594
595        let cell = HashSetCell {
596            value: [0u8; 32],
597            sequence_number: None,
598        };
599        // It should be always valid, no matter the sequence number.
600        assert!(cell.is_valid(0));
601        for _ in 0..100 {
602            let seq: usize = rng.gen();
603            assert!(cell.is_valid(seq));
604        }
605
606        let cell = HashSetCell {
607            value: [0u8; 32],
608            sequence_number: Some(2400),
609        };
610        // Sequence numbers up to 2400 should succeed.
611        for i in 0..2400 {
612            assert!(cell.is_valid(i));
613        }
614        for i in 2400..10000 {
615            assert!(!cell.is_valid(i));
616        }
617    }
618
619    /// Manual test cases. A simple check whether basic properties of the hash
620    /// set work.
621    #[test]
622    fn test_hash_set_manual() {
623        let mut hs = HashSet::new(256, 4).unwrap();
624
625        // Insert an element and immediately mark it with a sequence number.
626        // An equivalent to a single insertion in Light Protocol
627        let element_1_1 = 1.to_biguint().unwrap();
628        let index_1_1 = hs.insert(&element_1_1, 0).unwrap();
629        hs.mark_with_sequence_number(index_1_1, 1).unwrap();
630
631        // Check if element exists in the set.
632        assert!(hs.contains(&element_1_1, Some(1)).unwrap());
633        // Try inserting the same element, even though we didn't reach the
634        // threshold.
635        assert!(matches!(
636            hs.insert(&element_1_1, 1),
637            Err(HashSetError::ElementAlreadyExists)
638        ));
639
640        // Insert multiple elements and mark them with one sequence number.
641        // An equivalent to a batched insertion in Light Protocol.
642
643        let element_2_3 = 3.to_biguint().unwrap();
644        let element_2_6 = 6.to_biguint().unwrap();
645        let element_2_8 = 8.to_biguint().unwrap();
646        let element_2_9 = 9.to_biguint().unwrap();
647        let index_2_3 = hs.insert(&element_2_3, 1).unwrap();
648        let index_2_6 = hs.insert(&element_2_6, 1).unwrap();
649        let index_2_8 = hs.insert(&element_2_8, 1).unwrap();
650        let index_2_9 = hs.insert(&element_2_9, 1).unwrap();
651        assert!(hs.contains(&element_2_3, Some(2)).unwrap());
652        assert!(hs.contains(&element_2_6, Some(2)).unwrap());
653        assert!(hs.contains(&element_2_8, Some(2)).unwrap());
654        assert!(hs.contains(&element_2_9, Some(2)).unwrap());
655        hs.mark_with_sequence_number(index_2_3, 2).unwrap();
656        hs.mark_with_sequence_number(index_2_6, 2).unwrap();
657        hs.mark_with_sequence_number(index_2_8, 2).unwrap();
658        hs.mark_with_sequence_number(index_2_9, 2).unwrap();
659        assert!(matches!(
660            hs.insert(&element_2_3, 2),
661            Err(HashSetError::ElementAlreadyExists)
662        ));
663        assert!(matches!(
664            hs.insert(&element_2_6, 2),
665            Err(HashSetError::ElementAlreadyExists)
666        ));
667        assert!(matches!(
668            hs.insert(&element_2_8, 2),
669            Err(HashSetError::ElementAlreadyExists)
670        ));
671        assert!(matches!(
672            hs.insert(&element_2_9, 2),
673            Err(HashSetError::ElementAlreadyExists)
674        ));
675
676        let element_3_11 = 11.to_biguint().unwrap();
677        let element_3_13 = 13.to_biguint().unwrap();
678        let element_3_21 = 21.to_biguint().unwrap();
679        let element_3_29 = 29.to_biguint().unwrap();
680        let index_3_11 = hs.insert(&element_3_11, 2).unwrap();
681        let index_3_13 = hs.insert(&element_3_13, 2).unwrap();
682        let index_3_21 = hs.insert(&element_3_21, 2).unwrap();
683        let index_3_29 = hs.insert(&element_3_29, 2).unwrap();
684        assert!(hs.contains(&element_3_11, Some(3)).unwrap());
685        assert!(hs.contains(&element_3_13, Some(3)).unwrap());
686        assert!(hs.contains(&element_3_21, Some(3)).unwrap());
687        assert!(hs.contains(&element_3_29, Some(3)).unwrap());
688        hs.mark_with_sequence_number(index_3_11, 3).unwrap();
689        hs.mark_with_sequence_number(index_3_13, 3).unwrap();
690        hs.mark_with_sequence_number(index_3_21, 3).unwrap();
691        hs.mark_with_sequence_number(index_3_29, 3).unwrap();
692        assert!(matches!(
693            hs.insert(&element_3_11, 3),
694            Err(HashSetError::ElementAlreadyExists)
695        ));
696        assert!(matches!(
697            hs.insert(&element_3_13, 3),
698            Err(HashSetError::ElementAlreadyExists)
699        ));
700        assert!(matches!(
701            hs.insert(&element_3_21, 3),
702            Err(HashSetError::ElementAlreadyExists)
703        ));
704        assert!(matches!(
705            hs.insert(&element_3_29, 3),
706            Err(HashSetError::ElementAlreadyExists)
707        ));
708
709        let element_4_93 = 93.to_biguint().unwrap();
710        let element_4_65 = 64.to_biguint().unwrap();
711        let element_4_72 = 72.to_biguint().unwrap();
712        let element_4_15 = 15.to_biguint().unwrap();
713        let index_4_93 = hs.insert(&element_4_93, 3).unwrap();
714        let index_4_65 = hs.insert(&element_4_65, 3).unwrap();
715        let index_4_72 = hs.insert(&element_4_72, 3).unwrap();
716        let index_4_15 = hs.insert(&element_4_15, 3).unwrap();
717        assert!(hs.contains(&element_4_93, Some(4)).unwrap());
718        assert!(hs.contains(&element_4_65, Some(4)).unwrap());
719        assert!(hs.contains(&element_4_72, Some(4)).unwrap());
720        assert!(hs.contains(&element_4_15, Some(4)).unwrap());
721        hs.mark_with_sequence_number(index_4_93, 4).unwrap();
722        hs.mark_with_sequence_number(index_4_65, 4).unwrap();
723        hs.mark_with_sequence_number(index_4_72, 4).unwrap();
724        hs.mark_with_sequence_number(index_4_15, 4).unwrap();
725
726        // Try inserting the same elements we inserted before.
727        //
728        // Ones with the sequence number difference lower or equal to the
729        // sequence threshold (4) will fail.
730        //
731        // Ones with the higher dif will succeed.
732        assert!(matches!(
733            hs.insert(&element_1_1, 4),
734            Err(HashSetError::ElementAlreadyExists)
735        ));
736        assert!(matches!(
737            hs.insert(&element_2_3, 5),
738            Err(HashSetError::ElementAlreadyExists)
739        ));
740        assert!(matches!(
741            hs.insert(&element_2_6, 5),
742            Err(HashSetError::ElementAlreadyExists)
743        ));
744        assert!(matches!(
745            hs.insert(&element_2_8, 5),
746            Err(HashSetError::ElementAlreadyExists)
747        ));
748        assert!(matches!(
749            hs.insert(&element_2_9, 5),
750            Err(HashSetError::ElementAlreadyExists)
751        ));
752        hs.insert(&element_1_1, 5).unwrap();
753        hs.insert(&element_2_3, 6).unwrap();
754        hs.insert(&element_2_6, 6).unwrap();
755        hs.insert(&element_2_8, 6).unwrap();
756        hs.insert(&element_2_9, 6).unwrap();
757    }
758
759    /// Test cases with random prime field elements.
760    #[test]
761    fn test_hash_set_random() {
762        let mut hs = HashSet::new(6857, 2400).unwrap();
763
764        // The hash set should be empty.
765        assert_eq!(hs.first(0).unwrap(), None);
766        let mut rng = thread_rng();
767        let mut seq = 0;
768        let nullifiers: [BigUint; 10000] =
769            std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
770        for nf_chunk in nullifiers.chunks(2400) {
771            for nullifier in nf_chunk.iter() {
772                assert!(!hs.contains(nullifier, Some(seq)).unwrap());
773                let index = hs.insert(nullifier, seq).unwrap();
774                assert!(hs.contains(nullifier, Some(seq)).unwrap());
775
776                let nullifier_bytes = bigint_to_be_bytes_array(nullifier).unwrap();
777
778                let element = *hs.find_element(nullifier, Some(seq)).unwrap().unwrap().0;
779                assert_eq!(
780                    element,
781                    HashSetCell {
782                        value: bigint_to_be_bytes_array(nullifier).unwrap(),
783                        sequence_number: None,
784                    }
785                );
786                assert_eq!(element.value_bytes(), nullifier_bytes);
787                assert_eq!(&element.value_biguint(), nullifier);
788                assert_eq!(element.sequence_number(), None);
789                assert!(!element.is_marked());
790                assert!(element.is_valid(seq));
791
792                hs.mark_with_sequence_number(index, seq).unwrap();
793                let element = *hs.find_element(nullifier, Some(seq)).unwrap().unwrap().0;
794
795                assert_eq!(
796                    element,
797                    HashSetCell {
798                        value: nullifier_bytes,
799                        sequence_number: Some(2400 + seq)
800                    }
801                );
802                assert_eq!(element.value_bytes(), nullifier_bytes);
803                assert_eq!(&element.value_biguint(), nullifier);
804                assert_eq!(element.sequence_number(), Some(2400 + seq));
805                assert!(element.is_marked());
806                assert!(element.is_valid(seq));
807
808                // Trying to insert the same nullifier, before reaching the
809                // sequence threshold, should fail.
810                assert!(matches!(
811                    hs.insert(nullifier, seq + 2399),
812                    Err(HashSetError::ElementAlreadyExists),
813                ));
814                seq += 1;
815            }
816            seq += 2400;
817        }
818    }
819
820    fn hash_set_from_bytes_copy<
821        const CAPACITY: usize,
822        const SEQUENCE_THRESHOLD: usize,
823        const OPERATIONS: usize,
824    >() {
825        let mut hs_1 = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
826
827        let mut rng = thread_rng();
828
829        // Create a buffer with random bytes.
830        let mut bytes = vec![0u8; HashSet::size_in_account(CAPACITY)];
831        rng.fill(bytes.as_mut_slice());
832
833        // Initialize a hash set on top of a byte slice.
834        {
835            let mut hs_2 = unsafe {
836                HashSetZeroCopy::from_bytes_zero_copy_init(&mut bytes, CAPACITY, SEQUENCE_THRESHOLD)
837                    .unwrap()
838            };
839
840            for seq in 0..OPERATIONS {
841                let value = BigUint::from(Fr::rand(&mut rng));
842                hs_1.insert(&value, seq).unwrap();
843                hs_2.insert(&value, seq).unwrap();
844            }
845
846            assert_eq!(hs_1, *hs_2);
847        }
848
849        // Create a copy on top of a byte slice.
850        {
851            let hs_2 = unsafe { HashSet::from_bytes_copy(&mut bytes).unwrap() };
852            assert_eq!(hs_1, hs_2);
853        }
854    }
855
856    #[test]
857    fn test_hash_set_from_bytes_copy_6857_2400_3600() {
858        hash_set_from_bytes_copy::<6857, 2400, 3600>()
859    }
860
861    #[test]
862    fn test_hash_set_from_bytes_copy_9601_2400_5000() {
863        hash_set_from_bytes_copy::<9601, 2400, 5000>()
864    }
865
866    fn hash_set_full<const CAPACITY: usize, const SEQUENCE_THRESHOLD: usize>() {
867        for _ in 0..100 {
868            let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
869
870            let mut rng = rand::thread_rng();
871
872            // Insert as many values as possible. The important point is to
873            // encounter the `HashSetError::Full` at some point
874            for i in 0..CAPACITY {
875                let value = BigUint::from(Fr::rand(&mut rng));
876                match hs.insert(&value, 0) {
877                    Ok(index) => hs.mark_with_sequence_number(index, 0).unwrap(),
878                    Err(e) => {
879                        assert!(matches!(e, HashSetError::Full));
880                        println!("initial insertions: {i}: failed, stopping");
881                        break;
882                    }
883                }
884            }
885
886            // Keep inserting. It should mostly fail, although there might be
887            // also some successful insertions - there might be values which
888            // will end up in unused buckets.
889            for i in 0..1000 {
890                let value = BigUint::from(Fr::rand(&mut rng));
891                let res = hs.insert(&value, 0);
892                if res.is_err() {
893                    assert!(matches!(res, Err(HashSetError::Full)));
894                } else {
895                    println!("secondary insertions: {i}: apparent success with value: {value:?}");
896                }
897            }
898
899            // Try again with defined sequence numbers, but still too small to
900            // vacate any cell.
901            for i in 0..1000 {
902                let value = BigUint::from(Fr::rand(&mut rng));
903                // Sequence numbers lower than the threshold should not vacate
904                // any cell.
905                let sequence_number = rng.gen_range(0..hs.sequence_threshold);
906                let res = hs.insert(&value, sequence_number);
907                if res.is_err() {
908                    assert!(matches!(res, Err(HashSetError::Full)));
909                } else {
910                    println!("tertiary insertions: {i}: surprising success with value: {value:?}");
911                }
912            }
913
914            // Use sequence numbers which are going to vacate cells. All
915            // insertions should be successful now.
916            for i in 0..CAPACITY {
917                let value = BigUint::from(Fr::rand(&mut rng));
918                if let Err(e) = hs.insert(&value, SEQUENCE_THRESHOLD + i) {
919                    assert!(matches!(e, HashSetError::Full));
920                    println!("insertions after fillup: {i}: failed, stopping");
921                    break;
922                }
923            }
924        }
925    }
926
927    #[test]
928    fn test_hash_set_full_6857_2400() {
929        hash_set_full::<6857, 2400>()
930    }
931
932    #[test]
933    fn test_hash_set_full_9601_2400() {
934        hash_set_full::<9601, 2400>()
935    }
936
937    #[test]
938    fn test_hash_set_element_does_not_exist() {
939        let mut hs = HashSet::new(4800, 2400).unwrap();
940
941        let mut rng = thread_rng();
942
943        for _ in 0..1000 {
944            let index = rng.gen_range(0..4800);
945
946            // Assert `ElementDoesNotExist` error.
947            let res = hs.mark_with_sequence_number(index, 0);
948            assert!(matches!(res, Err(HashSetError::ElementDoesNotExist)));
949        }
950
951        for _ in 0..1000 {
952            // After actually appending the value, the same operation should be
953            // possible
954            let value = BigUint::from(Fr::rand(&mut rng));
955            let index = hs.insert(&value, 0).unwrap();
956            hs.mark_with_sequence_number(index, 1).unwrap();
957        }
958    }
959
960    #[test]
961    fn test_hash_set_iter_manual() {
962        let mut hs = HashSet::new(6857, 2400).unwrap();
963
964        let nullifier_1 = 945635_u32.to_biguint().unwrap();
965        let nullifier_2 = 3546656654734254353455_u128.to_biguint().unwrap();
966        let nullifier_3 = 543543656564_u64.to_biguint().unwrap();
967        let nullifier_4 = 43_u8.to_biguint().unwrap();
968        let nullifier_5 = 0_u8.to_biguint().unwrap();
969        let nullifier_6 = 65423_u32.to_biguint().unwrap();
970        let nullifier_7 = 745654665_u32.to_biguint().unwrap();
971        let nullifier_8 = 97664353453465354645645465_u128.to_biguint().unwrap();
972        let nullifier_9 = 453565465464565635475_u128.to_biguint().unwrap();
973        let nullifier_10 = 543645654645_u64.to_biguint().unwrap();
974
975        hs.insert(&nullifier_1, 0).unwrap();
976        hs.insert(&nullifier_2, 0).unwrap();
977        hs.insert(&nullifier_3, 0).unwrap();
978        hs.insert(&nullifier_4, 0).unwrap();
979        hs.insert(&nullifier_5, 0).unwrap();
980        hs.insert(&nullifier_6, 0).unwrap();
981        hs.insert(&nullifier_7, 0).unwrap();
982        hs.insert(&nullifier_8, 0).unwrap();
983        hs.insert(&nullifier_9, 0).unwrap();
984        hs.insert(&nullifier_10, 0).unwrap();
985
986        let inserted_nullifiers = hs
987            .iter()
988            .map(|(_, nullifier)| nullifier.value_biguint())
989            .collect::<Vec<_>>();
990        assert_eq!(inserted_nullifiers.len(), 10);
991        assert_eq!(inserted_nullifiers[0], nullifier_7);
992        assert_eq!(inserted_nullifiers[1], nullifier_3);
993        assert_eq!(inserted_nullifiers[2], nullifier_10);
994        assert_eq!(inserted_nullifiers[3], nullifier_1);
995        assert_eq!(inserted_nullifiers[4], nullifier_8);
996        assert_eq!(inserted_nullifiers[5], nullifier_5);
997        assert_eq!(inserted_nullifiers[6], nullifier_4);
998        assert_eq!(inserted_nullifiers[7], nullifier_2);
999        assert_eq!(inserted_nullifiers[8], nullifier_9);
1000        assert_eq!(inserted_nullifiers[9], nullifier_6);
1001    }
1002
1003    fn hash_set_iter_random<
1004        const INSERTIONS: usize,
1005        const CAPACITY: usize,
1006        const SEQUENCE_THRESHOLD: usize,
1007    >() {
1008        let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
1009        let mut rng = thread_rng();
1010
1011        let nullifiers: [BigUint; INSERTIONS] =
1012            std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
1013
1014        for nullifier in nullifiers.iter() {
1015            hs.insert(nullifier, 0).unwrap();
1016        }
1017
1018        let mut sorted_nullifiers = nullifiers.iter().collect::<Vec<_>>();
1019        let mut inserted_nullifiers = hs
1020            .iter()
1021            .map(|(_, nullifier)| nullifier.value_biguint())
1022            .collect::<Vec<_>>();
1023        sorted_nullifiers.sort();
1024        inserted_nullifiers.sort();
1025
1026        let inserted_nullifiers = inserted_nullifiers.iter().collect::<Vec<&BigUint>>();
1027        assert_eq!(inserted_nullifiers.len(), INSERTIONS);
1028        assert_eq!(sorted_nullifiers.as_slice(), inserted_nullifiers.as_slice());
1029    }
1030
1031    #[test]
1032    fn test_hash_set_iter_random_6857_2400() {
1033        hash_set_iter_random::<3500, 6857, 2400>()
1034    }
1035
1036    #[test]
1037    fn test_hash_set_iter_random_9601_2400() {
1038        hash_set_iter_random::<5000, 9601, 2400>()
1039    }
1040
1041    #[test]
1042    fn test_hash_set_get_bucket() {
1043        let mut hs = HashSet::new(6857, 2400).unwrap();
1044
1045        for i in 0..3600 {
1046            let bn_i = i.to_biguint().unwrap();
1047            hs.insert(&bn_i, i).unwrap();
1048        }
1049        let mut unused_indices = vec![true; 6857];
1050        for i in 0..3600 {
1051            let bn_i = i.to_biguint().unwrap();
1052            let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
1053            let element = hs.get_bucket(i).unwrap().unwrap();
1054            assert_eq!(element.value_biguint(), bn_i);
1055            unused_indices[i] = false;
1056        }
1057        // Unused cells within the capacity should be `Some(None)`.
1058        for i in unused_indices.iter().enumerate() {
1059            if *i.1 {
1060                assert!(hs.get_bucket(i.0).unwrap().is_none());
1061            }
1062        }
1063        // Cells over the capacity should be `None`.
1064        for i in 6857..10_000 {
1065            assert!(hs.get_bucket(i).is_none());
1066        }
1067    }
1068
1069    #[test]
1070    fn test_hash_set_get_bucket_mut() {
1071        let mut hs = HashSet::new(6857, 2400).unwrap();
1072
1073        for i in 0..3600 {
1074            let bn_i = i.to_biguint().unwrap();
1075            hs.insert(&bn_i, i).unwrap();
1076        }
1077        let mut unused_indices = vec![false; 6857];
1078
1079        for i in 0..3600 {
1080            let bn_i = i.to_biguint().unwrap();
1081            let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
1082
1083            let element = hs.get_bucket_mut(i).unwrap();
1084            assert_eq!(element.unwrap().value_biguint(), bn_i);
1085            unused_indices[i] = true;
1086
1087            // "Nullify" the element.
1088            *element = Some(HashSetCell {
1089                value: [0_u8; 32],
1090                sequence_number: None,
1091            });
1092        }
1093
1094        for (i, is_used) in unused_indices.iter().enumerate() {
1095            if *is_used {
1096                let element = hs.get_bucket_mut(i).unwrap().unwrap();
1097                assert_eq!(element.value_bytes(), [0_u8; 32]);
1098            }
1099        }
1100        // Unused cells within the capacity should be `Some(None)`.
1101        for (i, is_used) in unused_indices.iter().enumerate() {
1102            if !*is_used {
1103                assert!(hs.get_bucket_mut(i).unwrap().is_none());
1104            }
1105        }
1106        // Cells over the capacity should be `None`.
1107        for i in 6857..10_000 {
1108            assert!(hs.get_bucket_mut(i).is_none());
1109        }
1110    }
1111
1112    #[test]
1113    fn test_hash_set_get_unmarked_bucket() {
1114        let mut hs = HashSet::new(6857, 2400).unwrap();
1115
1116        // Insert incremental elements, so they end up being in the same
1117        // sequence in the hash set.
1118        (0..3600).for_each(|i| {
1119            let bn_i = i.to_biguint().unwrap();
1120            hs.insert(&bn_i, i).unwrap();
1121        });
1122
1123        for i in 0..3600 {
1124            let i = hs
1125                .find_element_index(&i.to_biguint().unwrap(), None)
1126                .unwrap()
1127                .unwrap();
1128            let element = hs.get_unmarked_bucket(i);
1129            assert!(element.is_some());
1130        }
1131
1132        // Mark the elements.
1133        for i in 0..3600 {
1134            let index = hs
1135                .find_element_index(&i.to_biguint().unwrap(), None)
1136                .unwrap()
1137                .unwrap();
1138            hs.mark_with_sequence_number(index, i).unwrap();
1139        }
1140
1141        for i in 0..3600 {
1142            let i = hs
1143                .find_element_index(&i.to_biguint().unwrap(), None)
1144                .unwrap()
1145                .unwrap();
1146            let element = hs.get_unmarked_bucket(i);
1147            assert!(element.is_none());
1148        }
1149    }
1150
1151    #[test]
1152    fn test_hash_set_first_no_seq() {
1153        let mut hs = HashSet::new(6857, 2400).unwrap();
1154
1155        // Insert incremental elements, so they end up being in the same
1156        // sequence in the hash set.
1157        for i in 0..3600 {
1158            let bn_i = i.to_biguint().unwrap();
1159            hs.insert(&bn_i, i).unwrap();
1160
1161            let element = hs.first_no_seq().unwrap().unwrap();
1162            assert_eq!(element.0.value_biguint(), 0.to_biguint().unwrap());
1163        }
1164    }
1165}