use crate::internal::cache_padded::CachePadded;
use core::fmt;
use parking_lot::Mutex;
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
struct LeftRightShared<T> {
live_idx: CachePadded<AtomicUsize>,
active_readers: [CachePadded<AtomicUsize>; 2],
data_0: UnsafeCell<T>,
data_1: UnsafeCell<T>,
}
impl<T> fmt::Debug for LeftRightShared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LeftRightShared")
.field("live_idx", &self.live_idx.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for LeftRightShared<T> {}
unsafe impl<T: Send + Sync> Sync for LeftRightShared<T> {}
#[derive(Debug)]
pub(crate) struct ReadHandle<T> {
shared: Arc<LeftRightShared<T>>,
}
impl<T> Clone for ReadHandle<T> {
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
#[derive(Debug)]
pub(crate) struct WriteHandle<T> {
shared: Arc<LeftRightShared<T>>,
writer_lock: Mutex<()>,
}
pub(crate) struct ReadGuard<T> {
shared: Arc<LeftRightShared<T>>,
idx: usize,
data: *const T,
}
impl<T> fmt::Debug for ReadGuard<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadGuard")
.field("idx", &self.idx)
.finish_non_exhaustive()
}
}
unsafe impl<T: Send + Sync> Send for ReadGuard<T> {}
unsafe impl<T: Sync> Sync for ReadGuard<T> {}
pub(crate) fn new<T: Default>() -> (ReadHandle<T>, WriteHandle<T>) {
let shared = Arc::new(LeftRightShared {
live_idx: CachePadded::new(AtomicUsize::new(0)),
active_readers: [
CachePadded::new(AtomicUsize::new(0)),
CachePadded::new(AtomicUsize::new(0)),
],
data_0: UnsafeCell::new(T::default()),
data_1: UnsafeCell::new(T::default()),
});
let rh = ReadHandle {
shared: shared.clone(),
};
let wh = WriteHandle {
shared,
writer_lock: Mutex::new(()),
};
(rh, wh)
}
impl<T> ReadHandle<T> {
pub(crate) fn enter(&self) -> ReadGuard<T> {
loop {
let idx = self.shared.live_idx.load(Ordering::SeqCst);
self.shared.active_readers[idx].fetch_add(1, Ordering::SeqCst);
if self.shared.live_idx.load(Ordering::SeqCst) == idx {
let data = if idx == 0 {
self.shared.data_0.get() as *const T
} else {
self.shared.data_1.get() as *const T
};
return ReadGuard {
shared: self.shared.clone(),
idx,
data,
};
}
self.shared.active_readers[idx].fetch_sub(1, Ordering::SeqCst);
}
}
}
impl<T> Deref for ReadGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.data }
}
}
impl<T> Drop for ReadGuard<T> {
fn drop(&mut self) {
self.shared.active_readers[self.idx].fetch_sub(1, Ordering::SeqCst);
}
}
impl<T> WriteHandle<T> {
pub(crate) fn modify<F>(&self, mut f: F)
where
F: FnMut(&mut T),
{
let _lock = self.writer_lock.lock();
let live_idx = self.shared.live_idx.load(Ordering::SeqCst);
let write_idx = 1 - live_idx;
let stale = unsafe {
if write_idx == 0 {
&mut *self.shared.data_0.get()
} else {
&mut *self.shared.data_1.get()
}
};
f(stale);
self.shared.live_idx.store(write_idx, Ordering::SeqCst);
while self.shared.active_readers[live_idx].load(Ordering::SeqCst) > 0 {
std::hint::spin_loop();
}
let old = unsafe {
if live_idx == 0 {
&mut *self.shared.data_0.get()
} else {
&mut *self.shared.data_1.get()
}
};
f(old);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
use std::thread;
use std::time::Duration;
#[test]
fn initial_state_is_default() {
let (rh, _wh) = new::<i32>();
assert_eq!(*rh.enter(), 0);
}
#[test]
fn write_and_read_back() {
let (rh, wh) = new::<String>();
assert_eq!(*rh.enter(), "");
wh.modify(|s| *s = "hello".to_string());
assert_eq!(*rh.enter(), "hello");
wh.modify(|s| s.push_str(" world"));
assert_eq!(*rh.enter(), "hello world");
}
#[test]
fn cloned_read_handle_sees_updates() {
let (rh, wh) = new::<usize>();
let rh2 = rh.clone();
assert_eq!(*rh.enter(), 0);
assert_eq!(*rh2.enter(), 0);
wh.modify(|val| *val = 100);
assert_eq!(*rh.enter(), 100);
assert_eq!(*rh2.enter(), 100);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct TestData {
version: usize,
data: [usize; 8],
}
impl TestData {
fn is_consistent(&self) -> bool {
self.data.iter().all(|&x| x == self.version)
}
}
#[test]
fn concurrent_reads_and_writes_are_consistent() {
let (rh, wh) = new::<TestData>();
let rh = Arc::new(rh);
let wh = Arc::new(wh);
let writer_finished = Arc::new(AtomicBool::new(false));
let mut reader_handles = Vec::new();
for _ in 0..4 {
let rh = rh.clone();
let finished = writer_finished.clone();
reader_handles.push(thread::spawn(move || {
let mut last_version = 0;
while !finished.load(Ordering::Acquire) {
let guard = rh.enter();
assert!(
guard.is_consistent(),
"Inconsistent read: {:?}",
*guard
);
assert!(
guard.version >= last_version,
"Version went backwards!"
);
last_version = guard.version;
thread::yield_now();
}
}));
}
let writer = thread::spawn(move || {
for i in 1..=100 {
wh.modify(|d| {
d.version = i;
d.data.iter_mut().for_each(|x| *x = i);
});
thread::sleep(Duration::from_micros(10));
}
writer_finished.store(true, Ordering::Release);
});
writer.join().expect("writer panicked");
for h in reader_handles {
h.join().expect("reader panicked");
}
let final_guard = rh.enter();
assert_eq!(final_guard.version, 100);
assert!(final_guard.is_consistent());
}
}