Documentation
use cacheguard::CacheGuard;
use core::sync::atomic::{AtomicBool, Ordering};
use lock_api::{GuardSend, RawMutex, RawMutexTimed};
use std::{
    num::NonZeroUsize,
    ops::{Deref, DerefMut},
    sync::atomic::AtomicUsize,
    thread::available_parallelism,
    time::{Duration, Instant},
};

lazy_static::lazy_static! {
    static ref PARALLELISM: usize = available_parallelism()
        .unwrap_or(NonZeroUsize::new(1).unwrap())
        .into();
    pub static ref LOCK_TIME_CNT: AtomicUsize = AtomicUsize::new(0);
}

pub type Mutex<T> = lock_api::Mutex<RawMutexLock, T>;
pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutexLock, T>;

pub struct RawMutexLock {
    locked: CacheGuard<AtomicBool>,
}

impl RawMutexLock {
    #[inline(never)]
    fn lock_no_inline(&self, timeout: Option<Instant>) -> bool {
        spin_cond(|| self.try_lock(), timeout)
    }
}

unsafe impl RawMutex for RawMutexLock {
    #[allow(clippy::declare_interior_mutable_const)]
    const INIT: RawMutexLock = RawMutexLock {
        locked: CacheGuard::new(AtomicBool::new(false)),
    };

    type GuardMarker = GuardSend;

    #[inline(always)]
    fn lock(&self) {
        if self.try_lock() {
            return;
        }
        self.lock_no_inline(None);
    }

    #[inline(always)]
    fn try_lock(&self) -> bool {
        self.locked
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .is_ok()
    }

    #[inline(always)]
    unsafe fn unlock(&self) {
        self.locked.store(false, Ordering::Release);
    }
}

unsafe impl RawMutexTimed for RawMutexLock {
    type Duration = Duration;
    type Instant = Instant;

    #[inline(always)]
    fn try_lock_for(&self, timeout: Self::Duration) -> bool {
        self.try_lock_until(Instant::now() + timeout)
    }

    #[inline(always)]
    fn try_lock_until(&self, timeout: Self::Instant) -> bool {
        if self.try_lock() {
            return true;
        }

        self.lock_no_inline(Some(timeout))
    }
}

#[allow(dead_code)]
pub struct TimedMutexGuard<'a, T> {
    guard: MutexGuard<'a, T>,
    instant: Instant,
}

impl<'a, T> From<MutexGuard<'a, T>> for TimedMutexGuard<'a, T> {
    fn from(value: MutexGuard<'a, T>) -> Self {
        Self {
            guard: value,
            instant: Instant::now(),
        }
    }
}

impl<'a, T> Deref for TimedMutexGuard<'a, T> {
    type Target = MutexGuard<'a, T>;

    fn deref(&self) -> &Self::Target {
        &self.guard
    }
}

impl<'a, T> DerefMut for TimedMutexGuard<'a, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.guard
    }
}

impl<'a, T> Drop for TimedMutexGuard<'a, T> {
    fn drop(&mut self) {
        LOCK_TIME_CNT.fetch_add(
            self.instant.elapsed().as_nanos() as usize,
            Ordering::Relaxed,
        );
    }
}

/// Spins in a loop for a finite amount of time.
#[inline(always)]
fn spin_wait(count: u8) {
    for _ in 0..count {
        std::hint::spin_loop();
    }
}

/// Spins until the specified condition becomes true.
/// This function uses a combination of spinning, yielding, and sleeping to
/// reduce busy waiting and improve the efficiency of concurrent systems.
///
/// The function starts with a short spinning phase, followed by a longer
/// spinning and yielding phase, then a longer spinning and yielding phase with
/// the operating system's yield function, and finally a phase with zero-length
/// sleeping and yielding.
///
/// The function uses a geometric backoff strategy to increase the spin time
/// between each phase. The spin time starts at 8 iterations and doubles after
/// each unsuccessful iteration, up to a maximum of 2^30 iterations.
///
/// The function also uses a simple randomization strategy to introduce some
/// variation into the spin time.
///
/// The function takes a closure that returns a boolean value indicating whether
/// the condition has been met. The function returns when the condition is true.
#[inline(always)]
pub fn spin_cond<F: Fn() -> bool>(cond: F, timeout: Option<Instant>) -> bool {
    macro_rules! timout_check {
        () => {
            if let Some(timeout) = timeout {
                if Instant::now() >= timeout {
                    return false;
                }
            }
        };
    }

    if *PARALLELISM == 1 {
        // For environments with limited resources, such as small Virtual Private
        // Servers (VPS) or single-core systems, active spinning may lead to inefficient
        // CPU usage without performance benefits. This is due to the fact that there's
        // only one thread of execution, making it impossible for another thread to make
        // progress during the spin wait period.
        loop {
            if cond() {
                return true;
            } else {
                timout_check!();
                std::thread::yield_now();
            }
        }
    }

    const NO_YIELD: usize = 1;
    const SPIN_YIELD: usize = 1;
    const ZERO_SLEEP: usize = 2;
    const SPINS: u32 = 8;
    let mut spins: u32 = SPINS;

    // Short spinning phase
    for _ in 0..NO_YIELD {
        for _ in 0..SPINS / 2 {
            if cond() {
                return true;
            }
            std::hint::spin_loop();
        }
    }

    // Longer spinning and yielding phase
    loop {
        for _ in 0..SPIN_YIELD {
            spin_wait(fastrand::u8(..0x7f));

            for _ in 0..spins {
                if cond() {
                    return true;
                }
            }
        }

        // Phase with zero-length sleeping and yielding
        for _ in 0..ZERO_SLEEP {
            std::thread::sleep(Duration::from_nanos(0));

            for _ in 0..spins {
                if cond() {
                    return true;
                }
            }
        }

        // Geometric backoff
        if spins < (1 << 30) {
            spins <<= 1;
        }

        timout_check!();
        // Backoff about 1ms
        // removed because 1ms is a lot. should likely try lower time and measure impact.
        // std::thread::sleep(Duration::from_nanos(1 << 20));
    }
}