use crate::buffer::{allocate_buffer, is_executable_address};
use crate::disasm::decode_instruction;
use crate::error::{HookError, Result};
use crate::instruction::*;
use std::ffi::c_void;
use std::ptr;
#[cfg(not(target_arch = "x86_64"))]
compile_error!("Trampoline module only supports x86_64 architecture");
const TRAMPOLINE_MAX_SIZE: usize = 64 - std::mem::size_of::<JmpAbs>();
const MIN_HOOK_SIZE: usize = 5;
const SHORT_JMP_SIZE: usize = 2;
fn is_code_padding(code: &[u8]) -> bool {
if code.is_empty() {
return false;
}
let first_byte = code[0];
if first_byte != 0x00 && first_byte != 0x90 && first_byte != 0xCC {
return false;
}
code.iter().all(|&byte| byte == first_byte)
}
pub fn create_trampoline_function(trampoline: &mut Trampoline) -> Result<()> {
let call_abs = CallAbs {
opcode0: 0xFF,
opcode1: 0x15,
dummy0: 0x00000002, dummy1: 0xEB, dummy2: 0x08,
address: 0x0000000000000000,
};
let jmp_abs = JmpAbs {
opcode0: 0xFF,
opcode1: 0x25,
dummy: 0x00000000, address: 0x0000000000000000,
};
let jcc_abs = JccAbs {
opcode: 0x70,
dummy0: 0x0E, dummy1: 0xFF, dummy2: 0x25,
dummy3: 0x00000000,
address: 0x0000000000000000,
};
let mut old_pos = 0u8;
let mut new_pos = 0u8;
let mut jmp_dest = 0usize; let mut finished = false;
let mut inst_buf = [0u8; 16];
trampoline.patch_above = false;
trampoline.n_ip = 0;
loop {
let old_inst_addr = trampoline.target as usize + old_pos as usize;
let new_inst_addr = trampoline.trampoline as usize + new_pos as usize;
let code_slice = unsafe { std::slice::from_raw_parts(old_inst_addr as *const u8, 16) };
let inst = decode_instruction(code_slice);
if inst.error {
return Err(HookError::UnsupportedFunction);
}
let mut copy_size = inst.len as usize;
let mut copy_src = old_inst_addr as *const u8;
if old_pos >= MIN_HOOK_SIZE as u8 {
let mut final_jmp = jmp_abs;
final_jmp.address = old_inst_addr as u64;
unsafe {
ptr::copy_nonoverlapping(
&final_jmp as *const JmpAbs as *const u8,
inst_buf.as_mut_ptr(),
std::mem::size_of::<JmpAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = std::mem::size_of::<JmpAbs>();
finished = true;
}
else if inst.is_rip_relative() {
unsafe {
ptr::copy_nonoverlapping(
old_inst_addr as *const u8,
inst_buf.as_mut_ptr(),
inst.len as usize,
);
}
let rel_addr_ptr =
unsafe { inst_buf.as_mut_ptr().add(inst.len as usize - 4) as *mut u32 };
let original_target = old_inst_addr + inst.len as usize + inst.displacement as usize;
let new_relative =
(original_target as i64 - (new_inst_addr + inst.len as usize) as i64) as u32;
unsafe {
*rel_addr_ptr = new_relative;
}
copy_src = inst_buf.as_ptr();
if inst.opcode == 0xFF && (inst.modrm >> 3 & 7) == 4 {
finished = true;
}
}
else if inst.opcode == 0xE8 {
let dest = old_inst_addr + inst.len as usize + inst.immediate as usize;
let mut call = call_abs;
call.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&call as *const CallAbs as *const u8,
inst_buf.as_mut_ptr(),
std::mem::size_of::<CallAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = std::mem::size_of::<CallAbs>();
}
else if (inst.opcode & 0xFD) == 0xE9 {
let dest = old_inst_addr + inst.len as usize;
let dest = if inst.opcode == 0xEB {
dest + (inst.immediate as i8) as usize
} else {
dest + inst.immediate as usize
};
if (trampoline.target as usize) <= dest
&& dest < (trampoline.target as usize + MIN_HOOK_SIZE)
{
if jmp_dest < dest {
jmp_dest = dest;
}
} else {
let mut jmp = jmp_abs;
jmp.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&jmp as *const JmpAbs as *const u8,
inst_buf.as_mut_ptr(),
std::mem::size_of::<JmpAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = std::mem::size_of::<JmpAbs>();
finished = old_inst_addr >= jmp_dest;
}
}
else if (inst.opcode & 0xF0) == 0x70
|| (inst.opcode & 0xFC) == 0xE0
|| (inst.opcode2 & 0xF0) == 0x80
{
let dest = old_inst_addr + inst.len as usize;
let dest = if (inst.opcode & 0xF0) == 0x70 || (inst.opcode & 0xFC) == 0xE0 {
dest + (inst.immediate as i8) as usize
} else {
dest + inst.immediate as usize
};
if (trampoline.target as usize) <= dest
&& dest < (trampoline.target as usize + MIN_HOOK_SIZE)
{
if jmp_dest < dest {
jmp_dest = dest;
}
} else if (inst.opcode & 0xFC) == 0xE0 {
return Err(HookError::UnsupportedFunction);
} else {
let condition = if inst.opcode != 0x0F {
inst.opcode
} else {
inst.opcode2
} & 0x0F;
let mut jcc = jcc_abs;
jcc.opcode = 0x71 ^ condition;
jcc.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&jcc as *const JccAbs as *const u8,
inst_buf.as_mut_ptr(),
std::mem::size_of::<JccAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = std::mem::size_of::<JccAbs>();
}
}
else if (inst.opcode & 0xFE) == 0xC2 {
finished = old_inst_addr >= jmp_dest;
}
if old_inst_addr < jmp_dest && copy_size != inst.len as usize {
return Err(HookError::UnsupportedFunction);
}
if new_pos as usize + copy_size > TRAMPOLINE_MAX_SIZE {
return Err(HookError::UnsupportedFunction);
}
if trampoline.n_ip >= 8 {
return Err(HookError::UnsupportedFunction);
}
trampoline.old_ips[trampoline.n_ip as usize] = old_pos;
trampoline.new_ips[trampoline.n_ip as usize] = new_pos;
trampoline.n_ip += 1;
unsafe {
ptr::copy_nonoverlapping(
copy_src,
(trampoline.trampoline as *mut u8).add(new_pos as usize),
copy_size,
);
}
new_pos += copy_size as u8;
old_pos += inst.len;
if finished {
break;
}
}
if (old_pos as usize) < MIN_HOOK_SIZE {
let remaining = MIN_HOOK_SIZE - old_pos as usize;
let padding_addr = unsafe { (trampoline.target as *const u8).add(old_pos as usize) };
let padding_slice = unsafe { std::slice::from_raw_parts(padding_addr, remaining) };
if !is_code_padding(padding_slice) {
if (old_pos as usize) < SHORT_JMP_SIZE {
let short_remaining = SHORT_JMP_SIZE - old_pos as usize;
let short_padding_slice =
unsafe { std::slice::from_raw_parts(padding_addr, short_remaining) };
if !is_code_padding(short_padding_slice) {
return Err(HookError::UnsupportedFunction);
}
}
let above_addr = unsafe { (trampoline.target as *const u8).sub(MIN_HOOK_SIZE) };
if !is_executable_address(above_addr as *mut c_void) {
return Err(HookError::UnsupportedFunction);
}
let above_slice = unsafe { std::slice::from_raw_parts(above_addr, MIN_HOOK_SIZE) };
if !is_code_padding(above_slice) {
return Err(HookError::UnsupportedFunction);
}
trampoline.patch_above = true;
}
}
let mut relay_jmp = jmp_abs;
relay_jmp.address = trampoline.detour as u64;
trampoline.relay =
unsafe { (trampoline.trampoline as *mut u8).add(new_pos as usize) as *mut c_void };
unsafe {
ptr::copy_nonoverlapping(
&relay_jmp as *const JmpAbs as *const u8,
trampoline.relay as *mut u8,
std::mem::size_of::<JmpAbs>(),
);
}
Ok(())
}
pub fn allocate_trampoline(target: *mut c_void, detour: *mut c_void) -> Result<Trampoline> {
let buffer = allocate_buffer(target)?;
let mut trampoline = Trampoline::new(target, detour, buffer);
match create_trampoline_function(&mut trampoline) {
Ok(()) => Ok(trampoline),
Err(e) => {
crate::buffer::free_buffer(buffer);
Err(e)
}
}
}