use core::mem::{self, transmute, transmute_copy};
use std::collections::{
HashMap,
hash_map::{Values, ValuesMut},
};
use bon::bon;
use windows::core::HRESULT;
use crate::{FnPtr, inline::InlineHook};
#[derive(Default)]
pub struct InlineHookMap {
hooks: HashMap<fn(), InlineHook<fn()>>,
leaked: bool,
}
#[bon]
impl InlineHookMap {
pub fn new() -> Self {
Self::default()
}
pub fn hooks<'a>(&'a self) -> Values<'a, fn(), InlineHook<fn()>> {
self.hooks.values()
}
pub fn hooks_mut<'a>(&'a mut self) -> ValuesMut<'a, fn(), InlineHook<fn()>> {
self.hooks.values_mut()
}
pub fn len(&self) -> usize {
self.hooks.len()
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub fn insert<'a, F: FnPtr>(&'a mut self, target: F, detour: F) -> &'a mut InlineHook<F> {
let hook = unsafe { InlineHook::new(target, detour).into_type_erased() };
let entry = self.hooks.entry(hook.target()).insert_entry(hook);
let entry: &'a mut InlineHook<fn()> = entry.into_mut();
unsafe { entry.cast_mut() }
}
#[builder]
pub fn enable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
for (target, hook) in self.hooks.iter_mut() {
let hr = hook.enable();
if !hr.is_ok() {
if let Some(on_error) = on_error.as_mut() {
on_error(*target, hr);
}
}
}
}
#[builder]
pub fn disable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
for (target, hook) in self.hooks.iter_mut() {
let hr = hook.disable();
if !hr.is_ok() {
if let Some(on_error) = on_error.as_mut() {
on_error(*target, hr);
}
}
}
}
pub fn get<F: FnPtr>(&self, target: F) -> Option<&InlineHook<F>> {
let target = unsafe { transmute_copy(&target) };
let hook: Option<&InlineHook<fn()>> = self.hooks.get(&target);
unsafe { transmute(hook) }
}
pub fn get_mut<F: FnPtr>(&mut self, target: F) -> Option<&mut InlineHook<F>> {
let target = unsafe { transmute_copy(&target) };
let hook: Option<&mut InlineHook<fn()>> = self.hooks.get_mut(&target);
unsafe { transmute(hook) }
}
pub fn remove<F: FnPtr>(&mut self, target: F) -> Option<InlineHook<F>> {
let target = unsafe { transmute_copy(&target) };
self.hooks
.remove(&target)
.map(|hook| unsafe { hook.cast_into() })
}
pub fn leak(&mut self) {
self.leaked = true;
}
}
impl Drop for InlineHookMap {
fn drop(&mut self) {
if self.leaked {
let hooks = mem::take(&mut self.hooks);
hooks
.into_values()
.map(|hook| mem::forget(hook))
.for_each(|()| ());
}
}
}
#[cfg(test)]
mod tests {
use crate::inline::tests::TEST_MUTEX;
use super::*;
#[test]
fn doc() {
let _guard = TEST_MUTEX.lock().unwrap();
type MyFn = extern "system" fn(u32) -> u32;
extern "system" fn original1(x: u32) -> u32 {
x + 1
}
extern "system" fn original2(x: u32) -> u32 {
x + 2
}
extern "system" fn hooked1(x: u32) -> u32 {
x + 0o721
}
extern "system" fn hooked2(x: u32) -> u32 {
x + 0o722
}
let mut hooks = InlineHookMap::new();
hooks.insert::<MyFn>(original1, hooked1);
hooks.insert::<MyFn>(original2, hooked2).enable().unwrap();
hooks
.enable()
.on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
.call();
assert_eq!(original1(0x100), 721); assert_eq!(original2(0x100), 722);
hooks
.disable()
.on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
.call();
assert_eq!(original1(0x100), 0x101); assert_eq!(original2(0x100), 0x102);
if let Some(hook) = hooks.get::<MyFn>(original1) {
println!("Hook is enabled: {}", hook.is_enabled());
}
}
}