seminix 0.1.60

seminix 内核标准库
Documentation
//! 自旋锁

use core::{
    cell::UnsafeCell,
    fmt,
    ops::{Deref, DerefMut},
    sync::atomic::AtomicBool,
};

use crate::{
    irq::irqflags::{local_irq_disable, local_irq_enable, local_irq_restore, local_irq_save},
    processor::cpu_relax,
    sched::preempt::{preempt_disable, preempt_enable},
};

/// 自旋锁
pub struct Spinlock<T: ?Sized> {
    inner: AtomicBool,
    data: UnsafeCell<T>,
}

unsafe impl<T: ?Sized + Send> Send for Spinlock<T> {}
unsafe impl<T: ?Sized + Send> Sync for Spinlock<T> {}

/// 锁保护区
pub struct SpinlockGuard<'a, T: ?Sized + 'a> {
    lock: &'a Spinlock<T>,
}

impl<T: ?Sized> !Send for SpinlockGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for SpinlockGuard<'_, T> {}

impl<T> Spinlock<T> {
    /// 构造锁
    pub const fn new(t: T) -> Spinlock<T> {
        Self { inner: AtomicBool::new(false), data: UnsafeCell::new(t) }
    }
}

impl<T: ?Sized> Spinlock<T> {
    #[inline]
    fn __try_lock_weak(&self) -> Result<bool, bool> {
        self.inner.compare_exchange_weak(
            false,
            true,
            core::sync::atomic::Ordering::Acquire,
            core::sync::atomic::Ordering::Relaxed,
        )
    }

    #[inline]
    fn try_lock_weak(&self) -> Option<SpinlockGuard<'_, T>> {
        if self.__try_lock_weak().is_ok() { Some(SpinlockGuard { lock: self }) } else { None }
    }

    /// 锁住
    pub fn lock(&self) -> SpinlockGuard<'_, T> {
        preempt_disable();
        loop {
            if let Some(guard) = self.try_lock_weak() {
                break guard;
            }

            while self.is_locked() {
                cpu_relax();
            }
        }
    }

    /// 尝试锁
    pub fn try_lock(&self) -> Option<SpinlockGuard<'_, T>> {
        preempt_disable();
        if self.__try_lock_weak().is_ok() {
            Some(SpinlockGuard { lock: self })
        } else {
            preempt_enable();
            None
        }
    }

    /// 解锁
    #[inline]
    pub fn unlock(guard: SpinlockGuard<'_, T>) {
        drop(guard);
    }

    /// 是否被锁
    #[inline]
    pub fn is_locked(&self) -> bool {
        self.inner.load(core::sync::atomic::Ordering::Relaxed)
    }

    /// 转换进
    #[inline]
    pub fn into_inner(self) -> T
    where
        T: Sized,
    {
        self.data.into_inner()
    }

    /// 获取可变引用
    #[inline]
    pub fn get_mut(&mut self) -> &mut T {
        self.data.get_mut()
    }
}

impl<T> From<T> for Spinlock<T> {
    fn from(value: T) -> Self {
        Spinlock::new(value)
    }
}

impl<T: Default> Default for Spinlock<T> {
    fn default() -> Self {
        Spinlock::new(Default::default())
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for Spinlock<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let mut d = f.debug_struct("SpinLock");
        if let Some(guard) = self.try_lock() {
            d.field("data", &&*guard);
        }
        d.finish_non_exhaustive()
    }
}

