use crate::account::{FixedLayout, Pod};
use hopper_runtime::error::ProgramError;
const MAP_HEADER: usize = 8;
const SLOT_OVERHEAD: usize = 8;
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct SlotKey {
pub index: u32,
pub generation: u32,
}
const _: () = assert!(core::mem::size_of::<SlotKey>() == 8);
const _: () = assert!(core::mem::align_of::<SlotKey>() == 4);
pub struct SlotMap<'a, T: Pod + FixedLayout> {
data: &'a mut [u8],
_phantom: core::marker::PhantomData<T>,
}
impl<'a, T: Pod + FixedLayout> SlotMap<'a, T> {
const SLOT_SIZE: usize = SLOT_OVERHEAD + T::SIZE;
#[inline]
pub fn from_bytes(data: &'a mut [u8]) -> Result<Self, ProgramError> {
if data.len() < MAP_HEADER {
return Err(ProgramError::AccountDataTooSmall);
}
Ok(Self {
data,
_phantom: core::marker::PhantomData,
})
}
#[inline(always)]
pub fn capacity(&self) -> usize {
(self.data.len() - MAP_HEADER) / Self::SLOT_SIZE
}
#[inline(always)]
pub fn count(&self) -> usize {
u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]) as usize
}
#[inline(always)]
fn set_count(&mut self, count: usize) {
self.data[0..4].copy_from_slice(&(count as u32).to_le_bytes());
}
#[inline(always)]
fn slot_offset(&self, index: usize) -> usize {
MAP_HEADER + index * Self::SLOT_SIZE
}
#[inline(always)]
fn slot_generation(&self, index: usize) -> u32 {
let off = self.slot_offset(index);
u32::from_le_bytes([
self.data[off],
self.data[off + 1],
self.data[off + 2],
self.data[off + 3],
])
}
#[inline(always)]
fn slot_occupied(&self, index: usize) -> bool {
let off = self.slot_offset(index) + 4;
self.data[off] != 0
}
#[inline]
pub fn insert(&mut self, value: T) -> Result<SlotKey, ProgramError> {
let cap = self.capacity();
for i in 0..cap {
if !self.slot_occupied(i) {
let off = self.slot_offset(i);
let gen = self.slot_generation(i);
self.data[off + 4] = 1;
let val_off = off + SLOT_OVERHEAD;
unsafe {
core::ptr::write_unaligned(
self.data.as_mut_ptr().add(val_off) as *mut T,
value,
);
}
self.set_count(self.count() + 1);
return Ok(SlotKey {
index: i as u32,
generation: gen,
});
}
}
Err(ProgramError::AccountDataTooSmall)
}
#[inline]
pub fn get(&self, key: SlotKey) -> Result<T, ProgramError> {
let index = key.index as usize;
if index >= self.capacity() {
return Err(ProgramError::InvalidArgument);
}
if !self.slot_occupied(index) || self.slot_generation(index) != key.generation {
return Err(ProgramError::InvalidArgument);
}
let off = self.slot_offset(index) + SLOT_OVERHEAD;
Ok(unsafe { core::ptr::read_unaligned(self.data.as_ptr().add(off) as *const T) })
}
#[inline]
pub fn remove(&mut self, key: SlotKey) -> Result<T, ProgramError> {
let index = key.index as usize;
if index >= self.capacity() {
return Err(ProgramError::InvalidArgument);
}
if !self.slot_occupied(index) || self.slot_generation(index) != key.generation {
return Err(ProgramError::InvalidArgument);
}
let off = self.slot_offset(index);
let val_off = off + SLOT_OVERHEAD;
let value =
unsafe { core::ptr::read_unaligned(self.data.as_ptr().add(val_off) as *const T) };
self.data[off + 4] = 0;
let new_gen = self.slot_generation(index).wrapping_add(1);
self.data[off..off + 4].copy_from_slice(&new_gen.to_le_bytes());
for byte in &mut self.data[val_off..val_off + T::SIZE] {
*byte = 0;
}
self.set_count(self.count() - 1);
Ok(value)
}
#[inline(always)]
pub const fn required_bytes(capacity: usize) -> usize {
MAP_HEADER + capacity * (SLOT_OVERHEAD + T::SIZE)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::abi::WireU64;
#[test]
fn insert_get_remove() {
let mut buf = [0u8; 8 + (8 + 8) * 4]; let mut map = SlotMap::<WireU64>::from_bytes(&mut buf).unwrap();
let k1 = map.insert(WireU64::new(100)).unwrap();
let k2 = map.insert(WireU64::new(200)).unwrap();
assert_eq!(map.count(), 2);
assert_eq!(map.get(k1).unwrap().get(), 100);
assert_eq!(map.get(k2).unwrap().get(), 200);
let removed = map.remove(k1).unwrap();
assert_eq!(removed.get(), 100);
assert_eq!(map.count(), 1);
assert!(map.get(k1).is_err());
}
#[test]
fn generation_prevents_aba() {
let mut buf = [0u8; 8 + (8 + 8) * 2];
let mut map = SlotMap::<WireU64>::from_bytes(&mut buf).unwrap();
let k1 = map.insert(WireU64::new(1)).unwrap();
map.remove(k1).unwrap();
let k2 = map.insert(WireU64::new(2)).unwrap();
assert_eq!(k2.index, k1.index); assert_ne!(k2.generation, k1.generation);
assert!(map.get(k1).is_err());
assert_eq!(map.get(k2).unwrap().get(), 2);
}
}