#[cfg(feature = "alloc")]
use alloc::sync::Arc;
#[cfg(feature = "std")]
use std::sync::Mutex;
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct RcuCell<T> {
inner: Mutex<Arc<T>>,
}
#[cfg(feature = "std")]
impl<T> RcuCell<T> {
#[must_use]
pub fn new(value: T) -> Self {
Self {
inner: Mutex::new(Arc::new(value)),
}
}
#[must_use]
pub fn read(&self) -> Arc<T> {
match self.inner.lock() {
Ok(g) => Arc::clone(&g),
Err(p) => Arc::clone(&p.into_inner()),
}
}
pub fn write_with(&self, f: impl FnOnce(&T) -> T)
where
T: Clone,
{
let mut guard = match self.inner.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
let new = f(&guard);
*guard = Arc::new(new);
}
pub fn modify(&self, f: impl FnOnce(&mut T))
where
T: Clone,
{
self.write_with(|cur| {
let mut new = cur.clone();
f(&mut new);
new
});
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use alloc::string::ToString;
use alloc::vec::Vec;
#[test]
fn new_cell_returns_initial_value() {
let cell = RcuCell::new(42u32);
assert_eq!(*cell.read(), 42);
}
#[test]
fn write_with_replaces_value() {
let cell = RcuCell::new(1u32);
cell.write_with(|c| c + 10);
assert_eq!(*cell.read(), 11);
}
#[test]
fn old_snapshot_is_immutable_after_write() {
let cell = RcuCell::new(1u32);
let snap = cell.read();
cell.write_with(|c| c + 100);
assert_eq!(*snap, 1, "RCU snapshot is decoupled from later writes");
assert_eq!(*cell.read(), 101);
}
#[test]
fn multiple_readers_share_arc() {
let cell = RcuCell::new("hello".to_string());
let a = cell.read();
let b = cell.read();
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn modify_uses_clone_and_mutator() {
let cell = RcuCell::new(alloc::vec![1u32, 2, 3]);
cell.modify(|v| v.push(4));
let r = cell.read();
assert_eq!(*r, alloc::vec![1, 2, 3, 4]);
}
#[test]
fn concurrent_readers_writers_smoke() {
use std::sync::Arc as StdArc;
use std::thread;
use std::time::Duration;
let cell: StdArc<RcuCell<u64>> = StdArc::new(RcuCell::new(0));
let stop: StdArc<std::sync::atomic::AtomicBool> =
StdArc::new(std::sync::atomic::AtomicBool::new(false));
let writer = {
let cell = StdArc::clone(&cell);
let stop = StdArc::clone(&stop);
thread::spawn(move || {
let mut i = 0u64;
while !stop.load(std::sync::atomic::Ordering::Relaxed) {
cell.write_with(|_| i);
i = i.wrapping_add(1);
}
i
})
};
let mut readers = Vec::new();
for _ in 0..4 {
let cell = StdArc::clone(&cell);
let stop = StdArc::clone(&stop);
readers.push(thread::spawn(move || {
let mut last = 0u64;
while !stop.load(std::sync::atomic::Ordering::Relaxed) {
let v = *cell.read();
assert!(v >= last);
last = v;
}
last
}));
}
thread::sleep(Duration::from_millis(10));
stop.store(true, std::sync::atomic::Ordering::Relaxed);
let writer_count = writer.join().unwrap();
for r in readers {
let last = r.join().unwrap();
assert!(last <= writer_count);
}
}
#[test]
fn write_with_recovers_from_poisoned_mutex() {
use std::sync::Arc as StdArc;
let cell: StdArc<RcuCell<u32>> = StdArc::new(RcuCell::new(7));
let cell_p = StdArc::clone(&cell);
let _ = std::thread::spawn(move || {
cell_p.write_with(|_| panic!("intentional"));
})
.join();
assert_eq!(*cell.read(), 7);
cell.write_with(|c| c + 1);
assert_eq!(*cell.read(), 8);
}
}