use std::alloc::{alloc, dealloc, handle_alloc_error, Layout};
use std::marker::PhantomData;
use std::mem::{self, MaybeUninit};
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicU32, Ordering};
const EMPTY: i64 = i64::MIN;
const MIN_CAPACITY: usize = 8;
const SHRINK_DIVISOR: usize = 4;
const MIN_SHRINK_CAPACITY: usize = 64;
#[inline]
fn load_factor_for_capacity(capacity: usize) -> (usize, usize) {
if capacity <= 1024 {
(3, 4) } else if capacity <= 65536 {
(7, 8) } else {
(15, 16) }
}
#[inline]
fn capacity_for_entries(n: usize) -> usize {
if n == 0 {
return MIN_CAPACITY;
}
let mut cap = n.next_power_of_two();
loop {
let (num, den) = load_factor_for_capacity(cap);
let max_entries = cap * num / den;
if max_entries >= n {
return cap.max(MIN_CAPACITY);
}
cap *= 2;
}
}
#[repr(C)]
struct Slot<V> {
key: i64,
value: MaybeUninit<V>,
}
#[repr(C)]
struct Header {
capacity: u32,
len: u32,
drop_count: AtomicU32,
}
impl Header {
#[inline]
fn get_capacity(&self) -> usize {
self.capacity as usize
}
}
pub struct CowHashMap<V: Clone> {
ptr: NonNull<u8>,
_marker: PhantomData<V>,
}
unsafe impl<V: Clone + Send> Send for CowHashMap<V> {}
unsafe impl<V: Clone + Sync> Sync for CowHashMap<V> {}
impl<V: Clone> Clone for CowHashMap<V> {
#[inline]
fn clone(&self) -> Self {
self.header().drop_count.fetch_add(1, Ordering::Relaxed);
Self {
ptr: self.ptr,
_marker: PhantomData,
}
}
}
impl<V: Clone> Drop for CowHashMap<V> {
fn drop(&mut self) {
let header = self.header();
let old_count = header.drop_count.fetch_sub(1, Ordering::AcqRel);
if old_count != 1 {
return;
}
let capacity = header.get_capacity();
if mem::needs_drop::<V>() {
let slots = self.slots_mut();
for slot in slots.iter_mut() {
if slot.key != EMPTY {
unsafe {
ptr::drop_in_place(slot.value.as_mut_ptr());
}
}
}
}
let (layout, _) = Self::layout_for_capacity::<V>(capacity);
unsafe {
dealloc(self.ptr.as_ptr(), layout);
}
}
}
impl<V: Clone> Default for CowHashMap<V> {
fn default() -> Self {
Self::new()
}
}
#[cold]
#[inline(never)]
fn assert_valid_key(key: i64) {
if key == EMPTY {
panic!("i64::MIN cannot be used as a key in CowHashMap (reserved as empty sentinel)");
}
}
#[inline(always)]
fn check_key(key: i64) {
if key == EMPTY {
assert_valid_key(key);
}
}
impl<V: Clone> CowHashMap<V> {
#[inline]
pub fn new() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Self {
let cap = capacity_for_entries(capacity);
if cap > u32::MAX as usize {
panic!("CowHashMap capacity overflow: {} > u32::MAX", cap);
}
let (layout, slots_offset) = Self::layout_for_capacity::<V>(cap);
let ptr = unsafe { alloc(layout) };
let ptr = match NonNull::new(ptr) {
Some(p) => p,
None => handle_alloc_error(layout),
};
unsafe {
let header = ptr.as_ptr() as *mut Header;
(*header).capacity = cap as u32;
(*header).len = 0;
(*header).drop_count = AtomicU32::new(1);
}
unsafe {
let slots_ptr = ptr.as_ptr().add(slots_offset) as *mut Slot<V>;
for i in 0..cap {
(*slots_ptr.add(i)).key = EMPTY;
}
}
Self {
ptr,
_marker: PhantomData,
}
}
fn layout_for_capacity<T>(capacity: usize) -> (Layout, usize) {
let header_layout = Layout::new::<Header>();
let slot_layout = Layout::new::<Slot<T>>();
let (layout, offset) = header_layout
.extend(
Layout::from_size_align(slot_layout.size() * capacity, slot_layout.align())
.unwrap(),
)
.unwrap();
(layout, offset)
}
#[inline]
fn slots_offset<T>() -> usize {
let header_layout = Layout::new::<Header>();
let slot_layout = Layout::new::<Slot<T>>();
let (_, offset) = header_layout.extend(slot_layout).unwrap();
offset
}
#[inline]
fn header(&self) -> &Header {
unsafe { &*(self.ptr.as_ptr() as *const Header) }
}
#[inline]
fn header_mut(&mut self) -> &mut Header {
unsafe { &mut *(self.ptr.as_ptr() as *mut Header) }
}
#[inline]
fn capacity(&self) -> usize {
self.header().capacity as usize
}
#[inline]
fn mask(&self) -> usize {
self.capacity() - 1
}
#[inline]
pub fn len(&self) -> usize {
self.header().len as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn slots(&self) -> &[Slot<V>] {
let capacity = self.capacity();
let offset = Self::slots_offset::<V>();
unsafe {
let slots_ptr = self.ptr.as_ptr().add(offset) as *const Slot<V>;
std::slice::from_raw_parts(slots_ptr, capacity)
}
}
#[inline]
fn slots_mut(&mut self) -> &mut [Slot<V>] {
let capacity = self.capacity();
let offset = Self::slots_offset::<V>();
unsafe {
let slots_ptr = self.ptr.as_ptr().add(offset) as *mut Slot<V>;
std::slice::from_raw_parts_mut(slots_ptr, capacity)
}
}
#[inline(always)]
fn hash(key: i64) -> usize {
let k = key as u64;
let k = k ^ (k >> 16); k.wrapping_mul(0x517cc1b727220a95) as usize
}
#[inline]
fn ensure_unique(&mut self) {
let drop_count = self.header().drop_count.load(Ordering::Acquire);
if drop_count > 1 {
*self = self.deep_clone();
}
}
fn deep_clone(&self) -> Self {
let capacity = self.capacity();
let len = self.len();
let mut new_map = Self::with_capacity(len);
if new_map.capacity() != capacity {
let slots = self.slots();
for slot in slots.iter() {
if slot.key != EMPTY {
let value = unsafe { (*slot.value.as_ptr()).clone() };
new_map.insert_internal(slot.key, value);
}
}
} else {
new_map.header_mut().len = len as u32;
let src_slots = self.slots();
let dst_slots = new_map.slots_mut();
for i in 0..capacity {
if src_slots[i].key != EMPTY {
unsafe {
let value = (*src_slots[i].value.as_ptr()).clone();
dst_slots[i].value = MaybeUninit::new(value);
dst_slots[i].key = src_slots[i].key;
}
}
}
}
new_map
}
fn insert_internal(&mut self, key: i64, value: V) -> Option<V> {
let mask = self.mask();
let mut idx = Self::hash(key) & mask;
let slots = self.slots_mut();
loop {
if slots[idx].key == EMPTY {
slots[idx].key = key;
slots[idx].value = MaybeUninit::new(value);
self.header_mut().len += 1;
return None;
}
if slots[idx].key == key {
let old = unsafe { ptr::read(slots[idx].value.as_ptr()) };
slots[idx].value = MaybeUninit::new(value);
return Some(old);
}
idx = (idx + 1) & mask;
}
}
fn grow(&mut self) {
let old_capacity = self.capacity();
let mut new_map = Self::with_capacity(self.len() * 2);
for slot in self.slots_mut().iter_mut().take(old_capacity) {
if slot.key != EMPTY {
let key = slot.key;
let value = unsafe { ptr::read(slot.value.as_ptr()) };
slot.key = EMPTY; new_map.insert_internal(key, value);
}
}
self.header_mut().len = 0;
*self = new_map;
}
#[inline]
pub fn get(&self, key: i64) -> Option<&V> {
check_key(key);
let mask = self.mask();
let mut idx = Self::hash(key) & mask;
let slots = self.slots();
loop {
if slots[idx].key == EMPTY {
return None;
}
if slots[idx].key == key {
return Some(unsafe { &*slots[idx].value.as_ptr() });
}
idx = (idx + 1) & mask;
}
}
#[inline]
pub fn contains_key(&self, key: i64) -> bool {
self.get(key).is_some()
}
#[inline]
fn needs_grow(&self) -> bool {
let cap = self.capacity();
let (num, den) = load_factor_for_capacity(cap);
self.len() * den >= cap * num
}
#[inline]
fn should_shrink(&self) -> bool {
let cap = self.capacity();
cap > MIN_SHRINK_CAPACITY && self.len() < cap / SHRINK_DIVISOR
}
fn shrink(&mut self) {
let new_cap = capacity_for_entries(self.len());
if new_cap >= self.capacity() {
return; }
let mut new_map = Self::with_capacity(self.len());
for slot in self.slots_mut().iter_mut() {
if slot.key != EMPTY {
let key = slot.key;
let value = unsafe { ptr::read(slot.value.as_ptr()) };
slot.key = EMPTY; new_map.insert_internal(key, value);
}
}
self.header_mut().len = 0;
*self = new_map;
}
pub fn shrink_to_fit(&mut self) {
self.ensure_unique();
self.shrink();
}
pub fn insert(&mut self, key: i64, value: V) -> Option<V> {
check_key(key);
self.ensure_unique();
if self.needs_grow() {
self.grow();
}
self.insert_internal(key, value)
}
pub fn remove(&mut self, key: i64) -> Option<V> {
check_key(key);
self.ensure_unique();
let capacity = self.capacity();
let mask = capacity - 1;
let mut idx = Self::hash(key) & mask;
{
let slots = self.slots();
loop {
if slots[idx].key == EMPTY {
return None;
}
if slots[idx].key == key {
break;
}
idx = (idx + 1) & mask;
}
}
self.header_mut().len -= 1;
let slots = self.slots_mut();
let value = unsafe { ptr::read(slots[idx].value.as_ptr()) };
let mut empty_idx = idx;
loop {
let next_idx = (empty_idx + 1) & mask;
if slots[next_idx].key == EMPTY {
slots[empty_idx].key = EMPTY;
break;
}
let next_natural = Self::hash(slots[next_idx].key) & mask;
let dist_to_empty = if empty_idx >= next_natural {
empty_idx - next_natural
} else {
capacity - next_natural + empty_idx
};
let dist_to_next = if next_idx >= next_natural {
next_idx - next_natural
} else {
capacity - next_natural + next_idx
};
if dist_to_empty <= dist_to_next {
unsafe {
let base = slots.as_mut_ptr();
let src = base.add(next_idx);
let dst = base.add(empty_idx);
(*dst).key = (*src).key;
ptr::copy_nonoverlapping((*src).value.as_ptr(), (*dst).value.as_mut_ptr(), 1);
}
empty_idx = next_idx;
} else {
slots[empty_idx].key = EMPTY;
break;
}
}
if self.should_shrink() {
self.shrink();
}
Some(value)
}
#[inline]
pub fn get_mut(&mut self, key: i64) -> Option<&mut V> {
check_key(key);
self.ensure_unique();
let mask = self.mask();
let mut idx = Self::hash(key) & mask;
let slots = self.slots_mut();
loop {
if slots[idx].key == EMPTY {
return None;
}
if slots[idx].key == key {
return Some(unsafe { &mut *slots[idx].value.as_mut_ptr() });
}
idx = (idx + 1) & mask;
}
}
pub fn iter(&self) -> impl Iterator<Item = (i64, &V)> {
let slots = self.slots();
slots.iter().filter_map(|slot| {
if slot.key != EMPTY {
Some((slot.key, unsafe { &*slot.value.as_ptr() }))
} else {
None
}
})
}
pub fn keys_iter(&self) -> impl Iterator<Item = i64> + '_ {
self.iter().map(|(k, _)| k)
}
pub fn values_iter(&self) -> impl Iterator<Item = &V> {
self.iter().map(|(_, v)| v)
}
pub fn clear(&mut self) {
self.ensure_unique();
for slot in self.slots_mut() {
if slot.key != EMPTY {
if mem::needs_drop::<V>() {
unsafe {
ptr::drop_in_place(slot.value.as_mut_ptr());
}
}
slot.key = EMPTY;
}
}
self.header_mut().len = 0;
}
pub fn entry(&mut self, key: i64) -> Entry<'_, V> {
check_key(key);
self.ensure_unique();
let mask = self.mask();
let mut idx = Self::hash(key) & mask;
let slots = self.slots();
loop {
if slots[idx].key == EMPTY {
return Entry::Vacant(VacantEntry { map: self, key });
}
if slots[idx].key == key {
return Entry::Occupied(OccupiedEntry { map: self, idx });
}
idx = (idx + 1) & mask;
}
}
}
pub enum Entry<'a, V: Clone> {
Occupied(OccupiedEntry<'a, V>),
Vacant(VacantEntry<'a, V>),
}
pub struct OccupiedEntry<'a, V: Clone> {
map: &'a mut CowHashMap<V>,
idx: usize,
}
pub struct VacantEntry<'a, V: Clone> {
map: &'a mut CowHashMap<V>,
key: i64,
}
impl<'a, V: Clone> Entry<'a, V> {
pub fn or_insert(self, default: V) -> &'a mut V {
match self {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(default),
}
}
pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> &'a mut V {
match self {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(default()),
}
}
}
impl<'a, V: Clone> OccupiedEntry<'a, V> {
pub fn get(&self) -> &V {
unsafe { &*self.map.slots()[self.idx].value.as_ptr() }
}
pub fn get_mut(&mut self) -> &mut V {
unsafe { &mut *self.map.slots_mut()[self.idx].value.as_mut_ptr() }
}
pub fn into_mut(self) -> &'a mut V {
unsafe { &mut *self.map.slots_mut()[self.idx].value.as_mut_ptr() }
}
}
impl<'a, V: Clone> VacantEntry<'a, V> {
pub fn insert(self, value: V) -> &'a mut V {
if self.map.needs_grow() {
self.map.grow();
}
let mask = self.map.mask();
let mut idx = CowHashMap::<V>::hash(self.key) & mask;
{
let slots = self.map.slots_mut();
while slots[idx].key != EMPTY {
idx = (idx + 1) & mask;
}
slots[idx].key = self.key;
slots[idx].value = MaybeUninit::new(value);
}
self.map.header_mut().len += 1;
unsafe { &mut *self.map.slots_mut()[idx].value.as_mut_ptr() }
}
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::RwLock;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_operations() {
let mut map: CowHashMap<String> = CowHashMap::new();
assert!(map.is_empty());
assert_eq!(map.len(), 0);
assert_eq!(map.insert(1, "one".to_string()), None);
assert_eq!(map.insert(2, "two".to_string()), None);
assert_eq!(map.insert(3, "three".to_string()), None);
assert_eq!(map.len(), 3);
assert!(!map.is_empty());
assert_eq!(map.get(1), Some(&"one".to_string()));
assert_eq!(map.get(2), Some(&"two".to_string()));
assert_eq!(map.get(3), Some(&"three".to_string()));
assert_eq!(map.get(4), None);
assert!(map.contains_key(1));
assert!(!map.contains_key(4));
assert_eq!(map.insert(1, "ONE".to_string()), Some("one".to_string()));
assert_eq!(map.get(1), Some(&"ONE".to_string()));
assert_eq!(map.remove(2), Some("two".to_string()));
assert_eq!(map.get(2), None);
assert_eq!(map.len(), 2);
}
#[test]
fn test_cow_semantics() {
let mut map: CowHashMap<i64> = CowHashMap::new();
map.insert(1, 100);
map.insert(2, 200);
let snapshot = map.clone();
assert_eq!(snapshot.get(1), Some(&100));
assert_eq!(snapshot.get(2), Some(&200));
map.insert(3, 300);
map.insert(1, 111);
assert_eq!(map.get(1), Some(&111));
assert_eq!(map.get(3), Some(&300));
assert_eq!(snapshot.get(1), Some(&100));
assert_eq!(snapshot.get(3), None);
}
#[test]
fn test_growth() {
let mut map: CowHashMap<i64> = CowHashMap::new();
for i in 0..1000 {
map.insert(i, i * 2);
}
assert_eq!(map.len(), 1000);
for i in 0..1000 {
assert_eq!(map.get(i), Some(&(i * 2)));
}
}
#[test]
fn test_negative_keys() {
let mut map: CowHashMap<String> = CowHashMap::new();
map.insert(-1, "negative one".to_string());
map.insert(0, "zero".to_string());
map.insert(i64::MIN + 1, "near_min".to_string());
map.insert(i64::MAX, "max".to_string());
assert_eq!(map.get(-1), Some(&"negative one".to_string()));
assert_eq!(map.get(0), Some(&"zero".to_string()));
assert_eq!(map.get(i64::MIN + 1), Some(&"near_min".to_string()));
assert_eq!(map.get(i64::MAX), Some(&"max".to_string()));
}
#[test]
fn test_concurrent_snapshots() {
let map: Arc<RwLock<CowHashMap<i64>>> = Arc::new(RwLock::new(CowHashMap::new()));
{
let mut guard = map.write();
for i in 0..100 {
guard.insert(i, i * 10);
}
}
let handles: Vec<_> = (0..4)
.map(|_| {
let map = Arc::clone(&map);
thread::spawn(move || {
for _ in 0..100 {
let snapshot = map.read().clone();
let sum: i64 = snapshot.iter().map(|(_, v)| *v).sum();
assert!(sum >= 0); }
})
})
.collect();
let writer_map = Arc::clone(&map);
let writer = thread::spawn(move || {
for i in 100..200 {
writer_map.write().insert(i, i * 10);
}
});
writer.join().unwrap();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(map.read().len(), 200);
}
#[test]
fn test_clear() {
let mut map: CowHashMap<String> = CowHashMap::new();
for i in 0..100 {
map.insert(i, format!("value_{}", i));
}
assert_eq!(map.len(), 100);
map.clear();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
assert_eq!(map.get(0), None);
}
#[test]
fn test_backward_shift_delete() {
let mut map: CowHashMap<i64> = CowHashMap::new();
for i in 0..50 {
map.insert(i, i);
}
for i in (0..50).step_by(2) {
map.remove(i);
}
assert_eq!(map.len(), 25);
for i in (1..50).step_by(2) {
assert_eq!(map.get(i), Some(&i));
}
for i in (0..50).step_by(2) {
assert_eq!(map.get(i), None);
}
}
#[test]
fn test_iter() {
let mut map: CowHashMap<i64> = CowHashMap::new();
for i in 0..10 {
map.insert(i, i * i);
}
let mut pairs: Vec<_> = map.iter().map(|(k, v)| (k, *v)).collect();
pairs.sort_by_key(|(k, _)| *k);
assert_eq!(pairs.len(), 10);
for (i, (k, v)) in pairs.iter().enumerate() {
assert_eq!(*k, i as i64);
assert_eq!(*v, (i * i) as i64);
}
}
#[test]
fn test_entry_api() {
let mut map: CowHashMap<i64> = CowHashMap::new();
*map.entry(1).or_insert(0) += 10;
assert_eq!(map.get(1), Some(&10));
*map.entry(1).or_insert(0) += 5;
assert_eq!(map.get(1), Some(&15));
map.entry(2).or_insert_with(|| 100);
assert_eq!(map.get(2), Some(&100));
}
#[test]
#[should_panic(expected = "i64::MIN cannot be used")]
fn test_min_key_panics() {
let mut map: CowHashMap<i64> = CowHashMap::new();
map.insert(i64::MIN, 0); }
#[test]
fn test_shrink_after_delete() {
let mut map: CowHashMap<i64> = CowHashMap::new();
for i in 0..1000 {
map.insert(i, i * 2);
}
let capacity_after_insert = map.capacity();
assert!(capacity_after_insert >= 1000);
for i in 10..1000 {
map.remove(i);
}
assert_eq!(map.len(), 10);
let capacity_after_remove = map.capacity();
assert!(
capacity_after_remove < capacity_after_insert,
"capacity should shrink: {} < {}",
capacity_after_remove,
capacity_after_insert
);
for i in 0..10 {
assert_eq!(map.get(i), Some(&(i * 2)));
}
}
#[test]
fn test_shrink_to_fit() {
let mut map: CowHashMap<i64> = CowHashMap::with_capacity(1000);
for i in 0..10 {
map.insert(i, i);
}
let initial_capacity = map.capacity();
assert!(initial_capacity >= 1000);
map.shrink_to_fit();
let after_shrink = map.capacity();
assert!(
after_shrink < initial_capacity,
"capacity should shrink: {} < {}",
after_shrink,
initial_capacity
);
for i in 0..10 {
assert_eq!(map.get(i), Some(&i));
}
}
}