use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use super::Handle;
pub struct LockFreeHandleStore<T> {
shards: Vec<RwLock<HashMap<Handle, Arc<T>>>>,
shard_count: usize,
stats: StoreStats,
}
#[derive(Debug, Default)]
pub struct StoreStats {
pub inserts: AtomicU64,
pub gets: AtomicU64,
pub removes: AtomicU64,
pub hits: AtomicU64,
pub misses: AtomicU64,
}
impl StoreStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed) as f64;
let total = hits + self.misses.load(Ordering::Relaxed) as f64;
if total > 0.0 { hits / total } else { 0.0 }
}
}
impl<T> LockFreeHandleStore<T> {
pub fn new() -> Self {
Self::with_shards(16)
}
pub fn with_shards(shard_count: usize) -> Self {
let shard_count = shard_count.next_power_of_two();
let shards = (0..shard_count)
.map(|_| RwLock::new(HashMap::new()))
.collect();
Self {
shards,
shard_count,
stats: StoreStats::default(),
}
}
#[inline]
fn shard_index(&self, handle: Handle) -> usize {
(handle as usize) & (self.shard_count - 1)
}
pub fn insert(&self, value: T) -> Handle {
let handle = super::new_handle();
let shard_idx = self.shard_index(handle);
let mut shard = self.shards[shard_idx].write().unwrap();
shard.insert(handle, Arc::new(value));
self.stats.inserts.fetch_add(1, Ordering::Relaxed);
handle
}
pub fn get(&self, handle: Handle) -> Option<Arc<T>> {
let shard_idx = self.shard_index(handle);
let shard = self.shards[shard_idx].read().unwrap();
let result = shard.get(&handle).cloned();
self.stats.gets.fetch_add(1, Ordering::Relaxed);
if result.is_some() {
self.stats.hits.fetch_add(1, Ordering::Relaxed);
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
}
result
}
pub fn remove(&self, handle: Handle) -> Option<Arc<T>> {
let shard_idx = self.shard_index(handle);
let mut shard = self.shards[shard_idx].write().unwrap();
let result = shard.remove(&handle);
self.stats.removes.fetch_add(1, Ordering::Relaxed);
result
}
pub fn keep(&self, handle: Handle) -> Handle {
handle
}
pub fn stats(&self) -> &StoreStats {
&self.stats
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().unwrap().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
for shard in &self.shards {
shard.write().unwrap().clear();
}
}
}
impl<T> Default for LockFreeHandleStore<T> {
fn default() -> Self {
Self::new()
}
}
pub struct LockFreeQueue<T> {
buffer: Vec<std::cell::UnsafeCell<Option<T>>>,
capacity: usize,
mask: usize,
head: AtomicUsize,
tail: AtomicUsize,
count: AtomicUsize,
}
unsafe impl<T: Send> Send for LockFreeQueue<T> {}
unsafe impl<T: Send> Sync for LockFreeQueue<T> {}
impl<T> LockFreeQueue<T> {
pub fn new(capacity: usize) -> Self {
let capacity = capacity.next_power_of_two();
let buffer = (0..capacity)
.map(|_| std::cell::UnsafeCell::new(None))
.collect();
Self {
buffer,
capacity,
mask: capacity - 1,
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
count: AtomicUsize::new(0),
}
}
pub fn push(&self, value: T) -> Result<(), T> {
loop {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail.wrapping_sub(head) >= self.capacity {
return Err(value);
}
if self
.tail
.compare_exchange_weak(
tail,
tail.wrapping_add(1),
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
let slot = &self.buffer[tail & self.mask];
unsafe {
*slot.get() = Some(value);
}
self.count.fetch_add(1, Ordering::Release);
return Ok(());
}
std::hint::spin_loop();
}
}
pub fn pop(&self) -> Option<T> {
loop {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Acquire);
if head == tail {
return None;
}
if self
.head
.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
{
let slot = &self.buffer[head & self.mask];
let value = unsafe { (*slot.get()).take() };
if value.is_some() {
self.count.fetch_sub(1, Ordering::Release);
}
return value;
}
std::hint::spin_loop();
}
}
pub fn is_empty(&self) -> bool {
self.count.load(Ordering::Acquire) == 0
}
pub fn len(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl<T> Default for LockFreeQueue<T> {
fn default() -> Self {
Self::new(1024)
}
}
pub struct ShardedMap<K, V> {
shards: Vec<RwLock<HashMap<K, V>>>,
shard_count: usize,
}
impl<K: std::hash::Hash + Eq, V> ShardedMap<K, V> {
pub fn new() -> Self {
Self::with_shards(16)
}
pub fn with_shards(shard_count: usize) -> Self {
let shard_count = shard_count.next_power_of_two();
let shards = (0..shard_count)
.map(|_| RwLock::new(HashMap::new()))
.collect();
Self {
shards,
shard_count,
}
}
#[inline]
fn shard_index(&self, key: &K) -> usize {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & (self.shard_count - 1)
}
pub fn insert(&self, key: K, value: V) -> Option<V> {
let shard_idx = self.shard_index(&key);
let mut shard = self.shards[shard_idx].write().unwrap();
shard.insert(key, value)
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
V: Clone,
{
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
let shard_idx = (hasher.finish() as usize) & (self.shard_count - 1);
let shard = self.shards[shard_idx].read().unwrap();
shard.get(key).cloned()
}
pub fn remove<Q>(&self, key: &Q) -> Option<V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
let shard_idx = (hasher.finish() as usize) & (self.shard_count - 1);
let mut shard = self.shards[shard_idx].write().unwrap();
shard.remove(key)
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
let shard_idx = (hasher.finish() as usize) & (self.shard_count - 1);
let shard = self.shards[shard_idx].read().unwrap();
shard.contains_key(key)
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().unwrap().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.read().unwrap().is_empty())
}
pub fn clear(&self) {
for shard in &self.shards {
shard.write().unwrap().clear();
}
}
}
impl<K: std::hash::Hash + Eq, V> Default for ShardedMap<K, V> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct AtomicRefCount {
count: AtomicUsize,
}
impl AtomicRefCount {
pub const fn new(initial: usize) -> Self {
Self {
count: AtomicUsize::new(initial),
}
}
#[inline]
pub fn increment(&self) -> usize {
self.count.fetch_add(1, Ordering::Relaxed) + 1
}
#[inline]
pub fn decrement(&self) -> usize {
let old = self.count.fetch_sub(1, Ordering::Release);
if old == 1 {
std::sync::atomic::fence(Ordering::Acquire);
}
old - 1
}
#[inline]
pub fn get(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
}
impl Default for AtomicRefCount {
fn default() -> Self {
Self::new(0)
}
}
use std::ffi::c_int;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct FfiStoreStats {
pub inserts: u64,
pub gets: u64,
pub removes: u64,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
}
impl From<&StoreStats> for FfiStoreStats {
fn from(stats: &StoreStats) -> Self {
Self {
inserts: stats.inserts.load(Ordering::Relaxed),
gets: stats.gets.load(Ordering::Relaxed),
removes: stats.removes.load(Ordering::Relaxed),
hits: stats.hits.load(Ordering::Relaxed),
misses: stats.misses.load(Ordering::Relaxed),
hit_rate: stats.hit_rate(),
}
}
}
static TASK_QUEUE: std::sync::LazyLock<LockFreeQueue<u64>> =
std::sync::LazyLock::new(|| LockFreeQueue::new(4096));
#[unsafe(no_mangle)]
pub extern "C" fn fz_lockfree_queue_push(task_id: u64) -> c_int {
match TASK_QUEUE.push(task_id) {
Ok(()) => 0,
Err(_) => -1, }
}
#[unsafe(no_mangle)]
pub extern "C" fn fz_lockfree_queue_pop() -> u64 {
TASK_QUEUE.pop().unwrap_or(0)
}
#[unsafe(no_mangle)]
pub extern "C" fn fz_lockfree_queue_is_empty() -> c_int {
if TASK_QUEUE.is_empty() { 1 } else { 0 }
}
#[unsafe(no_mangle)]
pub extern "C" fn fz_lockfree_queue_len() -> usize {
TASK_QUEUE.len()
}
#[unsafe(no_mangle)]
pub extern "C" fn fz_lockfree_queue_capacity() -> usize {
TASK_QUEUE.capacity()
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::thread;
#[test]
fn test_lockfree_store_basic() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::new();
let h1 = store.insert(42);
let h2 = store.insert(100);
assert!(store.get(h1).is_some());
assert_eq!(*store.get(h1).unwrap(), 42);
assert_eq!(*store.get(h2).unwrap(), 100);
store.remove(h1);
assert!(store.get(h1).is_none());
}
#[test]
fn test_lockfree_store_concurrent_reads() {
let store = Arc::new(LockFreeHandleStore::new());
let handles: Vec<_> = (0..100).map(|i| store.insert(i)).collect();
let threads: Vec<_> = (0..8)
.map(|_| {
let store = store.clone();
let handles = handles.clone();
thread::spawn(move || {
for _ in 0..1000 {
for &h in &handles {
let _ = store.get(h);
}
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
let stats = store.stats();
assert!(stats.gets.load(Ordering::Relaxed) > 0);
assert!(stats.hit_rate() > 0.99);
}
#[test]
fn test_lockfree_store_concurrent_writes() {
let store = Arc::new(LockFreeHandleStore::new());
let threads: Vec<_> = (0..4)
.map(|t| {
let store = store.clone();
thread::spawn(move || {
let handles: Vec<_> = (0..100).map(|i| store.insert(t * 100 + i)).collect();
handles
})
})
.collect();
let all_handles: Vec<Vec<_>> = threads.into_iter().map(|t| t.join().unwrap()).collect();
for handles in &all_handles {
for &h in handles {
assert!(store.get(h).is_some());
}
}
assert_eq!(store.len(), 400);
}
#[test]
fn test_lockfree_queue_basic() {
let queue: LockFreeQueue<i32> = LockFreeQueue::new(16);
queue.push(1).unwrap();
queue.push(2).unwrap();
queue.push(3).unwrap();
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(3));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_lockfree_queue_full() {
let queue: LockFreeQueue<i32> = LockFreeQueue::new(4);
assert!(queue.push(1).is_ok());
assert!(queue.push(2).is_ok());
assert!(queue.push(3).is_ok());
assert!(queue.push(4).is_ok());
assert!(queue.push(5).is_err()); }
#[test]
fn test_lockfree_queue_concurrent() {
let queue = Arc::new(LockFreeQueue::new(1024));
let sum = Arc::new(AtomicU64::new(0));
let producers: Vec<_> = (0..4)
.map(|t| {
let queue = queue.clone();
thread::spawn(move || {
for i in 0..100 {
let _ = queue.push(t * 100 + i);
}
})
})
.collect();
for p in producers {
p.join().unwrap();
}
let consumers: Vec<_> = (0..4)
.map(|_| {
let queue = queue.clone();
let sum = sum.clone();
thread::spawn(move || {
let mut local_sum = 0u64;
while let Some(v) = queue.pop() {
local_sum += v as u64;
}
sum.fetch_add(local_sum, Ordering::Relaxed);
})
})
.collect();
for c in consumers {
c.join().unwrap();
}
assert!(queue.is_empty());
}
#[test]
fn test_sharded_map_basic() {
let map: ShardedMap<String, i32> = ShardedMap::new();
map.insert("a".to_string(), 1);
map.insert("b".to_string(), 2);
map.insert("c".to_string(), 3);
assert_eq!(map.get("a"), Some(1));
assert_eq!(map.get("b"), Some(2));
assert_eq!(map.get("c"), Some(3));
assert_eq!(map.get("d"), None);
assert_eq!(map.len(), 3);
}
#[test]
fn test_sharded_map_concurrent() {
let map = Arc::new(ShardedMap::new());
let writers: Vec<_> = (0..4)
.map(|t| {
let map = map.clone();
thread::spawn(move || {
for i in 0..100 {
let key = format!("{}_{}", t, i);
map.insert(key, t * 100 + i);
}
})
})
.collect();
for w in writers {
w.join().unwrap();
}
assert_eq!(map.len(), 400);
let readers: Vec<_> = (0..4)
.map(|t| {
let map = map.clone();
thread::spawn(move || {
for i in 0..100 {
let key = format!("{}_{}", t, i);
assert!(map.contains_key(&key));
}
})
})
.collect();
for r in readers {
r.join().unwrap();
}
}
#[test]
fn test_atomic_ref_count() {
let rc = AtomicRefCount::new(1);
assert_eq!(rc.get(), 1);
assert_eq!(rc.increment(), 2);
assert_eq!(rc.increment(), 3);
assert_eq!(rc.get(), 3);
assert_eq!(rc.decrement(), 2);
assert_eq!(rc.decrement(), 1);
assert_eq!(rc.decrement(), 0);
}
#[test]
#[serial(lockfree_queue)]
fn test_ffi_queue() {
while fz_lockfree_queue_pop() != 0 {}
assert_eq!(fz_lockfree_queue_is_empty(), 1);
assert_eq!(fz_lockfree_queue_push(100), 0);
assert_eq!(fz_lockfree_queue_push(200), 0);
assert_eq!(fz_lockfree_queue_len(), 2);
assert_eq!(fz_lockfree_queue_is_empty(), 0);
assert_eq!(fz_lockfree_queue_pop(), 100);
assert_eq!(fz_lockfree_queue_pop(), 200);
assert_eq!(fz_lockfree_queue_pop(), 0); }
#[test]
fn test_store_stats() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::new();
let h = store.insert(42);
store.get(h);
store.get(h);
store.get(999);
let stats = store.stats();
assert_eq!(stats.inserts.load(Ordering::Relaxed), 1);
assert_eq!(stats.gets.load(Ordering::Relaxed), 3);
assert_eq!(stats.hits.load(Ordering::Relaxed), 2);
assert_eq!(stats.misses.load(Ordering::Relaxed), 1);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_store_with_shards() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::with_shards(8);
let h = store.insert(1);
assert_eq!(*store.get(h).unwrap(), 1);
}
#[test]
fn test_store_keep() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::new();
let h = store.insert(99);
assert_eq!(store.keep(h), h);
}
#[test]
fn test_store_len_is_empty_clear() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
store.insert(1);
store.insert(2);
assert_eq!(store.len(), 2);
assert!(!store.is_empty());
store.clear();
assert_eq!(store.len(), 0);
assert!(store.is_empty());
}
#[test]
fn test_store_default() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::default();
assert!(store.is_empty());
}
#[test]
fn test_store_stats_hit_rate_zero() {
let stats = StoreStats::default();
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_ffi_store_stats() {
let store: LockFreeHandleStore<i32> = LockFreeHandleStore::new();
let h = store.insert(1);
store.get(h);
let ffi_stats = FfiStoreStats::from(store.stats());
assert_eq!(ffi_stats.inserts, 1);
assert_eq!(ffi_stats.gets, 1);
assert!(ffi_stats.hit_rate > 0.0);
}
#[test]
fn test_queue_default() {
let queue: LockFreeQueue<i32> = LockFreeQueue::default();
assert_eq!(queue.capacity(), 1024);
assert!(queue.is_empty());
}
#[test]
fn test_queue_len() {
let queue: LockFreeQueue<i32> = LockFreeQueue::new(8);
queue.push(1).unwrap();
queue.push(2).unwrap();
assert_eq!(queue.len(), 2);
queue.pop();
assert_eq!(queue.len(), 1);
}
#[test]
#[serial(lockfree_queue)]
fn test_ffi_queue_full() {
while fz_lockfree_queue_pop() != 0 {}
for _ in 0..4096 {
let r = fz_lockfree_queue_push(1);
if r == -1 {
break;
}
}
assert_eq!(fz_lockfree_queue_capacity(), 4096);
}
#[test]
fn test_sharded_map_remove() {
let map: ShardedMap<String, i32> = ShardedMap::new();
map.insert("a".to_string(), 1);
assert_eq!(map.remove("a"), Some(1));
assert_eq!(map.get("a"), None);
}
#[test]
fn test_sharded_map_clear() {
let map: ShardedMap<String, i32> = ShardedMap::new();
map.insert("a".to_string(), 1);
map.clear();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
}
#[test]
fn test_sharded_map_contains_key() {
let map: ShardedMap<String, i32> = ShardedMap::new();
map.insert("x".to_string(), 42);
assert!(map.contains_key("x"));
assert!(!map.contains_key("y"));
}
#[test]
fn test_sharded_map_default() {
let map: ShardedMap<String, i32> = ShardedMap::default();
assert!(map.is_empty());
}
#[test]
fn test_atomic_ref_count_default() {
let rc = AtomicRefCount::default();
assert_eq!(rc.get(), 0);
}
#[test]
fn test_atomic_ref_count_decrement_to_zero() {
let rc = AtomicRefCount::new(1);
assert_eq!(rc.decrement(), 0);
}
}