use std::{
cell::UnsafeCell,
panic::UnwindSafe,
sync::atomic::{
AtomicU64,
Ordering,
fence,
},
};
#[derive(Debug)]
pub struct SeqLock<T: Copy> {
sequence: AtomicU64,
data: UnsafeCell<T>,
}
unsafe impl<T: Send + Copy> Sync for SeqLock<T> {}
#[derive(Debug)]
pub struct SeqLockWriter<T: Copy> {
lock: std::sync::Arc<SeqLock<T>>,
}
impl<T: Copy> SeqLockWriter<T> {
pub fn write<F>(&self, f: F)
where
F: FnOnce(&mut T) + UnwindSafe,
{
let lock = &self.lock;
lock.sequence.fetch_add(1, Ordering::AcqRel);
fence(Ordering::Acquire);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
let data = &mut *lock.data.get();
f(data);
}));
fence(Ordering::Release);
lock.sequence.fetch_add(1, Ordering::Release);
if let Err(e) = result {
std::panic::resume_unwind(e);
}
}
}
#[derive(Clone, Debug)]
pub struct SeqLockReader<T: Copy> {
lock: std::sync::Arc<SeqLock<T>>,
}
impl<T: Copy> SeqLockReader<T> {
pub fn read(&self) -> T {
let lock = &self.lock;
loop {
let start = lock.sequence.load(Ordering::Acquire);
if !start.is_multiple_of(2) {
std::thread::yield_now();
continue;
}
fence(Ordering::Acquire);
let data = unsafe { *lock.data.get() };
fence(Ordering::Acquire);
let end = lock.sequence.load(Ordering::Acquire);
if start == end && start.is_multiple_of(2) {
return data;
}
}
}
}
impl<T: Copy> SeqLock<T> {
#[allow(clippy::new_ret_no_self)]
pub unsafe fn new(data: T) -> (SeqLockWriter<T>, SeqLockReader<T>) {
let lock = Self {
sequence: AtomicU64::new(0),
data: UnsafeCell::new(data),
};
let shared = std::sync::Arc::new(lock);
(
SeqLockWriter {
lock: std::sync::Arc::clone(&shared),
},
SeqLockReader { lock: shared },
)
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_seqlock__provides_correct_values_in_order() {
let (writer, reader) = unsafe { SeqLock::new(42) };
let iterations = 100;
let writer = {
thread::spawn(move || {
for i in 0..iterations {
writer.write(|data| *data = i);
}
})
};
let reader = {
let lock = reader.clone();
thread::spawn(move || {
let seen = 0;
for _ in 0..iterations {
let value = lock.read();
assert!(value >= seen);
}
})
};
writer.join().unwrap();
reader.join().unwrap();
}
#[test]
fn test_seqlock__single_threaded() {
let (writer, reader) = unsafe { SeqLock::new(42) };
writer.write(|data| {
*data = 100;
});
let value = reader.read();
assert_eq!(value, 100);
}
}