use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::Ordering;
use crate::loom_exports::cell::UnsafeCell;
use crate::loom_exports::sync::atomic::AtomicUsize;
const WRITER_LOCK: usize = 0b01;
const READER_LOCK: usize = 0b10;
pub(crate) struct SharedCell<T> {
slots: [UnsafeCell<Option<T>>; 2],
slot_idx: AtomicUsize,
locks: [AtomicUsize; 2],
}
impl<T> SharedCell<T> {
pub(crate) fn new() -> Self {
Self {
slots: [UnsafeCell::new(None), UnsafeCell::new(None)],
slot_idx: AtomicUsize::new(0),
locks: [AtomicUsize::new(0), AtomicUsize::new(0)],
}
}
pub(crate) fn try_write(&self, value: T) -> Result<(), T> {
let slot_idx = self.slot_idx.load(Ordering::Acquire);
let lock = &self.locks[slot_idx];
match lock.fetch_or(WRITER_LOCK, Ordering::Acquire) {
0 => {
unsafe {
self.slots[slot_idx].with_mut(|v| *v = Some(value));
}
lock.store(0, Ordering::Release);
return Ok(());
}
WRITER_LOCK => return Err(value),
_ => {}
}
let slot_idx = 1 - slot_idx;
let lock = &self.locks[slot_idx];
if lock.fetch_or(WRITER_LOCK, Ordering::Acquire) == 0 {
unsafe {
self.slots[slot_idx].with_mut(|v| *v = Some(value));
}
self.slot_idx.store(slot_idx, Ordering::Release);
lock.store(0, Ordering::Release);
return Ok(());
}
Err(value)
}
pub(crate) fn try_read(&self) -> Option<T> {
let slot_idx = self.slot_idx.load(Ordering::Acquire);
let lock = &self.locks[slot_idx];
if lock.fetch_or(READER_LOCK, Ordering::Acquire) == 0 {
let value = unsafe { self.slots[slot_idx].with_mut(|v| (*v).take()) };
lock.store(0, Ordering::Release);
value
} else {
None
}
}
}
impl<T> Default for SharedCell<T> {
fn default() -> Self {
Self::new()
}
}
unsafe impl<T: Send> Sync for SharedCell<T> {}
impl<T> UnwindSafe for SharedCell<T> {}
impl<T> RefUnwindSafe for SharedCell<T> {}
#[cfg(all(test, not(nexosim_loom)))]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn shared_cell_smoke_test() {
let cell = SharedCell::new();
assert!(cell.try_read().is_none());
assert!(cell.try_write(123).is_ok());
assert_eq!(cell.try_read(), Some(123));
assert!(cell.try_read().is_none());
}
#[test]
fn shared_cell_overwrite_test() {
let cell = SharedCell::new();
assert!(cell.try_write(123).is_ok());
assert!(cell.try_write(42).is_ok());
assert_eq!(cell.try_read(), Some(42));
assert!(cell.try_read().is_none());
}
#[test]
fn shared_cell_multi_threaded_write() {
let cell = Arc::new(SharedCell::new());
thread::spawn({
let cell = cell.clone();
move || {
assert!(cell.try_write(123).is_ok());
}
});
loop {
if let Some(v) = cell.try_read() {
assert_eq!(v, 123);
return;
}
}
}
}
#[cfg(all(test, nexosim_loom))]
mod tests {
use super::*;
use crate::loom_exports::loom_builder;
use std::sync::Arc;
use loom::thread;
#[test]
fn loom_shared_cell_write() {
const DEFAULT_PREEMPTION_BOUND: usize = 4;
let builder = loom_builder(DEFAULT_PREEMPTION_BOUND);
builder.check(move || {
let cell = Arc::new(SharedCell::new());
let th = thread::spawn({
let cell = cell.clone();
move || assert!(cell.try_write(42).is_ok())
});
if let Some(v) = cell.try_read() {
assert_eq!(v, 42);
}
th.join().unwrap();
});
}
#[test]
fn loom_shared_cell_overwrite() {
const DEFAULT_PREEMPTION_BOUND: usize = 4;
let builder = loom_builder(DEFAULT_PREEMPTION_BOUND);
builder.check(move || {
let cell = Arc::new(SharedCell::new());
let th = thread::spawn({
let cell = cell.clone();
move || {
assert!(cell.try_write(42).is_ok());
assert!(cell.try_write(123).is_ok());
}
});
if let Some(v) = cell.try_read() {
if v == 42 {
th.join().unwrap();
assert_eq!(cell.try_read(), Some(123));
} else {
assert_eq!(v, 123);
th.join().unwrap();
}
} else {
th.join().unwrap();
}
});
}
#[test]
fn loom_shared_cell_concurrent_writers() {
const DEFAULT_PREEMPTION_BOUND: usize = 4;
let builder = loom_builder(DEFAULT_PREEMPTION_BOUND);
builder.check(move || {
let cell = Arc::new(SharedCell::new());
let th1 = thread::spawn({
let cell = cell.clone();
move || {
let _ = cell.try_write(42);
}
});
let th2 = thread::spawn({
let cell = cell.clone();
move || {
let _ = cell.try_write(123);
}
});
th1.join().unwrap();
th2.join().unwrap();
});
}
#[test]
fn loom_shared_cell_concurrent_readers() {
const DEFAULT_PREEMPTION_BOUND: usize = 4;
let builder = loom_builder(DEFAULT_PREEMPTION_BOUND);
builder.check(move || {
let cell = Arc::new(SharedCell::new());
let th1 = thread::spawn({
let cell = cell.clone();
move || cell.try_read()
});
let th2 = thread::spawn({
let cell = cell.clone();
move || cell.try_read()
});
assert!(cell.try_write(42).is_ok());
let _v1 = th1.join().unwrap();
let _v2 = th2.join().unwrap();
});
}
}