use std::{
ffi::{c_void, CStr},
ptr,
};
use thiserror::Error;
use windows::{
core::PCSTR,
Win32::{
Foundation::HMODULE,
System::{
Diagnostics::Debug::{IMAGE_DIRECTORY_ENTRY_IMPORT, IMAGE_NT_HEADERS64},
LibraryLoader::{GetProcAddress, LoadLibraryA},
Memory::{VirtualProtect, PAGE_PROTECTION_FLAGS, PAGE_READWRITE},
SystemServices::{
IMAGE_DOS_HEADER, IMAGE_DOS_SIGNATURE, IMAGE_IMPORT_DESCRIPTOR, IMAGE_NT_SIGNATURE,
IMAGE_ORDINAL_FLAG64,
},
},
},
};
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid DOS signature")]
InvalidDosSignature,
#[error("Invalid NT signature")]
InvalidNtSignature,
#[error("No import directory found")]
NoImportDirectory,
#[error("Failed to load dependency: {0}")]
FailedToLoadDependency(String),
#[error("Failed to resolve function: {0}")]
FailedToResolveFunction(String),
#[error("Failed to change memory protection")]
FailedToChangeProtection,
}
pub unsafe fn patch_iat(module: HMODULE) -> Result<(), Error> {
let base = module.0;
let dos_header = &*(base as *const IMAGE_DOS_HEADER);
if dos_header.e_magic != IMAGE_DOS_SIGNATURE {
return Err(Error::InvalidDosSignature);
}
let nt_headers = &*(base.offset(dos_header.e_lfanew as isize) as *const IMAGE_NT_HEADERS64);
if nt_headers.Signature != IMAGE_NT_SIGNATURE {
return Err(Error::InvalidNtSignature);
}
let import_dir =
&nt_headers.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT.0 as usize];
if import_dir.VirtualAddress == 0 {
return Ok(());
}
let mut import_desc =
base.offset(import_dir.VirtualAddress as isize) as *const IMAGE_IMPORT_DESCRIPTOR;
while (*import_desc).Name != 0 {
let dll_name_ptr = base.offset((*import_desc).Name as isize) as *const i8;
let Ok(dll_handle) = LoadLibraryA(PCSTR(dll_name_ptr as *const u8)) else {
let dll_name = CStr::from_ptr(dll_name_ptr);
return Err(Error::FailedToLoadDependency(
dll_name.to_string_lossy().into_owned(),
));
};
let mut int_entry =
base.offset((*import_desc).Anonymous.OriginalFirstThunk as isize) as *const u64;
let mut iat_entry = base.offset((*import_desc).FirstThunk as isize) as *mut u64;
while *int_entry != 0 {
let func_addr = if (*int_entry & IMAGE_ORDINAL_FLAG64) != 0 {
let ordinal = (*int_entry & 0xFFFF) as u16;
let Some(func_addr) =
GetProcAddress(dll_handle, PCSTR(ordinal as usize as *const u8))
else {
let name = format!("ordinal {}", *int_entry & 0xFFFF);
return Err(Error::FailedToResolveFunction(name));
};
func_addr
} else {
let import_by_name = base.offset(*int_entry as isize);
let func_name_ptr = import_by_name.offset(2) as *const i8;
let Some(func_addr) = GetProcAddress(dll_handle, PCSTR(func_name_ptr as *const u8))
else {
let name = CStr::from_ptr(func_name_ptr).to_string_lossy().into_owned();
return Err(Error::FailedToResolveFunction(name));
};
func_addr
};
let mut old_protect = PAGE_PROTECTION_FLAGS::default();
if VirtualProtect(
iat_entry as *const c_void,
std::mem::size_of::<u64>(),
PAGE_READWRITE,
&mut old_protect,
)
.is_err()
{
return Err(Error::FailedToChangeProtection);
}
ptr::write_volatile(iat_entry, func_addr as usize as u64);
let mut dummy = PAGE_PROTECTION_FLAGS::default();
let _ = VirtualProtect(
iat_entry as *const c_void,
std::mem::size_of::<u64>(),
old_protect,
&mut dummy,
);
int_entry = int_entry.offset(1);
iat_entry = iat_entry.offset(1);
}
import_desc = import_desc.offset(1);
}
Ok(())
}