#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::vec::Vec;
use crate::error::{Result, WraithError};
use crate::util::memory::ProtectionGuard;
use super::arch::Architecture;
use super::trampoline::ExecutableMemory;
use core::marker::PhantomData;
const PAGE_EXECUTE_READWRITE: u32 = 0x40;
pub struct HookGuard<A: Architecture> {
target: usize,
detour: usize,
original_bytes: Vec<u8>,
trampoline: Option<ExecutableMemory>,
auto_restore: bool,
_arch: PhantomData<A>,
}
impl<A: Architecture> HookGuard<A> {
pub(crate) fn new(
target: usize,
detour: usize,
original_bytes: Vec<u8>,
trampoline: Option<ExecutableMemory>,
) -> Self {
Self {
target,
detour,
original_bytes,
trampoline,
auto_restore: true,
_arch: PhantomData,
}
}
pub fn target(&self) -> usize {
self.target
}
pub fn detour(&self) -> usize {
self.detour
}
pub fn trampoline(&self) -> Option<usize> {
self.trampoline.as_ref().map(|t| t.base())
}
pub fn original_bytes(&self) -> &[u8] {
&self.original_bytes
}
pub fn will_auto_restore(&self) -> bool {
self.auto_restore
}
pub fn set_auto_restore(&mut self, restore: bool) {
self.auto_restore = restore;
}
pub fn leak(mut self) {
self.auto_restore = false;
if let Some(trampoline) = self.trampoline.take() {
trampoline.leak();
}
core::mem::forget(self);
}
pub fn restore(self) -> Result<()> {
self.restore_internal()?;
core::mem::forget(self);
Ok(())
}
pub fn disable(&mut self) -> Result<()> {
let _guard = ProtectionGuard::new(
self.target,
self.original_bytes.len(),
PAGE_EXECUTE_READWRITE,
)?;
unsafe {
core::ptr::copy_nonoverlapping(
self.original_bytes.as_ptr(),
self.target as *mut u8,
self.original_bytes.len(),
);
}
flush_icache(self.target, self.original_bytes.len())?;
Ok(())
}
pub fn enable(&mut self, hook_bytes: &[u8]) -> Result<()> {
if hook_bytes.len() != self.original_bytes.len() {
return Err(WraithError::WriteFailed {
address: self.target as u64,
size: hook_bytes.len(),
});
}
let _guard = ProtectionGuard::new(
self.target,
hook_bytes.len(),
PAGE_EXECUTE_READWRITE,
)?;
unsafe {
core::ptr::copy_nonoverlapping(
hook_bytes.as_ptr(),
self.target as *mut u8,
hook_bytes.len(),
);
}
flush_icache(self.target, hook_bytes.len())?;
Ok(())
}
fn restore_internal(&self) -> Result<()> {
let _guard = ProtectionGuard::new(
self.target,
self.original_bytes.len(),
PAGE_EXECUTE_READWRITE,
)?;
unsafe {
core::ptr::copy_nonoverlapping(
self.original_bytes.as_ptr(),
self.target as *mut u8,
self.original_bytes.len(),
);
}
flush_icache(self.target, self.original_bytes.len())?;
Ok(())
}
}
impl<A: Architecture> Drop for HookGuard<A> {
fn drop(&mut self) {
if self.auto_restore {
let _ = self.restore_internal();
}
}
}
unsafe impl<A: Architecture> Send for HookGuard<A> {}
unsafe impl<A: Architecture> Sync for HookGuard<A> {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookState {
Enabled,
Disabled,
}
pub struct StatefulHookGuard<A: Architecture> {
guard: HookGuard<A>,
hook_bytes: Vec<u8>,
state: HookState,
}
impl<A: Architecture> StatefulHookGuard<A> {
pub(crate) fn new(guard: HookGuard<A>, hook_bytes: Vec<u8>) -> Self {
Self {
guard,
hook_bytes,
state: HookState::Enabled,
}
}
pub fn state(&self) -> HookState {
self.state
}
pub fn is_enabled(&self) -> bool {
self.state == HookState::Enabled
}
pub fn disable(&mut self) -> Result<()> {
if self.state == HookState::Enabled {
self.guard.disable()?;
self.state = HookState::Disabled;
}
Ok(())
}
pub fn enable(&mut self) -> Result<()> {
if self.state == HookState::Disabled {
self.guard.enable(&self.hook_bytes)?;
self.state = HookState::Enabled;
}
Ok(())
}
pub fn toggle(&mut self) -> Result<()> {
match self.state {
HookState::Enabled => self.disable(),
HookState::Disabled => self.enable(),
}
}
pub fn target(&self) -> usize {
self.guard.target()
}
pub fn detour(&self) -> usize {
self.guard.detour()
}
pub fn trampoline(&self) -> Option<usize> {
self.guard.trampoline()
}
pub fn leak(self) {
self.guard.leak();
}
pub fn restore(self) -> Result<()> {
self.guard.restore()
}
}
fn flush_icache(address: usize, size: usize) -> Result<()> {
let result = unsafe {
FlushInstructionCache(
GetCurrentProcess(),
address as *const _,
size,
)
};
if result == 0 {
Err(WraithError::from_last_error("FlushInstructionCache"))
} else {
Ok(())
}
}
#[link(name = "kernel32")]
extern "system" {
fn FlushInstructionCache(
hProcess: *mut core::ffi::c_void,
lpBaseAddress: *const core::ffi::c_void,
dwSize: usize,
) -> i32;
fn GetCurrentProcess() -> *mut core::ffi::c_void;
}