use crate::sync::Backoff;
use crossbeam_utils::CachePadded;
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
type State = usize;
const UNLOCKED: State = 0;
const READERS_PARKED: State = 0b0001;
const WRITERS_PARKED: State = 0b0010;
const ONE_READER: State = 0b0100;
const ONE_WRITER: State = !(READERS_PARKED | WRITERS_PARKED);
pub struct SpinCell<T> {
state: CachePadded<AtomicUsize>,
ref_count: CachePadded<AtomicUsize>,
val: CachePadded<UnsafeCell<T>>,
}
#[inline]
fn is_writer_locked(state: State) -> bool {
(state & ONE_WRITER) == ONE_WRITER
}
#[inline]
fn readers_count(state: State) -> usize {
if is_writer_locked(state) {
0
} else {
(state & ONE_WRITER) >> 2
}
}
unsafe impl<T: Send> Send for SpinCell<T> {}
unsafe impl<T: Send> Sync for SpinCell<T> {}
impl<T> SpinCell<T> {
pub fn new(val: T) -> Self {
Self {
state: CachePadded::new(AtomicUsize::new(UNLOCKED)),
ref_count: CachePadded::new(AtomicUsize::new(1)),
val: CachePadded::new(UnsafeCell::new(val)),
}
}
pub fn get_ref_count(&self) -> usize {
self.ref_count.load(Ordering::Acquire)
}
pub fn get_shared_locked(&self) -> usize {
readers_count(self.state.load(Ordering::Acquire))
}
#[inline]
pub fn is_locked_exclusive(&self) -> bool {
is_writer_locked(self.state.load(Ordering::Relaxed))
}
#[inline]
pub fn is_locked_shared(&self) -> bool {
let s = self.state.load(Ordering::Relaxed);
!is_writer_locked(s) && readers_count(s) > 0
}
pub fn is_locked(&self) -> bool {
self.state.load(Ordering::Relaxed) != UNLOCKED
}
pub fn lock_exclusive(&self) -> ExclusiveGuard<'_, T> {
if self
.state
.compare_exchange(UNLOCKED, ONE_WRITER, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return ExclusiveGuard { lock: self };
}
self.lock_exclusive_slow()
}
pub fn unlock_exclusive(&self) {
if self
.state
.compare_exchange(ONE_WRITER, UNLOCKED, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
return;
}
let state = self.state.load(Ordering::Relaxed);
let parked = state & (READERS_PARKED | WRITERS_PARKED);
if parked & READERS_PARKED != 0 {
self.state.store(
if parked & WRITERS_PARKED != 0 {
WRITERS_PARKED
} else {
UNLOCKED
},
Ordering::Release,
);
} else if parked & WRITERS_PARKED != 0 {
self.state.store(UNLOCKED, Ordering::Release);
}
}
fn lock_exclusive_slow(&self) -> ExclusiveGuard<'_, T> {
let backoff = Backoff::new();
let mut state = self.state.load(Ordering::Relaxed);
loop {
while state & ONE_WRITER == 0 {
match self.state.compare_exchange_weak(
state,
state | ONE_WRITER,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return ExclusiveGuard { lock: self },
Err(e) => state = e,
}
}
if state & WRITERS_PARKED == 0 {
match self.state.compare_exchange_weak(
state,
state | WRITERS_PARKED,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => state |= WRITERS_PARKED,
Err(e) => {
state = e;
continue;
}
}
}
backoff.snooze();
state = self.state.load(Ordering::Relaxed);
}
}
pub fn lock_shared(&self) -> SharedGuard<'_, T> {
let state = self.state.load(Ordering::Relaxed);
if let Some(new_state) = state.checked_add(ONE_READER) {
if new_state & ONE_WRITER != ONE_WRITER {
if self
.state
.compare_exchange(state, new_state, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return SharedGuard { lock: self };
}
}
}
self.lock_shared_slow()
}
fn lock_shared_slow(&self) -> SharedGuard<'_, T> {
let backoff = Backoff::new();
loop {
let mut state = self.state.load(Ordering::Relaxed);
while !is_writer_locked(state) && (state & WRITERS_PARKED == 0) {
if let Some(new_state) = state.checked_add(ONE_READER) {
if self
.state
.compare_exchange_weak(state, new_state, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return SharedGuard { lock: self };
}
}
state = self.state.load(Ordering::Relaxed);
}
if !backoff.is_completed() {
backoff.snooze();
continue;
}
if state & READERS_PARKED == 0 {
if self
.state
.compare_exchange_weak(
state,
state | READERS_PARKED,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_err()
{
continue;
}
}
backoff.reset();
while is_writer_locked(self.state.load(Ordering::Relaxed)) {
backoff.spin();
}
backoff.reset();
}
}
#[inline]
pub fn unlock_shared(&self) {
let prev_state = self.state.fetch_sub(ONE_READER, Ordering::Release);
if prev_state == (ONE_READER | WRITERS_PARKED) {
let _ = self.state.compare_exchange(
WRITERS_PARKED,
UNLOCKED,
Ordering::Release,
Ordering::Relaxed,
);
}
}
#[inline]
pub fn unlock_all_shared(&self) {
loop {
let state = self.state.load(Ordering::Acquire);
let readers_count = readers_count(state);
if readers_count == 0 {
return;
}
let mut new_state = state & !ONE_READER.wrapping_mul(readers_count);
if state & WRITERS_PARKED != 0 {
new_state |= WRITERS_PARKED;
}
if self
.state
.compare_exchange(state, new_state, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
}
pub struct ExclusiveGuard<'a, T> {
lock: &'a SpinCell<T>,
}
impl<'a, T> Deref for ExclusiveGuard<'a, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.val.get() }
}
}
impl<'a, T> DerefMut for ExclusiveGuard<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.val.get() }
}
}
impl<'a, T> Drop for ExclusiveGuard<'a, T> {
#[inline]
fn drop(&mut self) {
self.lock.unlock_exclusive();
}
}
#[derive(Debug)]
pub struct SharedGuard<'a, T> {
lock: &'a SpinCell<T>,
}
impl<'a, T> Deref for SharedGuard<'a, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.val.get() }
}
}
impl<'a, T> Drop for SharedGuard<'a, T> {
#[inline]
fn drop(&mut self) {
self.lock.unlock_shared();
}
}
impl<T: fmt::Debug> fmt::Debug for SpinCell<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.load(Ordering::Acquire);
let readers = state / ONE_READER;
f.debug_struct("SpinLockCell")
.field("state", &format!("{:b}", state))
.field("exclusive_locked", &(state & ONE_WRITER == ONE_WRITER))
.field("readers_count", &readers)
.field("readers_parked", &(state & READERS_PARKED != 0))
.field("writers_parked", &(state & WRITERS_PARKED != 0))
.finish()
}
}