use core::cell::UnsafeCell;
use core::sync::atomic::{AtomicU32, Ordering};
use wdk_sys::{
ntddk::{
KeInitializeSpinLock, KeAcquireSpinLockRaiseToDpc, KeReleaseSpinLock,
ExAcquireFastMutex, ExReleaseFastMutex,
ExInitializeResourceLite, ExAcquireResourceExclusiveLite,
ExAcquireResourceSharedLite, ExReleaseResourceLite, ExDeleteResourceLite,
KeInitializeEvent, KeSetEvent, KeClearEvent, KeWaitForSingleObject,
},
KSPIN_LOCK, KIRQL, FAST_MUTEX, ERESOURCE, KEVENT,
LARGE_INTEGER,
};
pub unsafe fn ex_initialize_fast_mutex(mutex: &mut FAST_MUTEX) {
unsafe {
core::ptr::write_bytes(mutex as *mut FAST_MUTEX as *mut u8, 0, core::mem::size_of::<FAST_MUTEX>());
}
unsafe {
let count_ptr = mutex as *mut FAST_MUTEX as *mut i32;
*count_ptr = 1;
}
}
pub struct SpinLock {
lock: UnsafeCell<KSPIN_LOCK>,
}
unsafe impl Send for SpinLock {}
unsafe impl Sync for SpinLock {}
impl SpinLock {
pub const fn new() -> Self {
Self {
lock: UnsafeCell::new(0),
}
}
pub unsafe fn init(&self) {
unsafe { KeInitializeSpinLock(self.lock.get()) };
}
pub unsafe fn acquire(&self) -> KIRQL {
let old_irql: KIRQL = 0;
unsafe { KeAcquireSpinLockRaiseToDpc(self.lock.get()) };
old_irql
}
pub unsafe fn release(&self, old_irql: KIRQL) {
unsafe { KeReleaseSpinLock(self.lock.get(), old_irql) };
}
pub unsafe fn with_lock<T, F: FnOnce() -> T>(&self, f: F) -> T {
let irql = unsafe { self.acquire() };
let result = f();
unsafe { self.release(irql) };
result
}
}
pub struct FastMutex {
mutex: UnsafeCell<FAST_MUTEX>,
}
unsafe impl Send for FastMutex {}
unsafe impl Sync for FastMutex {}
impl FastMutex {
pub const fn new() -> Self {
Self {
mutex: UnsafeCell::new(unsafe { core::mem::zeroed() }),
}
}
pub unsafe fn init(&self) {
unsafe { ex_initialize_fast_mutex(&mut *self.mutex.get()) };
}
pub unsafe fn acquire(&self) {
unsafe { ExAcquireFastMutex(self.mutex.get()) };
}
pub unsafe fn release(&self) {
unsafe { ExReleaseFastMutex(self.mutex.get()) };
}
pub unsafe fn with_lock<T, F: FnOnce() -> T>(&self, f: F) -> T {
unsafe { self.acquire() };
let result = f();
unsafe { self.release() };
result
}
}
pub struct ExResource {
resource: UnsafeCell<ERESOURCE>,
initialized: AtomicU32,
}
unsafe impl Send for ExResource {}
unsafe impl Sync for ExResource {}
impl ExResource {
pub const fn new() -> Self {
Self {
resource: UnsafeCell::new(unsafe { core::mem::zeroed() }),
initialized: AtomicU32::new(0),
}
}
pub unsafe fn init(&self) -> Result<(), ()> {
if self.initialized.load(Ordering::SeqCst) != 0 {
return Ok(()); }
let status = unsafe { ExInitializeResourceLite(self.resource.get()) };
if status == 0 {
self.initialized.store(1, Ordering::SeqCst);
Ok(())
} else {
Err(())
}
}
pub unsafe fn acquire_exclusive(&self, wait: bool) -> bool {
unsafe {
ExAcquireResourceExclusiveLite(self.resource.get(), wait as u8) != 0
}
}
pub unsafe fn acquire_shared(&self, wait: bool) -> bool {
unsafe {
ExAcquireResourceSharedLite(self.resource.get(), wait as u8) != 0
}
}
pub unsafe fn release(&self) {
unsafe { ExReleaseResourceLite(self.resource.get()) };
}
pub unsafe fn with_exclusive<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if unsafe { self.acquire_exclusive(true) } {
let result = f();
unsafe { self.release() };
Some(result)
} else {
None
}
}
pub unsafe fn with_shared<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if unsafe { self.acquire_shared(true) } {
let result = f();
unsafe { self.release() };
Some(result)
} else {
None
}
}
}
impl Drop for ExResource {
fn drop(&mut self) {
if self.initialized.load(Ordering::SeqCst) != 0 {
unsafe { let _ = ExDeleteResourceLite(self.resource.get()); };
}
}
}
pub struct KernelEvent {
event: UnsafeCell<KEVENT>,
}
unsafe impl Send for KernelEvent {}
unsafe impl Sync for KernelEvent {}
pub enum KernelEventType {
SynchronizationEvent,
NotificationEvent,
}
impl KernelEvent {
pub const fn new() -> Self {
Self {
event: UnsafeCell::new(unsafe { core::mem::zeroed() }),
}
}
pub unsafe fn init(&self, event_type: KernelEventType, signaled: bool) {
let evt_type = match event_type {
KernelEventType::SynchronizationEvent => wdk_sys::_EVENT_TYPE::SynchronizationEvent,
KernelEventType::NotificationEvent => wdk_sys::_EVENT_TYPE::NotificationEvent,
};
unsafe {
KeInitializeEvent(self.event.get(), evt_type, signaled as u8);
}
}
pub unsafe fn set(&self) -> i32 {
unsafe { KeSetEvent(self.event.get(), 0, 0) }
}
pub unsafe fn clear(&self) {
unsafe { KeClearEvent(self.event.get()) };
}
pub unsafe fn wait(&self, timeout_ms: Option<u64>) -> bool {
let timeout = timeout_ms.map(|ms| {
let mut t: LARGE_INTEGER = unsafe { core::mem::zeroed() };
t.QuadPart = -(ms as i64 * 10_000);
t
});
let timeout_ptr = timeout
.as_ref()
.map(|t| t as *const _)
.unwrap_or(core::ptr::null());
let status = unsafe {
KeWaitForSingleObject(
self.event.get() as *mut _,
wdk_sys::_KWAIT_REASON::Executive,
wdk_sys::_MODE::KernelMode as wdk_sys::KPROCESSOR_MODE,
0, timeout_ptr as *mut _,
)
};
status == 0 }
}
pub mod interlocked {
use core::sync::atomic::{AtomicI32, AtomicI64, Ordering};
pub fn increment(value: &AtomicI32) -> i32 {
value.fetch_add(1, Ordering::SeqCst) + 1
}
pub fn decrement(value: &AtomicI32) -> i32 {
value.fetch_sub(1, Ordering::SeqCst) - 1
}
pub fn exchange(value: &AtomicI32, new_value: i32) -> i32 {
value.swap(new_value, Ordering::SeqCst)
}
pub fn compare_exchange(
value: &AtomicI32,
expected: i32,
new_value: i32,
) -> i32 {
match value.compare_exchange(expected, new_value, Ordering::SeqCst, Ordering::SeqCst) {
Ok(v) => v,
Err(v) => v,
}
}
pub fn add(value: &AtomicI64, addend: i64) -> i64 {
value.fetch_add(addend, Ordering::SeqCst) + addend
}
}