use crate::error::{Error, ErrorType};
use crate::wrap_winapi;
use anyhow::{Context, Result};
use std::ffi::c_void;
use std::path::PathBuf;
use std::ptr::copy_nonoverlapping;
use windows_sys::Win32::System::Diagnostics::Debug::FlushInstructionCache;
use windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW;
use windows_sys::Win32::System::Memory::{
VirtualProtect, VirtualQuery, MEMORY_BASIC_INFORMATION, MEM_FREE, PAGE_EXECUTE_READWRITE,
};
use windows_sys::Win32::System::Threading::GetCurrentProcess;
pub struct MemProtect {
addr: usize,
size: usize,
prot: u32,
}
impl MemProtect {
pub fn new(addr: usize, size: usize, prot: Option<u32>) -> Result<Self> {
let new_prot = prot.unwrap_or(PAGE_EXECUTE_READWRITE);
let mut old_prot = 0u32;
unsafe {
wrap_winapi!(
VirtualProtect(addr as *const c_void, size, new_prot, &mut old_prot),
x == 0
)?;
}
Ok(Self {
addr,
size,
prot: old_prot,
})
}
}
impl Drop for MemProtect {
fn drop(&mut self) {
let mut _prot = 0;
unsafe {
VirtualProtect(self.addr as _, self.size, self.prot, &mut _prot);
}
}
}
pub struct MemoryPattern {
pub size: usize,
pub pattern: fn(&[u8]) -> bool,
}
impl MemoryPattern {
pub fn new(size: usize, pattern: fn(&[u8]) -> bool) -> Self {
MemoryPattern { size, pattern }
}
pub fn scan(&self, val: &[u8]) -> bool {
(self.pattern)(val)
}
}
pub unsafe fn write_aob(ptr: usize, source: &[u8]) -> Result<()> {
let size = source.len();
let _mp = MemProtect::new(ptr, size, None)?;
copy_nonoverlapping(source.as_ptr(), ptr as *mut u8, size);
let ph = GetCurrentProcess();
FlushInstructionCache(ph, ptr as *const c_void, size);
Ok(())
}
pub unsafe fn hook_function(
original_function: usize,
new_function: usize,
new_function_end: Option<&mut usize>,
len: usize,
) -> Result<()> {
assert!(len >= 12, "Not enough space to inject the shellcode");
let ph = GetCurrentProcess();
let _mp = MemProtect::new(original_function, len, None)?;
let nops = vec![0x90; len];
write_aob(original_function, &nops).with_context(|| "Couldn't nop original bytes")?;
let aob: [u8; std::mem::size_of::<usize>()] = new_function.to_le_bytes();
let injection = if len < 14 {
let mut v = vec![0x48, 0xb8];
v.extend_from_slice(&aob);
v.extend_from_slice(&[0xff, 0xe0]);
v
} else {
let mut v = if cfg!(target_arch = "x86_64") {
vec![0xff, 0x25, 0x00, 0x00, 0x00, 0x00]
} else {
let mut v = vec![0xFF, 0x25];
v.extend_from_slice(&(original_function + 6).to_le_bytes());
v
};
v.extend_from_slice(&aob);
v
};
write_aob(original_function, &injection)
.with_context(|| "Couldn't write the injection to the original function")?;
FlushInstructionCache(ph, original_function as *const c_void, injection.len());
if let Some(p) = new_function_end {
*p = original_function + len;
}
Ok(())
}
pub fn check_valid_region(start_address: usize, len: usize) -> Result<()> {
if start_address == 0x0 {
return Err(Error::new(ErrorType::Internal, "start_address can't be 0".into()).into());
}
if len == 0x0 {
return Err(Error::new(ErrorType::Internal, "len can't be 0".into()).into());
}
let mut region_size = 0_usize;
let size_mem_inf = std::mem::size_of::<MEMORY_BASIC_INFORMATION>();
while region_size < len {
let mut information: MEMORY_BASIC_INFORMATION = unsafe { std::mem::zeroed() };
unsafe {
wrap_winapi!(
VirtualQuery(
(start_address + region_size) as *const c_void,
&mut information,
size_mem_inf
),
x == 0
)?;
}
if information.State == MEM_FREE {
return Err(Error::new(
ErrorType::Internal,
"The region to scan is invalid".to_string(),
)
.into());
}
region_size += information.RegionSize as usize;
}
Ok(())
}
pub unsafe fn resolve_module_path(lib: *const c_void) -> Result<PathBuf> {
let mut buf: Vec<u16> = vec![0x0; 255];
wrap_winapi!(GetModuleFileNameW(lib as _, buf.as_mut_ptr(), 255), x == 0)?;
let end_ix = buf
.iter()
.position(|&x| x == 0)
.expect("Invalid utf16 name");
let name = String::from_utf16(&buf[..end_ix]).unwrap();
let mut path: PathBuf = name.into();
path.pop();
Ok(path)
}