use cacheguard::CacheGuard;
use core::sync::atomic::{AtomicBool, Ordering};
use lock_api::{GuardSend, RawMutex, RawMutexTimed};
use std::{
num::NonZeroUsize,
ops::{Deref, DerefMut},
sync::atomic::AtomicUsize,
thread::available_parallelism,
time::{Duration, Instant},
};
lazy_static::lazy_static! {
static ref PARALLELISM: usize = available_parallelism()
.unwrap_or(NonZeroUsize::new(1).unwrap())
.into();
pub static ref LOCK_TIME_CNT: AtomicUsize = AtomicUsize::new(0);
}
pub type Mutex<T> = lock_api::Mutex<RawMutexLock, T>;
pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutexLock, T>;
pub struct RawMutexLock {
locked: CacheGuard<AtomicBool>,
}
impl RawMutexLock {
#[inline(never)]
fn lock_no_inline(&self, timeout: Option<Instant>) -> bool {
spin_cond(|| self.try_lock(), timeout)
}
}
unsafe impl RawMutex for RawMutexLock {
#[allow(clippy::declare_interior_mutable_const)]
const INIT: RawMutexLock = RawMutexLock {
locked: CacheGuard::new(AtomicBool::new(false)),
};
type GuardMarker = GuardSend;
#[inline(always)]
fn lock(&self) {
if self.try_lock() {
return;
}
self.lock_no_inline(None);
}
#[inline(always)]
fn try_lock(&self) -> bool {
self.locked
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
#[inline(always)]
unsafe fn unlock(&self) {
self.locked.store(false, Ordering::Release);
}
}
unsafe impl RawMutexTimed for RawMutexLock {
type Duration = Duration;
type Instant = Instant;
#[inline(always)]
fn try_lock_for(&self, timeout: Self::Duration) -> bool {
self.try_lock_until(Instant::now() + timeout)
}
#[inline(always)]
fn try_lock_until(&self, timeout: Self::Instant) -> bool {
if self.try_lock() {
return true;
}
self.lock_no_inline(Some(timeout))
}
}
#[allow(dead_code)]
pub struct TimedMutexGuard<'a, T> {
guard: MutexGuard<'a, T>,
instant: Instant,
}
impl<'a, T> From<MutexGuard<'a, T>> for TimedMutexGuard<'a, T> {
fn from(value: MutexGuard<'a, T>) -> Self {
Self {
guard: value,
instant: Instant::now(),
}
}
}
impl<'a, T> Deref for TimedMutexGuard<'a, T> {
type Target = MutexGuard<'a, T>;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<'a, T> DerefMut for TimedMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
impl<'a, T> Drop for TimedMutexGuard<'a, T> {
fn drop(&mut self) {
LOCK_TIME_CNT.fetch_add(
self.instant.elapsed().as_nanos() as usize,
Ordering::Relaxed,
);
}
}
#[inline(always)]
fn spin_wait(count: u8) {
for _ in 0..count {
std::hint::spin_loop();
}
}
#[inline(always)]
pub fn spin_cond<F: Fn() -> bool>(cond: F, timeout: Option<Instant>) -> bool {
macro_rules! timout_check {
() => {
if let Some(timeout) = timeout {
if Instant::now() >= timeout {
return false;
}
}
};
}
if *PARALLELISM == 1 {
loop {
if cond() {
return true;
} else {
timout_check!();
std::thread::yield_now();
}
}
}
const NO_YIELD: usize = 1;
const SPIN_YIELD: usize = 1;
const ZERO_SLEEP: usize = 2;
const SPINS: u32 = 8;
let mut spins: u32 = SPINS;
for _ in 0..NO_YIELD {
for _ in 0..SPINS / 2 {
if cond() {
return true;
}
std::hint::spin_loop();
}
}
loop {
for _ in 0..SPIN_YIELD {
spin_wait(fastrand::u8(..0x7f));
for _ in 0..spins {
if cond() {
return true;
}
}
}
for _ in 0..ZERO_SLEEP {
std::thread::sleep(Duration::from_nanos(0));
for _ in 0..spins {
if cond() {
return true;
}
}
}
if spins < (1 << 30) {
spins <<= 1;
}
timout_check!();
}
}