use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use parking_lot_core::SpinWait;
use std::{
sync::atomic::{AtomicUsize, Ordering}
};
use std::cell::UnsafeCell;
use crate::error;
pub struct HybridLatch<T: ?Sized> {
version: AtomicUsize,
lock: RwLock<()>,
data: UnsafeCell<T>
}
unsafe impl<T: ?Sized + Send> Send for HybridLatch<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for HybridLatch<T> {}
impl<T> HybridLatch<T> {
#[inline]
pub fn new(data: T) -> HybridLatch<T> {
HybridLatch {
version: AtomicUsize::new(0),
data: UnsafeCell::new(data),
lock: RwLock::new(()),
}
}
#[inline]
pub fn exclusive(&self) -> ExclusiveGuard<'_, T> {
let guard = self.lock.write();
let version = self.version.load(Ordering::Relaxed) + 1;
self.version.store(version, Ordering::Release);
ExclusiveGuard {
latch: self,
guard,
data: self.data.get(),
version
}
}
#[inline]
pub fn shared(&self) -> SharedGuard<'_, T> {
let guard = self.lock.read();
let version = self.version.load(Ordering::Relaxed);
SharedGuard {
latch: self,
guard,
data: self.data.get(),
version
}
}
#[inline(never)]
pub fn optimistic_or_spin(&self) -> OptimisticGuard<'_, T> {
let mut version = self.version.load(Ordering::Acquire);
if (version & 1) == 1 {
let mut spinwait = SpinWait::new();
loop {
version = self.version.load(Ordering::Acquire);
if (version & 1) == 1 {
let result = spinwait.spin();
if !result {
spinwait.reset();
}
continue
} else {
break
}
}
}
OptimisticGuard {
latch: self,
data: self.data.get(),
version
}
}
#[inline]
pub fn optimistic_or_unwind(&self) -> error::Result<OptimisticGuard<'_, T>> {
let version = self.version.load(Ordering::Acquire);
if (version & 1) == 1 {
return Err(error::Error::Unwind)
}
Ok(OptimisticGuard {
latch: self,
data: self.data.get(),
version
})
}
#[inline]
pub fn optimistic_or_shared(&self) -> OptimisticOrShared<'_, T> {
let version = self.version.load(Ordering::Acquire);
if (version & 1) == 1 {
let guard = self.lock.read();
let version = self.version.load(Ordering::Relaxed);
OptimisticOrShared::Shared(SharedGuard {
latch: self,
guard,
data: self.data.get(),
version
})
} else {
OptimisticOrShared::Optimistic(OptimisticGuard {
latch: self,
data: self.data.get(),
version
})
}
}
#[inline]
pub fn optimistic_or_exclusive(&self) -> OptimisticOrExclusive<'_, T> {
let version = self.version.load(Ordering::Acquire);
if (version & 1) == 1 {
let guard = self.lock.write();
let version = self.version.load(Ordering::Relaxed) + 1;
self.version.store(version, Ordering::Release);
OptimisticOrExclusive::Exclusive(ExclusiveGuard {
latch: self,
guard,
data: self.data.get(),
version
})
} else {
OptimisticOrExclusive::Optimistic(OptimisticGuard {
latch: self,
data: self.data.get(),
version
})
}
}
}
impl<T> std::convert::AsMut<T> for HybridLatch<T> {
#[inline]
fn as_mut(&mut self) -> &mut T {
unsafe { &mut *self.data.get() }
}
}
pub trait HybridGuard<T: ?Sized> {
fn inner(&self) -> &T;
fn recheck(&self) -> error::Result<()>;
fn latch(&self) -> &HybridLatch<T>;
}
pub struct OptimisticGuard<'a, T: ?Sized> {
latch: &'a HybridLatch<T>,
data: *const T,
version: usize
}
unsafe impl<'a, T: ?Sized + Sync> Sync for OptimisticGuard<'a, T> {}
impl<'a, T> OptimisticGuard<'a, T> {
#[inline]
pub fn recheck(&self) -> error::Result<()> {
if self.version != self.latch.version.load(Ordering::Acquire) {
return Err(error::Error::Unwind)
}
Ok(())
}
#[inline]
pub fn to_exclusive(self) -> error::Result<ExclusiveGuard<'a, T>> {
let new_version = self.version + 1;
let expected = self.version;
let locked = self.latch.lock.write();
if self.latch.version
.compare_exchange(
expected,
new_version,
Ordering::Acquire,
Ordering::Acquire).is_err()
{
drop(locked);
return Err(error::Error::Unwind)
}
Ok(ExclusiveGuard {
latch: self.latch,
guard: locked,
data: self.data as *mut _,
version: new_version
})
}
#[inline]
pub fn to_shared(self) -> error::Result<SharedGuard<'a, T>> {
if let Some(guard) = self.latch.lock.try_read() {
if self.version != self.latch.version.load(Ordering::Relaxed) {
return Err(error::Error::Unwind)
}
Ok(SharedGuard {
latch: self.latch,
guard,
data: self.data,
version: self.version
})
} else {
return Err(error::Error::Unwind)
}
}
pub fn latch(&self) -> &'a HybridLatch<T> {
self.latch
}
}
impl<'a, T> std::ops::Deref for OptimisticGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.data }
}
}
impl<'a, T> HybridGuard<T> for OptimisticGuard<'a, T> {
fn inner(&self) -> &T {
self
}
fn recheck(&self) -> error::Result<()> {
self.recheck()
}
fn latch(&self) -> &HybridLatch<T> {
self.latch()
}
}
pub struct ExclusiveGuard<'a, T: ?Sized> {
latch: &'a HybridLatch<T>,
#[allow(dead_code)]
guard: RwLockWriteGuard<'a, ()>,
data: *mut T,
version: usize
}
unsafe impl<'a, T: ?Sized + Sync> Sync for ExclusiveGuard<'a, T> {}
impl<'a, T> ExclusiveGuard<'a, T> {
#[inline]
pub fn recheck(&self) {
assert!(self.version == self.latch.version.load(Ordering::Relaxed));
}
#[inline]
pub fn unlock(self) -> OptimisticGuard<'a, T> {
let new_version = self.version + 1;
let latch = self.latch;
let data = self.data;
drop(self);
OptimisticGuard {
latch,
data,
version: new_version
}
}
pub fn latch(&self) -> &'a HybridLatch<T> {
self.latch
}
}
impl<'a, T: ?Sized> Drop for ExclusiveGuard<'a, T> {
#[inline]
fn drop(&mut self) {
let new_version = self.version + 1;
self.latch.version.store(new_version, Ordering::Release);
}
}
impl<'a, T> std::ops::Deref for ExclusiveGuard<'a, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.data }
}
}
impl<'a, T> std::ops::DerefMut for ExclusiveGuard<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.data }
}
}
impl<'a, T> std::convert::AsMut<T> for ExclusiveGuard<'a, T> {
#[inline]
fn as_mut(&mut self) -> &mut T {
unsafe { &mut *self.data }
}
}
impl<'a, T> HybridGuard<T> for ExclusiveGuard<'a, T> {
fn inner(&self) -> &T {
self
}
fn recheck(&self) -> error::Result<()> {
self.recheck();
Ok(())
}
fn latch(&self) -> &HybridLatch<T> {
self.latch()
}
}
pub struct SharedGuard<'a, T: ?Sized> {
latch: &'a HybridLatch<T>,
#[allow(dead_code)]
guard: RwLockReadGuard<'a, ()>,
data: *const T,
version: usize
}
unsafe impl<'a, T: ?Sized + Sync> Sync for SharedGuard<'a, T> {}
impl<'a, T> SharedGuard<'a, T> {
#[inline]
pub fn recheck(&self) {
assert!(self.version == self.latch.version.load(Ordering::Relaxed));
}
#[inline]
pub fn unlock(self) -> OptimisticGuard<'a, T> {
OptimisticGuard {
latch: self.latch,
data: self.data,
version: self.version
}
}
#[inline]
pub fn as_optimistic<'b>(&'b self) -> OptimisticGuard<'b, T> {
OptimisticGuard {
latch: self.latch,
data: self.data,
version: self.version
}
}
pub fn latch(&self) -> &'a HybridLatch<T> {
self.latch
}
}
impl<'a, T> std::ops::Deref for SharedGuard<'a, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.data }
}
}
impl<'a, T> HybridGuard<T> for SharedGuard<'a, T> {
fn inner(&self) -> &T {
self
}
fn recheck(&self) -> error::Result<()> {
self.recheck();
Ok(())
}
fn latch(&self) -> &HybridLatch<T> {
self.latch()
}
}
pub enum OptimisticOrShared<'a, T> {
Optimistic(OptimisticGuard<'a, T>),
Shared(SharedGuard<'a, T>)
}
impl<'a, T> OptimisticOrShared<'a, T> {
#[inline]
pub fn recheck(&self) -> error::Result<()> {
match self {
OptimisticOrShared::Optimistic(g) => g.recheck(),
OptimisticOrShared::Shared(g) => {
g.recheck();
Ok(())
}
}
}
}
pub enum OptimisticOrExclusive<'a, T> {
Optimistic(OptimisticGuard<'a, T>),
Exclusive(ExclusiveGuard<'a, T>)
}
impl<'a, T> OptimisticOrExclusive<'a, T> {
#[inline]
pub fn recheck(&self) -> error::Result<()> {
match self {
OptimisticOrExclusive::Optimistic(g) => g.recheck(),
OptimisticOrExclusive::Exclusive(g) => {
g.recheck();
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use super::HybridLatch;
use crate::error;
use std::cell::UnsafeCell;
use std::sync::Arc;
use std::thread;
use serial_test::serial;
struct Wrapper<T>(UnsafeCell<[T; 1000]>);
unsafe impl<T: Send> Send for Wrapper<T> {}
unsafe impl<T: Send + Sync> Sync for Wrapper<T> {}
#[test]
#[serial]
fn single_threaded_reader_baseline() {
let data = [1usize; 1000];
let mut result = 1usize;
let t0 = std::time::Instant::now();
for _i in 0..4000000 {
for j in 0..1000 {
result = result.saturating_mul(data[j]);
}
}
println!("Single threaded reader done in {:?}", t0.elapsed());
assert!(result == 1);
}
#[test]
#[serial]
fn concurrent_reading_and_writing() {
let data = Arc::new(Wrapper(UnsafeCell::new([1usize; 1000])));
let latch = Arc::new(HybridLatch::new(()));
let n_readers = 3;
let n_writers = 1;
let n = n_readers + n_writers;
let barrier = Arc::new(std::sync::Barrier::new(n + 1));
let mut readers = vec![];
for _i in 0..n_readers {
let data = data.clone();
let latch = latch.clone();
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let mut result = 1usize;
for _i in 0..4000000 {
loop {
let res = {
let attempt = || {
let locked = latch.optimistic_or_spin();
let arr = data.0.get();
let mut result = 1usize;
for j in 0..1000 {
result = result.saturating_mul(unsafe { (*arr)[j] });
}
locked.recheck()?;
error::Result::Ok(result)
};
attempt()
};
match res {
Ok(v) => {
result *= v;
break;
}
Err(_) => {
continue;
}
}
}
assert!(result == 1);
}
assert!(result == 1);
});
readers.push(handle);
}
let mut writers = vec![];
for _i in 0..n_writers {
let data = data.clone();
let latch = latch.clone();
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let seconds = 10f64;
let micros_per_sec = 1_000_000;
let freq = 100;
let critical = 1000;
for _i in 0..(seconds * freq as f64) as usize {
thread::sleep(std::time::Duration::from_micros((micros_per_sec / freq) - critical));
{
let _locked = latch.exclusive();
unsafe { (*data.0.get())[3] = 2 };
thread::sleep(std::time::Duration::from_micros(critical));
unsafe { (*data.0.get())[3] = 1 };
}
}
});
writers.push(handle);
}
barrier.wait();
let t0 = std::time::Instant::now();
for handle in readers {
handle.join().unwrap();
}
println!("Readers done in {:?}", t0.elapsed());
for handle in writers {
handle.join().unwrap();
}
println!("Writers done in at most {:?}", t0.elapsed());
}
#[test]
#[serial]
fn single_threaded_option_reader_baseline() {
let data = [Some(1usize); 1000];
let mut result = 1usize;
let t0 = std::time::Instant::now();
for _i in 0..4000000 {
for j in 0..1000 {
let opt = &data[j];
if let Some(n) = opt {
result = result.saturating_mul(*n);
} else {
result = 0;
}
}
}
println!("Single threaded option reader done in {:?}", t0.elapsed());
assert!(result == 1);
}
#[test]
#[serial]
fn concurrent_option_reading_and_writing() {
let data = Arc::new(Wrapper(UnsafeCell::new([Some(1usize); 1000])));
let latch = Arc::new(HybridLatch::new(()));
let n_readers = 3;
let n_writers = 1;
let n = n_readers + n_writers;
let barrier = Arc::new(std::sync::Barrier::new(n + 1));
let mut readers = vec![];
for _i in 0..n_readers {
let data = data.clone();
let latch = latch.clone();
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let mut result = 1usize;
for _i in 0..4000000 {
loop {
let res = {
let attempt = || {
let locked = latch.optimistic_or_spin();
let arr = data.0.get();
let mut result = 1usize;
for j in 0..1000 {
let opt = unsafe { &(*arr)[j] };
if let Some(n) = opt {
result = result.saturating_mul(*n);
} else {
result = 0;
}
}
locked.recheck()?;
error::Result::Ok(result)
};
attempt()
};
match res {
Ok(v) => {
result *= v;
break;
}
Err(_) => {
continue;
}
}
}
assert!(result == 1);
}
assert!(result == 1);
});
readers.push(handle);
}
let mut writers = vec![];
for _i in 0..n_writers {
let data = data.clone();
let latch = latch.clone();
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let seconds = 10f64;
let micros_per_sec = 1_000_000;
let freq = 100;
let critical = 1000;
for _i in 0..(seconds * freq as f64) as usize {
thread::sleep(std::time::Duration::from_micros((micros_per_sec / freq) - critical));
{
let _locked = latch.exclusive();
unsafe { (*data.0.get())[3] = None };
thread::sleep(std::time::Duration::from_micros(critical));
unsafe { (*data.0.get())[3] = Some(1) };
}
}
});
writers.push(handle);
}
barrier.wait();
let t0 = std::time::Instant::now();
for handle in readers {
handle.join().unwrap();
}
println!("Readers done in {:?}", t0.elapsed());
for handle in writers {
handle.join().unwrap();
}
println!("Writers done in at most {:?}", t0.elapsed());
}
}