light_hash_set/
zero_copy.rs

1use std::{
2    marker::PhantomData,
3    mem,
4    ops::{Deref, DerefMut},
5    ptr::NonNull,
6};
7
8use crate::{HashSet, HashSetCell, HashSetError};
9
10/// A `HashSet` wrapper which can be instantiated from Solana account bytes
11/// without copying them.
12#[derive(Debug)]
13pub struct HashSetZeroCopy<'a> {
14    pub hash_set: mem::ManuallyDrop<HashSet>,
15    _marker: PhantomData<&'a ()>,
16}
17
18impl<'a> HashSetZeroCopy<'a> {
19    // TODO(vadorovsky): Add a non-mut method: `from_bytes_zero_copy`.
20
21    /// Casts a byte slice into `HashSet`.
22    ///
23    /// # Purpose
24    ///
25    /// This method is meant to be used mostly in Solana programs, where memory
26    /// constraints are tight and we want to make sure no data is copied.
27    ///
28    /// # Safety
29    ///
30    /// This is highly unsafe. Ensuring the alignment and that the slice
31    /// provides actual data of the hash set is the caller's responsibility.
32    ///
33    /// Calling it in async context (or anyhwere where the underlying data can
34    /// be moved in the memory) is certainly going to cause undefined behavior.
35    pub unsafe fn from_bytes_zero_copy_mut(bytes: &'a mut [u8]) -> Result<Self, HashSetError> {
36        if bytes.len() < HashSet::non_dyn_fields_size() {
37            return Err(HashSetError::BufferSize(
38                HashSet::non_dyn_fields_size(),
39                bytes.len(),
40            ));
41        }
42
43        let capacity_values = usize::from_le_bytes(bytes[0..8].try_into().unwrap());
44        let sequence_threshold = usize::from_le_bytes(bytes[8..16].try_into().unwrap());
45
46        let offset = HashSet::non_dyn_fields_size() + mem::size_of::<usize>();
47
48        let values_size = mem::size_of::<Option<HashSetCell>>() * capacity_values;
49
50        let expected_size = HashSet::non_dyn_fields_size() + values_size;
51        if bytes.len() < expected_size {
52            return Err(HashSetError::BufferSize(expected_size, bytes.len()));
53        }
54
55        let buckets =
56            NonNull::new(bytes.as_mut_ptr().add(offset) as *mut Option<HashSetCell>).unwrap();
57
58        Ok(Self {
59            hash_set: mem::ManuallyDrop::new(HashSet {
60                capacity: capacity_values,
61                sequence_threshold,
62                buckets,
63            }),
64            _marker: PhantomData,
65        })
66    }
67
68    /// Casts a byte slice into `HashSet` and then initializes it.
69    ///
70    /// * `bytes` is casted into a reference of `HashSet` and used as
71    ///   storage for the struct.
72    /// * `capacity_indices` indicates the size of the indices table. It should
73    ///   already include a desired load factor and be greater than the expected
74    ///   number of elements to avoid filling the set too early and avoid
75    ///   creating clusters.
76    /// * `capacity_values` indicates the size of the values array. It should be
77    ///   equal to the number of expected elements, without load factor.
78    /// * `sequence_threshold` indicates a difference of sequence numbers which
79    ///   make elements of the has set expired. Expiration means that they can
80    ///   be replaced during insertion of new elements with sequence numbers
81    ///   higher by at least a threshold.
82    ///
83    /// # Purpose
84    ///
85    /// This method is meant to be used mostly in Solana programs to initialize
86    /// a new account which is supposed to store the hash set.
87    ///
88    /// # Safety
89    ///
90    /// This is highly unsafe. Ensuring the alignment and that the slice has
91    /// a correct size, which is able to fit the hash set, is the caller's
92    /// responsibility.
93    ///
94    /// Calling it in async context (or anywhere where the underlying data can
95    /// be moved in memory) is certainly going to cause undefined behavior.
96    pub unsafe fn from_bytes_zero_copy_init(
97        bytes: &'a mut [u8],
98        capacity_values: usize,
99        sequence_threshold: usize,
100    ) -> Result<Self, HashSetError> {
101        if bytes.len() < HashSet::non_dyn_fields_size() {
102            return Err(HashSetError::BufferSize(
103                HashSet::non_dyn_fields_size(),
104                bytes.len(),
105            ));
106        }
107
108        bytes[0..8].copy_from_slice(&capacity_values.to_le_bytes());
109        bytes[8..16].copy_from_slice(&sequence_threshold.to_le_bytes());
110        bytes[16..24].copy_from_slice(&0_usize.to_le_bytes());
111
112        let hash_set = Self::from_bytes_zero_copy_mut(bytes)?;
113
114        for i in 0..capacity_values {
115            std::ptr::write(hash_set.hash_set.buckets.as_ptr().add(i), None);
116        }
117
118        Ok(hash_set)
119    }
120}
121
122impl Drop for HashSetZeroCopy<'_> {
123    fn drop(&mut self) {
124        // SAFETY: Don't do anything here! Why?
125        //
126        // * Primitive fields of `HashSet` implement `Copy`, therefore `drop()`
127        //   has no effect on them - Rust drops them when they go out of scope.
128        // * Don't drop the dynamic fields (`indices` and `values`). In
129        //   `HashSetZeroCopy`, they are backed by buffers provided by the
130        //   caller. These buffers are going to be eventually deallocated.
131        //   Performing an another `drop()` here would result double `free()`
132        //   which would result in aborting the program (either with `SIGABRT`
133        //   or `SIGSEGV`).
134    }
135}
136
137impl Deref for HashSetZeroCopy<'_> {
138    type Target = HashSet;
139
140    fn deref(&self) -> &Self::Target {
141        &self.hash_set
142    }
143}
144
145impl DerefMut for HashSetZeroCopy<'_> {
146    fn deref_mut(&mut self) -> &mut Self::Target {
147        &mut self.hash_set
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use ark_bn254::Fr;
154    use ark_ff::UniformRand;
155    use num_bigint::BigUint;
156    use rand::{thread_rng, Rng};
157
158    use super::*;
159
160    #[test]
161    fn test_load_from_bytes() {
162        const VALUES: usize = 4800;
163        const SEQUENCE_THRESHOLD: usize = 2400;
164
165        // Create a buffer with random bytes.
166        let mut bytes = vec![0u8; HashSet::size_in_account(VALUES)];
167        thread_rng().fill(bytes.as_mut_slice());
168
169        // Create random nullifiers.
170        let mut rng = thread_rng();
171        let nullifiers: [BigUint; 2400] =
172            std::array::from_fn(|_| BigUint::from(Fr::rand(&mut rng)));
173
174        // Initialize a hash set on top of a byte slice.
175        {
176            let mut hs = unsafe {
177                HashSetZeroCopy::from_bytes_zero_copy_init(
178                    bytes.as_mut_slice(),
179                    VALUES,
180                    SEQUENCE_THRESHOLD,
181                )
182                .unwrap()
183            };
184
185            // Ensure that the underlying data were properly initialized.
186            assert_eq!(hs.hash_set.get_capacity(), VALUES);
187            assert_eq!(hs.hash_set.sequence_threshold, SEQUENCE_THRESHOLD);
188            for i in 0..VALUES {
189                assert!(unsafe { &*hs.hash_set.buckets.as_ptr().add(i) }.is_none());
190            }
191
192            for (seq, nullifier) in nullifiers.iter().enumerate() {
193                let index = hs.insert(nullifier, seq).unwrap();
194                hs.mark_with_sequence_number(index, seq).unwrap();
195            }
196        }
197
198        // Read the hash set from buffers again.
199        {
200            let mut hs =
201                unsafe { HashSetZeroCopy::from_bytes_zero_copy_mut(bytes.as_mut_slice()).unwrap() };
202
203            for (seq, nullifier) in nullifiers.iter().enumerate() {
204                assert!(hs.contains(nullifier, Some(seq)).unwrap());
205            }
206
207            for (seq, nullifier) in nullifiers.iter().enumerate() {
208                hs.insert(nullifier, 2400 + seq).unwrap();
209            }
210            drop(hs);
211        }
212
213        // Make a copy of hash set from the same buffers.
214        {
215            let hs = unsafe { HashSet::from_bytes_copy(bytes.as_mut_slice()).unwrap() };
216
217            for (seq, nullifier) in nullifiers.iter().enumerate() {
218                assert!(hs.contains(nullifier, Some(2400 + seq)).unwrap());
219            }
220        }
221    }
222
223    #[test]
224    fn test_buffer_size_error() {
225        const VALUES: usize = 4800;
226        const SEQUENCE_THRESHOLD: usize = 2400;
227
228        let mut invalid_bytes = vec![0_u8; 256];
229
230        let res = unsafe {
231            HashSetZeroCopy::from_bytes_zero_copy_init(
232                invalid_bytes.as_mut_slice(),
233                VALUES,
234                SEQUENCE_THRESHOLD,
235            )
236        };
237        assert!(matches!(res, Err(HashSetError::BufferSize(_, _))));
238    }
239}