vmt_hook/
lib.rs

1//! This library provides the ability to hook Virtual Method Tables (VMT).
2//! It works by copying the original VMT and then swapping it out with the modified version.
3
4use std::cell::UnsafeCell;
5
6/// Represents a structure responsible for hooking and managing the virtual function table (VTable) of a given type.
7///
8/// # Example
9///
10/// ```rust
11/// use vmt_hook::VTableHook;
12///
13/// use windows::{
14///     core::HRESULT,
15///     Win32::{
16///         Foundation::{HWND, RECT},
17///         Graphics::{
18///             Direct3D9::IDirect3DDevice9,
19///             Gdi::RGNDATA,
20///         },
21///     },
22/// };
23///
24/// type FnPresent = extern "stdcall" fn(
25///     IDirect3DDevice9,
26///     *const RECT,
27///     *const RECT,
28///     HWND,
29///     *const RGNDATA,
30/// ) -> HRESULT;
31///
32/// static mut ORIGINAL_PRESENT: Option<FnPresent> = None;
33///
34/// extern "stdcall" fn hk_present(
35///     device: IDirect3DDevice9,
36///     source_rect: *const RECT,
37///     dest_rect: *const RECT,
38///     dest_window_override: HWND,
39///     dirty_region: *const RGNDATA,
40/// ) -> HRESULT {
41///     // Your code.
42///
43///     unsafe {
44///         let original_present = ORIGINAL_PRESENT.unwrap();
45///         original_present(device, source_rect, dest_rect, dest_window_override, dirty_region)
46///     }
47/// }
48///
49/// unsafe fn instal_d3d9_hook() {
50///     let device: IDirect3DDevice9 = /* Your ptr. */;
51///
52///     // Creating a hook with automatic detection of the number of methods in its VMT.
53///     // let hook = VTableHook::new(device);
54///
55///     // If you know the number of methods in the table, you can specify it manually.
56///     let hook = VTableHook::with_count(device, 119);
57///
58///     // Getting the original method.
59///     ORIGINAL_PRESENT = Some(std::mem::transmute(hook.get_original_method(17)));
60///
61///     // Replacing the method at index 17 in the VMT with our function.
62///     hook.replace_method(17, hk_present as usize);
63/// }
64/// ````
65pub struct VTableHook<T> {
66    /// Pointer to the object whose VTable is being hooked.
67    object: T,
68    /// Pointer to the original VTable.
69    original_vtbl: &'static [usize],
70    /// New VTable containing hooked function address.
71    new_vtbl: UnsafeCell<Vec<usize>>,
72}
73
74impl<T> Drop for VTableHook<T> {
75    /// Restoring the original VTable.
76    fn drop(&mut self) {
77        unsafe {
78            *std::mem::transmute_copy::<_, *mut *const usize>(&self.object) = self.original_vtbl.as_ptr();
79        }
80    }
81}
82
83impl<T> VTableHook<T> {
84    /// Creates a new VTableHook instance for the provided object and replaces its VTable with the hooked VTable.
85    /// The count of methods is automatically determined.
86    pub unsafe fn new(object: T) -> Self {
87        Self::init(object, |vtable| Self::detect_vtable_methods_count(vtable))
88    }
89
90    /// Creates a new VTableHook instance for the provided object with a specified method count
91    /// and replaces its VTable with the hooked VTable.
92    pub unsafe fn with_count(object: T, count: usize) -> Self {
93        Self::init(object, |_| count)
94    }
95
96    unsafe fn init<F>(object: T, count_fn: F) -> Self
97    where
98        F: FnOnce(*const usize) -> usize
99    {
100        let object_ptr = std::mem::transmute_copy::<T, *mut *const usize>(&object);
101        let original_vtbl = *object_ptr;
102        let count = count_fn(original_vtbl);
103        let original_vtbl = std::slice::from_raw_parts(original_vtbl, count);
104        let new_vtbl = original_vtbl.to_vec();
105
106        *object_ptr = new_vtbl.as_ptr();
107
108        Self {
109            object,
110            original_vtbl,
111            new_vtbl: UnsafeCell::new(new_vtbl),
112        }
113    }
114
115    /// Detects the number of methods in the provided VTable.
116    unsafe fn detect_vtable_methods_count(vtable: *const usize) -> usize {
117        let mut vmt = vtable;
118
119        // Todo: Maybe add a memory region length check?
120        while std::ptr::read(vmt) != 0 {
121            vmt = vmt.add(1);
122        }
123
124        (vmt as usize - vtable as usize) / std::mem::size_of::<usize>()
125    }
126
127    /// Returns our hooked vtable.
128    fn vtbl(&self) -> &mut Vec<usize> {
129        unsafe { &mut *self.new_vtbl.get() }
130    }
131
132    /// Returns the original method address at the specified index in the VTable.
133    pub fn get_original_method(&self, id: usize) -> usize {
134        self.original_vtbl[id]
135    }
136
137    /// Returns the replaced method address at the specified index in the VTable.
138    pub fn get_replaced_method(&self, id: usize) -> usize {
139        self.vtbl()[id]
140    }
141
142    /// Hooks the method at the specified index in the VTable with a new function address.
143    pub unsafe fn replace_method(&self, id: usize, func: usize) {
144        self.vtbl()[id] = func;
145    }
146
147    /// Restores the original method at the specified index in the VTable.
148    pub unsafe fn restore_method(&self, id: usize) {
149        self.vtbl()[id] = self.get_original_method(id);
150    }
151
152    /// Restores all methods in the VTable to their original address.
153    pub unsafe fn restore_all_methods(&self) {
154        self.vtbl().copy_from_slice(self.original_vtbl);
155    }
156
157    /// Returns the original object.
158    pub fn object(&self) -> &T {
159        &self.object
160    }
161}