use ntapi::ntrtl::{
RtlRbInsertNodeEx, RtlReleaseSRWLockExclusive, RtlTryAcquireSRWLockExclusive, RTL_RB_TREE,
};
use std::{ffi::c_void, io::Result, ops::{Deref, DerefMut}, ptr, ptr::null_mut};
use winapi::{
shared::{minwindef::DWORD, ntdef::PRTL_BALANCED_NODE},
um::{
handleapi::CloseHandle,
memoryapi::VirtualProtect,
winnt::{HANDLE, PAGE_READWRITE, RTL_SRWLOCK},
},
};
pub unsafe fn rtl_rb_tree_insert<F: Fn(PRTL_BALANCED_NODE, PRTL_BALANCED_NODE) -> bool>(
tree: &mut RTL_RB_TREE,
node: PRTL_BALANCED_NODE,
compare: F,
) {
let (parent, right) = rtl_rb_tree_find_insert_location(tree, node, compare);
RtlRbInsertNodeEx(tree, parent, right as u8, node);
}
unsafe fn rtl_rb_tree_access_node(
tree: &mut RTL_RB_TREE,
node: *const PRTL_BALANCED_NODE,
) -> PRTL_BALANCED_NODE {
if (tree.Min as u64 & 1) != 0 {
if (*node).is_null() {
return null_mut();
}
(node as u64 ^ (*node) as u64) as PRTL_BALANCED_NODE
} else {
*node
}
}
unsafe fn rtl_rb_tree_find_insert_location<
F: Fn(PRTL_BALANCED_NODE, PRTL_BALANCED_NODE) -> bool,
>(
tree: &mut RTL_RB_TREE,
node: PRTL_BALANCED_NODE,
compare: F,
) -> (PRTL_BALANCED_NODE, bool) {
let mut cur_node = rtl_rb_tree_access_node(tree, &tree.Root);
let mut next_node;
let mut right = false;
while !cur_node.is_null() {
if compare(node, cur_node) {
next_node = rtl_rb_tree_access_node(tree, &(*cur_node).u.s().Left);
if next_node.is_null() {
right = false;
break;
}
} else {
next_node = rtl_rb_tree_access_node(tree, &(*cur_node).u.s().Right);
if next_node.is_null() {
right = true;
break;
}
}
cur_node = next_node;
}
(cur_node, right)
}
#[derive(Debug)]
pub struct Handle {
pub handle: HANDLE,
}
pub struct ProtectionGuard {
addr: *mut c_void,
size: usize,
old_prot: DWORD,
}
#[derive(Clone)]
pub struct RtlMutex<T> {
val_ref: *mut T,
lock_ref: *mut RTL_SRWLOCK,
}
pub struct RtlMutexGuard<'a, T> {
mutex: &'a RtlMutex<T>,
}
pub unsafe fn protected_write<T>(addr: *mut T, val: T) -> Result<()> {
let _prot_guard = ProtectionGuard::new(
addr as *mut c_void,
std::mem::size_of_val(&val),
PAGE_READWRITE,
)?;
ptr::write_unaligned(addr, val);
Ok(())
}
impl Handle {
pub fn is_invalid(&self) -> bool {
self.is_null() || (self.handle as i64) == -1
}
pub fn is_null(&self) -> bool {
self.handle.is_null()
}
}
impl From<HANDLE> for Handle {
fn from(handle: HANDLE) -> Self {
Self { handle }
}
}
impl Drop for Handle {
fn drop(&mut self) {
if self.is_invalid() {
return;
}
unsafe {
CloseHandle(self.handle);
}
}
}
impl ProtectionGuard {
pub fn new(addr: *mut c_void, size: usize, prot: DWORD) -> Result<Self> {
let mut old_prot = 0;
unsafe {
match VirtualProtect(addr, size, prot, &mut old_prot) {
0 => Err(std::io::Error::last_os_error()),
_ => Ok(ProtectionGuard {
addr,
size,
old_prot,
}),
}
}
}
}
impl Drop for ProtectionGuard {
fn drop(&mut self) {
let mut dummy = 0;
unsafe {
VirtualProtect(self.addr, self.size, self.old_prot, &mut dummy);
}
}
}
impl<T> RtlMutex<T> {
pub fn lock(&self) -> RtlMutexGuard<T> {
unsafe {
RtlTryAcquireSRWLockExclusive(self.lock_ref);
}
RtlMutexGuard { mutex: self }
}
pub fn from_ref(val_ref: *mut T, lock_ref: *mut RTL_SRWLOCK) -> RtlMutex<T> {
RtlMutex { val_ref, lock_ref }
}
}
unsafe impl<T> Send for RtlMutex<T> {}
unsafe impl<T> Sync for RtlMutex<T> {}
impl<'a, T> Deref for RtlMutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.mutex.val_ref }
}
}
impl<'a, T> DerefMut for RtlMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mutex.val_ref }
}
}
impl<'a, T> Drop for RtlMutexGuard<'a, T> {
fn drop(&mut self) {
unsafe { RtlReleaseSRWLockExclusive(self.mutex.lock_ref) }
}
}