use std::cell::UnsafeCell;
use std::fmt;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering, fence};
use crate::{Pod, atomic_load, atomic_store};
#[repr(C)]
struct Inner<T> {
seq: AtomicUsize,
writer_alive: AtomicBool,
data: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
pub struct Writer<T> {
local_seq: usize,
inner: Arc<Inner<T>>,
}
unsafe impl<T: Send> Send for Writer<T> {}
pub struct SharedReader<T> {
cached_seq: usize,
inner: Arc<Inner<T>>,
}
unsafe impl<T: Send> Send for SharedReader<T> {}
impl<T> Clone for SharedReader<T> {
fn clone(&self) -> Self {
Self {
cached_seq: self.cached_seq,
inner: Arc::clone(&self.inner),
}
}
}
pub fn shared_slot<T: Pod>() -> (Writer<T>, SharedReader<T>) {
const {
assert!(
!std::mem::needs_drop::<T>(),
"Pod types must not require drop"
);
};
let inner = Arc::new(Inner {
seq: AtomicUsize::new(2),
writer_alive: AtomicBool::new(true),
data: UnsafeCell::new(MaybeUninit::uninit()),
});
(
Writer {
local_seq: 2,
inner: Arc::clone(&inner),
},
SharedReader {
cached_seq: 2,
inner,
},
)
}
impl<T: Pod> Writer<T> {
#[inline]
pub fn write(&mut self, value: T) {
let inner = &*self.inner;
let seq = self.local_seq;
inner.seq.store(seq.wrapping_add(1), Ordering::Relaxed);
fence(Ordering::Release);
unsafe { atomic_store(inner.data.get().cast::<T>(), &value) };
fence(Ordering::Release);
self.local_seq = seq.wrapping_add(2);
inner.seq.store(self.local_seq, Ordering::Relaxed);
}
#[inline]
pub fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
}
impl<T> Drop for Writer<T> {
fn drop(&mut self) {
self.inner.writer_alive.store(false, Ordering::Release);
}
}
impl<T: Pod> fmt::Debug for Writer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Writer")
.field("seq", &self.local_seq)
.finish_non_exhaustive()
}
}
impl<T: Pod> SharedReader<T> {
#[inline]
pub fn read(&mut self) -> Option<T> {
let inner = &*self.inner;
loop {
let seq1 = inner.seq.load(Ordering::Relaxed);
if seq1 == 0 || seq1 == self.cached_seq {
return None;
}
if seq1 & 1 != 0 {
core::hint::spin_loop();
continue;
}
fence(Ordering::Acquire);
let value = unsafe { atomic_load(inner.data.get().cast::<T>()) };
fence(Ordering::Acquire);
let seq2 = inner.seq.load(Ordering::Relaxed);
if seq1 == seq2 {
self.cached_seq = seq1;
return Some(value);
}
core::hint::spin_loop();
}
}
#[inline]
pub fn read_versioned(&mut self) -> Option<(T, u64)> {
let inner = &*self.inner;
loop {
let seq1 = inner.seq.load(Ordering::Relaxed);
if seq1 == 0 || seq1 == self.cached_seq {
return None;
}
if seq1 & 1 != 0 {
core::hint::spin_loop();
continue;
}
fence(Ordering::Acquire);
let value = unsafe { atomic_load(inner.data.get().cast::<T>()) };
fence(Ordering::Acquire);
let seq2 = inner.seq.load(Ordering::Relaxed);
if seq1 == seq2 {
self.cached_seq = seq1;
return Some((value, seq1 as u64 / 2));
}
core::hint::spin_loop();
}
}
#[inline]
pub fn has_update(&self) -> bool {
let seq = self.inner.seq.load(Ordering::Relaxed);
seq != 0 && seq != self.cached_seq && seq & 1 == 0
}
#[inline]
pub fn is_disconnected(&self) -> bool {
!self.inner.writer_alive.load(Ordering::Acquire)
}
}
impl<T: Pod> fmt::Debug for SharedReader<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedReader")
.field("cached_seq", &self.cached_seq)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Default, PartialEq, Debug)]
#[repr(C)]
struct TestData {
a: u64,
b: u64,
}
unsafe impl Pod for TestData {}
#[test]
fn read_before_write_returns_none() {
let (_, mut reader) = shared_slot::<TestData>();
assert!(reader.read().is_none());
}
#[test]
fn read_consumes_value() {
let (mut writer, mut reader) = shared_slot::<TestData>();
writer.write(TestData { a: 1, b: 2 });
assert_eq!(reader.read(), Some(TestData { a: 1, b: 2 }));
assert!(reader.read().is_none());
}
#[test]
fn multiple_writes_conflate() {
let (mut writer, mut reader) = shared_slot::<TestData>();
writer.write(TestData { a: 1, b: 0 });
writer.write(TestData { a: 2, b: 0 });
writer.write(TestData { a: 3, b: 0 });
assert_eq!(reader.read(), Some(TestData { a: 3, b: 0 }));
assert!(reader.read().is_none());
}
#[test]
fn has_update_does_not_consume() {
let (mut writer, mut reader) = shared_slot::<TestData>();
assert!(!reader.has_update());
writer.write(TestData { a: 1, b: 0 });
assert!(reader.has_update());
assert!(reader.has_update());
reader.read();
assert!(!reader.has_update());
}
#[test]
fn two_readers_independent_consumption() {
let (mut writer, mut reader1) = shared_slot::<u64>();
let mut reader2 = reader1.clone();
writer.write(42);
assert_eq!(reader1.read(), Some(42));
assert_eq!(reader2.read(), Some(42));
assert!(reader1.read().is_none());
assert!(reader2.read().is_none());
}
#[test]
fn clone_after_read_starts_at_parent_position() {
let (mut writer, mut reader1) = shared_slot::<u64>();
writer.write(1);
assert_eq!(reader1.read(), Some(1));
let mut reader2 = reader1.clone();
assert!(reader1.read().is_none());
assert!(reader2.read().is_none());
writer.write(2);
assert_eq!(reader1.read(), Some(2));
assert_eq!(reader2.read(), Some(2));
}
#[test]
fn clone_before_read_both_see_value() {
let (mut writer, mut reader1) = shared_slot::<u64>();
let mut reader2 = reader1.clone();
writer.write(99);
assert_eq!(reader1.read(), Some(99));
assert_eq!(reader2.read(), Some(99));
}
#[test]
fn reader1_consumes_reader2_unaffected() {
let (mut writer, mut reader1) = shared_slot::<u64>();
let mut reader2 = reader1.clone();
writer.write(10);
assert_eq!(reader1.read(), Some(10));
assert!(reader1.read().is_none());
assert!(reader2.has_update());
assert_eq!(reader2.read(), Some(10));
}
#[test]
fn many_readers() {
let (mut writer, reader) = shared_slot::<u64>();
let mut readers: Vec<_> = (0..10).map(|_| reader.clone()).collect();
drop(reader);
writer.write(42);
for r in &mut readers {
assert_eq!(r.read(), Some(42));
}
}
#[test]
fn writer_detects_all_readers_dropped() {
let (writer, reader1) = shared_slot::<TestData>();
let reader2 = reader1.clone();
assert!(!writer.is_disconnected());
drop(reader1);
assert!(!writer.is_disconnected()); drop(reader2);
assert!(writer.is_disconnected());
}
#[test]
fn reader_detects_writer_dropped() {
let (writer, reader) = shared_slot::<TestData>();
assert!(!reader.is_disconnected());
drop(writer);
assert!(reader.is_disconnected());
}
#[test]
fn cloned_reader_detects_writer_dropped() {
let (writer, reader1) = shared_slot::<TestData>();
let reader2 = reader1.clone();
drop(writer);
assert!(reader1.is_disconnected());
assert!(reader2.is_disconnected());
}
#[test]
fn can_read_after_writer_disconnect() {
let (mut writer, mut reader) = shared_slot::<TestData>();
writer.write(TestData { a: 42, b: 0 });
drop(writer);
assert!(reader.is_disconnected());
assert_eq!(reader.read(), Some(TestData { a: 42, b: 0 }));
}
#[test]
fn cross_thread_two_readers() {
use std::thread;
let (mut writer, mut reader1) = shared_slot::<u64>();
let mut reader2 = reader1.clone();
let h1 = thread::spawn(move || {
let mut last = 0;
loop {
if reader1.is_disconnected() && !reader1.has_update() {
break;
}
if let Some(v) = reader1.read() {
assert!(v >= last, "reader1: monotonicity violation");
last = v;
}
}
last
});
let h2 = thread::spawn(move || {
let mut last = 0;
loop {
if reader2.is_disconnected() && !reader2.has_update() {
break;
}
if let Some(v) = reader2.read() {
assert!(v >= last, "reader2: monotonicity violation");
last = v;
}
}
last
});
for i in 0..100_000u64 {
writer.write(i);
}
drop(writer);
let last1 = h1.join().unwrap();
let last2 = h2.join().unwrap();
assert_eq!(last1, 99_999);
assert_eq!(last2, 99_999);
}
#[test]
fn cross_thread_data_integrity() {
use std::thread;
#[derive(Clone)]
#[repr(C)]
struct Checkable {
value: u64,
check: u64,
}
unsafe impl Pod for Checkable {}
let (mut writer, mut reader1) = shared_slot::<Checkable>();
let mut reader2 = reader1.clone();
let h1 = thread::spawn(move || {
loop {
if reader1.is_disconnected() && !reader1.has_update() {
break;
}
if let Some(data) = reader1.read() {
assert_eq!(data.check, !data.value, "reader1: torn read!");
}
}
});
let h2 = thread::spawn(move || {
loop {
if reader2.is_disconnected() && !reader2.has_update() {
break;
}
if let Some(data) = reader2.read() {
assert_eq!(data.check, !data.value, "reader2: torn read!");
}
}
});
for i in 0..100_000u64 {
writer.write(Checkable {
value: i,
check: !i,
});
}
drop(writer);
h1.join().unwrap();
h2.join().unwrap();
}
}