#![warn(missing_debug_implementations, missing_docs, unused_import_braces)]
#![cfg_attr(not(test), no_std)]
pub mod option;
use core::cell::UnsafeCell;
use core::hint::spin_loop;
#[cfg(not(debug_assertions))]
use core::hint::unreachable_unchecked;
use core::mem::swap;
use core::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct AtomicSwap<T> {
state: AtomicUsize,
data: UnsafeCell<T>,
}
impl<T> AtomicSwap<T> {
pub const fn new(value: T) -> Self {
Self {
state: AtomicUsize::new(AtomicSwapState::Assigned(0).into_usize()),
data: UnsafeCell::new(value),
}
}
#[inline]
pub fn swap(&self, value: T) -> T {
self.swap_hint(value, spin_loop)
}
pub fn swap_hint(&self, mut value: T, mut spin_loop_hint: impl FnMut()) -> T {
let mut state = self.state.load(Ordering::Acquire);
loop {
let state_enum = AtomicSwapState::from_usize(state);
match &state_enum {
AtomicSwapState::Assigned(0) => {
match self.state.compare_exchange_weak(
state,
AtomicSwapState::Assigning.into_usize(),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
unsafe {
swap(&mut value, &mut *self.data.get());
};
self.state
.compare_exchange(
AtomicSwapState::Assigning.into_usize(),
AtomicSwapState::Assigned(0).into_usize(),
Ordering::AcqRel,
Ordering::Acquire,
)
.expect("Assigning was changed improperly!");
return value;
}
Err(new_state) => state = new_state,
}
}
AtomicSwapState::Assigned(_) | AtomicSwapState::Assigning => {
spin_loop_hint();
state = self.state.load(Ordering::Acquire);
}
}
}
}
#[inline]
pub fn clone_inner(&self) -> T
where
T: Clone + Sync,
{
self.clone_inner_hint(spin_loop)
}
pub fn clone_inner_hint(&self, mut spin_loop_hint: impl FnMut()) -> T
where
T: Clone + Sync,
{
let mut state = self.state.load(Ordering::Acquire);
loop {
match AtomicSwapState::from_usize(state) {
AtomicSwapState::Assigned(readers) => match self.state.compare_exchange_weak(
state,
AtomicSwapState::Assigned(readers + 1).into_usize(),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => unsafe {
let out = (&*(self.data.get())).clone();
let result = self.state.fetch_sub(1, Ordering::AcqRel);
debug_assert_ne!(result, 0);
return out;
},
Err(new_state) => state = new_state,
},
AtomicSwapState::Assigning => {
spin_loop_hint();
state = self.state.load(Ordering::Acquire);
}
}
}
}
pub fn get_mut(&mut self) -> &mut T {
match AtomicSwapState::from_usize(self.state.load(Ordering::Acquire)) {
AtomicSwapState::Assigned(0) => unsafe { &mut *(self.data.get() as *mut T) },
AtomicSwapState::Assigning | AtomicSwapState::Assigned(_) => {
#[cfg(debug_assertions)]
unreachable!("We should have exclusive access and therefore nothing should be assigning/reading");
#[cfg(not(debug_assertions))]
unsafe {
unreachable_unchecked()
}
}
}
}
}
impl<T> Drop for AtomicSwap<T> {
fn drop(&mut self) {
match AtomicSwapState::from_usize(self.state.load(Ordering::Acquire)) {
AtomicSwapState::Assigning => {
#[cfg(debug_assertions)]
unreachable!("Should not have dropped while assigning!");
#[cfg(not(debug_assertions))]
unsafe {
unreachable_unchecked()
}
}
AtomicSwapState::Assigned(0) => {}
AtomicSwapState::Assigned(_) => {
#[cfg(debug_assertions)]
unreachable!("Should not drop while has readers!");
#[cfg(not(debug_assertions))]
unsafe {
unreachable_unchecked()
}
}
}
}
}
impl<T> EnsureSend for AtomicSwap<T> where T: Send {}
unsafe impl<T> Sync for AtomicSwap<T> where T: Send {}
impl<T> Default for AtomicSwap<T>
where
T: Default,
{
#[inline]
fn default() -> Self {
Self::new(T::default())
}
}
trait EnsureSend: Send {}
#[derive(Copy, Clone, Debug)]
enum AtomicSwapState {
Assigning,
Assigned(usize),
}
impl AtomicSwapState {
pub const fn into_usize(self) -> usize {
match self {
AtomicSwapState::Assigning => 0,
AtomicSwapState::Assigned(readers) => 1 + readers,
}
}
pub const fn from_usize(size: usize) -> Self {
match size {
0 => Self::Assigning,
x => Self::Assigned(x - 1),
}
}
}
#[cfg(test)]
pub mod test {
use crate::AtomicSwap;
use rand::{thread_rng, Rng};
use std::string::String;
use std::sync::Arc;
use std::thread::spawn;
use std::vec::Vec;
#[derive(Default, Debug, Eq, PartialEq, Clone)]
pub struct ComplexType {
string: String,
number: usize,
vector: Vec<i64>,
}
impl ComplexType {
pub fn generate<R: Rng + ?Sized>(rng: &mut R) -> Self {
let string_length = rng.gen_range(10..100);
let mut string = String::with_capacity(string_length);
let mut temp = [0; 4];
for _ in 0..string_length {
string += rng.gen::<char>().encode_utf8(&mut temp);
}
let vec_length = rng.gen_range(100..1000);
let mut vector = Vec::with_capacity(vec_length);
for _ in 0..vec_length {
vector.push(rng.gen());
}
Self {
string,
number: rng.gen(),
vector,
}
}
pub fn generate_option<R: Rng + ?Sized>(rng: &mut R) -> Option<Self> {
if rng.gen() {
Some(Self::generate(rng))
} else {
None
}
}
}
#[test]
fn slam_test() {
const OPS_PER_THREAD: usize = 1 << 10;
const NUM_THREADS: usize = 1 << 4;
let mut rng = thread_rng();
for _round_num in 0..4 {
let swap = Arc::new(AtomicSwap::new(ComplexType::generate(&mut rng)));
let mut threads = Vec::with_capacity(NUM_THREADS);
for _thread_num in 0..NUM_THREADS {
let swap_clone = swap.clone();
threads.push(spawn(move || {
let mut rng = thread_rng();
for _op_num in 0..OPS_PER_THREAD {
let operation = rng.gen_range(0..=1);
match operation {
0 => {
swap_clone.swap(ComplexType::generate(&mut rng));
}
1 => {
swap_clone.clone_inner();
}
_ => unreachable!(),
}
}
}));
}
for thread in threads {
thread.join().expect("Could not join");
}
}
}
#[test]
fn value_test() {
const NUM_OPS: usize = 1 << 10;
let mut rng = thread_rng();
for _round_num in 0..4 {
let mut last_value = ComplexType::generate(&mut rng);
let mut swap = AtomicSwap::new(last_value.clone());
let mut last_op = -1;
for _op_num in 0..NUM_OPS {
let operation = rng.gen_range(0..=2);
match operation {
0 => {
let new_value = ComplexType::generate(&mut rng);
assert_eq!(
last_value,
swap.swap(new_value.clone()),
"Last op: {}",
last_op
);
last_value = new_value;
}
1 => {
assert_eq!(last_value, swap.clone_inner(), "Last op: {}", last_op);
}
2 => {
let swap_ref = swap.get_mut();
assert_eq!(swap_ref, &mut last_value, "Last op: {}", last_op);
*swap_ref = ComplexType::generate(&mut rng);
last_value = swap_ref.clone();
}
_ => unreachable!(),
}
last_op = operation;
}
}
}
}