use core::{
cell::Cell,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
use wdk_sys::ntddk::PsGetCurrentThreadId;
pub struct ReentrantSpinLock {
lock: AtomicBool,
owner: AtomicUsize,
recursion_count: Cell<usize>,
}
pub struct ReentrantSpinGuard<'a> {
spinlock: &'a ReentrantSpinLock,
}
impl Default for ReentrantSpinLock {
fn default() -> Self {
Self::new()
}
}
impl ReentrantSpinLock {
pub fn new() -> Self {
ReentrantSpinLock {
lock: AtomicBool::new(false),
owner: AtomicUsize::new(0),
recursion_count: Cell::new(0),
}
}
pub fn lock(&self) -> ReentrantSpinGuard {
let current_thread_id = unsafe { PsGetCurrentThreadId() } as usize;
if self.owner.load(Ordering::Relaxed) == current_thread_id {
self.recursion_count.set(self.recursion_count.get() + 1);
} else {
while self
.lock
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
}
self.owner.store(current_thread_id, Ordering::Relaxed);
self.recursion_count.set(1);
}
ReentrantSpinGuard { spinlock: self }
}
}
impl Drop for ReentrantSpinGuard<'_> {
fn drop(&mut self) {
let count = self.spinlock.recursion_count.get();
if count > 1 {
self.spinlock.recursion_count.set(count - 1);
} else {
self.spinlock.recursion_count.set(0);
self.spinlock.owner.store(0, Ordering::Relaxed);
self.spinlock.lock.store(false, Ordering::Release);
}
}
}
impl core::ops::Deref for ReentrantSpinGuard<'_> {
type Target = ReentrantSpinLock;
fn deref(&self) -> &Self::Target {
self.spinlock
}
}