use crate::error::{ErrorCode, ExternError};
use crate::into_ffi::IntoFfi;
use std::error::Error as StdError;
use std::fmt;
use std::ops;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, RwLock};
#[derive(Debug, Clone)]
pub struct HandleMap<T> {
id: u16,
first_free: u16,
num_entries: usize,
entries: Vec<Entry<T>>,
}
#[derive(Debug, Clone)]
struct Entry<T> {
version: u16,
state: EntryState<T>,
}
#[derive(Debug, Clone)]
enum EntryState<T> {
Active(T),
InFreeList(u16),
EndOfFreeList,
}
impl<T> EntryState<T> {
#[cfg(any(debug_assertions, test))]
fn is_end_of_list(&self) -> bool {
match self {
EntryState::EndOfFreeList => true,
_ => false,
}
}
#[inline]
fn is_occupied(&self) -> bool {
self.get_item().is_some()
}
#[inline]
fn get_item(&self) -> Option<&T> {
match self {
EntryState::Active(v) => Some(v),
_ => None,
}
}
#[inline]
fn get_item_mut(&mut self) -> Option<&mut T> {
match self {
EntryState::Active(v) => Some(v),
_ => None,
}
}
}
#[inline]
fn to_u16(v: usize) -> u16 {
use std::u16::MAX as U16_MAX;
assert!(v <= (U16_MAX as usize), "Bug: Doesn't fit in u16: {}", v);
v as u16
}
pub const MAX_CAPACITY: usize = (1 << 15) - 1;
const MIN_CAPACITY: usize = 4;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum HandleError {
NullHandle,
InvalidHandle,
StaleVersion,
IndexPastEnd,
WrongMap,
}
impl StdError for HandleError {}
impl fmt::Display for HandleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use HandleError::*;
match self {
NullHandle => {
f.write_str("Tried to use a null handle (this object has probably been closed)")
}
InvalidHandle => f.write_str("u64 could not encode a valid Handle"),
StaleVersion => f.write_str("Handle has stale version number"),
IndexPastEnd => f.write_str("Handle references a index past the end of this HandleMap"),
WrongMap => f.write_str("Handle is from a different map"),
}
}
}
impl From<HandleError> for ExternError {
fn from(e: HandleError) -> Self {
ExternError::new_error(ErrorCode::INVALID_HANDLE, e.to_string())
}
}
impl<T> HandleMap<T> {
pub fn new() -> Self {
Self::new_with_capacity(MIN_CAPACITY)
}
pub fn new_with_capacity(request: usize) -> Self {
assert!(
request <= MAX_CAPACITY,
"HandleMap capacity is limited to {} (request was {})",
MAX_CAPACITY,
request
);
let capacity = request.max(MIN_CAPACITY);
let id = next_handle_map_id();
let mut entries = Vec::with_capacity(capacity);
for i in 0..(capacity - 1) {
entries.push(Entry {
version: 1,
state: EntryState::InFreeList(to_u16(i + 1)),
});
}
entries.push(Entry {
version: 1,
state: EntryState::EndOfFreeList,
});
Self {
id,
first_free: 0,
num_entries: 0,
entries,
}
}
#[inline]
pub fn len(&self) -> usize {
self.num_entries
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn capacity(&self) -> usize {
self.entries.len()
}
fn ensure_capacity(&mut self, cap_at_least: usize) {
assert_ne!(self.len(), self.capacity(), "Bug: should have grown by now");
assert!(cap_at_least <= MAX_CAPACITY, "HandleMap overfilled");
if self.capacity() > cap_at_least {
return;
}
let mut next_cap = self.capacity();
while next_cap <= cap_at_least {
next_cap *= 2;
}
next_cap = next_cap.min(MAX_CAPACITY);
let need_extra = next_cap.saturating_sub(self.entries.capacity());
self.entries.reserve(need_extra);
assert!(
!self.entries[self.first_free as usize].state.is_occupied(),
"Bug: HandleMap.first_free points at occupied index"
);
while self.entries.len() < next_cap - 1 {
self.entries.push(Entry {
version: 1,
state: EntryState::InFreeList(self.first_free),
});
self.first_free = to_u16(self.entries.len() - 1);
}
self.debug_check_valid();
}
#[inline]
fn debug_check_valid(&self) {
#[cfg(any(debug_assertions, test))]
{
self.assert_valid();
}
}
#[cfg(any(debug_assertions, test))]
fn assert_valid(&self) {
assert_ne!(self.len(), self.capacity());
assert!(self.capacity() <= MAX_CAPACITY, "Entries too large");
let number_of_ends = self
.entries
.iter()
.filter(|e| e.state.is_end_of_list())
.count();
assert_eq!(
number_of_ends, 1,
"More than one entry think's it's the end of the list, or no entries do"
);
let mut free_indices = vec![(false, false); self.capacity()];
for (i, e) in self.entries.iter().enumerate() {
if !e.state.is_occupied() {
free_indices[i].0 = true;
}
}
let mut next = self.first_free;
loop {
let ni = next as usize;
assert!(
ni <= free_indices.len(),
"Free list contains out of bounds index!"
);
assert!(
free_indices[ni].0,
"Free list has an index that shouldn't be free! {}",
ni
);
assert!(
!free_indices[ni].1,
"Free list hit an index ({}) more than once! Cycle detected!",
ni
);
free_indices[ni].1 = true;
match &self.entries[ni].state {
EntryState::InFreeList(next_index) => next = *next_index,
EntryState::EndOfFreeList => break,
EntryState::Active(..) => unreachable!("Bug: Active item in free list at {}", next),
}
}
let mut occupied_count = 0;
for (i, &(should_be_free, is_free)) in free_indices.iter().enumerate() {
assert_eq!(
should_be_free, is_free,
"Free list missed item, or contains an item it shouldn't: {}",
i
);
if !should_be_free {
occupied_count += 1;
}
}
assert_eq!(
self.num_entries, occupied_count,
"num_entries doesn't reflect the actual number of entries"
);
}
pub fn insert(&mut self, v: T) -> Handle {
let need_cap = self.len() + 1;
self.ensure_capacity(need_cap);
let index = self.first_free;
let result = {
let entry = &mut self.entries[index as usize];
let new_first_free = match entry.state {
EntryState::InFreeList(i) => i,
_ => panic!("Bug: next_index pointed at non-free list entry (or end of list)"),
};
entry.version += 1;
if entry.version == 0 {
entry.version += 2;
}
entry.state = EntryState::Active(v);
self.first_free = new_first_free;
self.num_entries += 1;
Handle {
map_id: self.id,
version: entry.version,
index,
}
};
self.debug_check_valid();
result
}
fn check_handle(&self, h: Handle) -> Result<usize, HandleError> {
if h.map_id != self.id {
log::info!(
"HandleMap access with handle having wrong map id: {:?} (our map id is {})",
h,
self.id
);
return Err(HandleError::WrongMap);
}
let index = h.index as usize;
if index >= self.entries.len() {
log::info!("HandleMap accessed with handle past end of map: {:?}", h);
return Err(HandleError::IndexPastEnd);
}
if self.entries[index].version != h.version {
log::info!(
"HandleMap accessed with handle with wrong version {:?} (entry version is {})",
h,
self.entries[index].version
);
return Err(HandleError::StaleVersion);
}
if (h.version % 2) != 0 {
log::info!(
"HandleMap given handle with matching but illegal version: {:?}",
h,
);
return Err(HandleError::StaleVersion);
}
Ok(index)
}
pub fn delete(&mut self, h: Handle) -> Result<(), HandleError> {
self.remove(h).map(drop)
}
pub fn remove(&mut self, h: Handle) -> Result<T, HandleError> {
let index = self.check_handle(h)?;
let prev = {
let entry = &mut self.entries[index];
entry.version += 1;
let index = h.index;
let last_state =
std::mem::replace(&mut entry.state, EntryState::InFreeList(self.first_free));
self.num_entries -= 1;
self.first_free = index;
if let EntryState::Active(value) = last_state {
value
} else {
unreachable!(
"Handle {:?} passed validation but references unoccupied entry",
h
);
}
};
self.debug_check_valid();
Ok(prev)
}
pub fn get(&self, h: Handle) -> Result<&T, HandleError> {
let idx = self.check_handle(h)?;
let entry = &self.entries[idx];
let item = entry
.state
.get_item()
.ok_or_else(|| HandleError::InvalidHandle)?;
Ok(item)
}
pub fn get_mut(&mut self, h: Handle) -> Result<&mut T, HandleError> {
let idx = self.check_handle(h)?;
let entry = &mut self.entries[idx];
let item = entry
.state
.get_item_mut()
.ok_or_else(|| HandleError::InvalidHandle)?;
Ok(item)
}
}
impl<T> Default for HandleMap<T> {
#[inline]
fn default() -> Self {
HandleMap::new()
}
}
impl<T> ops::Index<Handle> for HandleMap<T> {
type Output = T;
#[inline]
fn index(&self, h: Handle) -> &T {
self.get(h)
.expect("Indexed into HandleMap with invalid handle!")
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Handle {
map_id: u16,
version: u16,
index: u16,
}
const HANDLE_MAGIC: u16 = 0x4153_u16;
impl Handle {
#[inline]
pub fn into_u64(self) -> u64 {
let map_id = u64::from(self.map_id);
let version = u64::from(self.version);
let index = u64::from(self.index);
let magic = u64::from(HANDLE_MAGIC);
(magic << 48) | (map_id << 32) | (index << 16) | version
}
pub fn from_u64(v: u64) -> Result<Self, HandleError> {
if !Handle::is_valid(v) {
log::warn!("Illegal handle! {:x}", v);
if v == 0 {
Err(HandleError::NullHandle)
} else {
Err(HandleError::InvalidHandle)
}
} else {
let map_id = (v >> 32) as u16;
let index = (v >> 16) as u16;
let version = v as u16;
Ok(Self {
map_id,
version,
index,
})
}
}
#[inline]
pub fn is_valid(v: u64) -> bool {
(v >> 48) == u64::from(HANDLE_MAGIC) &&
((v & 1) == 0)
}
}
impl From<u64> for Handle {
fn from(u: u64) -> Self {
Handle::from_u64(u).expect("Illegal handle!")
}
}
impl From<Handle> for u64 {
#[inline]
fn from(h: Handle) -> u64 {
h.into_u64()
}
}
unsafe impl IntoFfi for Handle {
type Value = u64;
#[inline]
fn ffi_default() -> u64 {
0u64
}
#[inline]
fn into_ffi_value(self) -> u64 {
self.into_u64()
}
}
pub struct ConcurrentHandleMap<T> {
pub map: RwLock<HandleMap<Mutex<T>>>,
}
impl<T> ConcurrentHandleMap<T> {
pub fn new() -> Self {
Self {
map: RwLock::new(HandleMap::new()),
}
}
#[inline]
pub fn len(&self) -> usize {
let map = self.map.read().unwrap();
map.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn insert(&self, v: T) -> Handle {
let mut map = self.map.write().unwrap();
map.insert(Mutex::new(v))
}
pub fn delete(&self, h: Handle) -> Result<(), HandleError> {
let v = {
let mut map = self.map.write().unwrap();
map.remove(h)
};
v.map(drop)
}
pub fn delete_u64(&self, h: u64) -> Result<(), HandleError> {
self.delete(Handle::from_u64(h)?)
}
pub fn remove(&self, h: Handle) -> Result<Option<T>, HandleError> {
let mut map = self.map.write().unwrap();
let mutex = map.remove(h)?;
Ok(mutex.into_inner().ok())
}
pub fn remove_u64(&self, h: u64) -> Result<Option<T>, HandleError> {
self.remove(Handle::from_u64(h)?)
}
pub fn get<F, E, R>(&self, h: Handle, callback: F) -> Result<R, E>
where
F: FnOnce(&T) -> Result<R, E>,
E: From<HandleError>,
{
self.get_mut(h, |v| callback(v))
}
pub fn get_mut<F, E, R>(&self, h: Handle, callback: F) -> Result<R, E>
where
F: FnOnce(&mut T) -> Result<R, E>,
E: From<HandleError>,
{
let map = self.map.read().unwrap();
let mtx = map.get(h)?;
let mut hm = mtx.lock().unwrap();
callback(&mut *hm)
}
pub fn get_u64<F, E, R>(&self, u: u64, callback: F) -> Result<R, E>
where
F: FnOnce(&T) -> Result<R, E>,
E: From<HandleError>,
{
self.get(Handle::from_u64(u)?, callback)
}
pub fn get_mut_u64<F, E, R>(&self, u: u64, callback: F) -> Result<R, E>
where
F: FnOnce(&mut T) -> Result<R, E>,
E: From<HandleError>,
{
self.get_mut(Handle::from_u64(u)?, callback)
}
pub fn call_with_result_mut<R, E, F>(
&self,
out_error: &mut ExternError,
h: u64,
callback: F,
) -> R::Value
where
F: std::panic::UnwindSafe + FnOnce(&mut T) -> Result<R, E>,
ExternError: From<E>,
R: IntoFfi,
{
use crate::call_with_result;
call_with_result(out_error, || -> Result<_, ExternError> {
let h = Handle::from_u64(h)?;
let map = self.map.read().unwrap();
let mtx = map.get(h)?;
let mut hm = mtx.lock().unwrap();
Ok(callback(&mut *hm)?)
})
}
pub fn call_with_result<R, E, F>(
&self,
out_error: &mut ExternError,
h: u64,
callback: F,
) -> R::Value
where
F: std::panic::UnwindSafe + FnOnce(&T) -> Result<R, E>,
ExternError: From<E>,
R: IntoFfi,
{
self.call_with_result_mut(out_error, h, |r| callback(r))
}
pub fn call_with_output<R, F>(
&self,
out_error: &mut ExternError,
h: u64,
callback: F,
) -> R::Value
where
F: std::panic::UnwindSafe + FnOnce(&T) -> R,
R: IntoFfi,
{
self.call_with_result(out_error, h, |r| -> Result<_, HandleError> {
Ok(callback(r))
})
}
pub fn call_with_output_mut<R, F>(
&self,
out_error: &mut ExternError,
h: u64,
callback: F,
) -> R::Value
where
F: std::panic::UnwindSafe + FnOnce(&mut T) -> R,
R: IntoFfi,
{
self.call_with_result_mut(out_error, h, |r| -> Result<_, HandleError> {
Ok(callback(r))
})
}
pub fn insert_with_result<E, F>(&self, out_error: &mut ExternError, constructor: F) -> u64
where
F: std::panic::UnwindSafe + FnOnce() -> Result<T, E>,
ExternError: From<E>,
{
use crate::call_with_result;
call_with_result(out_error, || -> Result<_, ExternError> {
let to_insert = constructor()?;
Ok(self.insert(to_insert))
})
}
pub fn insert_with_output<F>(&self, out_error: &mut ExternError, constructor: F) -> u64
where
F: std::panic::UnwindSafe + FnOnce() -> T,
{
self.insert_with_result(out_error, || -> Result<_, HandleError> {
Ok(constructor())
})
}
}
impl<T> Default for ConcurrentHandleMap<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
fn next_handle_map_id() -> u16 {
let id = HANDLE_MAP_ID_COUNTER
.fetch_add(1, Ordering::SeqCst)
.wrapping_add(1);
id as u16
}
lazy_static::lazy_static! {
static ref HANDLE_MAP_ID_COUNTER: AtomicUsize = {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let init = RandomState::new().build_hasher().finish() as usize;
AtomicUsize::new(init)
};
}
#[cfg(test)]
mod test {
use super::*;
#[derive(PartialEq, Debug)]
pub(super) struct Foobar(usize);
#[test]
fn test_invalid_handle() {
assert_eq!(Handle::from_u64(0), Err(HandleError::NullHandle));
assert_eq!(
Handle::from_u64((u64::from(HANDLE_MAGIC) << 48) | 0x1234_0012_0001),
Err(HandleError::InvalidHandle)
);
assert_eq!(
Handle::from_u64((u64::from(HANDLE_MAGIC) << 48) | 0x1234_0012_0002),
Ok(Handle {
version: 0x0002,
index: 0x0012,
map_id: 0x1234,
})
);
}
#[test]
fn test_correct_value_single() {
let mut map = HandleMap::new();
let handle = map.insert(Foobar(1234));
assert_eq!(map.get(handle).unwrap(), &Foobar(1234));
map.delete(handle).unwrap();
assert_eq!(map.get(handle), Err(HandleError::StaleVersion));
}
#[test]
fn test_correct_value_multiple() {
let mut map = HandleMap::new();
let handle1 = map.insert(Foobar(1234));
let handle2 = map.insert(Foobar(4321));
assert_eq!(map.get(handle1).unwrap(), &Foobar(1234));
assert_eq!(map.get(handle2).unwrap(), &Foobar(4321));
map.delete(handle1).unwrap();
assert_eq!(map.get(handle1), Err(HandleError::StaleVersion));
assert_eq!(map.get(handle2).unwrap(), &Foobar(4321));
}
#[test]
fn test_wrong_map() {
let mut map1 = HandleMap::new();
let mut map2 = HandleMap::new();
let handle1 = map1.insert(Foobar(1234));
let handle2 = map2.insert(Foobar(1234));
assert_eq!(map1.get(handle1).unwrap(), &Foobar(1234));
assert_eq!(map2.get(handle2).unwrap(), &Foobar(1234));
assert_eq!(map1.get(handle2), Err(HandleError::WrongMap));
assert_eq!(map2.get(handle1), Err(HandleError::WrongMap));
}
#[test]
fn test_bad_index() {
let map: HandleMap<Foobar> = HandleMap::new();
assert_eq!(
map.get(Handle {
map_id: map.id,
version: 2,
index: 100
}),
Err(HandleError::IndexPastEnd)
);
}
#[test]
fn test_resizing() {
let mut map = HandleMap::new();
let mut handles = vec![];
for i in 0..1000 {
handles.push(map.insert(Foobar(i)))
}
for (i, &h) in handles.iter().enumerate() {
assert_eq!(map.get(h).unwrap(), &Foobar(i));
assert_eq!(map.remove(h).unwrap(), Foobar(i));
}
let mut handles2 = vec![];
for i in 1000..2000 {
let h = map.insert(Foobar(i));
let hu = h.into_u64();
assert_eq!(Handle::from_u64(hu).unwrap(), h);
handles2.push(hu);
}
for (i, (&h0, h1u)) in handles.iter().zip(handles2).enumerate() {
assert_eq!(map.get(h0), Err(HandleError::StaleVersion));
let h1 = Handle::from_u64(h1u).unwrap();
assert_eq!(map.get(h1).unwrap(), &Foobar(i + 1000));
}
}
#[cfg(not(coverage))]
mod panic_tests {
use super::*;
struct PanicOnDrop(());
impl Drop for PanicOnDrop {
fn drop(&mut self) {
panic!("intentional panic (drop)");
}
}
#[test]
fn test_panicking_drop() {
let map = ConcurrentHandleMap::new();
let h = map.insert(PanicOnDrop(())).into_u64();
let mut e = ExternError::success();
crate::call_with_result(&mut e, || map.delete_u64(h));
assert_eq!(e.get_code(), crate::ErrorCode::PANIC);
let _ = unsafe { e.get_and_consume_message() };
assert!(!map.map.is_poisoned());
let inner = map.map.read().unwrap();
inner.assert_valid();
assert_eq!(inner.len(), 0);
}
#[test]
fn test_panicking_call_with() {
let map = ConcurrentHandleMap::new();
let h = map.insert(Foobar(0)).into_u64();
let mut e = ExternError::success();
map.call_with_output(&mut e, h, |_thing| {
panic!("intentional panic (call_with_output)");
});
assert_eq!(e.get_code(), crate::ErrorCode::PANIC);
let _ = unsafe { e.get_and_consume_message() };
{
assert!(!map.map.is_poisoned());
let inner = map.map.read().unwrap();
inner.assert_valid();
assert_eq!(inner.len(), 1);
let mut seen = false;
for e in &inner.entries {
if let EntryState::Active(v) = &e.state {
assert!(!seen);
assert!(v.is_poisoned());
seen = true;
}
}
}
assert!(map.delete_u64(h).is_ok());
assert!(!map.map.is_poisoned());
let inner = map.map.read().unwrap();
inner.assert_valid();
assert_eq!(inner.len(), 0);
}
#[test]
fn test_panicking_insert_with() {
let map = ConcurrentHandleMap::new();
let mut e = ExternError::success();
let res = map.insert_with_output(&mut e, || {
panic!("intentional panic (insert_with_output)");
});
assert_eq!(e.get_code(), crate::ErrorCode::PANIC);
let _ = unsafe { e.get_and_consume_message() };
assert_eq!(res, 0);
assert!(!map.map.is_poisoned());
let inner = map.map.read().unwrap();
inner.assert_valid();
assert_eq!(inner.len(), 0);
}
}
}