use std::collections::VecDeque;
use crate::registration::Registration;
struct Slot {
generation: u32,
entry: Option<Registration>,
}
#[derive(Copy, Clone, Debug)]
struct Index {
generation: u32,
slot: u32,
}
impl Index {
pub fn as_u64(&self) -> u64 {
((self.generation as u64) << 32) | (self.slot as u64)
}
pub fn from_u64(packed: u64) -> Self {
Index {
slot: (packed & 0xFFFFFFFF) as u32,
generation: (packed >> 32) as u32,
}
}
pub fn slot(&self) -> u32 {
self.slot
}
pub fn generation(&self) -> u32 {
self.generation
}
}
pub struct OpStore {
slots: Vec<Slot>,
free_list: VecDeque<u32>,
next_slot: u32,
capacity: u32,
}
#[derive(Debug)]
pub struct StoreAtCapacity;
impl std::fmt::Display for StoreAtCapacity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("StoreAtCapacity")
}
}
impl std::error::Error for StoreAtCapacity {}
impl OpStore {
#[cfg(test)]
pub fn new() -> OpStore {
Self::with_capacity(1024)
}
pub fn with_capacity(cap: usize) -> OpStore {
let capacity = cap.min(u32::MAX as usize) as u32;
let slots =
(0..capacity).map(|_| Slot { generation: 0, entry: None }).collect();
Self {
slots,
free_list: VecDeque::with_capacity(cap),
next_slot: 0,
capacity,
}
}
fn next_id(&mut self) -> Option<Index> {
if let Some(slot_idx) = self.free_list.pop_front() {
let slot = &self.slots[slot_idx as usize];
return Some(Index { slot: slot_idx, generation: slot.generation });
}
if self.next_slot < self.capacity {
let slot_idx = self.next_slot;
self.next_slot += 1;
return Some(Index { slot: slot_idx, generation: 0 });
}
None
}
pub fn insert(&mut self, reg: Registration) -> u64 {
self.try_insert(reg).expect("at capacity")
}
pub fn try_insert(
&mut self,
reg: Registration,
) -> Result<u64, StoreAtCapacity> {
let index = self.next_id().ok_or(StoreAtCapacity)?;
let slot = &mut self.slots[index.slot() as usize];
assert!(
slot.entry.is_none(),
"OpStore: slot {} should be empty",
index.slot()
);
assert_eq!(
slot.generation,
index.generation(),
"OpStore: generation mismatch"
);
slot.entry = Some(reg);
Ok(index.as_u64())
}
pub fn remove(&mut self, id: u64) -> bool {
let index = Index::from_u64(id);
let slot = match self.raw_get_mut_slot(index) {
Some(slot) => slot,
None => return false,
};
slot.entry = None;
slot.generation = slot.generation.strict_add(1);
self.free_list.push_back(index.slot());
true
}
pub fn get_mut(&mut self, id: u64) -> Option<&mut Registration> {
self.raw_get_mut_slot(Index::from_u64(id))?.entry.as_mut()
}
pub fn get(&self, id: u64) -> Option<&Registration> {
self.raw_get_slot(Index::from_u64(id))?.entry.as_ref()
}
fn raw_get_slot(&self, index: Index) -> Option<&Slot> {
let slot = self.slots.get(index.slot() as usize)?;
if slot.generation == index.generation() { Some(slot) } else { None }
}
fn raw_get_mut_slot(&mut self, index: Index) -> Option<&mut Slot> {
let slot = self.slots.get_mut(index.slot() as usize)?;
if slot.generation == index.generation() { Some(slot) } else { None }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
fn dummy_stored_op() -> Registration {
use std::task::{RawWaker, RawWakerVTable, Waker};
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop(_: *const ()) {}
const VTABLE: RawWakerVTable =
RawWakerVTable::new(clone, wake, wake_by_ref, drop);
let raw_waker = RawWaker::new(std::ptr::null(), &VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
Registration::new_waker(waker)
}
#[test]
fn test_basic_insert_and_remove() {
let mut store = OpStore::new();
let id = store.insert(dummy_stored_op());
assert!(store.remove(id));
assert!(!store.remove(id)); }
#[test]
fn test_sequential_ids_are_unique() {
let mut store = OpStore::new();
let mut ids = HashSet::new();
for _ in 0..1000 {
let id = store.insert(dummy_stored_op());
assert!(ids.insert(id), "Generated duplicate ID: {}", id);
}
}
#[test]
fn test_slot_reuse_increments_generation() {
let mut store = OpStore::new();
let id1 = store.insert(dummy_stored_op());
let index1 = Index::from_u64(id1);
assert_eq!(index1.generation(), 0);
assert_eq!(index1.slot(), 0);
store.remove(id1);
let id2 = store.insert(dummy_stored_op());
let index2 = Index::from_u64(id2);
assert_eq!(index2.slot(), 0, "Slot should be reused");
assert_eq!(index2.generation(), 1, "Generation should increment");
}
#[test]
fn test_stale_id_rejected_on_remove() {
let mut store = OpStore::new();
let id1 = store.insert(dummy_stored_op());
store.remove(id1);
let id2 = store.insert(dummy_stored_op());
assert!(!store.remove(id1), "Stale ID should be rejected");
assert!(store.remove(id2));
}
#[test]
fn test_stale_id_rejected_on_get_mut() {
let mut store = OpStore::new();
let id1 = store.insert(dummy_stored_op());
store.remove(id1);
let id2 = store.insert(dummy_stored_op());
assert!(store.get_mut(id1).is_none(), "Stale ID should return None");
assert!(store.get_mut(id2).is_some());
}
#[test]
fn test_get_mut_works() {
let mut store = OpStore::new();
let id = store.insert(dummy_stored_op());
let registration = store.get_mut(id);
assert!(registration.is_some());
}
#[test]
fn test_get_works() {
let mut store = OpStore::new();
let id = store.insert(dummy_stored_op());
let registration = store.get(id);
assert!(registration.is_some());
}
#[test]
fn test_index_packing_unpacking() {
let index = Index { slot: 42, generation: 123 };
let packed = index.as_u64();
let unpacked = Index::from_u64(packed);
assert_eq!(unpacked.slot(), 42);
assert_eq!(unpacked.generation(), 123);
}
#[test]
fn test_capacity_limit() {
let mut store = OpStore::with_capacity(4);
let ids: Vec<_> = (0..4).map(|_| store.insert(dummy_stored_op())).collect();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
store.insert(dummy_stored_op());
}));
assert!(result.is_err(), "Should panic when over capacity");
store.remove(ids[0]);
let _ = store.insert(dummy_stored_op()); }
#[test]
fn test_aba_protection() {
let mut store = OpStore::with_capacity(4);
let id1 = store.insert(dummy_stored_op());
store.remove(id1);
let id2 = store.insert(dummy_stored_op());
store.remove(id2);
let id3 = store.insert(dummy_stored_op());
let idx1 = Index::from_u64(id1);
let idx2 = Index::from_u64(id2);
let idx3 = Index::from_u64(id3);
assert_eq!(idx1.slot(), 0);
assert_eq!(idx2.slot(), 0);
assert_eq!(idx3.slot(), 0);
assert_eq!(idx1.generation(), 0);
assert_eq!(idx2.generation(), 1);
assert_eq!(idx3.generation(), 2);
assert!(store.get(id1).is_none());
assert!(store.get(id2).is_none());
assert!(store.get(id3).is_some());
}
}