#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{boxed::Box, format, string::String, vec, vec::Vec};
#[cfg(feature = "std")]
use std::{boxed::Box, format, string::String, vec, vec::Vec};
use crate::error::{Result, WraithError};
use crate::util::memory::ProtectionGuard;
use core::marker::PhantomData;
const PAGE_READWRITE: u32 = 0x04;
pub struct VmtHook {
vtable_entry: usize,
original: usize,
detour: usize,
active: bool,
auto_restore: bool,
}
impl VmtHook {
pub unsafe fn new(object: *const (), index: usize, detour: usize) -> Result<Self> {
if object.is_null() {
return Err(WraithError::NullPointer { context: "object" });
}
let vptr = unsafe { *(object as *const usize) };
if vptr == 0 {
return Err(WraithError::NullPointer { context: "vptr" });
}
Self::new_at_vtable(vptr, index, detour)
}
pub fn new_at_vtable(vtable: usize, index: usize, detour: usize) -> Result<Self> {
if vtable == 0 {
return Err(WraithError::NullPointer { context: "vtable" });
}
let ptr_size = core::mem::size_of::<usize>();
let vtable_entry = vtable + index * ptr_size;
let original = unsafe { *(vtable_entry as *const usize) };
let mut hook = Self {
vtable_entry,
original,
detour,
active: false,
auto_restore: true,
};
hook.install()?;
Ok(hook)
}
pub fn install(&mut self) -> Result<()> {
if self.active {
return Ok(());
}
write_vtable_entry(self.vtable_entry, self.detour)?;
self.active = true;
Ok(())
}
pub fn uninstall(&mut self) -> Result<()> {
if !self.active {
return Ok(());
}
write_vtable_entry(self.vtable_entry, self.original)?;
self.active = false;
Ok(())
}
pub fn is_active(&self) -> bool {
self.active
}
pub fn original(&self) -> usize {
self.original
}
pub fn detour(&self) -> usize {
self.detour
}
pub fn vtable_entry(&self) -> usize {
self.vtable_entry
}
pub fn set_auto_restore(&mut self, restore: bool) {
self.auto_restore = restore;
}
pub fn leak(mut self) {
self.auto_restore = false;
core::mem::forget(self);
}
pub fn restore(mut self) -> Result<()> {
self.uninstall()?;
self.auto_restore = false;
Ok(())
}
}
impl Drop for VmtHook {
fn drop(&mut self) {
if self.auto_restore && self.active {
let _ = self.uninstall();
}
}
}
unsafe impl Send for VmtHook {}
unsafe impl Sync for VmtHook {}
pub struct ShadowVmt<T: ?Sized = ()> {
object: *mut (),
original_vtable: usize,
shadow_vtable: Box<[usize]>,
hooks: Vec<(usize, usize)>,
auto_restore: bool,
_marker: PhantomData<T>,
}
impl<T: ?Sized> ShadowVmt<T> {
pub unsafe fn new(object: *mut (), vtable_size: usize) -> Result<Self> {
if object.is_null() {
return Err(WraithError::NullPointer { context: "object" });
}
if vtable_size == 0 {
return Err(WraithError::InvalidPeFormat {
reason: "vtable_size cannot be 0".into(),
});
}
let original_vtable = unsafe { *(object as *const usize) };
if original_vtable == 0 {
return Err(WraithError::NullPointer { context: "vptr" });
}
let mut shadow = Vec::with_capacity(vtable_size);
for i in 0..vtable_size {
let entry_addr = original_vtable + i * core::mem::size_of::<usize>();
let entry = unsafe { *(entry_addr as *const usize) };
shadow.push(entry);
}
let shadow_vtable = shadow.into_boxed_slice();
unsafe {
*(object as *mut usize) = shadow_vtable.as_ptr() as usize;
}
Ok(Self {
object,
original_vtable,
shadow_vtable,
hooks: Vec::new(),
auto_restore: true,
_marker: PhantomData,
})
}
pub fn hook(&mut self, index: usize, detour: usize) -> Result<()> {
if index >= self.shadow_vtable.len() {
return Err(WraithError::InvalidPeFormat {
reason: format!(
"vtable index {} out of bounds (size {})",
index,
self.shadow_vtable.len()
),
});
}
if !self.hooks.iter().any(|(i, _)| *i == index) {
self.hooks.push((index, self.shadow_vtable[index]));
}
self.shadow_vtable[index] = detour;
Ok(())
}
pub fn unhook(&mut self, index: usize) -> Result<()> {
if let Some(pos) = self.hooks.iter().position(|(i, _)| *i == index) {
let (_, original) = self.hooks.remove(pos);
if index < self.shadow_vtable.len() {
self.shadow_vtable[index] = original;
}
}
Ok(())
}
pub fn unhook_all(&mut self) {
for (index, original) in self.hooks.drain(..) {
if index < self.shadow_vtable.len() {
self.shadow_vtable[index] = original;
}
}
}
pub fn original(&self, index: usize) -> Option<usize> {
for (i, original) in &self.hooks {
if *i == index {
return Some(*original);
}
}
self.shadow_vtable.get(index).copied()
}
pub fn original_vtable(&self) -> usize {
self.original_vtable
}
pub fn shadow_vtable(&self) -> usize {
self.shadow_vtable.as_ptr() as usize
}
pub fn vtable_size(&self) -> usize {
self.shadow_vtable.len()
}
pub fn is_hooked(&self, index: usize) -> bool {
self.hooks.iter().any(|(i, _)| *i == index)
}
pub fn hook_count(&self) -> usize {
self.hooks.len()
}
pub fn set_auto_restore(&mut self, restore: bool) {
self.auto_restore = restore;
}
pub fn restore(mut self) -> Result<()> {
self.restore_internal()?;
self.auto_restore = false;
Ok(())
}
fn restore_internal(&mut self) -> Result<()> {
unsafe {
*(self.object as *mut usize) = self.original_vtable;
}
Ok(())
}
}
impl<T: ?Sized> Drop for ShadowVmt<T> {
fn drop(&mut self) {
if self.auto_restore {
let _ = self.restore_internal();
}
}
}
unsafe impl<T: ?Sized> Send for ShadowVmt<T> {}
unsafe impl<T: ?Sized> Sync for ShadowVmt<T> {}
pub type VmtHookGuard = VmtHook;
pub unsafe fn get_vtable(object: *const ()) -> Result<usize> {
if object.is_null() {
return Err(WraithError::NullPointer { context: "object" });
}
let vptr = unsafe { *(object as *const usize) };
if vptr == 0 {
return Err(WraithError::NullPointer { context: "vptr" });
}
Ok(vptr)
}
pub unsafe fn get_vtable_entry(vtable: usize, index: usize) -> Result<usize> {
if vtable == 0 {
return Err(WraithError::NullPointer { context: "vtable" });
}
let entry_addr = vtable + index * core::mem::size_of::<usize>();
let entry = unsafe { *(entry_addr as *const usize) };
Ok(entry)
}
pub unsafe fn estimate_vtable_size(vtable: usize, max_scan: usize) -> usize {
if vtable == 0 {
return 0;
}
let mut count = 0;
for i in 0..max_scan {
let entry_addr = vtable + i * core::mem::size_of::<usize>();
let entry = unsafe { *(entry_addr as *const usize) };
if entry == 0 {
break;
}
#[cfg(target_arch = "x86_64")]
{
if entry < 0x10000 || entry > 0x7FFF_FFFF_FFFF {
break;
}
}
#[cfg(target_arch = "x86")]
{
if entry < 0x10000 {
break;
}
}
count = i + 1;
}
count
}
fn write_vtable_entry(entry: usize, value: usize) -> Result<()> {
let _guard = ProtectionGuard::new(entry, core::mem::size_of::<usize>(), PAGE_READWRITE)?;
unsafe {
*(entry as *mut usize) = value;
}
Ok(())
}
pub trait VmtObject {
fn vtable(&self) -> usize {
unsafe { *(self as *const Self as *const usize) }
}
}
pub struct VmtHookBuilder {
object: Option<*const ()>,
vtable: Option<usize>,
index: Option<usize>,
detour: Option<usize>,
}
impl VmtHookBuilder {
pub fn new() -> Self {
Self {
object: None,
vtable: None,
index: None,
detour: None,
}
}
pub unsafe fn object(mut self, object: *const ()) -> Self {
self.object = Some(object);
self
}
pub fn vtable(mut self, vtable: usize) -> Self {
self.vtable = Some(vtable);
self
}
pub fn index(mut self, index: usize) -> Self {
self.index = Some(index);
self
}
pub fn detour(mut self, detour: usize) -> Self {
self.detour = Some(detour);
self
}
pub fn build(self) -> Result<VmtHook> {
let vtable = if let Some(vt) = self.vtable {
vt
} else if let Some(obj) = self.object {
unsafe { get_vtable(obj)? }
} else {
return Err(WraithError::NullPointer {
context: "neither object nor vtable set",
});
};
let index = self.index.ok_or(WraithError::NullPointer {
context: "index not set",
})?;
let detour = self.detour.ok_or(WraithError::NullPointer {
context: "detour not set",
})?;
VmtHook::new_at_vtable(vtable, index, detour)
}
}
impl Default for VmtHookBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[repr(C)]
struct TestVtable {
func1: usize,
func2: usize,
func3: usize,
}
#[repr(C)]
struct TestObject {
vptr: *const TestVtable,
}
extern "C" fn test_func1() -> i32 {
1
}
extern "C" fn test_func2() -> i32 {
2
}
extern "C" fn test_func3() -> i32 {
3
}
#[test]
fn test_get_vtable() {
static VTABLE: TestVtable = TestVtable {
func1: test_func1 as usize,
func2: test_func2 as usize,
func3: test_func3 as usize,
};
let obj = TestObject {
vptr: &VTABLE,
};
let vptr = unsafe { get_vtable(&obj as *const _ as *const ()) }
.expect("should get vtable");
assert_eq!(vptr, &VTABLE as *const _ as usize);
}
#[test]
fn test_get_vtable_entry() {
static VTABLE: TestVtable = TestVtable {
func1: test_func1 as usize,
func2: test_func2 as usize,
func3: test_func3 as usize,
};
let vtable = &VTABLE as *const _ as usize;
let entry0 = unsafe { get_vtable_entry(vtable, 0) }.expect("should get entry");
let entry1 = unsafe { get_vtable_entry(vtable, 1) }.expect("should get entry");
assert_eq!(entry0, test_func1 as usize);
assert_eq!(entry1, test_func2 as usize);
}
#[test]
fn test_estimate_vtable_size() {
static VTABLE: [usize; 5] = [
test_func1 as usize,
test_func2 as usize,
test_func3 as usize,
0, 0,
];
let size = unsafe { estimate_vtable_size(VTABLE.as_ptr() as usize, 10) };
assert_eq!(size, 3);
}
}