impl<T: ?Sized> Deref for SpinlockGuard<'_, T> {
    type Target = T;

    fn deref(&self) -> &T {
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized> DerefMut for SpinlockGuard<'_, T> {
    fn deref_mut(&mut self) -> &mut T {
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized> Drop for SpinlockGuard<'_, T> {
    fn drop(&mut self) {
        self.lock.inner.store(false, core::sync::atomic::Ordering::Release);
        preempt_enable();
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for SpinlockGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(&**self, f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for SpinlockGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

/// 中断锁保护
pub struct SpinlockIrqGuard<'a, T: ?Sized + 'a> {
    lock: &'a Spinlock<T>,
}

impl<T: ?Sized> !Send for SpinlockIrqGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for SpinlockIrqGuard<'_, T> {}

impl<T: ?Sized> Spinlock<T> {
    #[inline]
    fn try_lock_irq_weak(&self) -> Option<SpinlockIrqGuard<'_, T>> {
        if self.__try_lock_weak().is_ok() { Some(SpinlockIrqGuard { lock: self }) } else { None }
    }

    /// 锁并关中断
    pub fn lock_irq(&self) -> SpinlockIrqGuard<'_, T> {
        local_irq_disable();
        preempt_disable();
        loop {
            if let Some(guard) = self.try_lock_irq_weak() {
                break guard;
            }

            while self.is_locked() {
                cpu_relax();
            }
        }
    }

    /// 尝试锁并关中断
    pub fn try_lock_irq(&self) -> Option<SpinlockIrqGuard<'_, T>> {
        local_irq_disable();
        preempt_disable();
        if self.__try_lock_weak().is_ok() {
            Some(SpinlockIrqGuard { lock: self })
        } else {
            local_irq_enable();
            preempt_enable();
            None
        }
    }

    /// 解锁并开中断
    #[inline]
    pub fn unlock_irq(guard: SpinlockIrqGuard<'_, T>) {
        drop(guard);
    }
}

impl<T: ?Sized> Deref for SpinlockIrqGuard<'_, T> {
    type Target = T;

    fn deref(&self) -> &T {
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized> DerefMut for SpinlockIrqGuard<'_, T> {
    fn deref_mut(&mut self) -> &mut T {
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized> Drop for SpinlockIrqGuard<'_, T> {
    fn drop(&mut self) {
        self.lock.inner.store(false, core::sync::atomic::Ordering::Release);
        local_irq_enable();
        preempt_enable();
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for SpinlockIrqGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(&**self, f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for SpinlockIrqGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

/// 实现 irq flags lock
pub struct SpinlockIrqFlagsGuard<'a, T: ?Sized + 'a> {
    lock: &'a Spinlock<T>,
    flags: usize,
}

impl<T: ?Sized> !Send for SpinlockIrqFlagsGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for SpinlockIrqFlagsGuard<'_, T> {}

impl<T: ?Sized> SpinlockIrqFlagsGuard<'_, T> {
    /// 中断标记
    #[inline]
    pub fn irq_flags(&self) -> usize {
        self.flags
    }
}

impl<T: ?Sized> Spinlock<T> {
    #[inline]
    fn try_lock_irq_save_weak(&self, flags: usize) -> Option<SpinlockIrqFlagsGuard<'_, T>> {
        if self.__try_lock_weak().is_ok() {
            Some(SpinlockIrqFlagsGuard { lock: self, flags })
        } else {
            None
        }
    }

    /// 锁并且保存中断状态
    pub fn lock_irq_save(&self) -> SpinlockIrqFlagsGuard<'_, T> {
        let flags = local_irq_save();
        preempt_disable();
        loop {
            if let Some(guard) = self.try_lock_irq_save_weak(flags) {
                break guard;
            }

            while self.is_locked() {
                cpu_relax();
            }
        }
    }

    /// 尝试锁并关保存中断状态
    pub fn try_lock_irq_save(&self) -> Option<SpinlockIrqFlagsGuard<'_, T>> {
        let flags = local_irq_save();
        preempt_disable();
        if self.__try_lock_weak().is_ok() {
            Some(SpinlockIrqFlagsGuard { lock: self, flags })
        } else {
            local_irq_restore(flags);
            preempt_enable();
            None
        }
    }

    /// 解锁并恢复 irq 状态
    #[inline]
    pub fn unlock_irq_restore(guard: SpinlockIrqFlagsGuard<'_, T>) {
        drop(guard);
    }
}

impl<T: ?Sized> Deref for SpinlockIrqFlagsGuard<'_, T> {
    type Target = T;

    fn deref(&self) -> &T {
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized> DerefMut for SpinlockIrqFlagsGuard<'_, T> {
    fn deref_mut(&mut self) -> &mut T {
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized> Drop for SpinlockIrqFlagsGuard<'_, T> {
    fn drop(&mut self) {
        self.lock.inner.store(false, core::sync::atomic::Ordering::Release);
        local_irq_restore(self.flags);
        preempt_enable();
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for SpinlockIrqFlagsGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(&**self, f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for SpinlockIrqFlagsGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}