1use bitflags::bitflags;
2use iced_x86::{
3 BlockEncoder, BlockEncoderOptions, Decoder, DecoderOptions, Instruction, InstructionBlock,
4};
5use std::io::{Cursor, Seek, SeekFrom, Write};
6use std::slice;
7
8#[cfg(windows)]
9use core::ffi::c_void;
10#[cfg(windows)]
11use windows_sys::Win32::Foundation::GetLastError;
12#[cfg(windows)]
13use windows_sys::Win32::System::Memory::VirtualProtect;
14
15#[cfg(unix)]
16use libc::{__errno_location, c_void, mprotect, sysconf};
17
18use crate::err::HookError;
19
20const MAX_INST_LEN: usize = 15;
21const JMP_INST_SIZE: usize = 5;
22
23pub type JmpBackRoutine = unsafe extern "cdecl" fn(regs: *mut Registers, user_data: usize);
31
32pub type RetnRoutine =
46 unsafe extern "cdecl" fn(regs: *mut Registers, ori_func_ptr: usize, user_data: usize) -> usize;
47
48pub type JmpToAddrRoutine =
57 unsafe extern "cdecl" fn(regs: *mut Registers, ori_func_ptr: usize, src_addr: usize);
58
59pub type JmpToRetRoutine =
72 unsafe extern "cdecl" fn(regs: *mut Registers, ori_func_ptr: usize, src_addr: usize) -> usize;
73
74pub enum HookType {
76 JmpBack(JmpBackRoutine),
78
79 Retn(usize, RetnRoutine),
82
83 JmpToAddr(usize, JmpToAddrRoutine),
85
86 JmpToRet(JmpToRetRoutine),
88}
89
90#[repr(C)]
92#[derive(Debug)]
93pub struct Registers {
94 pub eflags: u32,
96 pub edi: u32,
98 pub esi: u32,
100 pub ebp: u32,
102 pub esp: u32,
104 pub ebx: u32,
106 pub edx: u32,
108 pub ecx: u32,
110 pub eax: u32,
112}
113
114impl Registers {
115 #[must_use]
125 pub unsafe fn get_arg(&self, cnt: usize) -> u32 {
126 *((self.esp as usize + cnt * 4) as *mut u32)
127 }
128}
129
130pub trait ThreadCallback {
133 fn pre(&self) -> bool;
135 fn post(&self);
137}
138
139pub enum CallbackOption {
141 Some(Box<dyn ThreadCallback>),
143 None,
145}
146
147bitflags! {
148 pub struct HookFlags:u32 {
150 const NOT_MODIFY_MEMORY_PROTECT = 0x1;
152 }
153}
154
155pub struct Hooker {
158 addr: usize,
159 hook_type: HookType,
160 thread_cb: CallbackOption,
161 flags: HookFlags,
162 user_data: usize,
163}
164
165pub struct HookPoint {
167 addr: usize,
168 trampoline: Box<[u8; 100]>,
169 trampoline_prot: u32,
170 origin: Vec<u8>,
171 thread_cb: CallbackOption,
172 flags: HookFlags,
173}
174
175#[cfg(not(target_arch = "x86"))]
176fn env_lock() {
177 panic!("This crate should only be used in arch x86_32!")
178}
179#[cfg(target_arch = "x86")]
180fn env_lock() {}
181
182impl Hooker {
183 #[must_use]
192 pub fn new(
193 addr: usize,
194 hook_type: HookType,
195 thread_cb: CallbackOption,
196 user_data: usize,
197 flags: HookFlags,
198 ) -> Self {
199 env_lock();
200 Self {
201 addr,
202 hook_type,
203 thread_cb,
204 user_data,
205 flags,
206 }
207 }
208
209 pub unsafe fn hook(self) -> Result<HookPoint, HookError> {
222 let (moving_insts, origin) = get_moving_insts(self.addr)?;
223 let trampoline =
224 generate_trampoline(&self, moving_insts, origin.len() as u8, self.user_data)?;
225 let trampoline_prot = modify_mem_protect(trampoline.as_ptr() as usize, trampoline.len())?;
226 if !self.flags.contains(HookFlags::NOT_MODIFY_MEMORY_PROTECT) {
227 let old_prot = modify_mem_protect(self.addr, JMP_INST_SIZE)?;
228 let ret = modify_jmp_with_thread_cb(&self, trampoline.as_ptr() as usize);
229 recover_mem_protect(self.addr, JMP_INST_SIZE, old_prot);
230 ret?;
231 } else {
232 modify_jmp_with_thread_cb(&self, trampoline.as_ptr() as usize)?;
233 }
234 Ok(HookPoint {
235 addr: self.addr,
236 trampoline,
237 trampoline_prot,
238 origin,
239 thread_cb: self.thread_cb,
240 flags: self.flags,
241 })
242 }
243}
244
245impl HookPoint {
246 pub unsafe fn unhook(self) -> Result<(), HookError> {
248 self.unhook_by_ref()
249 }
250
251 fn unhook_by_ref(&self) -> Result<(), HookError> {
252 let ret: Result<(), HookError>;
253 if !self.flags.contains(HookFlags::NOT_MODIFY_MEMORY_PROTECT) {
254 let old_prot = modify_mem_protect(self.addr, JMP_INST_SIZE)?;
255 ret = recover_jmp_with_thread_cb(self);
256 recover_mem_protect(self.addr, JMP_INST_SIZE, old_prot);
257 } else {
258 ret = recover_jmp_with_thread_cb(self)
259 }
260 recover_mem_protect(
261 self.trampoline.as_ptr() as usize,
262 self.trampoline.len(),
263 self.trampoline_prot,
264 );
265 ret
266 }
267}
268
269impl Drop for HookPoint {
271 fn drop(&mut self) {
272 self.unhook_by_ref().unwrap_or_default();
273 }
274}
275
276fn get_moving_insts(addr: usize) -> Result<(Vec<Instruction>, Vec<u8>), HookError> {
277 let code_slice =
278 unsafe { slice::from_raw_parts(addr as *const u8, MAX_INST_LEN * JMP_INST_SIZE) };
279 let mut decoder = Decoder::new(32, code_slice, DecoderOptions::NONE);
280 decoder.set_ip(addr as u64);
281
282 let mut total_bytes = 0;
283 let mut ori_insts: Vec<Instruction> = vec![];
284 for inst in &mut decoder {
285 if inst.is_invalid() {
286 return Err(HookError::Disassemble);
287 }
288 ori_insts.push(inst);
289 total_bytes += inst.len();
290 if total_bytes >= JMP_INST_SIZE {
291 break;
292 }
293 }
294
295 Ok((ori_insts, code_slice[0..decoder.position()].into()))
296}
297
298#[cfg(windows)]
299fn modify_mem_protect(addr: usize, len: usize) -> Result<u32, HookError> {
300 let mut old_prot: u32 = 0;
301 let old_prot_ptr = std::ptr::addr_of_mut!(old_prot);
302 let ret = unsafe { VirtualProtect(addr as *const c_void, len, 0x40, old_prot_ptr) };
304 if ret == 0 {
305 Err(HookError::MemoryProtect(unsafe { GetLastError() }))
306 } else {
307 Ok(old_prot)
308 }
309}
310
311#[cfg(unix)]
312fn modify_mem_protect(addr: usize, len: usize) -> Result<u32, HookError> {
313 let page_size = unsafe { sysconf(30) }; if len > page_size.try_into().unwrap() {
315 Err(HookError::InvalidParameter)
316 } else {
317 let ret = unsafe {
319 mprotect(
320 (addr & !(page_size as usize - 1)) as *mut c_void,
321 page_size as usize,
322 7,
323 )
324 };
325 if ret != 0 {
326 let err = unsafe { *(__errno_location()) };
327 Err(HookError::MemoryProtect(err as u32))
328 } else {
329 Ok(7)
331 }
332 }
333}
334
335#[cfg(windows)]
336fn recover_mem_protect(addr: usize, len: usize, old: u32) {
337 let mut old_prot: u32 = 0;
338 let old_prot_ptr = std::ptr::addr_of_mut!(old_prot);
339 unsafe { VirtualProtect(addr as *const c_void, len, old, old_prot_ptr) };
340}
341
342#[cfg(unix)]
343fn recover_mem_protect(addr: usize, _: usize, old: u32) {
344 let page_size = unsafe { sysconf(30) }; unsafe {
346 mprotect(
347 (addr & !(page_size as usize - 1)) as *mut c_void,
348 page_size as usize,
349 old as i32,
350 )
351 };
352}
353
354fn write_relative_off<T: Write + Seek>(
355 buf: &mut T,
356 base_addr: u32,
357 dst_addr: u32,
358) -> Result<(), HookError> {
359 let dst_addr = dst_addr as i32;
360 let cur_pos = buf.stream_position().unwrap() as i32;
361 let call_off = dst_addr - (base_addr as i32 + cur_pos + 4);
362 buf.write(&call_off.to_le_bytes())?;
363 Ok(())
364}
365
366fn move_code_to_addr(ori_insts: &Vec<Instruction>, dest_addr: u32) -> Result<Vec<u8>, HookError> {
367 let block = InstructionBlock::new(ori_insts, u64::from(dest_addr));
368 let encoded = BlockEncoder::encode(32, block, BlockEncoderOptions::NONE)
369 .map_err(|_| HookError::MoveCode)?;
370 Ok(encoded.code_buffer)
371}
372
373fn write_ori_func_addr<T: Write + Seek>(buf: &mut T, ori_func_addr_off: u32, ori_func_off: u32) {
374 let pos = buf.stream_position().unwrap();
375 buf.seek(SeekFrom::Start(u64::from(ori_func_addr_off)))
376 .unwrap();
377 buf.write(&ori_func_off.to_le_bytes()).unwrap();
378 buf.seek(SeekFrom::Start(pos)).unwrap();
379}
380
381fn generate_jmp_back_trampoline<T: Write + Seek>(
382 buf: &mut T,
383 trampoline_base_addr: u32,
384 moving_code: &Vec<Instruction>,
385 ori_addr: u32,
386 cb: JmpBackRoutine,
387 ori_len: u8,
388 user_data: usize,
389) -> Result<(), HookError> {
390 buf.write(&[0x68])?;
392 buf.write(&user_data.to_le_bytes())?;
393
394 buf.write(&[0x55, 0xe8])?;
397 write_relative_off(buf, trampoline_base_addr, cb as u32)?;
398
399 buf.write(&[0x83, 0xc4, 0x08])?;
401 buf.write(&[0x9d, 0x61])?;
404
405 let cur_pos = buf.stream_position().unwrap() as u32;
406 buf.write(&move_code_to_addr(
407 moving_code,
408 trampoline_base_addr + cur_pos,
409 )?)?;
410 buf.write(&[0xe9])?;
412 write_relative_off(buf, trampoline_base_addr, ori_addr + u32::from(ori_len))
413}
414
415fn generate_retn_trampoline<T: Write + Seek>(
416 buf: &mut T,
417 trampoline_base_addr: u32,
418 moving_code: &Vec<Instruction>,
419 ori_addr: u32,
420 retn_val: u16,
421 cb: RetnRoutine,
422 ori_len: u8,
423 user_data: usize,
424) -> Result<(), HookError> {
425 buf.write(&[0x68])?;
427 buf.write(&user_data.to_le_bytes())?;
428
429 let ori_func_addr_off = buf.stream_position().unwrap() + 1;
433 buf.write(&[0x68, 0, 0, 0, 0, 0x55, 0xe8])?;
434 write_relative_off(buf, trampoline_base_addr, cb as u32)?;
435
436 buf.write(&[0x83, 0xc4, 0x0c])?;
438 buf.write(&[0x89, 0x44, 0x24, 0x20])?;
440 buf.write(&[0x9d, 0x61])?;
443 if retn_val == 0 {
444 buf.write(&[0xc3])?;
446 } else {
447 buf.write(&[0xc2])?;
449 buf.write(&retn_val.to_le_bytes())?;
450 }
451 let ori_func_off = buf.stream_position().unwrap() as u32;
452 write_ori_func_addr(
453 buf,
454 ori_func_addr_off as u32,
455 trampoline_base_addr + ori_func_off,
456 );
457
458 let cur_pos = buf.stream_position().unwrap() as u32;
459 buf.write(&move_code_to_addr(
460 moving_code,
461 trampoline_base_addr + cur_pos,
462 )?)?;
463
464 buf.write(&[0xe9])?;
466 write_relative_off(buf, trampoline_base_addr, ori_addr + u32::from(ori_len))
467}
468
469fn generate_jmp_addr_trampoline<T: Write + Seek>(
470 buf: &mut T,
471 trampoline_base_addr: u32,
472 moving_code: &Vec<Instruction>,
473 ori_addr: u32,
474 dest_addr: u32,
475 cb: JmpToAddrRoutine,
476 ori_len: u8,
477 user_data: usize,
478) -> Result<(), HookError> {
479 buf.write(&[0x68])?;
481 buf.write(&user_data.to_le_bytes())?;
482
483 let ori_func_addr_off = buf.stream_position().unwrap() + 1;
487 buf.write(&[0x68, 0, 0, 0, 0, 0x55, 0xe8])?;
488 write_relative_off(buf, trampoline_base_addr, cb as u32)?;
489
490 buf.write(&[0x83, 0xc4, 0x0c])?;
492 buf.write(&[0x9d, 0x61])?;
495 buf.write(&[0xe9])?;
497 write_relative_off(buf, trampoline_base_addr, dest_addr + u32::from(ori_len))?;
498
499 let ori_func_off = buf.stream_position().unwrap() as u32;
500 write_ori_func_addr(
501 buf,
502 ori_func_addr_off as u32,
503 trampoline_base_addr + ori_func_off,
504 );
505
506 let cur_pos = buf.stream_position().unwrap() as u32;
507 buf.write(&move_code_to_addr(
508 moving_code,
509 trampoline_base_addr + cur_pos,
510 )?)?;
511
512 buf.write(&[0xe9])?;
514 write_relative_off(buf, trampoline_base_addr, ori_addr + u32::from(ori_len))
515}
516
517fn generate_jmp_ret_trampoline<T: Write + Seek>(
518 buf: &mut T,
519 trampoline_base_addr: u32,
520 moving_code: &Vec<Instruction>,
521 ori_addr: u32,
522 cb: JmpToRetRoutine,
523 ori_len: u8,
524 user_data: usize,
525) -> Result<(), HookError> {
526 buf.write(&[0x68])?;
528 buf.write(&user_data.to_le_bytes())?;
529
530 let ori_func_addr_off = buf.stream_position().unwrap() + 1;
534 buf.write(&[0x68, 0, 0, 0, 0, 0x55, 0xe8])?;
535 write_relative_off(buf, trampoline_base_addr, cb as u32)?;
536
537 buf.write(&[0x83, 0xc4, 0x0c])?;
539 buf.write(&[0x89, 0x44, 0x24, 0xfc])?;
541 buf.write(&[0x9d, 0x61])?;
544 buf.write(&[0xff, 0x64, 0x24, 0xd8])?;
546
547 let ori_func_off = buf.stream_position().unwrap() as u32;
548 write_ori_func_addr(
549 buf,
550 ori_func_addr_off as u32,
551 trampoline_base_addr + ori_func_off,
552 );
553
554 let cur_pos = buf.stream_position().unwrap() as u32;
555 buf.write(&move_code_to_addr(
556 moving_code,
557 trampoline_base_addr + cur_pos,
558 )?)?;
559
560 buf.write(&[0xe9])?;
562 write_relative_off(buf, trampoline_base_addr, ori_addr + u32::from(ori_len))
563}
564
565fn generate_trampoline(
566 hooker: &Hooker,
567 moving_code: Vec<Instruction>,
568 ori_len: u8,
569 user_data: usize,
570) -> Result<Box<[u8; 100]>, HookError> {
571 let mut raw_buffer = Box::new([0u8; 100]);
572 let trampoline_addr = raw_buffer.as_ptr() as u32;
573 let mut buf = Cursor::new(&mut raw_buffer[..]);
574
575 buf.write(&[0x60, 0x9c, 0x8b, 0xec])?;
579
580 match hooker.hook_type {
581 HookType::JmpBack(cb) => generate_jmp_back_trampoline(
582 &mut buf,
583 trampoline_addr,
584 &moving_code,
585 hooker.addr as u32,
586 cb,
587 ori_len,
588 user_data,
589 ),
590 HookType::Retn(val, cb) => generate_retn_trampoline(
591 &mut buf,
592 trampoline_addr,
593 &moving_code,
594 hooker.addr as u32,
595 val as u16,
596 cb,
597 ori_len,
598 user_data,
599 ),
600 HookType::JmpToAddr(dest, cb) => generate_jmp_addr_trampoline(
601 &mut buf,
602 trampoline_addr,
603 &moving_code,
604 hooker.addr as u32,
605 dest as u32,
606 cb,
607 ori_len,
608 user_data,
609 ),
610 HookType::JmpToRet(cb) => generate_jmp_ret_trampoline(
611 &mut buf,
612 trampoline_addr,
613 &moving_code,
614 hooker.addr as u32,
615 cb,
616 ori_len,
617 user_data,
618 ),
619 }?;
620
621 Ok(raw_buffer)
622}
623
624fn modify_jmp(dest_addr: usize, trampoline_addr: usize) -> Result<(), HookError> {
625 let buf = unsafe { slice::from_raw_parts_mut(dest_addr as *mut u8, JMP_INST_SIZE) };
626 buf[0] = 0xe9;
628 let rel_off = trampoline_addr as i32 - (dest_addr as i32 + 5);
629 buf[1..5].copy_from_slice(&rel_off.to_le_bytes());
630 Ok(())
631}
632
633fn modify_jmp_with_thread_cb(hook: &Hooker, trampoline_addr: usize) -> Result<(), HookError> {
634 if let CallbackOption::Some(cbs) = &hook.thread_cb {
635 if !cbs.pre() {
636 return Err(HookError::PreHook);
637 }
638 let ret = modify_jmp(hook.addr, trampoline_addr);
639 cbs.post();
640 ret
641 } else {
642 modify_jmp(hook.addr, trampoline_addr)
643 }
644}
645
646fn recover_jmp(dest_addr: usize, origin: &[u8]) {
647 let buf = unsafe { slice::from_raw_parts_mut(dest_addr as *mut u8, origin.len()) };
648 buf.copy_from_slice(origin);
650}
651
652fn recover_jmp_with_thread_cb(hook: &HookPoint) -> Result<(), HookError> {
653 if let CallbackOption::Some(cbs) = &hook.thread_cb {
654 if !cbs.pre() {
655 return Err(HookError::PreHook);
656 }
657 recover_jmp(hook.addr, &hook.origin);
658 cbs.post();
659 } else {
660 recover_jmp(hook.addr, &hook.origin);
661 }
662 Ok(())
663}
664
665#[cfg(target_arch = "x86")]
666mod tests {
667 #[allow(unused_imports)]
668 use super::*;
669
670 #[cfg(test)]
671 #[inline(never)]
672 fn foo(x: u32) -> u32 {
673 println!("original foo, x:{}", x);
674 x * x
675 }
676 #[cfg(test)]
677 unsafe extern "cdecl" fn on_foo(
678 reg: *mut Registers,
679 old_func: usize,
680 user_data: usize,
681 ) -> usize {
682 let old_func = std::mem::transmute::<usize, fn(u32) -> u32>(old_func);
683 old_func((*reg).get_arg(1)) as usize + user_data
684 }
685
686 #[test]
687 fn test_hook_function_cdecl() {
688 assert_eq!(foo(5), 25);
689 let hooker = Hooker::new(
690 foo as usize,
691 HookType::Retn(0, on_foo),
692 CallbackOption::None,
693 100,
694 HookFlags::empty(),
695 );
696 let info = unsafe { hooker.hook().unwrap() };
697 assert_eq!(foo(5), 125);
698 unsafe { info.unhook().unwrap() };
699 assert_eq!(foo(5), 25);
700 }
701
702 #[cfg(test)]
703 #[inline(never)]
704 extern "stdcall" fn foo2(x: u32) -> u32 {
705 println!("original foo, x:{}", x);
706 x * x
707 }
708 #[cfg(test)]
709 unsafe extern "cdecl" fn on_foo2(
710 reg: *mut Registers,
711 old_func: usize,
712 user_data: usize,
713 ) -> usize {
714 let old_func = std::mem::transmute::<usize, extern "stdcall" fn(u32) -> u32>(old_func);
715 old_func((*reg).get_arg(1)) as usize + user_data
716 }
717 #[test]
718 fn test_hook_function_stdcall() {
719 assert_eq!(foo2(5), 25);
720 let hooker = Hooker::new(
721 foo2 as usize,
722 HookType::Retn(4, on_foo2),
723 CallbackOption::None,
724 100,
725 HookFlags::empty(),
726 );
727 let info = unsafe { hooker.hook().unwrap() };
728 assert_eq!(foo2(5), 125);
729 unsafe { info.unhook().unwrap() };
730 assert_eq!(foo2(5), 25);
731 }
732}