use light_utils::{bigint::bigint_to_be_bytes_array, UtilsError};
use num_bigint::{BigUint, ToBigUint};
use num_traits::{FromBytes, ToPrimitive};
use std::{
alloc::{self, handle_alloc_error, Layout},
cmp::Ordering,
marker::Send,
mem,
ptr::NonNull,
};
use thiserror::Error;
pub mod zero_copy;
pub const ITERATIONS: usize = 20;
#[derive(Debug, Error, PartialEq)]
pub enum HashSetError {
#[error("The hash set is full, cannot add any new elements")]
Full,
#[error("The provided element is already in the hash set")]
ElementAlreadyExists,
#[error("The provided element doesn't exist in the hash set")]
ElementDoesNotExist,
#[error("Could not convert the index from/to usize")]
UsizeConv,
#[error("Integer overflow")]
IntegerOverflow,
#[error("Invalid buffer size, expected {0}, got {1}")]
BufferSize(usize, usize),
#[error("Utils: big integer conversion error")]
Utils(#[from] UtilsError),
}
#[cfg(feature = "solana")]
impl From<HashSetError> for u32 {
fn from(e: HashSetError) -> u32 {
match e {
HashSetError::Full => 9001,
HashSetError::ElementAlreadyExists => 9002,
HashSetError::ElementDoesNotExist => 9003,
HashSetError::UsizeConv => 9004,
HashSetError::IntegerOverflow => 9005,
HashSetError::BufferSize(_, _) => 9006,
HashSetError::Utils(e) => e.into(),
}
}
}
#[cfg(feature = "solana")]
impl From<HashSetError> for solana_program::program_error::ProgramError {
fn from(e: HashSetError) -> Self {
solana_program::program_error::ProgramError::Custom(e.into())
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct HashSetCell {
pub value: [u8; 32],
pub sequence_number: Option<usize>,
}
unsafe impl Send for HashSet {}
impl HashSetCell {
pub fn value_bytes(&self) -> [u8; 32] {
self.value
}
pub fn value_biguint(&self) -> BigUint {
BigUint::from_bytes_be(self.value.as_slice())
}
pub fn sequence_number(&self) -> Option<usize> {
self.sequence_number
}
pub fn is_marked(&self) -> bool {
self.sequence_number.is_some()
}
pub fn is_valid(&self, current_sequence_number: usize) -> bool {
match self.sequence_number {
Some(sequence_number) => match sequence_number.cmp(¤t_sequence_number) {
Ordering::Less | Ordering::Equal => false,
Ordering::Greater => true,
},
None => true,
}
}
}
#[derive(Debug)]
pub struct HashSet {
pub capacity: usize,
pub sequence_threshold: usize,
buckets: NonNull<Option<HashSetCell>>,
}
unsafe impl Send for HashSetCell {}
impl HashSet {
pub fn non_dyn_fields_size() -> usize {
mem::size_of::<usize>()
+ mem::size_of::<usize>()
}
pub fn size_in_account(capacity_values: usize) -> usize {
let dyn_fields_size = Self::non_dyn_fields_size();
let buckets_size_unaligned = mem::size_of::<Option<HashSetCell>>() * capacity_values;
let buckets_size = buckets_size_unaligned + mem::align_of::<usize>()
- (buckets_size_unaligned % mem::align_of::<usize>());
dyn_fields_size + buckets_size
}
pub fn new(capacity_values: usize, sequence_threshold: usize) -> Result<Self, HashSetError> {
let layout = Layout::array::<Option<HashSetCell>>(capacity_values).unwrap();
let values_ptr = unsafe { alloc::alloc(layout) as *mut Option<HashSetCell> };
if values_ptr.is_null() {
handle_alloc_error(layout);
}
let values = NonNull::new(values_ptr).unwrap();
for i in 0..capacity_values {
unsafe {
std::ptr::write(values_ptr.add(i), None);
}
}
Ok(HashSet {
sequence_threshold,
capacity: capacity_values,
buckets: values,
})
}
pub unsafe fn from_bytes_copy(bytes: &mut [u8]) -> Result<Self, HashSetError> {
if bytes.len() < Self::non_dyn_fields_size() {
return Err(HashSetError::BufferSize(
Self::non_dyn_fields_size(),
bytes.len(),
));
}
let capacity = usize::from_ne_bytes(bytes[0..8].try_into().unwrap());
let sequence_threshold = usize::from_ne_bytes(bytes[8..16].try_into().unwrap());
let expected_size = Self::size_in_account(capacity);
if bytes.len() != expected_size {
return Err(HashSetError::BufferSize(expected_size, bytes.len()));
}
let buckets_layout = Layout::array::<Option<HashSetCell>>(capacity).unwrap();
let buckets_dst_ptr = unsafe { alloc::alloc(buckets_layout) as *mut Option<HashSetCell> };
if buckets_dst_ptr.is_null() {
handle_alloc_error(buckets_layout);
}
let buckets = NonNull::new(buckets_dst_ptr).unwrap();
for i in 0..capacity {
std::ptr::write(buckets_dst_ptr.add(i), None);
}
let offset = Self::non_dyn_fields_size() + mem::size_of::<usize>();
let buckets_src_ptr = bytes.as_ptr().add(offset) as *const Option<HashSetCell>;
std::ptr::copy(buckets_src_ptr, buckets_dst_ptr, capacity);
Ok(Self {
capacity,
sequence_threshold,
buckets,
})
}
fn probe_index(&self, value: &BigUint, iteration: usize) -> usize {
let iteration = iteration + self.capacity / 10;
let probe_index = (value
+ iteration.to_biguint().unwrap() * iteration.to_biguint().unwrap())
% self.capacity.to_biguint().unwrap();
probe_index.to_usize().unwrap()
}
pub fn get_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
if index >= self.capacity {
return None;
}
let bucket = unsafe { &*self.buckets.as_ptr().add(index) };
Some(bucket)
}
pub fn get_bucket_mut(&mut self, index: usize) -> Option<&mut Option<HashSetCell>> {
if index >= self.capacity {
return None;
}
let bucket = unsafe { &mut *self.buckets.as_ptr().add(index) };
Some(bucket)
}
pub fn get_unmarked_bucket(&self, index: usize) -> Option<&Option<HashSetCell>> {
let bucket = self.get_bucket(index);
let is_unmarked = match bucket {
Some(Some(bucket)) => !bucket.is_marked(),
Some(None) => false,
None => false,
};
if is_unmarked {
bucket
} else {
None
}
}
fn insert_into_occupied_cell(
&mut self,
value_index: usize,
value: &BigUint,
current_sequence_number: usize,
) -> Result<bool, HashSetError> {
let bucket = self.get_bucket_mut(value_index).unwrap();
match bucket {
Some(bucket) => {
if let Some(element_sequence_number) = bucket.sequence_number {
if current_sequence_number >= element_sequence_number {
*bucket = HashSetCell {
value: bigint_to_be_bytes_array(value)?,
sequence_number: None,
};
return Ok(true);
}
}
if &BigUint::from_be_bytes(bucket.value.as_slice()) == value {
return Err(HashSetError::ElementAlreadyExists);
}
}
None => unreachable!(),
}
Ok(false)
}
pub fn insert(
&mut self,
value: &BigUint,
current_sequence_number: usize,
) -> Result<usize, HashSetError> {
let index_bucket = self.find_element_iter(value, current_sequence_number, 0, ITERATIONS)?;
let (index, is_new) = match index_bucket {
Some(index) => index,
None => {
return Err(HashSetError::Full);
}
};
match is_new {
false => {
if self.insert_into_occupied_cell(index, value, current_sequence_number)? {
return Ok(index);
}
}
true => {
let bucket = self.get_bucket_mut(index).unwrap();
*bucket = Some(HashSetCell {
value: bigint_to_be_bytes_array(value)?,
sequence_number: None,
});
return Ok(index);
}
}
Err(HashSetError::Full)
}
pub fn find_element_index(
&self,
value: &BigUint,
current_sequence_number: Option<usize>,
) -> Result<Option<usize>, HashSetError> {
for i in 0..ITERATIONS {
let probe_index = self.probe_index(value, i);
let bucket = self.get_bucket(probe_index).unwrap();
match bucket {
Some(bucket) => {
if &bucket.value_biguint() == value {
match current_sequence_number {
Some(current_sequence_number) => {
if bucket.is_valid(current_sequence_number) {
return Ok(Some(probe_index));
}
continue;
}
None => return Ok(Some(probe_index)),
}
}
continue;
}
None => {
return Ok(None);
}
}
}
Ok(None)
}
pub fn find_element(
&self,
value: &BigUint,
current_sequence_number: Option<usize>,
) -> Result<Option<(&HashSetCell, usize)>, HashSetError> {
let index = self.find_element_index(value, current_sequence_number)?;
match index {
Some(index) => {
let bucket = self.get_bucket(index).unwrap();
match bucket {
Some(bucket) => Ok(Some((bucket, index))),
None => Ok(None),
}
}
None => Ok(None),
}
}
pub fn find_element_mut(
&mut self,
value: &BigUint,
current_sequence_number: Option<usize>,
) -> Result<Option<(&mut HashSetCell, usize)>, HashSetError> {
let index = self.find_element_index(value, current_sequence_number)?;
match index {
Some(index) => {
let bucket = self.get_bucket_mut(index).unwrap();
match bucket {
Some(bucket) => Ok(Some((bucket, index))),
None => Ok(None),
}
}
None => Ok(None),
}
}
pub fn find_element_iter(
&mut self,
value: &BigUint,
current_sequence_number: usize,
start_iter: usize,
num_iterations: usize,
) -> Result<Option<(usize, bool)>, HashSetError> {
let mut first_free_element: Option<(usize, bool)> = None;
for i in start_iter..start_iter + num_iterations {
let probe_index = self.probe_index(value, i);
let bucket = self.get_bucket(probe_index).unwrap();
match bucket {
Some(bucket) => {
let is_valid = bucket.is_valid(current_sequence_number);
if first_free_element.is_none() && !is_valid {
first_free_element = Some((probe_index, false));
}
if is_valid && &bucket.value_biguint() == value {
return Err(HashSetError::ElementAlreadyExists);
} else {
continue;
}
}
None => {
if first_free_element.is_none() {
first_free_element = Some((probe_index, true));
}
break;
}
}
}
Ok(first_free_element)
}
pub fn first(
&self,
current_sequence_number: usize,
) -> Result<Option<&HashSetCell>, HashSetError> {
for i in 0..self.capacity {
let bucket = self.get_bucket(i).unwrap();
if let Some(bucket) = bucket {
if bucket.is_valid(current_sequence_number) {
return Ok(Some(bucket));
}
}
}
Ok(None)
}
pub fn first_no_seq(&self) -> Result<Option<(HashSetCell, u16)>, HashSetError> {
for i in 0..self.capacity {
let bucket = self.get_bucket(i).unwrap();
if let Some(bucket) = bucket {
if bucket.sequence_number.is_none() {
return Ok(Some((*bucket, i as u16)));
}
}
}
Ok(None)
}
pub fn contains(
&self,
value: &BigUint,
sequence_number: Option<usize>,
) -> Result<bool, HashSetError> {
let element = self.find_element(value, sequence_number)?;
Ok(element.is_some())
}
pub fn mark_with_sequence_number(
&mut self,
index: usize,
sequence_number: usize,
) -> Result<(), HashSetError> {
let sequence_threshold = self.sequence_threshold;
let element = self
.get_bucket_mut(index)
.ok_or(HashSetError::ElementDoesNotExist)?;
match element {
Some(element) => {
element.sequence_number = Some(sequence_number + sequence_threshold);
Ok(())
}
None => Err(HashSetError::ElementDoesNotExist),
}
}
pub fn iter(&self) -> HashSetIterator {
HashSetIterator {
hash_set: self,
current: 0,
}
}
}
impl Drop for HashSet {
fn drop(&mut self) {
unsafe {
let layout = Layout::array::<Option<HashSetCell>>(self.capacity).unwrap();
alloc::dealloc(self.buckets.as_ptr() as *mut u8, layout);
}
}
}
impl PartialEq for HashSet {
fn eq(&self, other: &Self) -> bool {
self.capacity.eq(&other.capacity)
&& self.sequence_threshold.eq(&other.sequence_threshold)
&& self.iter().eq(other.iter())
}
}
pub struct HashSetIterator<'a> {
hash_set: &'a HashSet,
current: usize,
}
impl<'a> Iterator for HashSetIterator<'a> {
type Item = (usize, &'a HashSetCell);
fn next(&mut self) -> Option<Self::Item> {
while self.current < self.hash_set.capacity {
let element_index = self.current;
self.current += 1;
if let Some(Some(cur_element)) = self.hash_set.get_bucket(element_index) {
return Some((element_index, cur_element));
}
}
None
}
}
#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_ff::UniformRand;
use rand::{thread_rng, Rng};
use crate::zero_copy::HashSetZeroCopy;
use super::*;
#[test]
fn test_is_valid() {
let mut rng = thread_rng();
let cell = HashSetCell {
value: [0u8; 32],
sequence_number: None,
};
assert_eq!(cell.is_valid(0), true);
for _ in 0..100 {
let seq: usize = rng.gen();
assert_eq!(cell.is_valid(seq), true);
}
let cell = HashSetCell {
value: [0u8; 32],
sequence_number: Some(2400),
};
for i in 0..2400 {
assert_eq!(cell.is_valid(i), true);
}
for i in 2400..10000 {
assert_eq!(cell.is_valid(i), false);
}
}
#[test]
fn test_hash_set_manual() {
let mut hs = HashSet::new(256, 4).unwrap();
let element_1_1 = 1.to_biguint().unwrap();
let index_1_1 = hs.insert(&element_1_1, 0).unwrap();
hs.mark_with_sequence_number(index_1_1, 1).unwrap();
assert_eq!(hs.contains(&element_1_1, Some(1)).unwrap(), true);
assert!(matches!(
hs.insert(&element_1_1, 1),
Err(HashSetError::ElementAlreadyExists)
));
let element_2_3 = 3.to_biguint().unwrap();
let element_2_6 = 6.to_biguint().unwrap();
let element_2_8 = 8.to_biguint().unwrap();
let element_2_9 = 9.to_biguint().unwrap();
let index_2_3 = hs.insert(&element_2_3, 1).unwrap();
let index_2_6 = hs.insert(&element_2_6, 1).unwrap();
let index_2_8 = hs.insert(&element_2_8, 1).unwrap();
let index_2_9 = hs.insert(&element_2_9, 1).unwrap();
assert_eq!(hs.contains(&element_2_3, Some(2)).unwrap(), true);
assert_eq!(hs.contains(&element_2_6, Some(2)).unwrap(), true);
assert_eq!(hs.contains(&element_2_8, Some(2)).unwrap(), true);
assert_eq!(hs.contains(&element_2_9, Some(2)).unwrap(), true);
hs.mark_with_sequence_number(index_2_3, 2).unwrap();
hs.mark_with_sequence_number(index_2_6, 2).unwrap();
hs.mark_with_sequence_number(index_2_8, 2).unwrap();
hs.mark_with_sequence_number(index_2_9, 2).unwrap();
assert!(matches!(
hs.insert(&element_2_3, 2),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_6, 2),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_8, 2),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_9, 2),
Err(HashSetError::ElementAlreadyExists)
));
let element_3_11 = 11.to_biguint().unwrap();
let element_3_13 = 13.to_biguint().unwrap();
let element_3_21 = 21.to_biguint().unwrap();
let element_3_29 = 29.to_biguint().unwrap();
let index_3_11 = hs.insert(&element_3_11, 2).unwrap();
let index_3_13 = hs.insert(&element_3_13, 2).unwrap();
let index_3_21 = hs.insert(&element_3_21, 2).unwrap();
let index_3_29 = hs.insert(&element_3_29, 2).unwrap();
assert_eq!(hs.contains(&element_3_11, Some(3)).unwrap(), true);
assert_eq!(hs.contains(&element_3_13, Some(3)).unwrap(), true);
assert_eq!(hs.contains(&element_3_21, Some(3)).unwrap(), true);
assert_eq!(hs.contains(&element_3_29, Some(3)).unwrap(), true);
hs.mark_with_sequence_number(index_3_11, 3).unwrap();
hs.mark_with_sequence_number(index_3_13, 3).unwrap();
hs.mark_with_sequence_number(index_3_21, 3).unwrap();
hs.mark_with_sequence_number(index_3_29, 3).unwrap();
assert!(matches!(
hs.insert(&element_3_11, 3),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_3_13, 3),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_3_21, 3),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_3_29, 3),
Err(HashSetError::ElementAlreadyExists)
));
let element_4_93 = 93.to_biguint().unwrap();
let element_4_65 = 64.to_biguint().unwrap();
let element_4_72 = 72.to_biguint().unwrap();
let element_4_15 = 15.to_biguint().unwrap();
let index_4_93 = hs.insert(&element_4_93, 3).unwrap();
let index_4_65 = hs.insert(&element_4_65, 3).unwrap();
let index_4_72 = hs.insert(&element_4_72, 3).unwrap();
let index_4_15 = hs.insert(&element_4_15, 3).unwrap();
assert_eq!(hs.contains(&element_4_93, Some(4)).unwrap(), true);
assert_eq!(hs.contains(&element_4_65, Some(4)).unwrap(), true);
assert_eq!(hs.contains(&element_4_72, Some(4)).unwrap(), true);
assert_eq!(hs.contains(&element_4_15, Some(4)).unwrap(), true);
hs.mark_with_sequence_number(index_4_93, 4).unwrap();
hs.mark_with_sequence_number(index_4_65, 4).unwrap();
hs.mark_with_sequence_number(index_4_72, 4).unwrap();
hs.mark_with_sequence_number(index_4_15, 4).unwrap();
assert!(matches!(
hs.insert(&element_1_1, 4),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_3, 5),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_6, 5),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_8, 5),
Err(HashSetError::ElementAlreadyExists)
));
assert!(matches!(
hs.insert(&element_2_9, 5),
Err(HashSetError::ElementAlreadyExists)
));
hs.insert(&element_1_1, 5).unwrap();
hs.insert(&element_2_3, 6).unwrap();
hs.insert(&element_2_6, 6).unwrap();
hs.insert(&element_2_8, 6).unwrap();
hs.insert(&element_2_9, 6).unwrap();
}
#[test]
fn test_hash_set_random() {
let mut hs = HashSet::new(6857, 2400).unwrap();
assert_eq!(hs.first(0).unwrap(), None);
let mut rng = thread_rng();
let mut seq = 0;
let nullifiers: [BigUint; 24000] =
std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
for nf_chunk in nullifiers.chunks(2400) {
for nullifier in nf_chunk.iter() {
assert_eq!(hs.contains(&nullifier, Some(seq)).unwrap(), false);
let index = hs.insert(&nullifier, seq as usize).unwrap();
assert_eq!(hs.contains(&nullifier, Some(seq)).unwrap(), true);
let nullifier_bytes = bigint_to_be_bytes_array(&nullifier).unwrap();
let element = hs
.find_element(&nullifier, Some(seq))
.unwrap()
.unwrap()
.0
.clone();
assert_eq!(
element,
HashSetCell {
value: bigint_to_be_bytes_array(&nullifier).unwrap(),
sequence_number: None,
}
);
assert_eq!(element.value_bytes(), nullifier_bytes);
assert_eq!(&element.value_biguint(), nullifier);
assert_eq!(element.sequence_number(), None);
assert!(!element.is_marked());
assert!(element.is_valid(seq));
hs.mark_with_sequence_number(index, seq).unwrap();
let element = hs
.find_element(&nullifier, Some(seq))
.unwrap()
.unwrap()
.0
.clone();
assert_eq!(
element,
HashSetCell {
value: nullifier_bytes,
sequence_number: Some(2400 + seq)
}
);
assert_eq!(element.value_bytes(), nullifier_bytes);
assert_eq!(&element.value_biguint(), nullifier);
assert_eq!(element.sequence_number(), Some(2400 + seq));
assert!(element.is_marked());
assert!(element.is_valid(seq));
assert!(matches!(
hs.insert(&nullifier, seq as usize + 2399),
Err(HashSetError::ElementAlreadyExists),
));
seq += 1;
}
seq += 2400;
}
}
fn hash_set_from_bytes_copy<
const CAPACITY: usize,
const SEQUENCE_THRESHOLD: usize,
const OPERATIONS: usize,
>() {
let mut hs_1 = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
let mut rng = thread_rng();
let mut bytes = vec![0u8; HashSet::size_in_account(CAPACITY)];
rng.fill(bytes.as_mut_slice());
{
let mut hs_2 = unsafe {
HashSetZeroCopy::from_bytes_zero_copy_init(&mut bytes, CAPACITY, SEQUENCE_THRESHOLD)
.unwrap()
};
for seq in 0..OPERATIONS {
let value = BigUint::from(Fr::rand(&mut rng));
hs_1.insert(&value, seq).unwrap();
hs_2.insert(&value, seq).unwrap();
}
assert_eq!(hs_1, *hs_2);
}
{
let hs_2 = unsafe { HashSet::from_bytes_copy(&mut bytes).unwrap() };
assert_eq!(hs_1, hs_2);
}
}
#[test]
fn test_hash_set_from_bytes_copy_6857_2400_3600() {
hash_set_from_bytes_copy::<6857, 2400, 3600>()
}
#[test]
fn test_hash_set_from_bytes_copy_9601_2400_5000() {
hash_set_from_bytes_copy::<9601, 2400, 5000>()
}
fn hash_set_full<const CAPACITY: usize, const SEQUENCE_THRESHOLD: usize>() {
for _ in 0..100 {
let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
let mut rng = rand::thread_rng();
for i in 0..CAPACITY {
let value = BigUint::from(Fr::rand(&mut rng));
match hs.insert(&value, 0) {
Ok(index) => hs.mark_with_sequence_number(index, 0).unwrap(),
Err(e) => {
assert!(matches!(e, HashSetError::Full));
println!("initial insertions: {i}: failed, stopping");
break;
}
}
}
for i in 0..1000 {
let value = BigUint::from(Fr::rand(&mut rng));
let res = hs.insert(&value, 0);
if res.is_err() {
assert!(matches!(res, Err(HashSetError::Full)));
} else {
println!("secondary insertions: {i}: apparent success with value: {value:?}");
}
}
for i in 0..1000 {
let value = BigUint::from(Fr::rand(&mut rng));
let sequence_number = rng.gen_range(0..hs.sequence_threshold);
let res = hs.insert(&value, sequence_number);
if res.is_err() {
assert!(matches!(res, Err(HashSetError::Full)));
} else {
println!("tertiary insertions: {i}: surprising success with value: {value:?}");
}
}
for i in 0..CAPACITY {
let value = BigUint::from(Fr::rand(&mut rng));
if let Err(e) = hs.insert(&value, SEQUENCE_THRESHOLD + i) {
assert!(matches!(e, HashSetError::Full));
println!("insertions after fillup: {i}: failed, stopping");
break;
}
}
}
}
#[test]
fn test_hash_set_full_6857_2400() {
hash_set_full::<6857, 2400>()
}
#[test]
fn test_hash_set_full_9601_2400() {
hash_set_full::<9601, 2400>()
}
#[test]
fn test_hash_set_element_does_not_exist() {
let mut hs = HashSet::new(4800, 2400).unwrap();
let mut rng = thread_rng();
for _ in 0..1000 {
let index = rng.gen_range(0..4800);
let res = hs.mark_with_sequence_number(index, 0);
assert!(matches!(res, Err(HashSetError::ElementDoesNotExist)));
}
for _ in 0..1000 {
let value = BigUint::from(Fr::rand(&mut rng));
let index = hs.insert(&value, 0).unwrap();
hs.mark_with_sequence_number(index, 1).unwrap();
}
}
#[test]
fn test_hash_set_iter_manual() {
let mut hs = HashSet::new(6857, 2400).unwrap();
let nullifier_1 = 945635_u32.to_biguint().unwrap();
let nullifier_2 = 3546656654734254353455_u128.to_biguint().unwrap();
let nullifier_3 = 543543656564_u64.to_biguint().unwrap();
let nullifier_4 = 43_u8.to_biguint().unwrap();
let nullifier_5 = 0_u8.to_biguint().unwrap();
let nullifier_6 = 65423_u32.to_biguint().unwrap();
let nullifier_7 = 745654665_u32.to_biguint().unwrap();
let nullifier_8 = 97664353453465354645645465_u128.to_biguint().unwrap();
let nullifier_9 = 453565465464565635475_u128.to_biguint().unwrap();
let nullifier_10 = 543645654645_u64.to_biguint().unwrap();
hs.insert(&nullifier_1, 0).unwrap();
hs.insert(&nullifier_2, 0).unwrap();
hs.insert(&nullifier_3, 0).unwrap();
hs.insert(&nullifier_4, 0).unwrap();
hs.insert(&nullifier_5, 0).unwrap();
hs.insert(&nullifier_6, 0).unwrap();
hs.insert(&nullifier_7, 0).unwrap();
hs.insert(&nullifier_8, 0).unwrap();
hs.insert(&nullifier_9, 0).unwrap();
hs.insert(&nullifier_10, 0).unwrap();
let inserted_nullifiers = hs
.iter()
.map(|(_, nullifier)| nullifier.value_biguint())
.collect::<Vec<_>>();
assert_eq!(inserted_nullifiers.len(), 10);
assert_eq!(inserted_nullifiers[0], nullifier_7);
assert_eq!(inserted_nullifiers[1], nullifier_3);
assert_eq!(inserted_nullifiers[2], nullifier_10);
assert_eq!(inserted_nullifiers[3], nullifier_1);
assert_eq!(inserted_nullifiers[4], nullifier_8);
assert_eq!(inserted_nullifiers[5], nullifier_5);
assert_eq!(inserted_nullifiers[6], nullifier_4);
assert_eq!(inserted_nullifiers[7], nullifier_2);
assert_eq!(inserted_nullifiers[8], nullifier_9);
assert_eq!(inserted_nullifiers[9], nullifier_6);
}
fn hash_set_iter_random<
const INSERTIONS: usize,
const CAPACITY: usize,
const SEQUENCE_THRESHOLD: usize,
>() {
let mut hs = HashSet::new(CAPACITY, SEQUENCE_THRESHOLD).unwrap();
let mut rng = thread_rng();
let nullifiers: [BigUint; INSERTIONS] =
std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
for nullifier in nullifiers.iter() {
hs.insert(&nullifier, 0).unwrap();
}
let mut sorted_nullifiers = nullifiers.iter().collect::<Vec<_>>();
let mut inserted_nullifiers = hs
.iter()
.map(|(_, nullifier)| nullifier.value_biguint())
.collect::<Vec<_>>();
sorted_nullifiers.sort();
inserted_nullifiers.sort();
let inserted_nullifiers = inserted_nullifiers.iter().collect::<Vec<&BigUint>>();
assert_eq!(inserted_nullifiers.len(), INSERTIONS);
assert_eq!(sorted_nullifiers.as_slice(), inserted_nullifiers.as_slice());
}
#[test]
fn test_hash_set_iter_random_6857_2400() {
hash_set_iter_random::<3500, 6857, 2400>()
}
#[test]
fn test_hash_set_iter_random_9601_2400() {
hash_set_iter_random::<5000, 9601, 2400>()
}
#[test]
fn test_hash_set_get_bucket() {
let mut hs = HashSet::new(6857, 2400).unwrap();
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
}
let mut unused_indices = vec![true; 6857];
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
let element = hs.get_bucket(i).unwrap().unwrap();
assert_eq!(element.value_biguint(), bn_i);
unused_indices[i] = false;
}
for i in unused_indices.iter().enumerate() {
if *i.1 {
assert!(hs.get_bucket(i.0).unwrap().is_none());
}
}
for i in 6857..10_000 {
assert!(hs.get_bucket(i).is_none());
}
}
#[test]
fn test_hash_set_get_bucket_mut() {
let mut hs = HashSet::new(6857, 2400).unwrap();
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
}
let mut unused_indices = vec![false; 6857];
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
let i = hs.find_element_index(&bn_i, None).unwrap().unwrap();
let element = hs.get_bucket_mut(i).unwrap();
assert_eq!(element.unwrap().value_biguint(), bn_i);
unused_indices[i] = true;
*element = Some(HashSetCell {
value: [0_u8; 32],
sequence_number: None,
});
}
for (i, is_used) in unused_indices.iter().enumerate() {
if *is_used {
let element = hs.get_bucket_mut(i).unwrap().unwrap();
assert_eq!(element.value_bytes(), [0_u8; 32]);
}
}
for (i, is_used) in unused_indices.iter().enumerate() {
if !*is_used {
assert!(hs.get_bucket_mut(i).unwrap().is_none());
}
}
for i in 6857..10_000 {
assert!(hs.get_bucket_mut(i).is_none());
}
}
#[test]
fn test_hash_set_get_unmarked_bucket() {
let mut hs = HashSet::new(6857, 2400).unwrap();
(0..3600).for_each(|i| {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
});
for i in 0..3600 {
let i = hs
.find_element_index(&i.to_biguint().unwrap(), None)
.unwrap()
.unwrap();
let element = hs.get_unmarked_bucket(i);
assert!(element.is_some());
}
for i in 0..3600 {
let index = hs
.find_element_index(&i.to_biguint().unwrap(), None)
.unwrap()
.unwrap();
hs.mark_with_sequence_number(index, i).unwrap();
}
for i in 0..3600 {
let i = hs
.find_element_index(&i.to_biguint().unwrap(), None)
.unwrap()
.unwrap();
let element = hs.get_unmarked_bucket(i);
assert!(element.is_none());
}
}
#[test]
fn test_hash_set_first_no_seq() {
let mut hs = HashSet::new(6857, 2400).unwrap();
for i in 0..3600 {
let bn_i = i.to_biguint().unwrap();
hs.insert(&bn_i, i).unwrap();
let element = hs.first_no_seq().unwrap().unwrap();
assert_eq!(element.0.value_biguint(), 0.to_biguint().unwrap());
}
}
}