hopper_core/collections/
slot_map.rs1use crate::account::{FixedLayout, Pod};
16use hopper_runtime::error::ProgramError;
17
18const MAP_HEADER: usize = 8;
20
21const SLOT_OVERHEAD: usize = 8;
23
24#[derive(Clone, Copy, PartialEq, Eq)]
27#[repr(C)]
28pub struct SlotKey {
29 pub index: u32,
30 pub generation: u32,
31}
32
33const _: () = assert!(core::mem::size_of::<SlotKey>() == 8);
34const _: () = assert!(core::mem::align_of::<SlotKey>() == 4); pub struct SlotMap<'a, T: Pod + FixedLayout> {
42 data: &'a mut [u8],
43 _phantom: core::marker::PhantomData<T>,
44}
45
46impl<'a, T: Pod + FixedLayout> SlotMap<'a, T> {
47 const SLOT_SIZE: usize = SLOT_OVERHEAD + T::SIZE;
49
50 #[inline]
52 pub fn from_bytes(data: &'a mut [u8]) -> Result<Self, ProgramError> {
53 if data.len() < MAP_HEADER {
54 return Err(ProgramError::AccountDataTooSmall);
55 }
56 Ok(Self {
57 data,
58 _phantom: core::marker::PhantomData,
59 })
60 }
61
62 #[inline(always)]
64 pub fn capacity(&self) -> usize {
65 (self.data.len() - MAP_HEADER) / Self::SLOT_SIZE
66 }
67
68 #[inline(always)]
70 pub fn count(&self) -> usize {
71 u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]) as usize
72 }
73
74 #[inline(always)]
76 fn set_count(&mut self, count: usize) {
77 self.data[0..4].copy_from_slice(&(count as u32).to_le_bytes());
78 }
79
80 #[inline(always)]
82 fn slot_offset(&self, index: usize) -> usize {
83 MAP_HEADER + index * Self::SLOT_SIZE
84 }
85
86 #[inline(always)]
88 fn slot_generation(&self, index: usize) -> u32 {
89 let off = self.slot_offset(index);
90 u32::from_le_bytes([
91 self.data[off],
92 self.data[off + 1],
93 self.data[off + 2],
94 self.data[off + 3],
95 ])
96 }
97
98 #[inline(always)]
100 fn slot_occupied(&self, index: usize) -> bool {
101 let off = self.slot_offset(index) + 4;
102 self.data[off] != 0
103 }
104
105 #[inline]
109 pub fn insert(&mut self, value: T) -> Result<SlotKey, ProgramError> {
110 let cap = self.capacity();
111 for i in 0..cap {
112 if !self.slot_occupied(i) {
113 let off = self.slot_offset(i);
114 let gen = self.slot_generation(i);
115 self.data[off + 4] = 1;
117 let val_off = off + SLOT_OVERHEAD;
119 unsafe {
121 core::ptr::write_unaligned(
122 self.data.as_mut_ptr().add(val_off) as *mut T,
123 value,
124 );
125 }
126 self.set_count(self.count() + 1);
127 return Ok(SlotKey {
128 index: i as u32,
129 generation: gen,
130 });
131 }
132 }
133 Err(ProgramError::AccountDataTooSmall)
134 }
135
136 #[inline]
138 pub fn get(&self, key: SlotKey) -> Result<T, ProgramError> {
139 let index = key.index as usize;
140 if index >= self.capacity() {
141 return Err(ProgramError::InvalidArgument);
142 }
143 if !self.slot_occupied(index) || self.slot_generation(index) != key.generation {
144 return Err(ProgramError::InvalidArgument);
145 }
146 let off = self.slot_offset(index) + SLOT_OVERHEAD;
147 Ok(unsafe { core::ptr::read_unaligned(self.data.as_ptr().add(off) as *const T) })
149 }
150
151 #[inline]
153 pub fn remove(&mut self, key: SlotKey) -> Result<T, ProgramError> {
154 let index = key.index as usize;
155 if index >= self.capacity() {
156 return Err(ProgramError::InvalidArgument);
157 }
158 if !self.slot_occupied(index) || self.slot_generation(index) != key.generation {
159 return Err(ProgramError::InvalidArgument);
160 }
161 let off = self.slot_offset(index);
162 let val_off = off + SLOT_OVERHEAD;
163 let value =
165 unsafe { core::ptr::read_unaligned(self.data.as_ptr().add(val_off) as *const T) };
166 self.data[off + 4] = 0;
168 let new_gen = self.slot_generation(index).wrapping_add(1);
170 self.data[off..off + 4].copy_from_slice(&new_gen.to_le_bytes());
171 for byte in &mut self.data[val_off..val_off + T::SIZE] {
173 *byte = 0;
174 }
175 self.set_count(self.count() - 1);
176 Ok(value)
177 }
178
179 #[inline(always)]
181 pub const fn required_bytes(capacity: usize) -> usize {
182 MAP_HEADER + capacity * (SLOT_OVERHEAD + T::SIZE)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::abi::WireU64;
190
191 #[test]
192 fn insert_get_remove() {
193 let mut buf = [0u8; 8 + (8 + 8) * 4]; let mut map = SlotMap::<WireU64>::from_bytes(&mut buf).unwrap();
195
196 let k1 = map.insert(WireU64::new(100)).unwrap();
197 let k2 = map.insert(WireU64::new(200)).unwrap();
198 assert_eq!(map.count(), 2);
199
200 assert_eq!(map.get(k1).unwrap().get(), 100);
201 assert_eq!(map.get(k2).unwrap().get(), 200);
202
203 let removed = map.remove(k1).unwrap();
204 assert_eq!(removed.get(), 100);
205 assert_eq!(map.count(), 1);
206
207 assert!(map.get(k1).is_err());
209 }
210
211 #[test]
212 fn generation_prevents_aba() {
213 let mut buf = [0u8; 8 + (8 + 8) * 2];
214 let mut map = SlotMap::<WireU64>::from_bytes(&mut buf).unwrap();
215
216 let k1 = map.insert(WireU64::new(1)).unwrap();
217 map.remove(k1).unwrap();
218
219 let k2 = map.insert(WireU64::new(2)).unwrap();
221 assert_eq!(k2.index, k1.index); assert_ne!(k2.generation, k1.generation); assert!(map.get(k1).is_err());
226 assert_eq!(map.get(k2).unwrap().get(), 2);
227 }
228}