#![deny(unsafe_code)]
use std::fmt;
#[cfg(not(feature = "loom"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
use std::thread::yield_now;
#[cfg(feature = "loom")]
use loom::sync::atomic::AtomicUsize;
use crate::opcode::Opcode;
use crate::sync_primitive::SyncPrimitive;
use crate::wait_queue::WaitQueue;
#[derive(Default)]
pub struct Lock {
state: AtomicUsize,
}
impl Lock {
pub const MAX_SHARED_OWNERS: usize = WaitQueue::DATA_MASK - 1;
const POISONED_STATE: usize = WaitQueue::LOCKED_FLAG;
const ACQUIRED: u8 = 0_u8;
const NOT_ACQUIRED: u8 = 1_u8;
const POISONED: u8 = 2_u8;
#[inline]
pub fn is_free(&self, mo: Ordering) -> bool {
let state = self.state.load(mo);
state != Self::POISONED_STATE && (state & WaitQueue::DATA_MASK) == 0
}
#[inline]
pub fn is_locked(&self, mo: Ordering) -> bool {
(self.state.load(mo) & WaitQueue::DATA_MASK) == WaitQueue::DATA_MASK
}
#[inline]
pub fn is_shared(&self, mo: Ordering) -> bool {
let share_state = self.state.load(mo) & WaitQueue::DATA_MASK;
share_state != 0 && share_state != WaitQueue::DATA_MASK
}
#[inline]
pub fn is_poisoned(&self, mo: Ordering) -> bool {
self.state.load(mo) == Self::POISONED_STATE
}
#[inline]
pub async fn lock_async(&self) -> bool {
self.lock_async_with(|| ()).await
}
#[inline]
pub async fn lock_async_with<F: FnOnce()>(&self, mut begin_wait: F) -> bool {
loop {
let (result, state) = self.try_lock_internal();
if result == Self::ACQUIRED {
return true;
} else if result == Self::POISONED {
return false;
}
debug_assert_eq!(result, Self::NOT_ACQUIRED);
match self
.wait_resources_async(state, Opcode::Exclusive, begin_wait)
.await
{
Ok(result) => {
debug_assert!(result == Self::ACQUIRED || result == Self::POISONED);
return result == Self::ACQUIRED;
}
Err(returned) => begin_wait = returned,
}
}
}
#[inline]
pub fn lock_sync(&self) -> bool {
self.lock_sync_with(|| ())
}
#[inline]
pub fn lock_sync_with<F: FnOnce()>(&self, mut begin_wait: F) -> bool {
loop {
let (result, state) = self.try_lock_internal();
if result == Self::ACQUIRED {
return true;
} else if result == Self::POISONED {
return false;
}
debug_assert_eq!(result, Self::NOT_ACQUIRED);
match self.wait_resources_sync(state, Opcode::Exclusive, begin_wait) {
Ok(result) => {
debug_assert!(result == Self::ACQUIRED || result == Self::POISONED);
return result == Self::ACQUIRED;
}
Err(returned) => begin_wait = returned,
}
}
}
#[inline]
pub fn try_lock(&self) -> bool {
self.try_lock_internal().0 == Self::ACQUIRED
}
#[inline]
pub async fn share_async(&self) -> bool {
self.share_async_with(|| ()).await
}
#[inline]
pub async fn share_async_with<F: FnOnce()>(&self, mut begin_wait: F) -> bool {
loop {
let (result, state) = self.try_share_internal();
if result == Self::ACQUIRED {
return true;
} else if result == Self::POISONED {
return false;
}
debug_assert_eq!(result, Self::NOT_ACQUIRED);
match self
.wait_resources_async(state, Opcode::Shared, begin_wait)
.await
{
Ok(result) => {
debug_assert!(result == Self::ACQUIRED || result == Self::POISONED);
return result == Self::ACQUIRED;
}
Err(returned) => begin_wait = returned,
}
}
}
#[inline]
pub fn share_sync(&self) -> bool {
self.share_sync_with(|| ())
}
#[inline]
pub fn share_sync_with<F: FnOnce()>(&self, mut begin_wait: F) -> bool {
loop {
let (result, state) = self.try_share_internal();
if result == Self::ACQUIRED {
return true;
} else if result == Self::POISONED {
return false;
}
debug_assert_eq!(result, Self::NOT_ACQUIRED);
match self.wait_resources_sync(state, Opcode::Shared, begin_wait) {
Ok(result) => {
debug_assert!(result == Self::ACQUIRED || result == Self::POISONED);
return result == Self::ACQUIRED;
}
Err(returned) => begin_wait = returned,
}
}
}
#[inline]
pub fn try_share(&self) -> bool {
self.try_share_internal().0 == Self::ACQUIRED
}
#[inline]
pub fn release_lock(&self) -> bool {
match self
.state
.compare_exchange(WaitQueue::DATA_MASK, 0, Release, Relaxed)
{
Ok(_) => true,
Err(state) => self.release_loop(state, Opcode::Exclusive),
}
}
#[inline]
pub fn poison_lock(&self) -> bool {
match self.state.compare_exchange(
WaitQueue::DATA_MASK,
Self::POISONED_STATE,
Release,
Relaxed,
) {
Ok(_) => true,
Err(state) => self.poison_lock_internal(state),
}
}
#[inline]
pub fn clear_poison(&self) -> bool {
self.state
.compare_exchange(Self::POISONED_STATE, 0, Release, Relaxed)
.is_ok()
}
#[inline]
pub fn release_share(&self) -> bool {
match self.state.compare_exchange(1, 0, Release, Relaxed) {
Ok(_) => true,
Err(state) => self.release_loop(state, Opcode::Shared),
}
}
#[inline]
fn try_lock_internal(&self) -> (u8, usize) {
let Err(state) = self
.state
.compare_exchange(0, WaitQueue::DATA_MASK, Acquire, Acquire)
else {
return (Self::ACQUIRED, 0);
};
self.try_lock_internal_slow(state)
}
fn try_lock_internal_slow(&self, mut state: usize) -> (u8, usize) {
loop {
if state == Self::POISONED_STATE {
return (Self::POISONED, state);
} else if state & WaitQueue::ADDR_MASK != 0 || state & WaitQueue::DATA_MASK != 0 {
return (Self::NOT_ACQUIRED, state);
}
if state & WaitQueue::DATA_MASK == 0 {
match self.state.compare_exchange(
state,
state | WaitQueue::DATA_MASK,
Acquire,
Acquire,
) {
Ok(_) => return (Self::ACQUIRED, 0),
Err(new_state) => state = new_state,
}
}
}
}
#[inline]
fn try_share_internal(&self) -> (u8, usize) {
let Err(state) = self.state.compare_exchange(0, 1, Acquire, Acquire) else {
return (Self::ACQUIRED, 0);
};
self.try_share_internal_slow(state)
}
fn try_share_internal_slow(&self, mut state: usize) -> (u8, usize) {
loop {
if state == Self::POISONED_STATE {
return (Self::POISONED, state);
} else if state & WaitQueue::ADDR_MASK != 0
|| state & WaitQueue::DATA_MASK >= Self::MAX_SHARED_OWNERS
{
return (Self::NOT_ACQUIRED, state);
}
match self
.state
.compare_exchange(state, state + 1, Acquire, Acquire)
{
Ok(_) => return (Self::ACQUIRED, 0),
Err(new_state) => state = new_state,
}
}
}
fn poison_lock_internal(&self, mut state: usize) -> bool {
loop {
if state == Self::POISONED_STATE || state & WaitQueue::DATA_MASK != WaitQueue::DATA_MASK
{
return false;
}
if state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG {
yield_now();
state = self.state.load(Relaxed);
continue;
}
match self
.state
.compare_exchange(state, Self::POISONED_STATE, AcqRel, Relaxed)
{
Ok(prev_state) => {
let entry_addr = prev_state & WaitQueue::ADDR_MASK;
if entry_addr != 0 {
WaitQueue::iter_forward(
WaitQueue::addr_to_ptr(entry_addr),
false,
|entry, _| {
entry.set_result(Self::POISONED);
false
},
);
}
return true;
}
Err(new_state) => state = new_state,
}
}
}
}
impl fmt::Debug for Lock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.load(Relaxed);
let lock_share_state = state & WaitQueue::DATA_MASK;
let locked = lock_share_state == WaitQueue::DATA_MASK;
let share_count = if locked { 0 } else { lock_share_state };
let poisoned = state == Self::POISONED_STATE;
let wait_queue_being_processed = state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG;
let wait_queue_tail_addr = state & WaitQueue::ADDR_MASK;
f.debug_struct("WaitQueue")
.field("state", &state)
.field("locked", &locked)
.field("share_count", &share_count)
.field("poisoned", &poisoned)
.field("wait_queue_being_processed", &wait_queue_being_processed)
.field("wait_queue_tail_addr", &wait_queue_tail_addr)
.finish()
}
}
impl SyncPrimitive for Lock {
#[inline]
fn state(&self) -> &AtomicUsize {
&self.state
}
#[inline]
fn max_shared_owners() -> usize {
Self::MAX_SHARED_OWNERS
}
}