#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(all(not(feature = "std"), test))]
extern crate std;
mod shim;
#[cfg(test)]
mod tests;
use crate::shim::{
Arc, AtomicPtr, AtomicUsize, Box, Cell, Mutex, Ordering, Vec, VecDeque, heavy_barrier,
light_barrier,
};
use core::{fmt, marker::PhantomData, ops::Deref};
pub(crate) const AUTO_RECLAIM_THRESHOLD: usize = 16;
pub(crate) const INACTIVE_VERSION: usize = usize::MAX;
pub struct SwmrCell<T: 'static, const RP: bool = false> {
shared: Arc<SharedState<T, RP>>,
garbage: GarbageSet<T>,
auto_reclaim_threshold: Option<usize>,
}
impl<T: 'static> SwmrCell<T, false> {
#[inline]
pub fn new(data: T) -> Self {
Self::builder().build(data)
}
#[inline]
pub fn builder() -> SwmrCellBuilder<T, false> {
SwmrCellBuilder::default()
}
}
impl<T: 'static, const RP: bool> SwmrCell<T, RP> {
#[inline]
pub fn local_reader(&self) -> LocalReader<T, RP> {
LocalReader::new(self.shared.clone())
}
#[inline]
pub fn reader_factory(&self) -> SwmrReaderFactory<T, RP> {
SwmrReaderFactory {
shared: self.shared.clone(),
}
}
pub fn store(&mut self, data: T) {
let new_ptr = Box::into_raw(Box::new(data));
let old_ptr = self.shared.ptr.swap(new_ptr, Ordering::Release);
let old_version = self.shared.global_version.fetch_add(1, Ordering::AcqRel);
if !old_ptr.is_null() {
unsafe {
self.garbage.add(Box::from_raw(old_ptr), old_version);
}
}
if let Some(threshold) = self.auto_reclaim_threshold
&& self.garbage.len() > threshold
{
self.collect();
}
}
#[inline]
pub fn previous(&self) -> Option<&T> {
self.garbage.back()
}
#[inline]
pub fn get(&self) -> &T {
unsafe { &*self.shared.ptr.load(Ordering::Acquire) }
}
#[inline]
pub fn update<F>(&mut self, f: F)
where
F: FnOnce(&T) -> T,
{
let new_value = f(self.get());
self.store(new_value);
}
#[inline]
pub fn version(&self) -> usize {
self.shared.global_version.load(Ordering::Acquire)
}
#[inline]
pub fn garbage_count(&self) -> usize {
self.garbage.len()
}
pub fn collect(&mut self) {
let current_version = self.shared.global_version.load(Ordering::Acquire);
let safety_limit = current_version.saturating_sub(2);
let mut min_active = current_version;
heavy_barrier::<RP>();
let mut shared_readers = self.shared.readers.lock();
for arc_slot in shared_readers.iter() {
let version = arc_slot.active_version.load(Ordering::Acquire);
if version != INACTIVE_VERSION {
min_active = min_active.min(version);
}
}
shared_readers.retain(|arc_slot| Arc::strong_count(arc_slot) > 1);
drop(shared_readers);
let reclaim_threshold = min_active.min(safety_limit);
self.shared
.min_active_version
.store(reclaim_threshold, Ordering::Release);
self.garbage.collect(reclaim_threshold, current_version);
}
}
pub struct SwmrReaderFactory<T: 'static, const RP: bool = false> {
shared: Arc<SharedState<T, RP>>,
}
impl<T: 'static, const RP: bool> SwmrReaderFactory<T, RP> {
#[inline]
pub fn local_reader(&self) -> LocalReader<T, RP> {
LocalReader::new(self.shared.clone())
}
}
impl<T: 'static, const RP: bool> Clone for SwmrReaderFactory<T, RP> {
#[inline]
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
impl<T: 'static, const RP: bool> fmt::Debug for SwmrReaderFactory<T, RP> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SwmrReaderFactory")
.field("read_preferred", &RP)
.finish()
}
}
pub struct SwmrCellBuilder<T, const RP: bool = false> {
auto_reclaim_threshold: Option<usize>,
marker: PhantomData<T>,
}
impl<T: 'static, const RP: bool> SwmrCellBuilder<T, RP> {
#[inline]
pub fn auto_reclaim_threshold(mut self, threshold: Option<usize>) -> Self {
self.auto_reclaim_threshold = threshold;
self
}
pub fn build(self, data: T) -> SwmrCell<T, RP> {
let shared = Arc::new(SharedState {
global_version: AtomicUsize::new(0),
min_active_version: AtomicUsize::new(0),
ptr: AtomicPtr::new(Box::into_raw(Box::new(data))),
readers: Mutex::new(Vec::new()),
});
SwmrCell {
shared,
garbage: GarbageSet::new(),
auto_reclaim_threshold: self.auto_reclaim_threshold,
}
}
}
impl<T: 'static, const RP: bool> Default for SwmrCellBuilder<T, RP> {
fn default() -> Self {
SwmrCellBuilder {
auto_reclaim_threshold: Some(AUTO_RECLAIM_THRESHOLD),
marker: PhantomData,
}
}
}
impl<T: 'static> SwmrCellBuilder<T, false> {
#[inline]
pub fn read_preferred(self) -> SwmrCellBuilder<T, true> {
SwmrCellBuilder {
auto_reclaim_threshold: self.auto_reclaim_threshold,
marker: PhantomData,
}
}
}
struct GarbageSet<T> {
queue: VecDeque<(usize, Box<T>)>,
}
impl<T> GarbageSet<T> {
fn new() -> Self {
Self {
queue: VecDeque::new(),
}
}
#[inline]
fn len(&self) -> usize {
self.queue.len()
}
#[inline]
fn back(&self) -> Option<&T> {
self.queue.back().map(|(_, boxed)| boxed.as_ref())
}
#[inline]
fn add(&mut self, node: Box<T>, current_version: usize) {
self.queue.push_back((current_version, node));
}
#[inline]
fn collect(&mut self, min_active_version: usize, _current_version: usize) {
while let Some((version, _)) = self.queue.front() {
if *version >= min_active_version {
break;
}
self.queue.pop_front(); }
}
}
#[derive(Debug)]
#[repr(align(64))]
pub(crate) struct ReaderSlot {
pub(crate) active_version: AtomicUsize,
}
#[repr(align(64))]
pub(crate) struct SharedState<T: 'static, const RP: bool = false> {
pub(crate) global_version: AtomicUsize,
pub(crate) min_active_version: AtomicUsize,
pub(crate) ptr: AtomicPtr<T>,
pub(crate) readers: Mutex<Vec<Arc<ReaderSlot>>>,
}
impl<T: 'static, const RP: bool> Drop for SharedState<T, RP> {
fn drop(&mut self) {
let ptr = self.ptr.load(Ordering::Acquire);
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
}
pub struct LocalReader<T: 'static, const RP: bool = false> {
slot: Arc<ReaderSlot>,
shared: Arc<SharedState<T, RP>>,
pin_count: Cell<usize>,
}
impl<T: 'static, const RP: bool> LocalReader<T, RP> {
fn new(shared: Arc<SharedState<T, RP>>) -> Self {
let slot = Arc::new(ReaderSlot {
active_version: AtomicUsize::new(INACTIVE_VERSION),
});
shared.readers.lock().push(Arc::clone(&slot));
LocalReader {
slot,
shared,
pin_count: Cell::new(0),
}
}
#[inline]
pub fn is_pinned(&self) -> bool {
self.pin_count.get() > 0
}
#[inline]
pub fn version(&self) -> usize {
self.shared.global_version.load(Ordering::Acquire)
}
#[inline]
pub fn pin(&self) -> PinGuard<'_, T, RP> {
let pin_count = self.pin_count.get();
if pin_count > 0 {
self.pin_count.set(pin_count + 1);
let ptr = self.shared.ptr.load(Ordering::Acquire);
let version = self.slot.active_version.load(Ordering::Acquire);
return PinGuard {
local: self,
ptr,
version,
};
}
loop {
let current_version = self.shared.global_version.load(Ordering::Acquire);
self.slot
.active_version
.store(current_version, Ordering::Release);
light_barrier::<RP>();
let min_active = self.shared.min_active_version.load(Ordering::Acquire);
if current_version >= min_active {
break;
}
core::hint::spin_loop();
}
self.pin_count.set(1);
let ptr = self.shared.ptr.load(Ordering::Acquire);
let version = self.slot.active_version.load(Ordering::Acquire);
PinGuard {
local: self,
ptr,
version,
}
}
#[inline]
pub fn reader_factory(&self) -> SwmrReaderFactory<T, RP> {
SwmrReaderFactory {
shared: self.shared.clone(),
}
}
#[inline]
pub fn into_swmr(self) -> SwmrReaderFactory<T, RP> {
SwmrReaderFactory {
shared: self.shared.clone(),
}
}
}
impl<T: 'static, const RP: bool> Clone for LocalReader<T, RP> {
#[inline]
fn clone(&self) -> Self {
Self::new(self.shared.clone())
}
}
#[must_use]
pub struct PinGuard<'a, T: 'static, const RP: bool = false> {
local: &'a LocalReader<T, RP>,
ptr: *const T,
version: usize,
}
impl<T: 'static, const RP: bool> PinGuard<'_, T, RP> {
#[inline]
pub fn version(&self) -> usize {
self.version
}
}
impl<'a, T, const RP: bool> Deref for PinGuard<'a, T, RP> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.ptr }
}
}
impl<'a, T, const RP: bool> Clone for PinGuard<'a, T, RP> {
#[inline]
fn clone(&self) -> Self {
let pin_count = self.local.pin_count.get();
assert!(
pin_count > 0,
"BUG: Cloning a PinGuard in an unpinned state (pin_count = 0). \
This indicates incorrect API usage or a library bug."
);
self.local.pin_count.set(pin_count + 1);
PinGuard {
local: self.local,
ptr: self.ptr,
version: self.version,
}
}
}
impl<'a, T, const RP: bool> Drop for PinGuard<'a, T, RP> {
#[inline]
fn drop(&mut self) {
let pin_count = self.local.pin_count.get();
assert!(
pin_count > 0,
"BUG: Dropping a PinGuard in an unpinned state (pin_count = 0). \
This indicates incorrect API usage or a library bug."
);
if pin_count == 1 {
self.local
.slot
.active_version
.store(INACTIVE_VERSION, Ordering::Release);
}
self.local.pin_count.set(pin_count - 1);
}
}
impl<T: 'static, const RP: bool> AsRef<T> for PinGuard<'_, T, RP> {
#[inline]
fn as_ref(&self) -> &T {
self.deref()
}
}
impl<T: Default + 'static, const RP: bool> Default for SwmrCell<T, RP> {
#[inline]
fn default() -> Self {
SwmrCell::from(T::default())
}
}
impl<T: 'static, const RP: bool> From<T> for SwmrCell<T, RP> {
#[inline]
fn from(value: T) -> Self {
SwmrCellBuilder::default().build(value)
}
}
impl<T: fmt::Debug + 'static, const RP: bool> fmt::Debug for SwmrCell<T, RP> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SwmrCell")
.field("value", self.get())
.field("version", &self.version())
.field("garbage_count", &self.garbage_count())
.field("reclaim_threshold", &self.auto_reclaim_threshold)
.field("is_read-preferred", &RP)
.finish()
}
}
impl<T: 'static, const RP: bool> fmt::Debug for LocalReader<T, RP> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalReader")
.field("is_pinned", &self.is_pinned())
.field("version", &self.version())
.field("is_read-preferred", &RP)
.finish()
}
}
impl<T: fmt::Debug + 'static, const RP: bool> fmt::Debug for PinGuard<'_, T, RP> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PinGuard")
.field("value", &self.deref())
.field("version", &self.version)
.field("is_read-preferred", &RP)
.finish()
}
}