Skip to main content

ib_hook/inline/
map.rs

1use core::mem::{self, transmute, transmute_copy};
2use std::collections::{
3    HashMap,
4    hash_map::{Values, ValuesMut},
5};
6
7use bon::bon;
8use windows::core::HRESULT;
9
10use crate::{FnPtr, inline::InlineHook};
11
12/**
13A type-erased map of inline hooks indexed by target function pointer.
14
15A collection of inline hooks that can be enabled/disabled together.
16
17If you just want to store hooks of the same function type, just use
18[`HashMap<F, InlineHook<F>>`] instead of [`InlineHookMap`].
19
20## Examples
21```no_run
22// cargo add ib-hook --features inline
23use ib_hook::inline::{InlineHook, InlineHookMap};
24
25type MyFn = extern "system" fn(u32) -> u32;
26
27extern "system" fn original1(x: u32) -> u32 { x + 1 }
28extern "system" fn original2(x: u32) -> u32 { x + 2 }
29
30extern "system" fn hooked1(x: u32) -> u32 { x + 0o721 }
31extern "system" fn hooked2(x: u32) -> u32 { x + 0o722 }
32
33// Create a collection of hooks
34let mut hooks = InlineHookMap::new();
35hooks.insert::<MyFn>(original1, hooked1);
36// Insert and enable a hook
37hooks.insert::<MyFn>(original2, hooked2).enable().unwrap();
38
39// Enable all hooks at once
40hooks.enable().on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"));
41
42// Verify hooks are enabled
43assert_eq!(original1(0x100), 721); // redirected to hooked1
44assert_eq!(original2(0x100), 722); // redirected to hooked2
45
46// Disable all hooks at once
47hooks.disable().on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"));
48
49// Verify hooks are disabled
50assert_eq!(original1(0x100), 0x101); // back to original
51assert_eq!(original2(0x100), 0x102); // back to original
52
53// Access individual hooks by target function
54if let Some(hook) = hooks.get::<MyFn>(original1) {
55    println!("Hook is enabled: {}", hook.is_enabled());
56}
57```
58*/
59#[derive(Default)]
60pub struct InlineHookMap {
61    hooks: HashMap<fn(), InlineHook<fn()>>,
62    leaked: bool,
63}
64
65#[bon]
66impl InlineHookMap {
67    /// Creates a new empty [`InlineHookMap`].
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Returns an iterator of the hooks.
73    pub fn hooks<'a>(&'a self) -> Values<'a, fn(), InlineHook<fn()>> {
74        self.hooks.values()
75    }
76
77    /// Returns a mutable iterator of the hooks.
78    pub fn hooks_mut<'a>(&'a mut self) -> ValuesMut<'a, fn(), InlineHook<fn()>> {
79        self.hooks.values_mut()
80    }
81
82    /// Returns the number of hooks in the collection.
83    pub fn len(&self) -> usize {
84        self.hooks.len()
85    }
86
87    /// Returns `true` if the collection is empty.
88    pub fn is_empty(&self) -> bool {
89        self.hooks.is_empty()
90    }
91
92    /// Add a new hook to the collection.
93    ///
94    /// The hook is created but not enabled. Use [`enable()`](InlineHookMap::enable) to enable it.
95    pub fn insert<'a, F: FnPtr>(&'a mut self, target: F, detour: F) -> &'a mut InlineHook<F> {
96        let hook = unsafe { InlineHook::new(target, detour).into_type_erased() };
97        // OccupiedEntry<'a, fn(), InlineHook<fn()>>
98        let entry = self.hooks.entry(hook.target()).insert_entry(hook);
99        // Likely undefined behavior, but anyway...
100        // unsafe { transmute::<OccupiedEntry<'a, F, InlineHook<F>>>(entry) }
101
102        // Only get()/get_mut()/into_mut() and remove() are useful,
103        // but there can only be one &mut, short remove() isn't quite useful.
104        let entry: &'a mut InlineHook<fn()> = entry.into_mut();
105        unsafe { entry.cast_mut() }
106    }
107
108    /// Enable all hooks in the collection.
109    ///
110    /// Errors are reported via the `on_error` callback (if provided).
111    /// Hooks that fail will remain disabled.
112    ///
113    /// TODO: Transaction
114    #[builder]
115    pub fn enable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
116        for (target, hook) in self.hooks.iter_mut() {
117            let hr = hook.enable();
118            if !hr.is_ok() {
119                if let Some(on_error) = on_error.as_mut() {
120                    on_error(*target, hr);
121                }
122            }
123        }
124    }
125
126    /// Disable all hooks in the collection.
127    ///
128    /// Errors are reported via the `on_error` callback (if provided).
129    /// Hooks that fail will remain enabled.
130    ///
131    /// TODO: Transaction
132    #[builder]
133    pub fn disable(&mut self, mut on_error: Option<impl FnMut(fn(), HRESULT)>) {
134        for (target, hook) in self.hooks.iter_mut() {
135            let hr = hook.disable();
136            if !hr.is_ok() {
137                if let Some(on_error) = on_error.as_mut() {
138                    on_error(*target, hr);
139                }
140            }
141        }
142    }
143
144    /// Get a reference to a specific hook by target function.
145    pub fn get<F: FnPtr>(&self, target: F) -> Option<&InlineHook<F>> {
146        let target = unsafe { transmute_copy(&target) };
147        let hook: Option<&InlineHook<fn()>> = self.hooks.get(&target);
148        unsafe { transmute(hook) }
149    }
150
151    /// Get a mutable reference to a specific hook by target function.
152    pub fn get_mut<F: FnPtr>(&mut self, target: F) -> Option<&mut InlineHook<F>> {
153        let target = unsafe { transmute_copy(&target) };
154        let hook: Option<&mut InlineHook<fn()>> = self.hooks.get_mut(&target);
155        unsafe { transmute(hook) }
156    }
157
158    /// Remove a hook from the collection by target function.
159    pub fn remove<F: FnPtr>(&mut self, target: F) -> Option<InlineHook<F>> {
160        let target = unsafe { transmute_copy(&target) };
161        self.hooks
162            .remove(&target)
163            .map(|hook| unsafe { hook.cast_into() })
164    }
165
166    /// Leak all hooks, preventing automatic [`disable()`](Self::disable) on drop.
167    pub fn leak(&mut self) {
168        self.leaked = true;
169    }
170}
171
172impl Drop for InlineHookMap {
173    fn drop(&mut self) {
174        if self.leaked {
175            let hooks = mem::take(&mut self.hooks);
176            hooks
177                .into_values()
178                .map(|hook| mem::forget(hook))
179                .for_each(|()| ());
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use crate::inline::tests::TEST_MUTEX;
187
188    use super::*;
189
190    #[test]
191    fn doc() {
192        let _guard = TEST_MUTEX.lock().unwrap();
193
194        type MyFn = extern "system" fn(u32) -> u32;
195
196        extern "system" fn original1(x: u32) -> u32 {
197            x + 1
198        }
199        extern "system" fn original2(x: u32) -> u32 {
200            x + 2
201        }
202
203        extern "system" fn hooked1(x: u32) -> u32 {
204            x + 0o721
205        }
206        extern "system" fn hooked2(x: u32) -> u32 {
207            x + 0o722
208        }
209
210        // Create a collection of hooks
211        let mut hooks = InlineHookMap::new();
212        hooks.insert::<MyFn>(original1, hooked1);
213        // Insert and enable a hook
214        hooks.insert::<MyFn>(original2, hooked2).enable().unwrap();
215
216        // Enable all hooks at once
217        hooks
218            .enable()
219            .on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
220            .call();
221
222        // Verify hooks are enabled
223        assert_eq!(original1(0x100), 721); // redirected to hooked1
224        assert_eq!(original2(0x100), 722); // redirected to hooked2
225
226        // Disable all hooks at once
227        hooks
228            .disable()
229            .on_error(|target, e| eprintln!("Target {target:?} failed: {e:?}"))
230            .call();
231
232        // Verify hooks are disabled
233        assert_eq!(original1(0x100), 0x101); // back to original
234        assert_eq!(original2(0x100), 0x102); // back to original
235
236        // Access individual hooks by target function
237        if let Some(hook) = hooks.get::<MyFn>(original1) {
238            println!("Hook is enabled: {}", hook.is_enabled());
239        }
240    }
241}