use aliasable::boxed::AliasableBox;
use foldhash::fast::RandomState;
use hashbrown::hash_table::Entry;
use hashbrown::HashTable;
use parking_lot::lock_api::RawMutex as _;
use parking_lot::RawMutex;
use std::borrow::Borrow;
use std::cell::UnsafeCell;
use std::hash::{BuildHasher, Hash};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::OnceLock;
struct StateFlags(u32);
impl StateFlags {
const HAS_VALUE_FLAG: u32 = 1 << 31;
const REFCNT_MASK: u32 = !Self::HAS_VALUE_FLAG;
fn new(refcnt: u32, has_value: bool) -> Self {
let mut val = refcnt & Self::REFCNT_MASK;
if has_value {
val |= Self::HAS_VALUE_FLAG;
}
Self(val)
}
fn refcnt(&self) -> u32 {
self.0 & Self::REFCNT_MASK
}
fn has_value(&self) -> bool {
(self.0 & Self::HAS_VALUE_FLAG) != 0
}
fn pending_cleanup(&self) -> bool {
self.0 == 0
}
}
struct State<K, V> {
key: K,
hash: u64,
flags: AtomicU32,
mutex: RawMutex,
value: UnsafeCell<Option<V>>,
prev: UnsafeCell<*mut Self>,
next: UnsafeCell<*mut Self>,
}
impl<K, V> State<K, V> {
fn new(key: K, value: Option<V>, refcnt: u32, hash: u64) -> AliasableBox<Self> {
AliasableBox::from_unique(Box::new(Self {
key,
hash,
flags: AtomicU32::new(StateFlags::new(refcnt, value.is_some()).0),
mutex: RawMutex::INIT,
value: UnsafeCell::new(value),
prev: UnsafeCell::new(std::ptr::null_mut()),
next: UnsafeCell::new(std::ptr::null_mut()),
}))
}
fn flags(&self) -> StateFlags {
StateFlags(self.flags.load(Ordering::Acquire))
}
fn inc_ref(&self) -> StateFlags {
StateFlags(self.flags.fetch_add(1, Ordering::AcqRel) + 1)
}
fn dec_ref(&self) -> StateFlags {
StateFlags(self.flags.fetch_sub(1, Ordering::AcqRel) - 1)
}
fn set_value_state(&self, has_value: bool) {
if has_value {
self.flags
.fetch_or(StateFlags::HAS_VALUE_FLAG, Ordering::Release);
} else {
self.flags
.fetch_and(!StateFlags::HAS_VALUE_FLAG, Ordering::Release);
}
}
unsafe fn value_ref(&self) -> &Option<V> {
&*self.value.get()
}
#[allow(clippy::mut_from_ref)]
unsafe fn value_mut(&self) -> &mut Option<V> {
&mut *self.value.get()
}
}
struct LruShardInner<K, V> {
table: HashTable<AliasableBox<State<K, V>>>,
head: *mut State<K, V>,
tail: *mut State<K, V>,
max_size: usize,
}
unsafe impl<K: Send, V: Send> Send for LruShardInner<K, V> {}
impl<K, V> LruShardInner<K, V> {
fn with_capacity(max_size: usize, capacity: usize) -> Self {
Self {
table: HashTable::with_capacity(capacity),
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
max_size,
}
}
unsafe fn detach(&mut self, node: *mut State<K, V>) {
let prev = *(*node).prev.get();
let next = *(*node).next.get();
if !prev.is_null() {
*(*prev).next.get() = next;
} else {
self.head = next;
}
if !next.is_null() {
*(*next).prev.get() = prev;
} else {
self.tail = prev;
}
*(*node).prev.get() = std::ptr::null_mut();
*(*node).next.get() = std::ptr::null_mut();
}
unsafe fn push_front(&mut self, node: *mut State<K, V>) {
*(*node).next.get() = self.head;
*(*node).prev.get() = std::ptr::null_mut();
if !self.head.is_null() {
*(*self.head).prev.get() = node;
}
self.head = node;
if self.tail.is_null() {
self.tail = node;
}
}
unsafe fn move_to_front(&mut self, node: *mut State<K, V>) {
if self.head == node {
return;
}
self.detach(node);
self.push_front(node);
}
fn try_evict(&mut self, current: *mut State<K, V>) {
let mut cursor = self.tail;
while self.table.len() > self.max_size && !cursor.is_null() && cursor != current {
let prev = unsafe { *(*cursor).prev.get() };
let state = unsafe { &*cursor };
if state.flags().refcnt() > 0 {
cursor = prev;
continue;
}
unsafe { self.detach(cursor) };
let hash = state.hash;
if let Ok(entry) = self.table.find_entry(hash, |s| std::ptr::eq(&**s, cursor)) {
let _ = entry.remove();
}
cursor = prev;
}
}
}
struct LruShardMap<K, V> {
inner: std::sync::Mutex<LruShardInner<K, V>>,
}
impl<K, V> LruShardMap<K, V> {
fn with_capacity(max_size: usize, capacity: usize) -> Self {
Self {
inner: std::sync::Mutex::new(LruShardInner::with_capacity(max_size, capacity)),
}
}
fn len(&self) -> usize {
self.inner.lock().unwrap().table.len()
}
fn is_empty(&self) -> bool {
self.inner.lock().unwrap().table.is_empty()
}
fn max_size(&self) -> usize {
self.inner.lock().unwrap().max_size
}
fn set_max_size(&self, max_size: usize) {
self.inner.lock().unwrap().max_size = max_size;
}
}
pub struct LruLockMap<K, V> {
shards: Vec<LruShardMap<K, V>>,
hasher: RandomState,
}
fn default_shard_amount() -> usize {
static DEFAULT_SHARD_AMOUNT: OnceLock<usize> = OnceLock::new();
*DEFAULT_SHARD_AMOUNT.get_or_init(|| {
(std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two()
})
}
impl<K: Eq + Hash, V> LruLockMap<K, V> {
pub fn new(max_size: usize) -> Self {
Self::with_options(max_size, 0, default_shard_amount())
}
pub fn with_options(max_size: usize, initial_capacity: usize, shard_amount: usize) -> Self {
assert!(shard_amount > 0, "shard_amount must be greater than 0");
let per_shard_max_size = max_size.div_ceil(shard_amount);
let per_shard_capacity = initial_capacity.div_ceil(shard_amount);
Self {
shards: (0..shard_amount)
.map(|_| LruShardMap::with_capacity(per_shard_max_size, per_shard_capacity))
.collect(),
hasher: RandomState::default(),
}
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.is_empty())
}
pub fn max_size(&self) -> usize {
let max_size = self.shards.first().map(|s| s.max_size()).unwrap_or(0);
self.shards.len() * max_size
}
pub fn set_max_size(&self, max_size: usize) {
let per_shard_max_size = max_size.div_ceil(self.shards.len());
for shard in &self.shards {
shard.set_max_size(per_shard_max_size);
}
}
#[inline(always)]
fn shard_index(&self, hash: u64) -> usize {
((hash >> 32) as usize) % self.shards.len()
}
#[inline(always)]
fn state_hasher() -> impl Fn(&AliasableBox<State<K, V>>) -> u64 {
|s| s.hash
}
pub fn entry(&self, key: K) -> LruEntry<'_, K, V> {
let hash = self.hasher.hash_one(&key);
let shard = &self.shards[self.shard_index(hash)];
let ptr: *mut State<K, V> = {
let mut inner = shard.inner.lock().unwrap();
let ptr =
match inner
.table
.entry(hash, |s| s.key.borrow() == &key, Self::state_hasher())
{
Entry::Occupied(occupied) => {
let ptr = &**occupied.get() as *const State<K, V> as *mut State<K, V>;
unsafe { &*ptr }.inc_ref();
unsafe { inner.move_to_front(ptr) };
ptr
}
Entry::Vacant(vacant) => {
let state = State::new(key, None, 1, hash);
let ptr = &*state as *const State<K, V> as *mut State<K, V>;
vacant.insert(state);
unsafe { inner.push_front(ptr) };
ptr
}
};
inner.try_evict(ptr);
ptr
};
self.guard(ptr)
}
pub fn entry_by_ref<Q>(&self, key: &Q) -> LruEntry<'_, K, V>
where
K: Borrow<Q> + for<'c> From<&'c Q>,
Q: Eq + Hash + ?Sized,
{
let hash = self.hasher.hash_one(key);
let shard = &self.shards[self.shard_index(hash)];
let ptr: *mut State<K, V> = {
let mut inner = shard.inner.lock().unwrap();
let ptr = match inner
.table
.entry(hash, |s| s.key.borrow() == key, Self::state_hasher())
{
Entry::Occupied(occupied) => {
let ptr = &**occupied.get() as *const State<K, V> as *mut State<K, V>;
unsafe { &*ptr }.inc_ref();
unsafe { inner.move_to_front(ptr) };
ptr
}
Entry::Vacant(vacant) => {
let owned_key: K = key.into();
let state = State::new(owned_key, None, 1, hash);
let ptr = &*state as *const State<K, V> as *mut State<K, V>;
vacant.insert(state);
unsafe { inner.push_front(ptr) };
ptr
}
};
inner.try_evict(ptr);
ptr
};
self.guard(ptr)
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
V: Clone,
Q: Eq + Hash + ?Sized,
{
let hash = self.hasher.hash_one(key);
let shard = &self.shards[self.shard_index(hash)];
let mut ptr: *mut State<K, V> = std::ptr::null_mut();
let value = {
let mut inner = shard.inner.lock().unwrap();
let p = inner
.table
.find(hash, |s| s.key.borrow() == key)
.map(|s| &**s as *const State<K, V> as *mut State<K, V>)
.unwrap_or(std::ptr::null_mut());
if !p.is_null() {
unsafe { inner.move_to_front(p) };
let state = unsafe { &*p };
if state.flags().refcnt() == 0 {
unsafe { state.value_ref() }.clone()
} else {
state.inc_ref();
ptr = p;
None
}
} else {
None
}
};
if ptr.is_null() {
return value;
}
self.guard(ptr).get().clone()
}
pub fn insert(&self, key: K, value: V) -> Option<V> {
let hash = self.hasher.hash_one(&key);
let shard = &self.shards[self.shard_index(hash)];
let (ptr, old) = {
let mut inner = shard.inner.lock().unwrap();
match inner
.table
.entry(hash, |s| s.key.borrow() == &key, Self::state_hasher())
{
Entry::Occupied(occupied) => {
let p = &**occupied.get() as *const State<K, V> as *mut State<K, V>;
unsafe { inner.move_to_front(p) };
let state = unsafe { &*p };
let flags = state.flags();
if flags.refcnt() == 0 {
let old = unsafe { state.value_mut() }.replace(value);
state.set_value_state(true);
(std::ptr::null_mut(), old)
} else {
state.inc_ref();
(p, Some(value))
}
}
Entry::Vacant(vacant) => {
let state = State::new(key, Some(value), 0, hash);
let new_ptr = &*state as *const State<K, V> as *mut State<K, V>;
vacant.insert(state);
unsafe { inner.push_front(new_ptr) };
inner.try_evict(new_ptr);
(std::ptr::null_mut(), None)
}
}
};
if ptr.is_null() {
return old;
}
self.guard(ptr).swap(old)
}
pub fn insert_by_ref<Q>(&self, key: &Q, value: V) -> Option<V>
where
K: Borrow<Q> + for<'c> From<&'c Q>,
Q: Eq + Hash + ?Sized,
{
let hash = self.hasher.hash_one(key);
let shard = &self.shards[self.shard_index(hash)];
let (ptr, old) = {
let mut inner = shard.inner.lock().unwrap();
match inner
.table
.entry(hash, |s| s.key.borrow() == key, Self::state_hasher())
{
Entry::Occupied(occupied) => {
let p = &**occupied.get() as *const State<K, V> as *mut State<K, V>;
unsafe { inner.move_to_front(p) };
let state = unsafe { &*p };
let flags = state.flags();
if flags.refcnt() == 0 {
let old = unsafe { state.value_mut() }.replace(value);
state.set_value_state(true);
(std::ptr::null_mut(), old)
} else {
state.inc_ref();
(p, Some(value))
}
}
Entry::Vacant(vacant) => {
let owned_key: K = key.into();
let state = State::new(owned_key, Some(value), 0, hash);
let new_ptr = &*state as *const State<K, V> as *mut State<K, V>;
vacant.insert(state);
unsafe { inner.push_front(new_ptr) };
inner.try_evict(new_ptr);
(std::ptr::null_mut(), None)
}
}
};
if ptr.is_null() {
return old;
}
self.guard(ptr).swap(old)
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
let hash = self.hasher.hash_one(key);
let shard = &self.shards[self.shard_index(hash)];
let mut ptr: *mut State<K, V> = std::ptr::null_mut();
let found = {
let mut inner = shard.inner.lock().unwrap();
let p = inner
.table
.find(hash, |s| s.key.borrow() == key)
.map(|s| &**s as *const State<K, V> as *mut State<K, V>)
.unwrap_or(std::ptr::null_mut());
if !p.is_null() {
unsafe { inner.move_to_front(p) };
let state = unsafe { &*p };
if state.flags().refcnt() == 0 {
unsafe { state.value_ref() }.is_some()
} else {
state.inc_ref();
ptr = p;
false
}
} else {
false
}
};
if ptr.is_null() {
return found;
}
self.guard(ptr).get().is_some()
}
pub fn remove<Q>(&self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
let hash = self.hasher.hash_one(key);
let shard = &self.shards[self.shard_index(hash)];
let ptr = {
let mut inner = shard.inner.lock().unwrap();
let p = match inner.table.find_entry(hash, |s| s.key.borrow() == key) {
Ok(occupied) => {
let p = &**occupied.get() as *const State<K, V> as *mut State<K, V>;
let state = unsafe { &*p };
if state.flags().refcnt() == 0 {
let value = unsafe { state.value_mut() }.take();
let (state_box, _) = occupied.remove();
unsafe { inner.detach(p) };
drop(state_box);
return value;
}
state.inc_ref();
p
}
Err(_) => return None,
};
p
};
self.guard(ptr).remove()
}
fn guard(&self, ptr: *mut State<K, V>) -> LruEntry<'_, K, V> {
unsafe { (*ptr).mutex.lock() };
LruEntry {
map: self,
state: ptr,
}
}
}
impl<K: Eq + Hash, V> Default for LruLockMap<K, V> {
fn default() -> Self {
Self::new(usize::MAX)
}
}
impl<K, V> std::fmt::Debug for LruLockMap<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LruLockMap").finish()
}
}
pub struct LruEntry<'a, K: Eq + Hash, V> {
map: &'a LruLockMap<K, V>,
state: *mut State<K, V>,
}
unsafe impl<K: Eq + Hash + Sync, V: Sync> Sync for LruEntry<'_, K, V> {}
impl<K: Eq + Hash, V> LruEntry<'_, K, V> {
pub fn key(&self) -> &K {
unsafe { &(*self.state).key }
}
pub fn get(&self) -> &Option<V> {
unsafe { (*self.state).value_ref() }
}
pub fn get_mut(&mut self) -> &mut Option<V> {
unsafe { (*self.state).value_mut() }
}
pub fn insert(&mut self, value: V) -> Option<V> {
self.get_mut().replace(value)
}
pub fn swap(&mut self, mut value: Option<V>) -> Option<V> {
std::mem::swap(self.get_mut(), &mut value);
value
}
pub fn remove(&mut self) -> Option<V> {
self.get_mut().take()
}
}
impl<K: Eq + Hash + std::fmt::Debug, V: std::fmt::Debug> std::fmt::Debug for LruEntry<'_, K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LruEntry")
.field("key", self.key())
.field("value", self.get())
.finish()
}
}
impl<K: Eq + Hash, V> Drop for LruEntry<'_, K, V> {
fn drop(&mut self) {
let has_value = self.get().is_some();
let state_ref = unsafe { &*self.state };
state_ref.set_value_state(has_value);
unsafe { state_ref.mutex.unlock() };
let mut current = state_ref.flags.load(Ordering::Acquire);
loop {
let flags = StateFlags(current);
if flags.refcnt() == 1 && !flags.has_value() {
break;
}
let new_flags = StateFlags::new(flags.refcnt() - 1, flags.has_value());
match state_ref.flags.compare_exchange_weak(
current,
new_flags.0,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return, Err(actual) => current = actual,
}
}
let shard_idx = self.map.shard_index(state_ref.hash);
let shard = &self.map.shards[shard_idx];
let mut inner = shard.inner.lock().unwrap();
let final_flags = state_ref.dec_ref();
if final_flags.pending_cleanup() {
unsafe { inner.detach(self.state) };
let state_ptr = self.state as *const State<K, V>;
if let Ok(entry) = inner
.table
.find_entry(state_ref.hash, |s| std::ptr::eq(&**s, state_ptr))
{
let _ = entry.remove();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
#[test]
fn test_basic_insert_get_remove() {
let cache = LruLockMap::<String, u32>::new(100);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache.insert("a".to_string(), 1);
assert_eq!(cache.get("a"), Some(1));
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
assert_eq!(cache.insert("a".to_string(), 2), Some(1));
assert_eq!(cache.get("a"), Some(2));
assert_eq!(cache.remove("a"), Some(2));
assert_eq!(cache.get("a"), None);
assert!(cache.is_empty());
}
#[test]
fn test_insert_by_ref() {
let cache = LruLockMap::<String, u32>::new(100);
cache.insert_by_ref("key", 42);
assert_eq!(cache.get("key"), Some(42));
assert_eq!(cache.insert_by_ref("key", 99), Some(42));
assert_eq!(cache.get("key"), Some(99));
}
#[test]
fn test_contains_key() {
let cache = LruLockMap::<String, u32>::new(100);
assert!(!cache.contains_key("x"));
cache.insert("x".to_string(), 7);
assert!(cache.contains_key("x"));
cache.remove("x");
assert!(!cache.contains_key("x"));
}
#[test]
fn test_entry_by_val() {
let cache = LruLockMap::<u32, u32>::new(100);
{
let mut entry = cache.entry(1);
assert_eq!(*entry.key(), 1);
assert!(entry.get().is_none());
entry.insert(42);
assert_eq!(*entry.get(), Some(42));
println!("{:?}", entry);
}
assert_eq!(cache.get(&1), Some(42));
{
let mut entry = cache.entry(1);
assert_eq!(entry.remove(), Some(42));
}
assert_eq!(cache.get(&1), None);
}
#[test]
fn test_entry_by_ref() {
let cache = LruLockMap::<String, u32>::new(100);
{
let mut entry = cache.entry_by_ref("key");
assert_eq!(entry.key(), "key");
entry.insert(7);
println!("{:?}", entry);
}
assert_eq!(cache.get("key"), Some(7));
{
let mut entry = cache.entry_by_ref("key");
assert_eq!(entry.get_mut().take(), Some(7));
}
assert_eq!(cache.get("key"), None);
}
#[test]
fn test_default_and_debug() {
let cache = LruLockMap::<u32, u32>::default();
println!("{:?}", cache);
assert!(cache.is_empty());
}
#[test]
fn test_lru_zero_capacity() {
let cache = LruLockMap::<u32, u32>::with_options(0, 0, 1);
assert!(cache.is_empty());
assert_eq!(cache.insert(1, 10), None);
assert_eq!(cache.len(), 1);
assert_eq!(cache.insert(2, 20), None);
assert_eq!(cache.len(), 1); }
#[test]
fn test_set_max_size() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 4);
assert_eq!(cache.max_size(), 4);
cache.set_max_size(6);
assert_eq!(cache.max_size(), 8);
}
#[test]
fn test_lru_eviction_basic() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
assert_eq!(cache.len(), 3);
cache.insert(4, 40);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(&1), None);
assert_eq!(cache.get(&2), Some(20));
assert_eq!(cache.get(&3), Some(30));
assert_eq!(cache.get(&4), Some(40));
}
#[test]
fn test_lru_access_promotes() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
assert_eq!(cache.get(&1), Some(10));
cache.insert(4, 40);
assert_eq!(cache.get(&2), None); assert_eq!(cache.get(&1), Some(10)); assert_eq!(cache.get(&3), Some(30));
assert_eq!(cache.get(&4), Some(40));
}
#[test]
fn test_lru_entry_promotes() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
{
let entry = cache.entry(1);
assert_eq!(*entry.get(), Some(10));
}
cache.insert(4, 40);
assert_eq!(cache.get(&2), None); assert_eq!(cache.get(&1), Some(10)); }
#[test]
fn test_lru_skip_in_use_entry() {
let cache = Arc::new(LruLockMap::<u32, u32>::with_options(3, 3, 1));
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
let _entry = cache.entry(1);
let cache2 = cache.clone();
let t = std::thread::spawn(move || {
cache2.insert(4, 40);
});
t.join().unwrap();
assert_eq!(*_entry.get(), Some(10));
assert_eq!(cache.get(&2), None);
assert_eq!(cache.get(&3), Some(30));
assert_eq!(cache.get(&4), Some(40));
drop(_entry);
assert!(cache.len() <= 4);
}
#[test]
fn test_lru_evict_skips_multiple_in_use() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
let _entry1 = cache.entry(1);
let _entry2 = cache.entry(2);
cache.insert(4, 40);
assert_eq!(*_entry1.get(), Some(10));
assert_eq!(*_entry2.get(), Some(20));
assert_eq!(cache.get(&3), None);
assert_eq!(cache.get(&4), Some(40));
drop(_entry2);
drop(_entry1);
}
#[test]
fn test_lru_insert_overwrite_no_evict() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
cache.insert(2, 200);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(&1), Some(10));
assert_eq!(cache.get(&2), Some(200));
assert_eq!(cache.get(&3), Some(30));
}
#[test]
fn test_lru_remove_frees_slot() {
let cache = LruLockMap::<u32, u32>::with_options(3, 3, 1);
cache.insert(1, 10);
cache.insert(2, 20);
cache.insert(3, 30);
cache.remove(&2);
assert_eq!(cache.len(), 2);
cache.insert(4, 40);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(&1), Some(10));
assert_eq!(cache.get(&3), Some(30));
assert_eq!(cache.get(&4), Some(40));
}
#[test]
fn test_concurrent_same_key() {
let cache = Arc::new(LruLockMap::<usize, usize>::new(1024));
let counter = Arc::new(AtomicU32::default());
#[cfg(not(miri))]
const N: usize = 1 << 16;
#[cfg(miri)]
const N: usize = 1 << 6;
const M: usize = 4;
cache.insert(0, 0);
let threads: Vec<_> = (0..M)
.map(|_| {
let cache = cache.clone();
let counter = counter.clone();
std::thread::spawn(move || {
for _ in 0..N {
let mut entry = cache.entry(0);
let now = counter.fetch_add(1, Ordering::AcqRel);
assert_eq!(now, 0);
let v = entry.get_mut().as_mut().unwrap();
*v += 1;
let now = counter.fetch_sub(1, Ordering::AcqRel);
assert_eq!(now, 1);
}
})
})
.collect();
threads.into_iter().for_each(|t| t.join().unwrap());
let entry = cache.entry(0);
assert_eq!(*entry.get(), Some(N * M));
}
#[test]
fn test_concurrent_same_key_by_ref() {
let cache = Arc::new(LruLockMap::<String, usize>::new(1024));
let counter = Arc::new(AtomicU32::default());
#[cfg(not(miri))]
const N: usize = 1 << 16;
#[cfg(miri)]
const N: usize = 1 << 6;
const M: usize = 4;
cache.insert_by_ref("hello", 0);
let threads: Vec<_> = (0..M)
.map(|_| {
let cache = cache.clone();
let counter = counter.clone();
std::thread::spawn(move || {
for _ in 0..N {
let mut entry = cache.entry_by_ref("hello");
let now = counter.fetch_add(1, Ordering::AcqRel);
assert_eq!(now, 0);
let v = entry.get_mut().as_mut().unwrap();
*v += 1;
let now = counter.fetch_sub(1, Ordering::AcqRel);
assert_eq!(now, 1);
}
})
})
.collect();
threads.into_iter().for_each(|t| t.join().unwrap());
let entry = cache.entry_by_ref("hello");
assert_eq!(*entry.get(), Some(N * M));
}
#[test]
fn test_concurrent_random_keys() {
let cache = Arc::new(LruLockMap::<u32, u32>::with_options(256, 16, 1));
let total = Arc::new(AtomicU32::default());
#[cfg(not(miri))]
const N: usize = 1 << 12;
#[cfg(miri)]
const N: usize = 1 << 6;
const M: usize = 8;
let threads: Vec<_> = (0..M)
.map(|_| {
let cache = cache.clone();
let total = total.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = rand::random::<u32>() % 32;
let mut entry = cache.entry(key);
assert!(entry.get().is_none());
entry.insert(1);
total.fetch_add(1, Ordering::AcqRel);
entry.remove();
}
})
})
.collect();
threads.into_iter().for_each(|t| t.join().unwrap());
assert_eq!(total.load(Ordering::Acquire) as usize, N * M);
}
#[test]
fn test_concurrent_get_set() {
let cache = Arc::new(LruLockMap::<u32, u32>::with_options(256, 16, 1));
#[cfg(not(miri))]
const N: usize = 1 << 16;
#[cfg(miri)]
const N: usize = 1 << 6;
let entry_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = rand::random::<u32>() % 32;
let value = rand::random::<u32>() % 32;
let mut entry = cache.entry(key);
if value < 16 {
entry.get_mut().take();
} else {
entry.get_mut().replace(value);
}
}
})
};
let set_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = rand::random::<u32>() % 32;
let value = rand::random::<u32>() % 32;
if value < 16 {
cache.remove(&key);
} else {
cache.insert(key, value);
}
}
})
};
let get_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = rand::random::<u32>() % 32;
let value = cache.get(&key);
if let Some(v) = value {
assert!(v >= 16);
}
}
})
};
entry_thread.join().unwrap();
set_thread.join().unwrap();
get_thread.join().unwrap();
}
#[test]
fn test_concurrent_get_set_by_ref() {
let cache = Arc::new(LruLockMap::<String, u32>::with_options(256, 16, 1));
#[cfg(not(miri))]
const N: usize = 1 << 14;
#[cfg(miri)]
const N: usize = 1 << 6;
let entry_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = rand::random::<u32>() % 32;
let mut entry = cache.entry_by_ref(&key);
if value < 16 {
entry.get_mut().take();
} else {
entry.get_mut().replace(value);
}
}
})
};
let set_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = rand::random::<u32>() % 32;
if value < 16 {
cache.remove(&key);
} else {
cache.insert_by_ref(&key, value);
}
}
})
};
let get_thread = {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = cache.get(&key);
if let Some(v) = value {
assert!(v >= 16);
}
}
})
};
entry_thread.join().unwrap();
set_thread.join().unwrap();
get_thread.join().unwrap();
}
#[test]
fn test_concurrent_with_eviction() {
let cache = Arc::new(LruLockMap::<u32, u32>::with_options(32, 4, 1));
#[cfg(not(miri))]
const N: usize = 1 << 14;
#[cfg(miri)]
const N: usize = 1 << 6;
const M: usize = 8;
let threads: Vec<_> = (0..M)
.map(|_| {
let cache = cache.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = rand::random::<u32>() % 64;
let op = rand::random::<u32>() % 4;
match op {
0 => {
cache.insert(key, key);
}
1 => {
let _ = cache.get(&key);
}
2 => {
let _ = cache.remove(&key);
}
_ => {
let mut entry = cache.entry(key);
entry.insert(key);
drop(entry);
}
}
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert!(cache.len() <= 64);
}
#[test]
fn test_swap() {
let cache = LruLockMap::<u32, u32>::new(100);
cache.insert(1, 10);
{
let mut entry = cache.entry(1);
let old = entry.swap(Some(20));
assert_eq!(old, Some(10));
}
assert_eq!(cache.get(&1), Some(20));
}
#[test]
fn test_lockmap_same_key_by_ref() {
let lock_map = Arc::new(LruLockMap::<String, usize>::new(1 << 20));
let current = Arc::new(AtomicU32::default());
#[cfg(not(miri))]
const N: usize = 1 << 20;
#[cfg(miri)]
const N: usize = 1 << 6;
const M: usize = 4;
const S: &str = "hello";
lock_map.insert_by_ref(S, 0);
let threads = (0..M)
.map(|_| {
let lock_map = lock_map.clone();
let current = current.clone();
std::thread::spawn(move || {
for _ in 0..N {
let mut entry = lock_map.entry_by_ref(S);
let now = current.fetch_add(1, Ordering::AcqRel);
assert_eq!(now, 0);
let v = entry.get_mut().as_mut().unwrap();
*v += 1;
let now = current.fetch_sub(1, Ordering::AcqRel);
assert_eq!(now, 1);
}
})
})
.collect::<Vec<_>>();
threads.into_iter().for_each(|t| t.join().unwrap());
let mut entry = lock_map.entry_by_ref(S);
println!("{:?}", entry);
assert_eq!(entry.key(), S);
assert_eq!(*entry.get(), Some(N * M));
assert_eq!(entry.insert(0).unwrap(), N * M);
}
#[test]
fn test_lockmap_get_set_by_ref() {
let lock_map = Arc::new(LruLockMap::<String, u32>::with_options(1 << 20, 16, 1));
#[cfg(not(miri))]
const N: usize = 1 << 18;
#[cfg(miri)]
const N: usize = 1 << 6;
let entry_thread = {
let lock_map = lock_map.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = rand::random::<u32>() % 32;
let mut entry = lock_map.entry_by_ref(&key);
if value < 16 {
entry.get_mut().take();
} else {
entry.get_mut().replace(value);
}
}
})
};
let set_thread = {
let lock_map = lock_map.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = rand::random::<u32>() % 32;
if value < 16 {
lock_map.remove(&key);
} else {
lock_map.insert_by_ref(&key, value);
}
}
})
};
let get_thread = {
let lock_map = lock_map.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = lock_map.get(&key);
if let Some(v) = value {
assert!(v >= 16)
}
}
})
};
entry_thread.join().unwrap();
set_thread.join().unwrap();
get_thread.join().unwrap();
}
#[test]
fn test_lockmap_insert_remove() {
let lock_map = Arc::new(LruLockMap::<String, u32>::with_options(1 << 20, 16, 1));
#[cfg(not(miri))]
const N: usize = 1 << 22;
#[cfg(miri)]
const N: usize = 1 << 6;
let entry_thread = {
let lock_map = lock_map.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let mut entry = lock_map.entry_by_ref(&key);
entry.remove();
}
})
};
let set_thread = {
let lock_map = lock_map.clone();
std::thread::spawn(move || {
for _ in 0..N {
let key = (rand::random::<u32>() % 32).to_string();
let value = rand::random::<u32>() % 32;
lock_map.insert_by_ref(&key, value);
}
})
};
entry_thread.join().unwrap();
set_thread.join().unwrap();
}
#[test]
fn test_lockmap_heavy_contention() {
let lock_map = Arc::new(LruLockMap::<u32, u32>::new(1 << 20));
#[cfg(not(miri))]
const THREADS: usize = 16;
#[cfg(miri)]
const THREADS: usize = 4;
#[cfg(not(miri))]
const OPS_PER_THREAD: usize = 10000;
#[cfg(miri)]
const OPS_PER_THREAD: usize = 10;
const HOT_KEYS: u32 = 5;
let counter = Arc::new(AtomicU32::new(0));
let threads: Vec<_> = (0..THREADS)
.map(|_| {
let lock_map = lock_map.clone();
let counter = counter.clone();
std::thread::spawn(move || {
for _ in 0..OPS_PER_THREAD {
let key = rand::random::<u32>() % HOT_KEYS;
let mut entry = lock_map.entry(key);
std::thread::sleep(std::time::Duration::from_nanos(10));
match entry.get_mut() {
Some(value) => {
*value = value.wrapping_add(1);
counter.fetch_add(1, Ordering::Relaxed);
}
None => {
entry.insert(1);
counter.fetch_add(1, Ordering::Relaxed);
}
}
drop(entry);
assert!(lock_map.contains_key(&key), "Key {} should exist", key);
}
})
})
.collect();
for thread in threads {
thread.join().unwrap();
}
assert_eq!(
counter.load(Ordering::Relaxed),
THREADS as u32 * OPS_PER_THREAD as u32
);
}
}