vmt_hook/lib.rs
1//! This library provides the ability to hook Virtual Method Tables (VMT).
2//! It works by copying the original VMT and then swapping it out with the modified version.
3
4use std::cell::UnsafeCell;
5
6/// Represents a structure responsible for hooking and managing the virtual function table (VTable) of a given type.
7///
8/// # Example
9///
10/// ```rust
11/// use vmt_hook::VTableHook;
12///
13/// use windows::{
14/// core::HRESULT,
15/// Win32::{
16/// Foundation::{HWND, RECT},
17/// Graphics::{
18/// Direct3D9::IDirect3DDevice9,
19/// Gdi::RGNDATA,
20/// },
21/// },
22/// };
23///
24/// type FnPresent = extern "stdcall" fn(
25/// IDirect3DDevice9,
26/// *const RECT,
27/// *const RECT,
28/// HWND,
29/// *const RGNDATA,
30/// ) -> HRESULT;
31///
32/// static mut ORIGINAL_PRESENT: Option<FnPresent> = None;
33///
34/// extern "stdcall" fn hk_present(
35/// device: IDirect3DDevice9,
36/// source_rect: *const RECT,
37/// dest_rect: *const RECT,
38/// dest_window_override: HWND,
39/// dirty_region: *const RGNDATA,
40/// ) -> HRESULT {
41/// // Your code.
42///
43/// unsafe {
44/// let original_present = ORIGINAL_PRESENT.unwrap();
45/// original_present(device, source_rect, dest_rect, dest_window_override, dirty_region)
46/// }
47/// }
48///
49/// unsafe fn instal_d3d9_hook() {
50/// let device: IDirect3DDevice9 = /* Your ptr. */;
51///
52/// // Creating a hook with automatic detection of the number of methods in its VMT.
53/// // let hook = VTableHook::new(device);
54///
55/// // If you know the number of methods in the table, you can specify it manually.
56/// let hook = VTableHook::with_count(device, 119);
57///
58/// // Getting the original method.
59/// ORIGINAL_PRESENT = Some(std::mem::transmute(hook.get_original_method(17)));
60///
61/// // Replacing the method at index 17 in the VMT with our function.
62/// hook.replace_method(17, hk_present as usize);
63/// }
64/// ````
65pub struct VTableHook<T> {
66 /// Pointer to the object whose VTable is being hooked.
67 object: T,
68 /// Pointer to the original VTable.
69 original_vtbl: &'static [usize],
70 /// New VTable containing hooked function address.
71 new_vtbl: UnsafeCell<Vec<usize>>,
72}
73
74impl<T> Drop for VTableHook<T> {
75 /// Restoring the original VTable.
76 fn drop(&mut self) {
77 unsafe {
78 *std::mem::transmute_copy::<_, *mut *const usize>(&self.object) = self.original_vtbl.as_ptr();
79 }
80 }
81}
82
83impl<T> VTableHook<T> {
84 /// Creates a new VTableHook instance for the provided object and replaces its VTable with the hooked VTable.
85 /// The count of methods is automatically determined.
86 pub unsafe fn new(object: T) -> Self {
87 Self::init(object, |vtable| Self::detect_vtable_methods_count(vtable))
88 }
89
90 /// Creates a new VTableHook instance for the provided object with a specified method count
91 /// and replaces its VTable with the hooked VTable.
92 pub unsafe fn with_count(object: T, count: usize) -> Self {
93 Self::init(object, |_| count)
94 }
95
96 unsafe fn init<F>(object: T, count_fn: F) -> Self
97 where
98 F: FnOnce(*const usize) -> usize
99 {
100 let object_ptr = std::mem::transmute_copy::<T, *mut *const usize>(&object);
101 let original_vtbl = *object_ptr;
102 let count = count_fn(original_vtbl);
103 let original_vtbl = std::slice::from_raw_parts(original_vtbl, count);
104 let new_vtbl = original_vtbl.to_vec();
105
106 *object_ptr = new_vtbl.as_ptr();
107
108 Self {
109 object,
110 original_vtbl,
111 new_vtbl: UnsafeCell::new(new_vtbl),
112 }
113 }
114
115 /// Detects the number of methods in the provided VTable.
116 unsafe fn detect_vtable_methods_count(vtable: *const usize) -> usize {
117 let mut vmt = vtable;
118
119 // Todo: Maybe add a memory region length check?
120 while std::ptr::read(vmt) != 0 {
121 vmt = vmt.add(1);
122 }
123
124 (vmt as usize - vtable as usize) / std::mem::size_of::<usize>()
125 }
126
127 /// Returns our hooked vtable.
128 fn vtbl(&self) -> &mut Vec<usize> {
129 unsafe { &mut *self.new_vtbl.get() }
130 }
131
132 /// Returns the original method address at the specified index in the VTable.
133 pub fn get_original_method(&self, id: usize) -> usize {
134 self.original_vtbl[id]
135 }
136
137 /// Returns the replaced method address at the specified index in the VTable.
138 pub fn get_replaced_method(&self, id: usize) -> usize {
139 self.vtbl()[id]
140 }
141
142 /// Hooks the method at the specified index in the VTable with a new function address.
143 pub unsafe fn replace_method(&self, id: usize, func: usize) {
144 self.vtbl()[id] = func;
145 }
146
147 /// Restores the original method at the specified index in the VTable.
148 pub unsafe fn restore_method(&self, id: usize) {
149 self.vtbl()[id] = self.get_original_method(id);
150 }
151
152 /// Restores all methods in the VTable to their original address.
153 pub unsafe fn restore_all_methods(&self) {
154 self.vtbl().copy_from_slice(self.original_vtbl);
155 }
156
157 /// Returns the original object.
158 pub fn object(&self) -> &T {
159 &self.object
160 }
161}