use std::ffi::{c_void, OsString};
use std::fmt::Display;
use std::mem::{size_of, ManuallyDrop};
use std::os::windows::ffi::OsStringExt;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::{RwLock, RwLockReadGuard};
use tracing::{debug, error};
use windows::core::s;
use windows::Win32::Foundation::{HANDLE, HMODULE, HWND, MAX_PATH, RECT};
use windows::Win32::Graphics::Direct3D::ID3DBlob;
use windows::Win32::Graphics::Direct3D12::{
D3D12GetDebugInterface, ID3D12Debug, ID3D12Device, ID3D12Fence, ID3D12Resource,
D3D12_FENCE_FLAG_NONE, D3D12_RESOURCE_BARRIER, D3D12_RESOURCE_BARRIER_0,
D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES, D3D12_RESOURCE_BARRIER_FLAG_NONE,
D3D12_RESOURCE_BARRIER_TYPE_TRANSITION, D3D12_RESOURCE_STATES,
D3D12_RESOURCE_TRANSITION_BARRIER,
};
use windows::Win32::Graphics::Dxgi::{
DXGIGetDebugInterface1, IDXGIInfoQueue, DXGI_DEBUG_ALL, DXGI_INFO_QUEUE_MESSAGE,
};
use windows::Win32::System::LibraryLoader::{
GetModuleFileNameW, GetModuleHandleExA, GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
};
use windows::Win32::System::Memory::{
VirtualQuery, MEMORY_BASIC_INFORMATION, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE,
PAGE_PROTECTION_FLAGS, PAGE_READONLY, PAGE_READWRITE,
};
use windows::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO};
use windows::Win32::System::Threading::{CreateEventExW, WaitForSingleObjectEx, CREATE_EVENT};
use windows::Win32::UI::WindowsAndMessaging::GetClientRect;
pub fn try_out_param<T, F, E, O>(mut f: F) -> Result<T, E>
where
T: Default,
F: FnMut(&mut T) -> Result<O, E>,
{
let mut t: T = Default::default();
match f(&mut t) {
Ok(_) => Ok(t),
Err(e) => Err(e),
}
}
pub fn try_out_ptr<T, F, E, O>(mut f: F) -> Result<T, E>
where
F: FnMut(&mut Option<T>) -> Result<O, E>,
{
let mut t: Option<T> = None;
match f(&mut t) {
Ok(_) => Ok(t.unwrap()),
Err(e) => Err(e),
}
}
pub fn try_out_err_blob<T1, T2, F, E, O>(mut f: F) -> Result<T1, (E, T2)>
where
F: FnMut(&mut Option<T1>, &mut Option<T2>) -> Result<O, E>,
{
let mut t1: Option<T1> = None;
let mut t2: Option<T2> = None;
match f(&mut t1, &mut t2) {
Ok(_) => Ok(t1.unwrap()),
Err(e) => Err((e, t2.unwrap())),
}
}
pub fn out_param<T: Default, F>(f: F) -> T
where
F: FnOnce(&mut T),
{
let mut val = Default::default();
f(&mut val);
val
}
pub fn print_error_blob<D: Display, E>(msg: D) -> impl Fn((E, ID3DBlob)) -> E {
move |(e, err_blob): (E, ID3DBlob)| {
let buf_ptr = unsafe { err_blob.GetBufferPointer() } as *mut u8;
let buf_size = unsafe { err_blob.GetBufferSize() };
let s = unsafe { String::from_raw_parts(buf_ptr, buf_size, buf_size + 1) };
error!("{msg}: {s}");
e
}
}
pub fn enable_debug_interface() {
let debug_interface: Result<ID3D12Debug, _> =
try_out_ptr(|v| unsafe { D3D12GetDebugInterface(v) });
match debug_interface {
Ok(debug_interface) => unsafe { debug_interface.EnableDebugLayer() },
Err(e) => {
error!("Could not create debug interface: {e:?}")
},
}
}
pub fn print_dxgi_debug_messages() {
let Ok(diq): Result<IDXGIInfoQueue, _> = (unsafe { DXGIGetDebugInterface1(0) }) else {
return;
};
let n = unsafe { diq.GetNumStoredMessages(DXGI_DEBUG_ALL) };
for i in 0..n {
let mut msg_len: usize = 0;
unsafe { diq.GetMessage(DXGI_DEBUG_ALL, i, None, &mut msg_len as _).unwrap() };
let diqm = vec![0u8; msg_len];
let pdiqm = diqm.as_ptr() as *mut DXGI_INFO_QUEUE_MESSAGE;
unsafe { diq.GetMessage(DXGI_DEBUG_ALL, i, Some(pdiqm), &mut msg_len as _).unwrap() };
let diqm = unsafe { pdiqm.as_ref().unwrap() };
debug!(
"[DIQ] {}",
String::from_utf8_lossy(unsafe {
std::slice::from_raw_parts(diqm.pDescription, diqm.DescriptionByteLength - 1)
})
);
}
unsafe { diq.ClearStoredMessages(DXGI_DEBUG_ALL) };
}
pub fn win_size(hwnd: HWND) -> (i32, i32) {
let mut rect = RECT::default();
unsafe { GetClientRect(hwnd, &mut rect).unwrap() };
(rect.right - rect.left, rect.bottom - rect.top)
}
pub fn get_dll_path() -> Option<PathBuf> {
let mut hmodule = HMODULE(std::ptr::null_mut());
if let Err(e) = unsafe {
GetModuleHandleExA(
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
s!("DllMain"),
&mut hmodule,
)
} {
error!("get_dll_path: GetModuleHandleExA error: {e:?}");
return None;
}
let mut sz_filename = [0u16; MAX_PATH as usize];
let len = unsafe { GetModuleFileNameW(Some(hmodule), &mut sz_filename) } as usize;
Some(OsString::from_wide(&sz_filename[..len]).into())
}
pub fn create_barrier(
resource: &ID3D12Resource,
before: D3D12_RESOURCE_STATES,
after: D3D12_RESOURCE_STATES,
) -> D3D12_RESOURCE_BARRIER {
D3D12_RESOURCE_BARRIER {
Type: D3D12_RESOURCE_BARRIER_TYPE_TRANSITION,
Flags: D3D12_RESOURCE_BARRIER_FLAG_NONE,
Anonymous: D3D12_RESOURCE_BARRIER_0 {
Transition: ManuallyDrop::new(D3D12_RESOURCE_TRANSITION_BARRIER {
pResource: ManuallyDrop::new(Some(resource.clone())),
Subresource: D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES,
StateBefore: before,
StateAfter: after,
}),
},
}
}
pub fn drop_barrier(barrier: D3D12_RESOURCE_BARRIER) {
let transition = ManuallyDrop::into_inner(unsafe { barrier.Anonymous.Transition });
let _ = ManuallyDrop::into_inner(transition.pResource);
}
pub struct Fence {
fence: ID3D12Fence,
value: AtomicU64,
event: HANDLE,
}
impl Fence {
pub fn new(device: &ID3D12Device) -> windows::core::Result<Self> {
let fence = unsafe { device.CreateFence(0, D3D12_FENCE_FLAG_NONE) }?;
let value = AtomicU64::new(0);
let event = unsafe { CreateEventExW(None, None, CREATE_EVENT(0), 0x1f0003) }?;
Ok(Fence { fence, value, event })
}
pub fn fence(&self) -> &ID3D12Fence {
&self.fence
}
pub fn value(&self) -> u64 {
self.value.load(Ordering::SeqCst)
}
pub fn incr(&self) -> u64 {
self.value.fetch_add(1, Ordering::SeqCst)
}
pub fn wait(&self) -> windows::core::Result<()> {
let value = self.value();
self.wait_for_value(value)
}
pub fn wait_for_value(&self, value: u64) -> windows::core::Result<()> {
unsafe {
if self.fence.GetCompletedValue() < value {
self.fence.SetEventOnCompletion(value, self.event)?;
WaitForSingleObjectEx(self.event, u32::MAX, false);
}
}
Ok(())
}
}
pub unsafe fn readable_region<T>(ptr: *const T, limit: usize) -> &'static [T] {
unsafe fn is_readable(
ptr: *const c_void,
memory_basic_info: &mut MEMORY_BASIC_INFORMATION,
) -> bool {
const PAGE_READABLE: PAGE_PROTECTION_FLAGS = PAGE_PROTECTION_FLAGS(
PAGE_READONLY.0 | PAGE_READWRITE.0 | PAGE_EXECUTE_READ.0 | PAGE_EXECUTE_READWRITE.0,
);
(unsafe {
VirtualQuery(Some(ptr), memory_basic_info, size_of::<MEMORY_BASIC_INFORMATION>())
} != 0)
&& (memory_basic_info.Protect & PAGE_READABLE).0 != 0
}
let page_size_bytes = {
let mut system_info = SYSTEM_INFO::default();
unsafe { GetSystemInfo(&mut system_info) };
system_info.dwPageSize as usize
};
let page_align_mask = page_size_bytes - 1;
let first_page_addr = (ptr as usize) & !page_align_mask;
let last_page_addr = (ptr as usize + (limit * size_of::<T>()) - 1) & !page_align_mask;
let mut memory_basic_info = MEMORY_BASIC_INFORMATION::default();
for page_addr in (first_page_addr..=last_page_addr).step_by(page_size_bytes) {
if unsafe { is_readable(page_addr as _, &mut memory_basic_info) } {
continue;
}
let num_readable = page_addr.saturating_sub(ptr as usize) / size_of::<T>();
return std::slice::from_raw_parts(ptr, num_readable);
}
std::slice::from_raw_parts(ptr, limit)
}
pub struct HookEjectionBarrier(RwLock<()>);
impl HookEjectionBarrier {
pub const fn new() -> Self {
Self(RwLock::new(()))
}
pub fn acquire_ejection_guard(&self) -> RwLockReadGuard<'_, ()> {
self.0.read()
}
pub fn wait_for_all_guards(&self) {
let _wait_guard = self.0.write();
}
}
impl Default for HookEjectionBarrier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use windows::Win32::System::Memory::{VirtualAlloc, VirtualProtect, MEM_COMMIT, PAGE_NOACCESS};
use super::*;
#[test]
fn test_readable_region() -> windows::core::Result<()> {
const PAGE_SIZE: usize = 0x1000;
let region = unsafe { VirtualAlloc(None, 2 * PAGE_SIZE, MEM_COMMIT, PAGE_READWRITE) };
if region.is_null() {
return Err(windows::core::Error::from_thread());
}
let mut old_protect = PAGE_PROTECTION_FLAGS::default();
unsafe {
VirtualProtect(
(region as usize + PAGE_SIZE) as _,
PAGE_SIZE,
PAGE_NOACCESS,
&mut old_protect,
)
}?;
assert_eq!(old_protect, PAGE_READWRITE);
let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE) };
assert_eq!(slice.len(), PAGE_SIZE);
let slice = unsafe { readable_region::<u8>(region as _, PAGE_SIZE + 1) };
assert_eq!(slice.len(), PAGE_SIZE);
let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE) as _, 1) };
assert!(slice.is_empty());
let slice = unsafe { readable_region::<u8>((region as usize + PAGE_SIZE - 1) as _, 2) };
assert_eq!(slice.len(), 1);
Ok(())
}
}