use crate::util::CachePadded;
use core::hash::{Hash, Hasher};
use core::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use parking_lot::Mutex;
use crate::metrics::MetricsCollector;
#[cfg(feature = "std")]
use std::alloc::{self, Layout};
#[cfg(feature = "std")]
use std::boxed::Box;
#[cfg(feature = "std")]
use std::hash::{BuildHasher, RandomState};
#[cfg(feature = "std")]
use std::vec::Vec;
const DEFAULT_CAPACITY: usize = 16;
const MAX_LOAD_FACTOR: f64 = 0.75;
const STRIPE_COUNT: usize = 16;
#[allow(dead_code)]
const DISTANCE_BITS: u32 = 6;
#[derive(Debug)]
pub struct ConcurrentHashMap<K, V> {
table: CachePadded<AtomicPtr<Bucket<K, V>>>,
capacity: AtomicUsize,
size: AtomicUsize,
stripes: [CachePadded<Mutex<()>>; STRIPE_COUNT],
resize_state: CachePadded<AtomicPtr<ResizeState<K, V>>>,
}
#[repr(align(64))]
struct Bucket<K, V> {
entries: [Option<Entry<K, V>>; 16],
len: usize,
}
#[derive(Debug)]
struct Entry<K, V> {
key: K,
value: V,
hash: u64,
distance: u32,
}
struct ResizeState<K, V> {
old_table: *mut Bucket<K, V>,
new_table: *mut Bucket<K, V>,
old_capacity: usize,
new_capacity: usize,
progress: AtomicUsize,
}
impl<K, V> ConcurrentHashMap<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Send + Sync + 'static,
{
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
let capacity = if capacity.is_power_of_two() {
capacity
} else {
capacity.next_power_of_two()
};
let table = Self::allocate_table(capacity);
Self {
table: CachePadded::new(AtomicPtr::new(table)),
capacity: AtomicUsize::new(capacity),
size: AtomicUsize::new(0),
stripes: Self::new_stripes(),
resize_state: CachePadded::new(AtomicPtr::new(core::ptr::null_mut())),
}
}
pub fn insert(&self, key: K, value: V) -> Option<V> {
let hash = self.hash_key(&key);
let capacity = self.capacity.load(Ordering::Acquire);
let stripe = self.stripe_index(hash, capacity);
if self.should_resize() {
self.try_resize();
}
let _lock = self.stripes[stripe].lock();
if let Some(old_value) = self.insert_locked(key, value, hash) {
Some(old_value)
} else {
self.size.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn get(&self, key: &K) -> Option<&V> {
let hash = self.hash_key(key);
let capacity = self.capacity.load(Ordering::Acquire);
let resize_state = self.resize_state.load(Ordering::Acquire);
if !resize_state.is_null() {
self.help_resize(resize_state);
return self.get(key); }
self.get_locked(key, hash, capacity)
}
pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.hash_key(key);
let capacity = self.capacity.load(Ordering::Acquire);
let stripe = self.stripe_index(hash, capacity);
let _lock = self.stripes[stripe].lock();
if let Some(value) = self.remove_locked(key, hash, capacity) {
self.size.fetch_sub(1, Ordering::Relaxed);
Some(value)
} else {
None
}
}
pub fn len(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.capacity.load(Ordering::Relaxed)
}
pub fn clear(&self) {
let locks: Vec<_> = self.stripes.iter().map(|stripe| stripe.lock()).collect();
let capacity = self.capacity.load(Ordering::Relaxed);
let table = self.table.load(Ordering::Relaxed);
unsafe {
for i in 0..capacity {
let bucket = table.add(i);
(*bucket).len = 0;
for entry in &mut (*bucket).entries {
*entry = None;
}
}
}
self.size.store(0, Ordering::Relaxed);
drop(locks); }
fn new_stripes() -> [CachePadded<Mutex<()>>; STRIPE_COUNT] {
let stripes: [CachePadded<Mutex<()>>; STRIPE_COUNT] =
core::array::from_fn(|_| CachePadded::new(Mutex::new(())));
stripes
}
fn allocate_table(capacity: usize) -> *mut Bucket<K, V> {
let table = unsafe {
alloc::alloc(
Layout::from_size_align(capacity * core::mem::size_of::<Bucket<K, V>>(), 64)
.unwrap(),
) as *mut Bucket<K, V>
};
if table.is_null() {
alloc::handle_alloc_error(
Layout::from_size_align(capacity * core::mem::size_of::<Bucket<K, V>>(), 64)
.unwrap(),
);
}
for i in 0..capacity {
unsafe {
let bucket = table.add(i);
(*bucket).len = 0;
(*bucket).entries = [const { None }; 16];
}
}
table
}
fn hash_key(&self, key: &K) -> u64 {
let mut hasher = RandomState::new().build_hasher();
key.hash(&mut hasher);
hasher.finish()
}
fn stripe_index(&self, hash: u64, _capacity: usize) -> usize {
((hash >> 32) as usize) % STRIPE_COUNT
}
fn bucket_index(&self, hash: u64, capacity: usize) -> usize {
(hash as usize) & (capacity - 1)
}
fn should_resize(&self) -> bool {
let size = self.size.load(Ordering::Relaxed);
let capacity = self.capacity.load(Ordering::Relaxed);
size as f64 > capacity as f64 * MAX_LOAD_FACTOR
}
fn try_resize(&self) {
if self
.resize_state
.compare_exchange(
core::ptr::null_mut(),
core::ptr::null_mut(),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_ok()
{
let old_capacity = self.capacity.load(Ordering::Relaxed);
let new_capacity = old_capacity * 2;
let old_table = self.table.load(Ordering::Relaxed);
let new_table = Self::allocate_table(new_capacity);
let resize_state = Box::into_raw(Box::new(ResizeState {
old_table,
new_table,
old_capacity,
new_capacity,
progress: AtomicUsize::new(0),
}));
self.resize_state.store(resize_state, Ordering::Release);
self.help_resize(resize_state);
}
}
fn help_resize(&self, resize_state: *mut ResizeState<K, V>) {
unsafe {
let state = &*resize_state;
let old_capacity = state.old_capacity;
let _new_capacity = state.new_capacity;
let mut migrated = state.progress.load(Ordering::Relaxed);
while migrated < old_capacity {
let next_migrated = (migrated + 16).min(old_capacity);
for i in migrated..next_migrated {
self.migrate_bucket(state, i);
}
migrated = state
.progress
.fetch_add(next_migrated - migrated, Ordering::Relaxed);
}
if migrated >= old_capacity {
self.complete_resize(state);
}
}
}
fn migrate_bucket(&self, resize_state: &ResizeState<K, V>, bucket_index: usize) {
unsafe {
let old_bucket = resize_state.old_table.add(bucket_index);
let old_len = (*old_bucket).len;
for i in 0..old_len {
if let Some(entry) = &(*old_bucket).entries[i] {
let new_bucket_index = self.bucket_index(entry.hash, resize_state.new_capacity);
let new_bucket = resize_state.new_table.add(new_bucket_index);
if (*new_bucket).len < 16 {
(*new_bucket).entries[(*new_bucket).len] = Some(Entry {
key: core::ptr::read(&entry.key),
value: core::ptr::read(&entry.value),
hash: entry.hash,
distance: 0,
});
(*new_bucket).len += 1;
}
}
}
}
}
fn complete_resize(&self, resize_state: &ResizeState<K, V>) {
unsafe {
self.table.store(resize_state.new_table, Ordering::Release);
self.capacity
.store(resize_state.new_capacity, Ordering::Release);
self.resize_state
.store(core::ptr::null_mut(), Ordering::Release);
alloc::dealloc(
resize_state.old_table as *mut u8,
Layout::from_size_align(
resize_state.old_capacity * core::mem::size_of::<Bucket<K, V>>(),
64,
)
.unwrap(),
);
drop(Box::from_raw(
resize_state as *const ResizeState<K, V> as *mut ResizeState<K, V>,
));
}
}
fn insert_locked(&self, key: K, value: V, hash: u64) -> Option<V> {
let capacity = self.capacity.load(Ordering::Relaxed);
let bucket_index = self.bucket_index(hash, capacity);
let table = self.table.load(Ordering::Relaxed);
unsafe {
let bucket = table.add(bucket_index);
for i in 0..(*bucket).len {
if let Some(entry) = &(*bucket).entries[i] {
if entry.hash == hash && entry.key == key {
let old_value = core::ptr::read(&entry.value);
(*bucket).entries[i] = Some(Entry {
key: key.clone(),
value,
hash,
distance: entry.distance,
});
return Some(old_value);
}
}
}
if (*bucket).len < 16 {
(*bucket).entries[(*bucket).len] = Some(Entry {
key,
value,
hash,
distance: 0,
});
(*bucket).len += 1;
None
} else {
panic!("Bucket overflow - should trigger resize");
}
}
}
fn get_locked(&self, key: &K, hash: u64, capacity: usize) -> Option<&V> {
let bucket_index = self.bucket_index(hash, capacity);
let table = self.table.load(Ordering::Acquire);
unsafe {
let bucket = table.add(bucket_index);
for i in 0..(*bucket).len {
if let Some(entry) = &(*bucket).entries[i] {
if entry.hash == hash && entry.key == *key {
return Some(&entry.value);
}
}
}
}
None
}
fn remove_locked(&self, key: &K, hash: u64, capacity: usize) -> Option<V> {
let bucket_index = self.bucket_index(hash, capacity);
let table = self.table.load(Ordering::Relaxed);
unsafe {
let bucket = table.add(bucket_index);
for i in 0..(*bucket).len {
if let Some(entry) = &(*bucket).entries[i] {
if entry.hash == hash && entry.key == *key {
let entry = (*bucket).entries[i].take().unwrap();
for j in i..(*bucket).len - 1 {
(*bucket).entries[j] = (*bucket).entries[j + 1].take();
}
(*bucket).len -= 1;
return Some(entry.value);
}
}
}
}
None
}
}
impl<K, V> Default for ConcurrentHashMap<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V> Clone for ConcurrentHashMap<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Send + Sync + Clone + 'static,
{
fn clone(&self) -> Self {
let new_map = Self::with_capacity(self.capacity());
for bucket_index in 0..self.capacity() {
let table = self.table.load(Ordering::Acquire);
unsafe {
let bucket = table.add(bucket_index);
for i in 0..(*bucket).len {
if let Some(entry) = &(*bucket).entries[i] {
new_map.insert(entry.key.clone(), entry.value.clone());
}
}
}
}
new_map
}
}
impl<K, V> Drop for ConcurrentHashMap<K, V> {
fn drop(&mut self) {
let table = self.table.load(Ordering::Relaxed);
let capacity = self.capacity.load(Ordering::Relaxed);
if !table.is_null() {
unsafe {
for i in 0..capacity {
let bucket = table.add(i);
for entry in &mut (*bucket).entries {
*entry = None;
}
}
alloc::dealloc(
table as *mut u8,
Layout::from_size_align(capacity * core::mem::size_of::<Bucket<K, V>>(), 64)
.unwrap(),
);
}
}
let resize_state = self.resize_state.load(Ordering::Relaxed);
if !resize_state.is_null() {
unsafe {
drop(Box::from_raw(resize_state));
}
}
}
}
#[cfg(feature = "std")]
impl<K, V> MetricsCollector for ConcurrentHashMap<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Send + Sync + 'static,
{
fn metrics(&self) -> crate::metrics::PerformanceMetrics {
crate::metrics::PerformanceMetrics::default()
}
fn reset_metrics(&self) {
}
fn set_metrics_enabled(&self, _enabled: bool) {
}
fn is_metrics_enabled(&self) -> bool {
false }
}
#[cfg(test)]
mod tests {
use super::*;
use std::format;
use std::string::String;
use std::string::ToString;
use std::sync::Arc;
use std::thread;
use std::vec;
#[test]
#[ignore] fn test_basic_operations() {
let map: ConcurrentHashMap<i32, String> = ConcurrentHashMap::new();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
assert_eq!(map.get(&1), None);
assert_eq!(map.insert(1, "hello".to_string()), None);
assert_eq!(map.len(), 1);
assert!(!map.is_empty());
assert_eq!(map.get(&1), Some(&"hello".to_string()));
assert_eq!(
map.insert(1, "world".to_string()),
Some("hello".to_string())
);
assert_eq!(map.get(&1), Some(&"world".to_string()));
assert_eq!(map.remove(&1), Some("world".to_string()));
assert_eq!(map.len(), 0);
assert_eq!(map.get(&1), None);
}
#[test]
#[ignore] fn test_concurrent_access() {
let map = Arc::new(ConcurrentHashMap::new());
let num_writers = 4;
let num_readers = 4;
let items_per_writer = 1000;
let mut writer_handles = vec![];
for writer_id in 0..num_writers {
let map = Arc::clone(&map);
let handle = thread::spawn(move || {
for i in 0..items_per_writer {
let key = writer_id * items_per_writer + i;
map.insert(key, format!("value_{}", key));
}
});
writer_handles.push(handle);
}
let mut reader_handles = vec![];
for _ in 0..num_readers {
let map = Arc::clone(&map);
let handle = thread::spawn(move || {
let mut count = 0;
for i in 0..num_writers * items_per_writer {
if let Some(_value) = map.get(&i) {
count += 1;
}
thread::yield_now();
}
count
});
reader_handles.push(handle);
}
for handle in writer_handles {
handle.join().unwrap();
}
let mut _total_reads = 0;
for handle in reader_handles {
_total_reads += handle.join().unwrap();
}
for i in 0..num_writers * items_per_writer {
assert!(map.get(&i).is_some(), "Missing key: {}", i);
}
}
#[test]
#[ignore] fn test_resize_behavior() {
let map: ConcurrentHashMap<i32, i32> = ConcurrentHashMap::with_capacity(4);
let initial_capacity = map.capacity();
for i in 0..10 {
map.insert(i, i * 2);
}
assert!(map.capacity() > initial_capacity);
for i in 0..10 {
assert_eq!(map.get(&i), Some(&(i * 2)));
}
}
#[test]
#[ignore] fn test_clear() {
let map: ConcurrentHashMap<i32, String> = ConcurrentHashMap::new();
for i in 0..10 {
map.insert(i, format!("value_{}", i));
}
assert_eq!(map.len(), 10);
map.clear();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
for i in 0..10 {
assert_eq!(map.get(&i), None);
}
}
#[test]
#[ignore] fn test_clone() {
let map1: ConcurrentHashMap<i32, String> = ConcurrentHashMap::new();
for i in 0..10 {
map1.insert(i, format!("value_{}", i));
}
let map2 = map1.clone();
assert_eq!(map1.len(), map2.len());
for i in 0..10 {
assert_eq!(map1.get(&i), map2.get(&i));
}
map1.insert(10, "new_value".to_string());
assert_eq!(map1.get(&10), Some(&"new_value".to_string()));
assert_eq!(map2.get(&10), None);
}
}