light_hash_set/
zero_copy.rs1use std::{
2 marker::PhantomData,
3 mem,
4 ops::{Deref, DerefMut},
5 ptr::NonNull,
6};
7
8use crate::{HashSet, HashSetCell, HashSetError};
9
10#[derive(Debug)]
13pub struct HashSetZeroCopy<'a> {
14 pub hash_set: mem::ManuallyDrop<HashSet>,
15 _marker: PhantomData<&'a ()>,
16}
17
18impl<'a> HashSetZeroCopy<'a> {
19 pub unsafe fn from_bytes_zero_copy_mut(bytes: &'a mut [u8]) -> Result<Self, HashSetError> {
36 if bytes.len() < HashSet::non_dyn_fields_size() {
37 return Err(HashSetError::BufferSize(
38 HashSet::non_dyn_fields_size(),
39 bytes.len(),
40 ));
41 }
42
43 let capacity_values = usize::from_le_bytes(bytes[0..8].try_into().unwrap());
44 let sequence_threshold = usize::from_le_bytes(bytes[8..16].try_into().unwrap());
45
46 let offset = HashSet::non_dyn_fields_size() + mem::size_of::<usize>();
47
48 let values_size = mem::size_of::<Option<HashSetCell>>() * capacity_values;
49
50 let expected_size = HashSet::non_dyn_fields_size() + values_size;
51 if bytes.len() < expected_size {
52 return Err(HashSetError::BufferSize(expected_size, bytes.len()));
53 }
54
55 let buckets =
56 NonNull::new(bytes.as_mut_ptr().add(offset) as *mut Option<HashSetCell>).unwrap();
57
58 Ok(Self {
59 hash_set: mem::ManuallyDrop::new(HashSet {
60 capacity: capacity_values,
61 sequence_threshold,
62 buckets,
63 }),
64 _marker: PhantomData,
65 })
66 }
67
68 pub unsafe fn from_bytes_zero_copy_init(
97 bytes: &'a mut [u8],
98 capacity_values: usize,
99 sequence_threshold: usize,
100 ) -> Result<Self, HashSetError> {
101 if bytes.len() < HashSet::non_dyn_fields_size() {
102 return Err(HashSetError::BufferSize(
103 HashSet::non_dyn_fields_size(),
104 bytes.len(),
105 ));
106 }
107
108 bytes[0..8].copy_from_slice(&capacity_values.to_le_bytes());
109 bytes[8..16].copy_from_slice(&sequence_threshold.to_le_bytes());
110 bytes[16..24].copy_from_slice(&0_usize.to_le_bytes());
111
112 let hash_set = Self::from_bytes_zero_copy_mut(bytes)?;
113
114 for i in 0..capacity_values {
115 std::ptr::write(hash_set.hash_set.buckets.as_ptr().add(i), None);
116 }
117
118 Ok(hash_set)
119 }
120}
121
122impl Drop for HashSetZeroCopy<'_> {
123 fn drop(&mut self) {
124 }
135}
136
137impl Deref for HashSetZeroCopy<'_> {
138 type Target = HashSet;
139
140 fn deref(&self) -> &Self::Target {
141 &self.hash_set
142 }
143}
144
145impl DerefMut for HashSetZeroCopy<'_> {
146 fn deref_mut(&mut self) -> &mut Self::Target {
147 &mut self.hash_set
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use ark_bn254::Fr;
154 use ark_ff::UniformRand;
155 use num_bigint::BigUint;
156 use rand::{thread_rng, Rng};
157
158 use super::*;
159
160 #[test]
161 fn test_load_from_bytes() {
162 const VALUES: usize = 4800;
163 const SEQUENCE_THRESHOLD: usize = 2400;
164
165 let mut bytes = vec![0u8; HashSet::size_in_account(VALUES)];
167 thread_rng().fill(bytes.as_mut_slice());
168
169 let mut rng = thread_rng();
171 let nullifiers: [BigUint; 2400] =
172 std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
173
174 {
176 let mut hs = unsafe {
177 HashSetZeroCopy::from_bytes_zero_copy_init(
178 bytes.as_mut_slice(),
179 VALUES,
180 SEQUENCE_THRESHOLD,
181 )
182 .unwrap()
183 };
184
185 assert_eq!(hs.hash_set.get_capacity(), VALUES);
187 assert_eq!(hs.hash_set.sequence_threshold, SEQUENCE_THRESHOLD);
188 for i in 0..VALUES {
189 assert!(unsafe { &*hs.hash_set.buckets.as_ptr().add(i) }.is_none());
190 }
191
192 for (seq, nullifier) in nullifiers.iter().enumerate() {
193 let index = hs.insert(nullifier, seq).unwrap();
194 hs.mark_with_sequence_number(index, seq).unwrap();
195 }
196 }
197
198 {
200 let mut hs =
201 unsafe { HashSetZeroCopy::from_bytes_zero_copy_mut(bytes.as_mut_slice()).unwrap() };
202
203 for (seq, nullifier) in nullifiers.iter().enumerate() {
204 assert!(hs.contains(nullifier, Some(seq)).unwrap());
205 }
206
207 for (seq, nullifier) in nullifiers.iter().enumerate() {
208 hs.insert(nullifier, 2400 + seq).unwrap();
209 }
210 drop(hs);
211 }
212
213 {
215 let hs = unsafe { HashSet::from_bytes_copy(bytes.as_mut_slice()).unwrap() };
216
217 for (seq, nullifier) in nullifiers.iter().enumerate() {
218 assert!(hs.contains(nullifier, Some(2400 + seq)).unwrap());
219 }
220 }
221 }
222
223 #[test]
224 fn test_buffer_size_error() {
225 const VALUES: usize = 4800;
226 const SEQUENCE_THRESHOLD: usize = 2400;
227
228 let mut invalid_bytes = vec![0_u8; 256];
229
230 let res = unsafe {
231 HashSetZeroCopy::from_bytes_zero_copy_init(
232 invalid_bytes.as_mut_slice(),
233 VALUES,
234 SEQUENCE_THRESHOLD,
235 )
236 };
237 assert!(matches!(res, Err(HashSetError::BufferSize(_, _))));
238 }
239}