use parking_lot::{RwLock, RwLockReadGuard};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
const NUM_SHARDS: usize = 128;
const SHARD_MASK: usize = NUM_SHARDS - 1;
const DEFAULT_BUCKET_COUNT: usize = 4096;
const PREFETCH_DISTANCE: usize = 2;
struct Shard<K, V> {
buckets: Vec<Option<(K, V)>>,
count: AtomicUsize,
}
impl<K: Clone + Eq + Hash, V: Clone> Shard<K, V> {
fn new(capacity: usize) -> Self {
Self {
buckets: (0..capacity).map(|_| None).collect(),
count: AtomicUsize::new(0),
}
}
fn insert(&mut self, key: K, value: V) -> Option<V> {
let hash = hash_key(&key);
let capacity = self.buckets.len();
let mut idx = (hash as usize) % capacity;
for _ in 0..capacity {
match &mut self.buckets[idx] {
slot @ None => {
*slot = Some((key, value));
self.count.fetch_add(1, Ordering::Relaxed);
return None;
}
Some((k, v)) if *k == key => {
let old = std::mem::replace(v, value);
return Some(old);
}
_ => {
idx = (idx + 1) % capacity;
}
}
}
None
}
fn get(&self, key: &K) -> Option<&V> {
let hash = hash_key(key);
let capacity = self.buckets.len();
let mut idx = (hash as usize) % capacity;
for _ in 0..capacity {
match &self.buckets[idx] {
None => return None,
Some((k, v)) if k == key => return Some(v),
_ => idx = (idx + 1) % capacity,
}
}
None
}
fn remove(&mut self, key: &K) -> Option<V> {
let hash = hash_key(key);
let capacity = self.buckets.len();
let mut idx = (hash as usize) % capacity;
for _ in 0..capacity {
match &mut self.buckets[idx] {
None => return None,
Some((k, _)) if k == key => {
let old = self.buckets[idx].take().map(|(_, v)| v);
self.count.fetch_sub(1, Ordering::Relaxed);
self.rehash_from(idx);
return old;
}
_ => idx = (idx + 1) % capacity,
}
}
None
}
fn rehash_from(&mut self, start: usize) {
let capacity = self.buckets.len();
let mut idx = (start + 1) % capacity;
while let Some((k, v)) = self.buckets[idx].take() {
let hash = hash_key(&k);
let mut new_idx = (hash as usize) % capacity;
while self.buckets[new_idx].is_some() {
new_idx = (new_idx + 1) % capacity;
}
self.buckets[new_idx] = Some((k, v));
idx = (idx + 1) % capacity;
}
}
fn len(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
}
fn hash_key<K: Hash>(key: &K) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
#[inline]
fn shard_index(hash: u64) -> usize {
(hash as usize) & SHARD_MASK
}
#[inline]
fn prefetch<T>(ptr: *const T) {
#[cfg(target_arch = "x86_64")]
unsafe {
use std::arch::x86_64::_mm_prefetch;
use std::arch::x86_64::_MM_HINT_T0;
_mm_prefetch(ptr as *const i8, _MM_HINT_T0);
}
#[cfg(not(target_arch = "x86_64"))]
{
let _ = unsafe { std::ptr::read_volatile(ptr as *const u8) };
}
}
pub struct ShardCoalescedMap<K, V> {
shards: Vec<RwLock<Shard<K, V>>>,
version: AtomicU64,
}
impl<K: Clone + Eq + Hash, V: Clone> ShardCoalescedMap<K, V> {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_BUCKET_COUNT)
}
pub fn with_capacity(per_shard_capacity: usize) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(Shard::new(per_shard_capacity)))
.collect();
Self {
shards,
version: AtomicU64::new(0),
}
}
#[inline]
pub fn insert(&self, key: K, value: V) -> Option<V> {
let hash = hash_key(&key);
let shard_idx = shard_index(hash);
let mut shard = self.shards[shard_idx].write();
let result = shard.insert(key, value);
self.version.fetch_add(1, Ordering::Relaxed);
result
}
#[inline]
pub fn get(&self, key: &K) -> Option<V> {
let hash = hash_key(key);
let shard_idx = shard_index(hash);
let shard = self.shards[shard_idx].read();
shard.get(key).cloned()
}
#[inline]
pub fn remove(&self, key: &K) -> Option<V> {
let hash = hash_key(key);
let shard_idx = shard_index(hash);
let mut shard = self.shards[shard_idx].write();
let result = shard.remove(key);
if result.is_some() {
self.version.fetch_add(1, Ordering::Relaxed);
}
result
}
pub fn batch_insert(&self, batch: Vec<(K, V)>) -> usize {
if batch.is_empty() {
return 0;
}
let mut shard_batches: [Vec<(K, V)>; NUM_SHARDS] =
std::array::from_fn(|_| Vec::new());
for (key, value) in batch {
let hash = hash_key(&key);
let shard_idx = shard_index(hash);
shard_batches[shard_idx].push((key, value));
}
let mut inserted = 0;
for i in 0..NUM_SHARDS {
if i + PREFETCH_DISTANCE < NUM_SHARDS && !shard_batches[i + PREFETCH_DISTANCE].is_empty() {
prefetch(self.shards[i + PREFETCH_DISTANCE].data_ptr());
}
if !shard_batches[i].is_empty() {
let mut shard = self.shards[i].write();
for (key, value) in shard_batches[i].drain(..) {
if shard.insert(key, value).is_none() {
inserted += 1;
}
}
}
}
self.version.fetch_add(inserted as u64, Ordering::Relaxed);
inserted
}
pub fn batch_get(&self, keys: &[K]) -> Vec<Option<V>> {
if keys.is_empty() {
return Vec::new();
}
let mut shard_queries: [Vec<(usize, &K)>; NUM_SHARDS] =
std::array::from_fn(|_| Vec::new());
for (idx, key) in keys.iter().enumerate() {
let hash = hash_key(key);
let shard_idx = shard_index(hash);
shard_queries[shard_idx].push((idx, key));
}
let mut results = vec![None; keys.len()];
for i in 0..NUM_SHARDS {
if i + PREFETCH_DISTANCE < NUM_SHARDS && !shard_queries[i + PREFETCH_DISTANCE].is_empty() {
prefetch(self.shards[i + PREFETCH_DISTANCE].data_ptr());
}
if !shard_queries[i].is_empty() {
let shard = self.shards[i].read();
for (idx, key) in &shard_queries[i] {
results[*idx] = shard.get(key).cloned();
}
}
}
results
}
pub fn batch_remove(&self, keys: &[K]) -> Vec<Option<V>> {
if keys.is_empty() {
return Vec::new();
}
let mut shard_removes: [Vec<(usize, &K)>; NUM_SHARDS] =
std::array::from_fn(|_| Vec::new());
for (idx, key) in keys.iter().enumerate() {
let hash = hash_key(key);
let shard_idx = shard_index(hash);
shard_removes[shard_idx].push((idx, key));
}
let mut results = vec![None; keys.len()];
let mut removed = 0;
for i in 0..NUM_SHARDS {
if i + PREFETCH_DISTANCE < NUM_SHARDS && !shard_removes[i + PREFETCH_DISTANCE].is_empty() {
prefetch(self.shards[i + PREFETCH_DISTANCE].data_ptr());
}
if !shard_removes[i].is_empty() {
let mut shard = self.shards[i].write();
for (idx, key) in &shard_removes[i] {
if let Some(v) = shard.remove(key) {
results[*idx] = Some(v);
removed += 1;
}
}
}
}
if removed > 0 {
self.version.fetch_add(removed as u64, Ordering::Relaxed);
}
results
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn version(&self) -> u64 {
self.version.load(Ordering::Relaxed)
}
pub fn clear(&self) {
for shard in &self.shards {
let mut guard = shard.write();
for bucket in &mut guard.buckets {
*bucket = None;
}
guard.count.store(0, Ordering::Relaxed);
}
self.version.fetch_add(1, Ordering::Relaxed);
}
}
impl<K: Clone + Eq + Hash, V: Clone> Default for ShardCoalescedMap<K, V> {
fn default() -> Self {
Self::new()
}
}
pub struct BatchBuilder<K, V> {
entries: Vec<(K, V)>,
capacity: usize,
}
impl<K: Clone + Eq + Hash, V: Clone> BatchBuilder<K, V> {
pub fn with_capacity(capacity: usize) -> Self {
Self {
entries: Vec::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, key: K, value: V) {
self.entries.push((key, value));
}
pub fn is_full(&self) -> bool {
self.entries.len() >= self.capacity
}
pub fn flush_to(&mut self, map: &ShardCoalescedMap<K, V>) -> usize {
let entries = std::mem::take(&mut self.entries);
self.entries = Vec::with_capacity(self.capacity);
map.batch_insert(entries)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
pub struct ShardIterator<'a, K, V> {
map: &'a ShardCoalescedMap<K, V>,
current_shard: usize,
current_bucket: usize,
guard: Option<RwLockReadGuard<'a, Shard<K, V>>>,
}
impl<'a, K: Clone + Eq + Hash, V: Clone> ShardIterator<'a, K, V> {
fn new(map: &'a ShardCoalescedMap<K, V>) -> Self {
let mut iter = Self {
map,
current_shard: 0,
current_bucket: 0,
guard: None,
};
iter.advance_to_valid();
iter
}
fn advance_to_valid(&mut self) {
loop {
if self.current_shard >= NUM_SHARDS {
self.guard = None;
return;
}
if self.guard.is_none() {
self.guard = Some(self.map.shards[self.current_shard].read());
self.current_bucket = 0;
}
let guard = self.guard.as_ref().unwrap();
while self.current_bucket < guard.buckets.len() {
if guard.buckets[self.current_bucket].is_some() {
return;
}
self.current_bucket += 1;
}
self.current_shard += 1;
self.guard = None;
}
}
}
impl<'a, K: Clone + Eq + Hash, V: Clone> Iterator for ShardIterator<'a, K, V> {
type Item = (K, V);
fn next(&mut self) -> Option<Self::Item> {
if self.current_shard >= NUM_SHARDS {
return None;
}
let guard = self.guard.as_ref()?;
let entry = guard.buckets[self.current_bucket].as_ref()?;
let result = entry.clone();
self.current_bucket += 1;
self.advance_to_valid();
Some(result)
}
}
impl<K: Clone + Eq + Hash, V: Clone> ShardCoalescedMap<K, V> {
pub fn iter(&self) -> ShardIterator<'_, K, V> {
ShardIterator::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_operations() {
let map = ShardCoalescedMap::<String, i32>::new();
assert!(map.insert("key1".to_string(), 1).is_none());
assert!(map.insert("key2".to_string(), 2).is_none());
assert_eq!(map.get(&"key1".to_string()), Some(1));
assert_eq!(map.get(&"key2".to_string()), Some(2));
assert_eq!(map.get(&"key3".to_string()), None);
assert_eq!(map.insert("key1".to_string(), 10), Some(1));
assert_eq!(map.get(&"key1".to_string()), Some(10));
assert_eq!(map.remove(&"key1".to_string()), Some(10));
assert_eq!(map.get(&"key1".to_string()), None);
}
#[test]
fn test_batch_insert() {
let map = ShardCoalescedMap::<i32, i32>::new();
let batch: Vec<_> = (0..1000).map(|i| (i, i * 10)).collect();
let inserted = map.batch_insert(batch);
assert_eq!(inserted, 1000);
assert_eq!(map.len(), 1000);
for i in 0..1000 {
assert_eq!(map.get(&i), Some(i * 10));
}
}
#[test]
fn test_batch_get() {
let map = ShardCoalescedMap::<i32, i32>::new();
for i in 0..100i32 {
map.insert(i, i * 2);
}
let keys: Vec<i32> = (0..150i32).collect();
let results = map.batch_get(&keys);
assert_eq!(results.len(), 150);
for i in 0..100usize {
assert_eq!(results[i], Some((i * 2) as i32));
}
for i in 100..150usize {
assert_eq!(results[i], None);
}
}
#[test]
fn test_batch_remove() {
let map = ShardCoalescedMap::<i32, i32>::new();
for i in 0..100i32 {
map.insert(i, i);
}
let to_remove: Vec<i32> = (50..150i32).collect();
let results = map.batch_remove(&to_remove);
assert_eq!(results.len(), 100);
for i in 0..50usize {
assert_eq!(results[i], Some((i + 50) as i32));
}
for i in 50..100usize {
assert_eq!(results[i], None);
}
assert_eq!(map.len(), 50);
}
#[test]
fn test_concurrent_batch_insert() {
let map = Arc::new(ShardCoalescedMap::<i32, i32>::new());
let num_threads: usize = 8;
let batch_size: usize = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let map = Arc::clone(&map);
thread::spawn(move || {
let batch: Vec<_> = (0..batch_size)
.map(|i| ((t * batch_size + i) as i32, i as i32))
.collect();
map.batch_insert(batch)
})
})
.collect();
let total: usize = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert_eq!(total, num_threads * batch_size);
assert_eq!(map.len(), num_threads * batch_size);
}
#[test]
fn test_batch_builder() {
let map = ShardCoalescedMap::<i32, i32>::new();
let mut builder = BatchBuilder::with_capacity(100);
for i in 0..100i32 {
builder.push(i, i * 2);
}
assert!(builder.is_full());
let inserted = builder.flush_to(&map);
assert_eq!(inserted, 100);
assert!(builder.is_empty());
for i in 0..100i32 {
assert_eq!(map.get(&i), Some(i * 2));
}
}
#[test]
fn test_iterator() {
let map = ShardCoalescedMap::<i32, i32>::new();
for i in 0..100i32 {
map.insert(i, i * 2);
}
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by_key(|(k, _)| *k);
assert_eq!(entries.len(), 100);
for (i, (k, v)) in entries.iter().enumerate() {
assert_eq!(*k, i as i32);
assert_eq!(*v, (i * 2) as i32);
}
}
#[test]
fn test_clear() {
let map = ShardCoalescedMap::<i32, i32>::new();
for i in 0..100i32 {
map.insert(i, i);
}
assert_eq!(map.len(), 100);
map.clear();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
for i in 0..100i32 {
assert!(map.get(&i).is_none());
}
}
}