use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::sync::Mutex;
use std::thread::{self, ThreadId};
use crossbeam_utils::CachePadded;
use num_cpus;
use parking_lot;
pub struct RwLock<T> {
shards: Vec<CachePadded<parking_lot::RwLock<()>>>,
value: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for RwLock<T> {}
unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
impl<T> RwLock<T> {
pub fn new(value: T) -> RwLock<T> {
let num_shards = num_cpus::get().next_power_of_two();
RwLock {
shards: (0..num_shards)
.map(|_| CachePadded::new(parking_lot::RwLock::new(())))
.collect(),
value: UnsafeCell::new(value),
}
}
pub fn read(&self) -> RwLockReadGuard<T> {
let shard_index = thread_index() & (self.shards.len() - 1);
RwLockReadGuard {
parent: self,
_guard: self.shards[shard_index].read(),
_marker: PhantomData,
}
}
pub fn write(&self) -> RwLockWriteGuard<T> {
for shard in &self.shards {
mem::forget(shard.write());
}
RwLockWriteGuard {
parent: self,
_marker: PhantomData,
}
}
}
pub struct RwLockReadGuard<'a, T: 'a> {
parent: &'a RwLock<T>,
_guard: parking_lot::RwLockReadGuard<'a, ()>,
_marker: PhantomData<parking_lot::RwLockReadGuard<'a, T>>,
}
unsafe impl<'a, T: Sync> Sync for RwLockReadGuard<'a, T> {}
impl<'a, T> Deref for RwLockReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
pub struct RwLockWriteGuard<'a, T: 'a> {
parent: &'a RwLock<T>,
_marker: PhantomData<parking_lot::RwLockWriteGuard<'a, T>>,
}
unsafe impl<'a, T: Sync> Sync for RwLockWriteGuard<'a, T> {}
impl<'a, T> Drop for RwLockWriteGuard<'a, T> {
fn drop(&mut self) {
for shard in self.parent.shards.iter().rev() {
unsafe {
shard.force_unlock_write();
}
}
}
}
impl<'a, T> Deref for RwLockWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
impl<'a, T> DerefMut for RwLockWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.parent.value.get() }
}
}
#[inline]
pub fn thread_index() -> usize {
REGISTRATION.try_with(|reg| reg.index).unwrap_or(0)
}
struct ThreadIndices {
mapping: HashMap<ThreadId, usize>,
free_list: Vec<usize>,
next_index: usize,
}
lazy_static! {
static ref THREAD_INDICES: Mutex<ThreadIndices> = Mutex::new(ThreadIndices {
mapping: HashMap::new(),
free_list: Vec::new(),
next_index: 0,
});
}
struct Registration {
index: usize,
thread_id: ThreadId,
}
impl Drop for Registration {
fn drop(&mut self) {
let mut indices = THREAD_INDICES.lock().unwrap();
indices.mapping.remove(&self.thread_id);
indices.free_list.push(self.index);
}
}
thread_local! {
static REGISTRATION: Registration = {
let thread_id = thread::current().id();
let mut indices = THREAD_INDICES.lock().unwrap();
let index = match indices.free_list.pop() {
Some(i) => i,
None => {
let i = indices.next_index;
indices.next_index += 1;
i
}
};
indices.mapping.insert(thread_id, index);
Registration {
index,
thread_id,
}
};
}