use alloc::{
format, vec::Vec, vec,
string::{String, ToString},
};
use core::{
ffi::{c_void, CStr},
ptr::null_mut,
slice::from_raw_parts
};
use obfstr::obfstr as s;
use crate::{types::*, helper::PE};
use crate::hash::{crc32ba, murmur3};
use crate::winapis::{LoadLibraryA, NtCurrentPeb};
static NTDLL: spin::Once<u64> = spin::Once::new();
#[inline(always)]
pub fn get_ntdll_address() -> *mut c_void {
*NTDLL.call_once(|| get_module_address(
2788516083u32,
Some(murmur3)) as u64
) as *mut c_void
}
pub fn get_module_address<T>(
module: T,
hash: Option<fn(&str) -> u32>
) -> HMODULE
where
T: ToString
{
unsafe {
let hash = hash.unwrap_or(crc32ba);
let peb = NtCurrentPeb();
let ldr_data = (*peb).Ldr;
let mut list_node = (*ldr_data).InMemoryOrderModuleList.Flink;
let mut data_table_entry = (*ldr_data).InMemoryOrderModuleList.Flink
as *const LDR_DATA_TABLE_ENTRY;
if module.to_string().is_empty() {
return (*peb).ImageBaseAddress;
}
let head_node = list_node;
let mut addr = null_mut();
while !(*data_table_entry).FullDllName.Buffer.is_null() {
if (*data_table_entry).FullDllName.Length != 0 {
let buffer = from_raw_parts(
(*data_table_entry).FullDllName.Buffer,
((*data_table_entry).FullDllName.Length / 2) as usize
);
let mut dll_file_name = String::from_utf16_lossy(buffer).to_uppercase();
if let Ok(dll_hash) = module.to_string().parse::<u32>() {
if dll_hash == hash(&dll_file_name) {
addr = (*data_table_entry).Reserved2[0];
break;
}
} else {
let module = canonicalize_module(&module.to_string());
dll_file_name = canonicalize_module(&dll_file_name);
if dll_file_name == module {
addr = (*data_table_entry).Reserved2[0];
break;
}
}
}
list_node = (*list_node).Flink;
if list_node == head_node {
break
}
data_table_entry = list_node as *const LDR_DATA_TABLE_ENTRY
}
addr
}
}
pub fn get_proc_address<T>(
h_module: HMODULE,
function: T,
hash: Option<fn(&str) -> u32>
) -> *mut c_void
where
T: ToString,
{
if h_module.is_null() {
return null_mut();
}
unsafe {
let h_module = h_module as usize;
let pe = PE::parse(h_module as *mut c_void);
let Some((nt_header, export_dir)) = pe.nt_header().zip(pe.exports().directory()) else {
return null_mut();
};
let export_size = (*nt_header).OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT as usize].Size as usize;
let names = from_raw_parts(
(h_module + (*export_dir).AddressOfNames as usize) as *const u32,
(*export_dir).NumberOfNames as usize
);
let functions = from_raw_parts(
(h_module + (*export_dir).AddressOfFunctions as usize) as *const u32,
(*export_dir).NumberOfFunctions as usize
);
let ordinals = from_raw_parts(
(h_module + (*export_dir).AddressOfNameOrdinals as usize) as *const u16,
(*export_dir).NumberOfNames as usize
);
let api_name = function.to_string();
if let Ok(ordinal) = api_name.parse::<u32>() && ordinal <= 0xFFFF {
let ordinal = ordinal & 0xFFFF;
if ordinal < (*export_dir).Base || (ordinal >= (*export_dir).Base + (*export_dir).NumberOfFunctions) {
return null_mut();
}
return (h_module + functions[ordinal as usize - (*export_dir).Base as usize] as usize) as *mut c_void;
}
let dll_name = {
let ptr = (h_module + (*export_dir).Name as usize) as *const i8;
CStr::from_ptr(ptr).to_string_lossy().into_owned()
};
let hash = hash.unwrap_or(crc32ba);
for i in 0..(*export_dir).NumberOfNames as usize {
let name = CStr::from_ptr((h_module + names[i] as usize) as *const i8)
.to_str()
.unwrap_or("");
let ordinal = ordinals[i] as usize;
let address = (h_module + functions[ordinal] as usize) as *mut c_void;
if let Ok(api_hash) = api_name.parse::<u32>() {
if hash(name) == api_hash {
return get_forwarded_address(&dll_name, address, export_dir, export_size, hash);
}
} else {
if name == api_name {
return get_forwarded_address(&dll_name, address, export_dir, export_size, hash);
}
}
}
}
null_mut()
}
fn get_forwarded_address(
module: &str,
address: *mut c_void,
export_dir: *const IMAGE_EXPORT_DIRECTORY,
export_size: usize,
hash: fn(&str) -> u32,
) -> *mut c_void {
if (address as usize) >= export_dir as usize &&
(address as usize) < (export_dir as usize + export_size)
{
let cstr = unsafe { CStr::from_ptr(address as *const i8) };
let forwarder = cstr.to_str().unwrap_or_default();
let (module_name, function_name) = forwarder.split_once('.')
.unwrap_or(("", ""));
let module_resolved = if module_name.starts_with(s!("api-ms")) || module_name.starts_with(s!("ext-ms")) {
let base_contract = module_name.rsplit_once('-').map(|(b, _)| b).unwrap_or(module_name);
resolve_api_set_map(module, base_contract)
} else {
Some(vec![format!("{}.dll", module_name)])
};
if let Some(modules) = module_resolved {
for module in modules {
let mut addr = get_module_address(module.as_str(), None);
if addr.is_null() {
addr = LoadLibraryA(module.as_str());
}
if !addr.is_null() {
let resolved = get_proc_address(addr, hash(function_name), Some(hash));
if !resolved.is_null() {
return resolved;
}
}
}
}
}
address
}
fn resolve_api_set_map(
host_name: &str,
contract_name: &str
) -> Option<Vec<String>> {
unsafe {
let peb = NtCurrentPeb();
let map = (*peb).ApiSetMap;
let ns_entry = ((*map).EntryOffset as usize + map as usize) as *const API_SET_NAMESPACE_ENTRY;
let ns_entries = from_raw_parts(ns_entry, (*map).Count as usize);
for entry in ns_entries {
let name = String::from_utf16_lossy(from_raw_parts(
(map as usize + entry.NameOffset as usize) as *const u16,
entry.NameLength as usize / 2,
));
if name.starts_with(contract_name) {
let values = from_raw_parts(
(map as usize + entry.ValueOffset as usize) as *const API_SET_VALUE_ENTRY,
entry.ValueCount as usize
);
if values.len() == 1 {
let val = &values[0];
let dll = String::from_utf16_lossy(from_raw_parts(
(map as usize + val.ValueOffset as usize) as *const u16,
val.ValueLength as usize / 2,
));
return Some(vec![dll]);
}
let mut result = Vec::new();
for val in values {
let name = String::from_utf16_lossy(from_raw_parts(
(map as usize + val.ValueOffset as usize) as *const u16,
val.ValueLength as usize / 2,
));
if !name.eq_ignore_ascii_case(host_name) {
let dll = String::from_utf16_lossy(from_raw_parts(
(map as usize + val.ValueOffset as usize) as *const u16,
val.ValueLength as usize / 2,
));
result.push(dll);
}
}
if !result.is_empty() {
return Some(result);
}
}
}
}
None
}
pub fn canonicalize_module(name: &str) -> String {
let file = name.rsplit(['\\', '/']).next().unwrap_or(name);
let upper = file.to_ascii_uppercase();
upper.trim_end_matches(".DLL").to_string()
}
#[cfg(test)]
mod tests {
use core::ptr::null_mut;
use super::*;
#[test]
fn test_modules() {
assert_ne!(get_module_address("kernel32.dll", None), null_mut());
assert_ne!(get_module_address("kernel32.DLL", None), null_mut());
assert_ne!(get_module_address("kernel32", None), null_mut());
assert_ne!(get_module_address("KERNEL32.dll", None), null_mut());
assert_ne!(get_module_address("KERNEL32", None), null_mut());
}
#[test]
fn test_function() {
let module = get_module_address("KERNEL32.dll", None);
assert_ne!(module, null_mut());
let addr = get_proc_address(module, "VirtualAlloc", None);
assert_ne!(addr, null_mut());
}
#[test]
fn test_forwarded() {
let kernel32 = get_module_address("KERNEL32.dll", None);
assert_ne!(kernel32, null_mut());
assert_ne!(
get_proc_address(kernel32, "SetIoRingCompletionEvent", None),
null_mut()
);
assert_ne!(
get_proc_address(kernel32, "SetProtectedPolicy", None),
null_mut()
);
assert_ne!(
get_proc_address(kernel32, "SetProcessDefaultCpuSetMasks", None),
null_mut()
);
assert_ne!(
get_proc_address(kernel32, "SetDefaultDllDirectories", None),
null_mut()
);
assert_ne!(
get_proc_address(kernel32, "SetProcessDefaultCpuSets", None),
null_mut()
);
assert_ne!(
get_proc_address(kernel32, "InitializeProcThreadAttributeList", None),
null_mut()
);
let advapi32 = LoadLibraryA("advapi32.dll");
assert_ne!(advapi32, null_mut());
assert_ne!(
get_proc_address(advapi32, "SystemFunction028", None),
null_mut()
);
assert_ne!(
get_proc_address(advapi32, "PerfIncrementULongCounterValue", None),
null_mut()
);
assert_ne!(
get_proc_address(advapi32, "PerfSetCounterRefValue", None),
null_mut()
);
assert_ne!(
get_proc_address(advapi32, "I_QueryTagInformation", None),
null_mut()
);
assert_ne!(
get_proc_address(advapi32, "TraceQueryInformation", None),
null_mut()
);
assert_ne!(
get_proc_address(advapi32, "TraceMessage", None),
null_mut()
);
}
}