Skip to main content

procmod_hook/
hook.rs

1use crate::alloc;
2use crate::error::{Error, Result};
3use crate::jump;
4use crate::protect;
5use crate::relocate;
6
7const TRAMPOLINE_SIZE: usize = 64;
8
9/// An installed inline hook that redirects a target function to a detour.
10///
11/// The hook overwrites the first few bytes of the target function with a jump
12/// to the detour. A trampoline is allocated nearby containing the original
13/// instructions and a jump back, allowing the detour to call the original.
14///
15/// The hook is automatically removed when dropped.
16pub struct Hook {
17    target: *mut u8,
18    trampoline_ptr: *mut u8,
19    original_bytes: Vec<u8>,
20    stolen_len: usize,
21    installed: bool,
22}
23
24unsafe impl Send for Hook {}
25unsafe impl Sync for Hook {}
26
27impl Hook {
28    /// Install an inline hook at `target`, redirecting calls to `detour`.
29    ///
30    /// Returns a `Hook` whose [`trampoline`](Hook::trampoline) can be used
31    /// to call the original function.
32    ///
33    /// # Safety
34    ///
35    /// - `target` must point to the start of a callable function in executable memory.
36    /// - `detour` must be a function with the same calling convention and signature.
37    /// - No thread may be executing the first 14 bytes of `target` during this call.
38    pub unsafe fn install(target: *const u8, detour: *const u8) -> Result<Self> {
39        let target_addr = target as u64;
40        let detour_addr = detour as u64;
41
42        let (patch, patch_len) = if let Some(rel) = jump::encode_rel32(target_addr, detour_addr) {
43            (rel.to_vec(), jump::REL32_LEN)
44        } else {
45            (jump::encode_abs64(detour_addr).to_vec(), jump::ABS64_LEN)
46        };
47
48        let read_len = patch_len.max(16);
49        let original_code = std::slice::from_raw_parts(target, read_len);
50
51        let trampoline = alloc::alloc_near(target as usize, TRAMPOLINE_SIZE)?;
52        let trampoline_addr = trampoline as u64;
53
54        let relocated =
55            match relocate::relocate(original_code, target_addr, trampoline_addr, patch_len) {
56                Ok(r) => r,
57                Err(e) => {
58                    alloc::free(trampoline, TRAMPOLINE_SIZE);
59                    return Err(e);
60                }
61            };
62
63        let jump_back_rip = trampoline_addr + relocated.bytes.len() as u64;
64        let continue_addr = target_addr + relocated.stolen_len as u64;
65        let jump_back = match jump::encode_rel32(jump_back_rip, continue_addr) {
66            Some(jb) => jb,
67            None => {
68                alloc::free(trampoline, TRAMPOLINE_SIZE);
69                return Err(Error::RelocationFailed);
70            }
71        };
72
73        std::ptr::copy_nonoverlapping(relocated.bytes.as_ptr(), trampoline, relocated.bytes.len());
74        std::ptr::copy_nonoverlapping(
75            jump_back.as_ptr(),
76            trampoline.add(relocated.bytes.len()),
77            jump::REL32_LEN,
78        );
79
80        let original_bytes = original_code[..relocated.stolen_len].to_vec();
81
82        let old_prot = match protect::make_writable(target as usize, relocated.stolen_len) {
83            Ok(p) => p,
84            Err(e) => {
85                alloc::free(trampoline, TRAMPOLINE_SIZE);
86                return Err(e);
87            }
88        };
89
90        std::ptr::copy_nonoverlapping(patch.as_ptr(), target as *mut u8, patch_len);
91
92        if relocated.stolen_len > patch_len {
93            std::ptr::write_bytes(
94                (target as *mut u8).add(patch_len),
95                0x90,
96                relocated.stolen_len - patch_len,
97            );
98        }
99
100        let _ = protect::restore_protection(target as usize, relocated.stolen_len, old_prot);
101
102        Ok(Hook {
103            target: target as *mut u8,
104            trampoline_ptr: trampoline,
105            original_bytes,
106            stolen_len: relocated.stolen_len,
107            installed: true,
108        })
109    }
110
111    /// Returns a pointer to the trampoline that calls the original function.
112    ///
113    /// Transmute this to the original function's type to call it:
114    ///
115    /// ```ignore
116    /// let original: extern "C" fn(i32) -> i32 = std::mem::transmute(hook.trampoline());
117    /// ```
118    pub fn trampoline(&self) -> *const u8 {
119        self.trampoline_ptr as *const u8
120    }
121
122    /// Remove the hook, restoring the original function bytes.
123    ///
124    /// # Safety
125    ///
126    /// No thread may be executing the trampoline or the patched region of the target.
127    pub unsafe fn unhook(&mut self) -> Result<()> {
128        if !self.installed {
129            return Err(Error::NotInstalled);
130        }
131
132        let old_prot = protect::make_writable(self.target as usize, self.stolen_len)?;
133        std::ptr::copy_nonoverlapping(self.original_bytes.as_ptr(), self.target, self.stolen_len);
134        let _ = protect::restore_protection(self.target as usize, self.stolen_len, old_prot);
135
136        alloc::free(self.trampoline_ptr, TRAMPOLINE_SIZE);
137        self.installed = false;
138
139        Ok(())
140    }
141}
142
143impl Drop for Hook {
144    fn drop(&mut self) {
145        if self.installed {
146            unsafe {
147                let _ = self.unhook();
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use std::sync::atomic::{AtomicPtr, Ordering};
157
158    static TRAMPOLINE_A: AtomicPtr<u8> = AtomicPtr::new(std::ptr::null_mut());
159    static TRAMPOLINE_B: AtomicPtr<u8> = AtomicPtr::new(std::ptr::null_mut());
160
161    #[inline(never)]
162    extern "C" fn add_one(x: i32) -> i32 {
163        std::hint::black_box(std::hint::black_box(x) + 1)
164    }
165
166    extern "C" fn add_one_detour(x: i32) -> i32 {
167        let original: extern "C" fn(i32) -> i32 =
168            unsafe { std::mem::transmute(TRAMPOLINE_A.load(Ordering::SeqCst)) };
169        original(x) + 100
170    }
171
172    #[inline(never)]
173    extern "C" fn double(x: i32) -> i32 {
174        std::hint::black_box(std::hint::black_box(x) * 2)
175    }
176
177    extern "C" fn double_detour(x: i32) -> i32 {
178        let original: extern "C" fn(i32) -> i32 =
179            unsafe { std::mem::transmute(TRAMPOLINE_B.load(Ordering::SeqCst)) };
180        original(x) + 1000
181    }
182
183    #[test]
184    fn hook_and_unhook() {
185        assert_eq!(add_one(5), 6);
186
187        let mut hook =
188            unsafe { Hook::install(add_one as *const u8, add_one_detour as *const u8) }.unwrap();
189
190        TRAMPOLINE_A.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
191
192        assert_eq!(add_one(5), 106);
193
194        unsafe { hook.unhook().unwrap() };
195
196        assert_eq!(add_one(5), 6);
197    }
198
199    #[test]
200    fn hook_auto_unhook_on_drop() {
201        assert_eq!(double(7), 14);
202
203        {
204            let hook =
205                unsafe { Hook::install(double as *const u8, double_detour as *const u8) }.unwrap();
206
207            TRAMPOLINE_B.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
208
209            assert_eq!(double(7), 1014);
210        }
211
212        assert_eq!(double(7), 14);
213    }
214
215    #[test]
216    fn unhook_twice_fails() {
217        let mut hook =
218            unsafe { Hook::install(add_one as *const u8, add_one_detour as *const u8) }.unwrap();
219
220        TRAMPOLINE_A.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
221
222        unsafe { hook.unhook().unwrap() };
223
224        let result = unsafe { hook.unhook() };
225        assert!(result.is_err());
226    }
227}