ib_hook/inline/map.rs
1use core::mem::{self, transmute, transmute_copy};
2use std::collections::{
3 HashMap,
4 hash_map::{Values, ValuesMut},
5};
6
7use bon::bon;
8use windows::core::HRESULT;
9
10use crate::{FnPtr, inline::InlineHook};
11
12/**
13A type-erased map of inline hooks indexed by target function pointer.
14
15A collection of inline hooks that can be enabled/disabled together.
16
17If you just want to store hooks of the same function type, just use
18[`HashMap<F, InlineHook<F>>`] instead of [`InlineHookMap`].
19
20## Examples
21```no_run
22// cargo add ib-hook --features inline
23use ib_hook::inline::{InlineHook, InlineHookMap};
24
25type MyFn = extern "system" fn(u32) -> u32;
26
27extern "system" fn original1(x: u32) -> u32 { x + 1 }
28extern "system" fn original2(x: u32) -> u32 { x + 2 }
29
30extern "system" fn hooked1(x: u32) -> u32 { x + 0o721 }
31extern "system" fn hooked2(x: u32) -> u32 { x + 0o722 }
32
33// Create a collection of hooks
34let mut hooks = InlineHookMap::new();
35hooks.insert::<MyFn>(original1, hooked1);
36// Insert and enable a hook
37hooks.insert::<MyFn>(original2, hooked2).enable().unwrap();
38
39// Enable all hooks at once
40hooks.enable().on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"));
41
42// Verify hooks are enabled
43assert_eq!(original1(0x100), 721); // redirected to hooked1
44assert_eq!(original2(0x100), 722); // redirected to hooked2
45
46// Disable all hooks at once
47hooks.disable().on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"));
48
49// Verify hooks are disabled
50assert_eq!(original1(0x100), 0x101); // back to original
51assert_eq!(original2(0x100), 0x102); // back to original
52
53// Access individual hooks by target function
54if let Some(hook) = hooks.get::<MyFn>(original1) {
55 println!("Hook is enabled: {}", hook.is_enabled());
56}
57```
58*/
59#[derive(Default)]
60pub struct InlineHookMap {
61 hooks: HashMap<fn(), InlineHook<fn()>>,
62 leaked: bool,
63}
64
65#[bon]
66impl InlineHookMap {
67 /// Creates a new empty [`InlineHookMap`].
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 /// Returns an iterator of the hooks.
73 pub fn hooks<'a>(&'a self) -> Values<'a, fn(), InlineHook<fn()>> {
74 self.hooks.values()
75 }
76
77 /// Returns a mutable iterator of the hooks.
78 pub fn hooks_mut<'a>(&'a mut self) -> ValuesMut<'a, fn(), InlineHook<fn()>> {
79 self.hooks.values_mut()
80 }
81
82 /// Returns the number of hooks in the collection.
83 pub fn len(&self) -> usize {
84 self.hooks.len()
85 }
86
87 /// Returns `true` if the collection is empty.
88 pub fn is_empty(&self) -> bool {
89 self.hooks.is_empty()
90 }
91
92 /// Add a new hook to the collection.
93 ///
94 /// The hook is created but not enabled. Use [`enable()`](InlineHookMap::enable) to enable it.
95 pub fn insert<'a, F: FnPtr>(&'a mut self, target: F, detour: F) -> &'a mut InlineHook<F> {
96 let hook = unsafe { InlineHook::new(target, detour).into_type_erased() };
97 // OccupiedEntry<'a, fn(), InlineHook<fn()>>
98 let entry = self.hooks.entry(hook.target()).insert_entry(hook);
99 // Likely undefined behavior, but anyway...
100 // unsafe { transmute::<OccupiedEntry<'a, F, InlineHook<F>>>(entry) }
101
102 // Only get()/get_mut()/into_mut() and remove() are useful,
103 // but there can only be one &mut, short remove() isn't quite useful.
104 let entry: &'a mut InlineHook<fn()> = entry.into_mut();
105 unsafe { entry.cast_mut() }
106 }
107
108 /// Enable all hooks in the collection.
109 ///
110 /// Errors are reported via the `on_error` callback (if provided).
111 /// Hooks that fail will remain disabled.
112 ///
113 /// TODO: Transaction
114 #[builder]
115 pub fn enable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
116 for (target, hook) in self.hooks.iter_mut() {
117 let hr = hook.enable();
118 if !hr.is_ok() {
119 if let Some(on_error) = on_error.as_mut() {
120 on_error(*target, hr);
121 }
122 }
123 }
124 }
125
126 /// Disable all hooks in the collection.
127 ///
128 /// Errors are reported via the `on_error` callback (if provided).
129 /// Hooks that fail will remain enabled.
130 ///
131 /// TODO: Transaction
132 #[builder]
133 pub fn disable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
134 for (target, hook) in self.hooks.iter_mut() {
135 let hr = hook.disable();
136 if !hr.is_ok() {
137 if let Some(on_error) = on_error.as_mut() {
138 on_error(*target, hr);
139 }
140 }
141 }
142 }
143
144 /// Get a reference to a specific hook by target function.
145 pub fn get<F: FnPtr>(&self, target: F) -> Option<&InlineHook<F>> {
146 let target = unsafe { transmute_copy(&target) };
147 let hook: Option<&InlineHook<fn()>> = self.hooks.get(&target);
148 unsafe { transmute(hook) }
149 }
150
151 /// Get a mutable reference to a specific hook by target function.
152 pub fn get_mut<F: FnPtr>(&mut self, target: F) -> Option<&mut InlineHook<F>> {
153 let target = unsafe { transmute_copy(&target) };
154 let hook: Option<&mut InlineHook<fn()>> = self.hooks.get_mut(&target);
155 unsafe { transmute(hook) }
156 }
157
158 /// Remove a hook from the collection by target function.
159 pub fn remove<F: FnPtr>(&mut self, target: F) -> Option<InlineHook<F>> {
160 let target = unsafe { transmute_copy(&target) };
161 self.hooks
162 .remove(&target)
163 .map(|hook| unsafe { hook.cast_into() })
164 }
165
166 /// Leak all hooks, preventing automatic [`disable()`](Self::disable) on drop.
167 pub fn leak(&mut self) {
168 self.leaked = true;
169 }
170}
171
172impl Drop for InlineHookMap {
173 fn drop(&mut self) {
174 if self.leaked {
175 let hooks = mem::take(&mut self.hooks);
176 hooks
177 .into_values()
178 .map(|hook| mem::forget(hook))
179 .for_each(|()| ());
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use crate::inline::tests::TEST_MUTEX;
187
188 use super::*;
189
190 #[test]
191 fn doc() {
192 let _guard = TEST_MUTEX.lock().unwrap();
193
194 type MyFn = extern "system" fn(u32) -> u32;
195
196 extern "system" fn original1(x: u32) -> u32 {
197 x + 1
198 }
199 extern "system" fn original2(x: u32) -> u32 {
200 x + 2
201 }
202
203 extern "system" fn hooked1(x: u32) -> u32 {
204 x + 0o721
205 }
206 extern "system" fn hooked2(x: u32) -> u32 {
207 x + 0o722
208 }
209
210 // Create a collection of hooks
211 let mut hooks = InlineHookMap::new();
212 hooks.insert::<MyFn>(original1, hooked1);
213 // Insert and enable a hook
214 hooks.insert::<MyFn>(original2, hooked2).enable().unwrap();
215
216 // Enable all hooks at once
217 hooks
218 .enable()
219 .on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
220 .call();
221
222 // Verify hooks are enabled
223 assert_eq!(original1(0x100), 721); // redirected to hooked1
224 assert_eq!(original2(0x100), 722); // redirected to hooked2
225
226 // Disable all hooks at once
227 hooks
228 .disable()
229 .on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
230 .call();
231
232 // Verify hooks are disabled
233 assert_eq!(original1(0x100), 0x101); // back to original
234 assert_eq!(original2(0x100), 0x102); // back to original
235
236 // Access individual hooks by target function
237 if let Some(hook) = hooks.get::<MyFn>(original1) {
238 println!("Hook is enabled: {}", hook.is_enabled());
239 }
240 }
241}