1use crate::alloc;
2use crate::error::{Error, Result};
3use crate::jump;
4use crate::protect;
5use crate::relocate;
6
7const TRAMPOLINE_SIZE: usize = 64;
8
9pub struct Hook {
17 target: *mut u8,
18 trampoline_ptr: *mut u8,
19 original_bytes: Vec<u8>,
20 stolen_len: usize,
21 installed: bool,
22}
23
24unsafe impl Send for Hook {}
25unsafe impl Sync for Hook {}
26
27impl Hook {
28 pub unsafe fn install(target: *const u8, detour: *const u8) -> Result<Self> {
39 let target_addr = target as u64;
40 let detour_addr = detour as u64;
41
42 let (patch, patch_len) = if let Some(rel) = jump::encode_rel32(target_addr, detour_addr) {
43 (rel.to_vec(), jump::REL32_LEN)
44 } else {
45 (jump::encode_abs64(detour_addr).to_vec(), jump::ABS64_LEN)
46 };
47
48 let read_len = patch_len.max(16);
49 let original_code = std::slice::from_raw_parts(target, read_len);
50
51 let trampoline = alloc::alloc_near(target as usize, TRAMPOLINE_SIZE)?;
52 let trampoline_addr = trampoline as u64;
53
54 let relocated =
55 match relocate::relocate(original_code, target_addr, trampoline_addr, patch_len) {
56 Ok(r) => r,
57 Err(e) => {
58 alloc::free(trampoline, TRAMPOLINE_SIZE);
59 return Err(e);
60 }
61 };
62
63 let jump_back_rip = trampoline_addr + relocated.bytes.len() as u64;
64 let continue_addr = target_addr + relocated.stolen_len as u64;
65 let jump_back = match jump::encode_rel32(jump_back_rip, continue_addr) {
66 Some(jb) => jb,
67 None => {
68 alloc::free(trampoline, TRAMPOLINE_SIZE);
69 return Err(Error::RelocationFailed);
70 }
71 };
72
73 std::ptr::copy_nonoverlapping(relocated.bytes.as_ptr(), trampoline, relocated.bytes.len());
74 std::ptr::copy_nonoverlapping(
75 jump_back.as_ptr(),
76 trampoline.add(relocated.bytes.len()),
77 jump::REL32_LEN,
78 );
79
80 let original_bytes = original_code[..relocated.stolen_len].to_vec();
81
82 let old_prot = match protect::make_writable(target as usize, relocated.stolen_len) {
83 Ok(p) => p,
84 Err(e) => {
85 alloc::free(trampoline, TRAMPOLINE_SIZE);
86 return Err(e);
87 }
88 };
89
90 std::ptr::copy_nonoverlapping(patch.as_ptr(), target as *mut u8, patch_len);
91
92 if relocated.stolen_len > patch_len {
93 std::ptr::write_bytes(
94 (target as *mut u8).add(patch_len),
95 0x90,
96 relocated.stolen_len - patch_len,
97 );
98 }
99
100 let _ = protect::restore_protection(target as usize, relocated.stolen_len, old_prot);
101
102 Ok(Hook {
103 target: target as *mut u8,
104 trampoline_ptr: trampoline,
105 original_bytes,
106 stolen_len: relocated.stolen_len,
107 installed: true,
108 })
109 }
110
111 pub fn trampoline(&self) -> *const u8 {
119 self.trampoline_ptr as *const u8
120 }
121
122 pub unsafe fn unhook(&mut self) -> Result<()> {
128 if !self.installed {
129 return Err(Error::NotInstalled);
130 }
131
132 let old_prot = protect::make_writable(self.target as usize, self.stolen_len)?;
133 std::ptr::copy_nonoverlapping(self.original_bytes.as_ptr(), self.target, self.stolen_len);
134 let _ = protect::restore_protection(self.target as usize, self.stolen_len, old_prot);
135
136 alloc::free(self.trampoline_ptr, TRAMPOLINE_SIZE);
137 self.installed = false;
138
139 Ok(())
140 }
141}
142
143impl Drop for Hook {
144 fn drop(&mut self) {
145 if self.installed {
146 unsafe {
147 let _ = self.unhook();
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use std::sync::atomic::{AtomicPtr, Ordering};
157
158 static TRAMPOLINE_A: AtomicPtr<u8> = AtomicPtr::new(std::ptr::null_mut());
159 static TRAMPOLINE_B: AtomicPtr<u8> = AtomicPtr::new(std::ptr::null_mut());
160
161 #[inline(never)]
162 extern "C" fn add_one(x: i32) -> i32 {
163 std::hint::black_box(std::hint::black_box(x) + 1)
164 }
165
166 extern "C" fn add_one_detour(x: i32) -> i32 {
167 let original: extern "C" fn(i32) -> i32 =
168 unsafe { std::mem::transmute(TRAMPOLINE_A.load(Ordering::SeqCst)) };
169 original(x) + 100
170 }
171
172 #[inline(never)]
173 extern "C" fn double(x: i32) -> i32 {
174 std::hint::black_box(std::hint::black_box(x) * 2)
175 }
176
177 extern "C" fn double_detour(x: i32) -> i32 {
178 let original: extern "C" fn(i32) -> i32 =
179 unsafe { std::mem::transmute(TRAMPOLINE_B.load(Ordering::SeqCst)) };
180 original(x) + 1000
181 }
182
183 #[test]
184 fn hook_and_unhook() {
185 assert_eq!(add_one(5), 6);
186
187 let mut hook =
188 unsafe { Hook::install(add_one as *const u8, add_one_detour as *const u8) }.unwrap();
189
190 TRAMPOLINE_A.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
191
192 assert_eq!(add_one(5), 106);
193
194 unsafe { hook.unhook().unwrap() };
195
196 assert_eq!(add_one(5), 6);
197 }
198
199 #[test]
200 fn hook_auto_unhook_on_drop() {
201 assert_eq!(double(7), 14);
202
203 {
204 let hook =
205 unsafe { Hook::install(double as *const u8, double_detour as *const u8) }.unwrap();
206
207 TRAMPOLINE_B.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
208
209 assert_eq!(double(7), 1014);
210 }
211
212 assert_eq!(double(7), 14);
213 }
214
215 #[test]
216 fn unhook_twice_fails() {
217 let mut hook =
218 unsafe { Hook::install(add_one as *const u8, add_one_detour as *const u8) }.unwrap();
219
220 TRAMPOLINE_A.store(hook.trampoline() as *mut u8, Ordering::SeqCst);
221
222 unsafe { hook.unhook().unwrap() };
223
224 let result = unsafe { hook.unhook() };
225 assert!(result.is_err());
226 }
227}