use std::cell::UnsafeCell;
use std::fmt;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering, fence};
use crate::{Pod, atomic_load, atomic_store};
#[repr(C)]
struct Inner<T> {
seq: AtomicUsize,
data: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
pub struct Writer<T> {
local_seq: usize,
inner: Arc<Inner<T>>,
}
unsafe impl<T: Send> Send for Writer<T> {}
pub struct Reader<T> {
cached_seq: usize,
inner: Arc<Inner<T>>,
}
unsafe impl<T: Send> Send for Reader<T> {}
pub fn slot<T: Pod>() -> (Writer<T>, Reader<T>) {
const {
assert!(
!std::mem::needs_drop::<T>(),
"Pod types must not require drop"
);
};
let inner = Arc::new(Inner {
seq: AtomicUsize::new(2),
data: UnsafeCell::new(MaybeUninit::uninit()),
});
(
Writer {
local_seq: 2,
inner: Arc::clone(&inner),
},
Reader {
cached_seq: 2,
inner,
},
)
}
impl<T: Pod> Writer<T> {
#[inline]
pub fn write(&mut self, value: T) {
let inner = &*self.inner;
let seq = self.local_seq;
inner.seq.store(seq.wrapping_add(1), Ordering::Relaxed);
fence(Ordering::Release);
unsafe { atomic_store(inner.data.get().cast::<T>(), &value) };
fence(Ordering::Release);
self.local_seq = seq.wrapping_add(2);
inner.seq.store(self.local_seq, Ordering::Relaxed);
}
#[inline]
pub fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
}
impl<T: Pod> fmt::Debug for Writer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Writer")
.field("seq", &self.local_seq)
.finish_non_exhaustive()
}
}
impl<T: Pod> Reader<T> {
#[inline]
pub fn read(&mut self) -> Option<T> {
let inner = &*self.inner;
loop {
let seq1 = inner.seq.load(Ordering::Relaxed);
if seq1 == 0 || seq1 == self.cached_seq {
return None;
}
if seq1 & 1 != 0 {
core::hint::spin_loop();
continue;
}
fence(Ordering::Acquire);
let value = unsafe { atomic_load(inner.data.get().cast::<T>()) };
fence(Ordering::Acquire);
let seq2 = inner.seq.load(Ordering::Relaxed);
if seq1 == seq2 {
self.cached_seq = seq1;
return Some(value);
}
core::hint::spin_loop();
}
}
#[inline]
pub fn read_versioned(&mut self) -> Option<(T, u64)> {
let inner = &*self.inner;
loop {
let seq1 = inner.seq.load(Ordering::Relaxed);
if seq1 == 0 || seq1 == self.cached_seq {
return None;
}
if seq1 & 1 != 0 {
core::hint::spin_loop();
continue;
}
fence(Ordering::Acquire);
let value = unsafe { atomic_load(inner.data.get().cast::<T>()) };
fence(Ordering::Acquire);
let seq2 = inner.seq.load(Ordering::Relaxed);
if seq1 == seq2 {
self.cached_seq = seq1;
return Some((value, seq1 as u64 / 2));
}
core::hint::spin_loop();
}
}
#[inline]
pub fn has_update(&self) -> bool {
let seq = self.inner.seq.load(Ordering::Relaxed);
seq != 0 && seq != self.cached_seq && seq & 1 == 0
}
#[inline]
pub fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.inner) == 1
}
}
impl<T: Pod> fmt::Debug for Reader<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Reader")
.field("cached_seq", &self.cached_seq)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Default, PartialEq, Debug)]
#[repr(C)]
struct TestData {
a: u64,
b: u64,
}
unsafe impl Pod for TestData {}
#[test]
fn read_before_write_returns_none() {
let (_, mut reader) = slot::<TestData>();
assert!(reader.read().is_none());
}
#[test]
fn read_consumes_value() {
let (mut writer, mut reader) = slot::<TestData>();
writer.write(TestData { a: 1, b: 2 });
assert_eq!(reader.read(), Some(TestData { a: 1, b: 2 }));
assert!(reader.read().is_none());
assert!(reader.read().is_none());
}
#[test]
fn new_write_makes_data_available_again() {
let (mut writer, mut reader) = slot::<TestData>();
writer.write(TestData { a: 1, b: 0 });
assert!(reader.read().is_some());
assert!(reader.read().is_none());
writer.write(TestData { a: 2, b: 0 });
assert!(reader.read().is_some()); assert!(reader.read().is_none()); }
#[test]
fn multiple_writes_before_read_conflates() {
let (mut writer, mut reader) = slot::<TestData>();
writer.write(TestData { a: 1, b: 0 });
writer.write(TestData { a: 2, b: 0 });
writer.write(TestData { a: 3, b: 0 });
assert_eq!(reader.read(), Some(TestData { a: 3, b: 0 }));
assert!(reader.read().is_none());
}
#[test]
fn has_update_does_not_consume() {
let (mut writer, mut reader) = slot::<TestData>();
assert!(!reader.has_update());
writer.write(TestData { a: 1, b: 0 });
assert!(reader.has_update());
assert!(reader.has_update()); assert!(reader.has_update());
reader.read();
assert!(!reader.has_update());
}
#[test]
fn writer_detects_disconnect() {
let (writer, reader) = slot::<TestData>();
assert!(!writer.is_disconnected());
drop(reader);
assert!(writer.is_disconnected());
}
#[test]
fn reader_detects_disconnect() {
let (writer, reader) = slot::<TestData>();
assert!(!reader.is_disconnected());
drop(writer);
assert!(reader.is_disconnected());
}
#[test]
fn can_read_after_writer_disconnect() {
let (mut writer, mut reader) = slot::<TestData>();
writer.write(TestData { a: 42, b: 0 });
drop(writer);
assert!(reader.is_disconnected());
assert_eq!(reader.read(), Some(TestData { a: 42, b: 0 }));
}
#[test]
fn cross_thread_write_read() {
use std::thread;
let (mut writer, mut reader) = slot::<TestData>();
let handle = thread::spawn(move || {
while reader.read().is_none() {
core::hint::spin_loop();
}
});
writer.write(TestData { a: 1, b: 2 });
handle.join().unwrap();
}
#[test]
fn cross_thread_conflation() {
use std::thread;
let (mut writer, mut reader) = slot::<u64>();
let handle = thread::spawn(move || {
let mut last = 0;
let mut count = 0;
loop {
if reader.is_disconnected() && !reader.has_update() {
break;
}
if let Some(v) = reader.read() {
assert!(v >= last, "must be monotonic");
last = v;
count += 1;
}
}
(last, count)
});
for i in 0..100_000u64 {
writer.write(i);
}
drop(writer);
let (last, count) = handle.join().unwrap();
assert_eq!(last, 99_999);
assert!(count <= 100_000); assert!(count >= 1);
}
#[test]
fn data_integrity() {
use std::thread;
#[derive(Clone)]
#[repr(C)]
struct Checkable {
value: u64,
check: u64,
}
unsafe impl Pod for Checkable {}
let (mut writer, mut reader) = slot::<Checkable>();
let handle = thread::spawn(move || {
loop {
if reader.is_disconnected() && !reader.has_update() {
break;
}
if let Some(data) = reader.read() {
assert_eq!(data.check, !data.value, "torn read!");
}
}
});
for i in 0..100_000u64 {
writer.write(Checkable {
value: i,
check: !i,
});
}
drop(writer);
handle.join().unwrap();
}
#[test]
fn large_struct_integrity() {
use std::thread;
#[derive(Clone)]
#[repr(C)]
struct Large {
seq: u64,
data: [u64; 31],
}
unsafe impl Pod for Large {}
let (mut writer, mut reader) = slot::<Large>();
let handle = thread::spawn(move || {
loop {
if reader.is_disconnected() && !reader.has_update() {
break;
}
if let Some(d) = reader.read() {
for &val in &d.data {
assert_eq!(val, d.seq, "torn read");
}
}
}
});
for i in 0..10_000u64 {
writer.write(Large {
seq: i,
data: [i; 31],
});
}
drop(writer);
handle.join().unwrap();
}
#[test]
fn stress_writes_then_single_read() {
let (mut writer, mut reader) = slot::<u64>();
for i in 0..1_000_000 {
writer.write(i);
}
assert_eq!(reader.read(), Some(999_999));
assert!(reader.read().is_none());
}
#[test]
fn ping_pong() {
use std::thread;
let (mut w1, mut r1) = slot::<u64>();
let (mut w2, mut r2) = slot::<u64>();
let handle = thread::spawn(move || {
for i in 0..10_000u64 {
while r1.read().is_none() {
core::hint::spin_loop();
}
w2.write(i);
}
});
for i in 0..10_000u64 {
w1.write(i);
while r2.read().is_none() {
core::hint::spin_loop();
}
}
handle.join().unwrap();
}
#[test]
fn read_versioned_returns_version() {
let (mut writer, mut reader) = slot::<TestData>();
writer.write(TestData { a: 1, b: 2 });
let (val, ver1) = reader.read_versioned().unwrap();
assert_eq!(val.a, 1);
writer.write(TestData { a: 3, b: 4 });
let (val, ver2) = reader.read_versioned().unwrap();
assert_eq!(val.a, 3);
assert_eq!(ver2.wrapping_sub(ver1), 1);
}
#[test]
fn read_versioned_detects_conflation() {
let (mut writer, mut reader) = slot::<TestData>();
for i in 0..5 {
writer.write(TestData { a: i, b: 0 });
}
let (val, ver1) = reader.read_versioned().unwrap();
assert_eq!(val.a, 4);
assert!(reader.read_versioned().is_none());
writer.write(TestData { a: 99, b: 0 });
let (val, ver2) = reader.read_versioned().unwrap();
assert_eq!(val.a, 99);
assert_eq!(ver2.wrapping_sub(ver1), 1);
}
#[test]
fn read_versioned_none_before_first_write() {
let (_writer, mut reader) = slot::<TestData>();
assert!(reader.read_versioned().is_none());
}
}