use crate::buffer::{allocate_buffer, is_executable_address};
use crate::disasm::{F_ERROR, decode_instruction};
use crate::error::{HookError, Result};
use crate::instruction::*;
use std::ffi::c_void;
use std::ptr;
#[cfg(not(target_arch = "x86_64"))]
compile_error!("Trampoline module only supports x86_64 architecture");
const TRAMPOLINE_MAX_SIZE: usize = 64 - size_of::<JmpAbs>();
const JMP_REL_SIZE: usize = 5;
const JMP_REL_SHORT_SIZE: usize = 2;
fn is_code_padding(inst: &[u8]) -> bool {
if inst.is_empty() {
return false;
}
let first_byte = inst[0];
if first_byte != 0x00 && first_byte != 0x90 && first_byte != 0xCC {
return false;
}
inst.iter().all(|&byte| byte == first_byte)
}
pub fn create_trampoline_function(trampoline: &mut Trampoline) -> Result<()> {
let call_abs = CallAbs {
opcode0: 0xFF,
opcode1: 0x15,
dummy0: 0x00000002, dummy1: 0xEB, dummy2: 0x08,
address: 0x0000000000000000,
};
let jmp_abs = JmpAbs {
opcode0: 0xFF,
opcode1: 0x25,
dummy: 0x00000000, address: 0x0000000000000000,
};
let jcc_abs = JccAbs {
opcode: 0x70,
dummy0: 0x0E, dummy1: 0xFF, dummy2: 0x25,
dummy3: 0x00000000,
address: 0x0000000000000000,
};
let mut old_pos = 0u8;
let mut new_pos = 0u8;
let mut jmp_dest = 0usize; let mut finished = false; let mut inst_buf = [0u8; 16];
trampoline.patch_above = false;
trampoline.n_ip = 0;
loop {
let p_old_inst = trampoline.target as usize + old_pos as usize;
let p_new_inst = trampoline.trampoline as usize + new_pos as usize;
let code_slice = unsafe { std::slice::from_raw_parts(p_old_inst as *const u8, 16) };
let hs = decode_instruction(code_slice);
if (hs.flags & F_ERROR) != 0 {
return Err(HookError::UnsupportedFunction);
}
let mut copy_size = hs.len as usize;
let mut copy_src = p_old_inst as *const u8;
if old_pos >= JMP_REL_SIZE as u8 {
let mut jmp = jmp_abs;
jmp.address = p_old_inst as u64;
unsafe {
ptr::copy_nonoverlapping(
&jmp as *const JmpAbs as *const u8,
inst_buf.as_mut_ptr(),
size_of::<JmpAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = size_of::<JmpAbs>();
finished = true;
}
else if (hs.modrm & 0xC7) == 0x05 {
unsafe {
ptr::copy_nonoverlapping(p_old_inst as *const u8, inst_buf.as_mut_ptr(), copy_size);
}
copy_src = inst_buf.as_ptr();
let imm_len = ((hs.flags & 0x3C) >> 2) as usize; let rel_addr_offset = hs.len as usize - imm_len - 4;
let _current_rel_addr = {
let bytes = [
inst_buf[rel_addr_offset],
inst_buf[rel_addr_offset + 1],
inst_buf[rel_addr_offset + 2],
inst_buf[rel_addr_offset + 3],
];
u32::from_le_bytes(bytes)
};
let original_target = p_old_inst + hs.len as usize + hs.displacement as usize;
let new_relative =
(original_target as i64 - (p_new_inst + hs.len as usize) as i64) as u32;
let new_bytes = new_relative.to_le_bytes();
inst_buf[rel_addr_offset] = new_bytes[0];
inst_buf[rel_addr_offset + 1] = new_bytes[1];
inst_buf[rel_addr_offset + 2] = new_bytes[2];
inst_buf[rel_addr_offset + 3] = new_bytes[3];
if hs.opcode == 0xFF && hs.modrm_reg == 4 {
finished = true;
}
}
else if hs.opcode == 0xE8 {
let dest = p_old_inst + hs.len as usize + hs.immediate as usize;
let mut call = call_abs;
call.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&call as *const CallAbs as *const u8,
inst_buf.as_mut_ptr(),
size_of::<CallAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = size_of::<CallAbs>();
}
else if (hs.opcode & 0xFD) == 0xE9 {
let mut dest = p_old_inst + hs.len as usize;
if hs.opcode == 0xEB {
dest = dest.wrapping_add((hs.immediate as i8) as usize);
} else {
dest = dest.wrapping_add(hs.immediate as usize);
}
if (trampoline.target as usize) <= dest
&& dest < (trampoline.target as usize + JMP_REL_SIZE)
{
if jmp_dest < dest {
jmp_dest = dest;
}
} else {
let mut jmp = jmp_abs;
jmp.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&jmp as *const JmpAbs as *const u8,
inst_buf.as_mut_ptr(),
size_of::<JmpAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = size_of::<JmpAbs>();
finished = p_old_inst >= jmp_dest;
}
}
else if (hs.opcode & 0xF0) == 0x70
|| (hs.opcode & 0xFC) == 0xE0
|| (hs.opcode2 & 0xF0) == 0x80
{
let mut dest = p_old_inst + hs.len as usize;
if (hs.opcode & 0xF0) == 0x70 || (hs.opcode & 0xFC) == 0xE0
{
dest = dest.wrapping_add((hs.immediate as i8) as usize);
} else {
dest = dest.wrapping_add(hs.immediate as usize);
}
if (trampoline.target as usize) <= dest
&& dest < (trampoline.target as usize + JMP_REL_SIZE)
{
if jmp_dest < dest {
jmp_dest = dest;
}
} else if (hs.opcode & 0xFC) == 0xE0 {
return Err(HookError::UnsupportedFunction);
} else {
let cond = if hs.opcode != 0x0F {
hs.opcode
} else {
hs.opcode2
} & 0x0F;
let mut jcc = jcc_abs;
jcc.opcode = 0x71 ^ cond;
jcc.address = dest as u64;
unsafe {
ptr::copy_nonoverlapping(
&jcc as *const JccAbs as *const u8,
inst_buf.as_mut_ptr(),
size_of::<JccAbs>(),
);
}
copy_src = inst_buf.as_ptr();
copy_size = size_of::<JccAbs>();
}
}
else if (hs.opcode & 0xFE) == 0xC2 {
finished = p_old_inst >= jmp_dest;
}
if p_old_inst < jmp_dest && copy_size != hs.len as usize {
return Err(HookError::UnsupportedFunction);
}
if (new_pos as usize + copy_size) > TRAMPOLINE_MAX_SIZE {
return Err(HookError::UnsupportedFunction);
}
if trampoline.n_ip >= 8 {
return Err(HookError::UnsupportedFunction);
}
trampoline.old_ips[trampoline.n_ip as usize] = old_pos;
trampoline.new_ips[trampoline.n_ip as usize] = new_pos;
trampoline.n_ip += 1;
unsafe {
ptr::copy_nonoverlapping(
copy_src,
(trampoline.trampoline as *mut u8).add(new_pos as usize),
copy_size,
);
}
new_pos += copy_size as u8;
old_pos += hs.len;
if finished {
break;
}
}
if (old_pos as usize) < JMP_REL_SIZE {
let remaining = JMP_REL_SIZE - old_pos as usize;
let padding_slice = unsafe {
std::slice::from_raw_parts(
(trampoline.target as *const u8).add(old_pos as usize),
remaining,
)
};
if !is_code_padding(padding_slice) {
if (old_pos as usize) < JMP_REL_SHORT_SIZE {
let short_remaining = JMP_REL_SHORT_SIZE - old_pos as usize;
let short_padding_slice = unsafe {
std::slice::from_raw_parts(
(trampoline.target as *const u8).add(old_pos as usize),
short_remaining,
)
};
if !is_code_padding(short_padding_slice) {
return Err(HookError::UnsupportedFunction);
}
}
let above_addr =
unsafe { (trampoline.target as *const u8).sub(JMP_REL_SIZE) as *mut c_void };
if !is_executable_address(above_addr) {
return Err(HookError::UnsupportedFunction);
}
let above_slice =
unsafe { std::slice::from_raw_parts(above_addr as *const u8, JMP_REL_SIZE) };
if !is_code_padding(above_slice) {
return Err(HookError::UnsupportedFunction);
}
trampoline.patch_above = true;
}
}
let mut relay_jmp = jmp_abs;
relay_jmp.address = trampoline.detour as u64;
trampoline.relay =
unsafe { (trampoline.trampoline as *mut u8).add(new_pos as usize) as *mut c_void };
unsafe {
ptr::copy_nonoverlapping(
&relay_jmp as *const JmpAbs as *const u8,
trampoline.relay as *mut u8,
size_of::<JmpAbs>(),
);
}
Ok(())
}
pub fn allocate_trampoline(target: *mut c_void, detour: *mut c_void) -> Result<Trampoline> {
let buffer = allocate_buffer(target)?;
let mut trampoline = Trampoline::new(target, detour, buffer);
match create_trampoline_function(&mut trampoline) {
Ok(()) => Ok(trampoline),
Err(e) => {
crate::buffer::free_buffer(buffer);
Err(e)
}
}
}