use std::collections::HashMap;
use crate::{
file::pe::Import,
utils::{to_u32, write_le_at, write_string_at},
Result,
};
#[derive(Debug, Clone)]
pub struct NativeImports {
descriptors: HashMap<String, ImportDescriptor>,
iat_entries: HashMap<u32, ImportAddressEntry>,
next_iat_rva: u32,
is_pe32_plus: bool,
}
#[derive(Debug, Clone)]
pub struct ImportDescriptor {
pub dll_name: String,
pub original_first_thunk: u32,
pub first_thunk: u32,
pub functions: Vec<Import>,
pub timestamp: u32,
pub forwarder_chain: u32,
}
#[derive(Debug, Clone)]
pub struct ImportAddressEntry {
pub rva: u32,
pub dll_name: String,
pub function_identifier: String,
pub original_value: u64,
}
impl NativeImports {
#[must_use]
pub fn new() -> Self {
Self {
descriptors: HashMap::new(),
iat_entries: HashMap::new(),
next_iat_rva: 0x1000, is_pe32_plus: false, }
}
pub fn set_pe_format(&mut self, is_pe32_plus: bool) {
self.is_pe32_plus = is_pe32_plus;
}
#[must_use]
pub fn is_pe32_plus(&self) -> bool {
self.is_pe32_plus
}
#[must_use]
pub fn iat_entry_size(&self) -> u32 {
if self.is_pe32_plus {
8
} else {
4
}
}
pub fn from_pe_imports(pe_imports: &[Import], is_pe32_plus: bool) -> Result<Self> {
let mut native = Self::new();
native.is_pe32_plus = is_pe32_plus;
let mut imports_by_dll: HashMap<&str, Vec<&Import>> = HashMap::new();
for import in pe_imports {
imports_by_dll.entry(&import.dll).or_default().push(import);
}
for (dll_name, dll_imports) in imports_by_dll {
let dll_name_owned = dll_name.to_owned();
let mut descriptor = ImportDescriptor {
dll_name: dll_name_owned.clone(),
original_first_thunk: 0,
first_thunk: 0,
functions: Vec::with_capacity(dll_imports.len()),
timestamp: 0,
forwarder_chain: 0,
};
for pe_import in dll_imports {
let function_identifier = Self::build_function_identifier(pe_import);
native.iat_entries.insert(
pe_import.rva,
ImportAddressEntry {
rva: pe_import.rva,
dll_name: dll_name_owned.clone(),
function_identifier,
original_value: 0,
},
);
descriptor.functions.push(pe_import.clone());
}
native.descriptors.insert(dll_name_owned, descriptor);
}
Ok(native)
}
fn build_function_identifier(import: &Import) -> String {
if let Some(ref name) = import.name {
if !name.is_empty() {
return name.clone();
}
}
import
.ordinal
.map_or_else(|| "unknown".to_string(), |ord| format!("#{ord}"))
}
pub fn add_dll(&mut self, dll_name: &str) -> Result<()> {
if dll_name.is_empty() {
return Err(malformed_error!("DLL name cannot be empty"));
}
if !self.descriptors.contains_key(dll_name) {
let descriptor = ImportDescriptor {
dll_name: dll_name.to_owned(),
original_first_thunk: 0, first_thunk: 0, functions: Vec::new(),
timestamp: 0,
forwarder_chain: 0,
};
self.descriptors.insert(dll_name.to_owned(), descriptor);
}
Ok(())
}
pub fn add_function(&mut self, dll_name: &str, function_name: &str) -> Result<()> {
if function_name.is_empty() {
return Err(malformed_error!("Function name cannot be empty"));
}
if let Some(descriptor) = self.descriptors.get(dll_name) {
if descriptor
.functions
.iter()
.any(|f| f.name.as_deref() == Some(function_name))
{
return Err(malformed_error!(
"Function '{function_name}' already imported from '{dll_name}'"
));
}
} else {
return Err(malformed_error!(
"DLL '{dll_name}' not found in import table"
));
}
let iat_rva = self.allocate_iat_rva()?;
let function = Import {
dll: dll_name.to_owned(),
name: Some(function_name.to_owned()),
ordinal: None,
rva: iat_rva,
hint: 0,
ilt_value: 0,
};
let iat_entry = ImportAddressEntry {
rva: iat_rva,
dll_name: dll_name.to_owned(),
function_identifier: function_name.to_owned(),
original_value: 0,
};
let descriptor = self
.descriptors
.get_mut(dll_name)
.ok_or_else(|| malformed_error!("DLL '{dll_name}' disappeared from import table"))?;
descriptor.functions.push(function);
self.iat_entries.insert(iat_rva, iat_entry);
Ok(())
}
pub fn add_function_by_ordinal(&mut self, dll_name: &str, ordinal: u16) -> Result<()> {
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if let Some(descriptor) = self.descriptors.get(dll_name) {
if descriptor
.functions
.iter()
.any(|f| f.ordinal == Some(ordinal))
{
return Err(malformed_error!(
"Ordinal {ordinal} already imported from '{dll_name}'"
));
}
} else {
return Err(malformed_error!(
"DLL '{dll_name}' not found in import table"
));
}
let iat_rva = self.allocate_iat_rva()?;
let descriptor = self
.descriptors
.get_mut(dll_name)
.ok_or_else(|| malformed_error!("DLL '{dll_name}' disappeared from import table"))?;
let function = Import {
dll: dll_name.to_owned(),
name: None,
ordinal: Some(ordinal),
rva: iat_rva,
hint: 0,
ilt_value: 0x8000_0000_0000_0000u64 | u64::from(ordinal),
};
let iat_entry = ImportAddressEntry {
rva: iat_rva,
dll_name: dll_name.to_owned(),
function_identifier: format!("#{ordinal}"),
original_value: function.ilt_value,
};
descriptor.functions.push(function);
self.iat_entries.insert(iat_rva, iat_entry);
Ok(())
}
#[must_use]
pub fn get_descriptor(&self, dll_name: &str) -> Option<&ImportDescriptor> {
self.descriptors.get(dll_name)
}
pub fn descriptors(&self) -> impl Iterator<Item = &ImportDescriptor> {
self.descriptors.values()
}
#[must_use]
pub fn has_dll(&self, dll_name: &str) -> bool {
self.descriptors.contains_key(dll_name)
}
#[must_use]
pub fn dll_count(&self) -> usize {
self.descriptors.len()
}
#[must_use]
pub fn total_function_count(&self) -> usize {
self.descriptors
.values()
.map(|descriptor| descriptor.functions.len())
.sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.descriptors.is_empty()
}
#[must_use]
pub fn get_dll_names(&self) -> Vec<String> {
self.descriptors.keys().cloned().collect()
}
pub fn update_iat_rvas(&mut self, rva_delta: i64) -> Result<()> {
let mut updated_entries = HashMap::new();
for (old_rva, mut entry) in self.iat_entries.drain() {
let new_rva = adjust_rva(old_rva, rva_delta)?;
entry.rva = new_rva;
updated_entries.insert(new_rva, entry);
}
self.iat_entries = updated_entries;
for descriptor in self.descriptors.values_mut() {
for function in &mut descriptor.functions {
function.rva = adjust_rva(function.rva, rva_delta)?;
}
}
self.next_iat_rva = adjust_rva(self.next_iat_rva, rva_delta)?;
Ok(())
}
fn allocate_iat_rva(&mut self) -> Result<u32> {
let rva = self.next_iat_rva;
self.next_iat_rva = self
.next_iat_rva
.checked_add(self.iat_entry_size())
.ok_or_else(|| malformed_error!("IAT RVA counter overflow"))?;
Ok(rva)
}
pub fn iat_byte_size(&self, is_pe32_plus: bool) -> Result<usize> {
let entry_size: usize = if is_pe32_plus { 8 } else { 4 };
let mut total_entries: usize = 0;
for descriptor in self.descriptors.values() {
let dll_entries = descriptor
.functions
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("IAT entry count overflow"))?;
total_entries = total_entries
.checked_add(dll_entries)
.ok_or_else(|| malformed_error!("IAT entry count overflow"))?;
}
total_entries
.checked_mul(entry_size)
.ok_or_else(|| malformed_error!("IAT byte size overflow"))
}
pub fn build_iat_bytes(&self, is_pe32_plus: bool, import_table_rva: u32) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let entry_size: usize = if is_pe32_plus { 8 } else { 4 };
let mut iat_bytes = Vec::with_capacity(self.iat_byte_size(is_pe32_plus)?);
let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect();
descriptors_sorted.sort_by_key(|d| d.dll_name.to_lowercase());
let descriptor_count_with_null = self
.descriptors
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("Descriptor count overflow"))?;
let descriptor_size = descriptor_count_with_null
.checked_mul(20)
.ok_or_else(|| malformed_error!("Descriptor table size overflow"))?;
let mut total_ilt_entries: usize = 0;
for desc in &descriptors_sorted {
let dll_entries = desc
.functions
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("ILT entry count overflow"))?;
total_ilt_entries = total_ilt_entries
.checked_add(dll_entries)
.ok_or_else(|| malformed_error!("ILT entry count overflow"))?;
}
let ilt_size = total_ilt_entries
.checked_mul(entry_size)
.ok_or_else(|| malformed_error!("ILT byte size overflow"))?;
let header_size = descriptor_size
.checked_add(ilt_size)
.ok_or_else(|| malformed_error!("Import table header size overflow"))?;
let strings_start_rva = import_table_rva
.checked_add(to_u32(header_size)?)
.ok_or_else(|| malformed_error!("Strings start RVA overflow"))?;
let mut current_string_rva = strings_start_rva;
let mut dll_name_end_rva = current_string_rva;
for desc in &descriptors_sorted {
let dll_name_size = desc
.dll_name
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("DLL name size overflow"))?;
dll_name_end_rva = dll_name_end_rva
.checked_add(to_u32(dll_name_size)?)
.ok_or_else(|| malformed_error!("DLL name RVA overflow"))?;
}
current_string_rva = dll_name_end_rva;
for descriptor in &descriptors_sorted {
for function in &descriptor.functions {
let thunk_value = if let Some(ordinal) = function.ordinal {
if function.name.is_none() {
if is_pe32_plus {
0x8000_0000_0000_0000u64 | u64::from(ordinal)
} else {
0x8000_0000u64 | u64::from(ordinal)
}
} else {
u64::from(current_string_rva)
}
} else {
u64::from(current_string_rva)
};
if is_pe32_plus {
iat_bytes.extend_from_slice(&thunk_value.to_le_bytes());
} else {
#[allow(clippy::cast_possible_truncation)]
iat_bytes.extend_from_slice(&(thunk_value as u32).to_le_bytes());
}
if let Some(function_name) = function.name.as_ref() {
let name_size = function_name
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("Function name size overflow"))?;
let advance = name_size
.checked_add(2)
.ok_or_else(|| malformed_error!("Hint/name advance overflow"))?;
current_string_rva = current_string_rva
.checked_add(to_u32(advance)?)
.ok_or_else(|| malformed_error!("String RVA overflow"))?;
}
}
if is_pe32_plus {
iat_bytes.extend_from_slice(&0u64.to_le_bytes());
} else {
iat_bytes.extend_from_slice(&0u32.to_le_bytes());
}
}
Ok(iat_bytes)
}
pub fn build_import_table(
&self,
is_pe32_plus: bool,
iat_rva: u32,
table_rva: u32,
) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let entry_size: usize = if is_pe32_plus { 8 } else { 4 };
let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect();
descriptors_sorted.sort_by_key(|d| d.dll_name.to_lowercase());
let descriptor_count_with_null = descriptors_sorted
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("Descriptor count overflow"))?;
let descriptor_table_size = descriptor_count_with_null
.checked_mul(20)
.ok_or_else(|| malformed_error!("Descriptor table size overflow"))?;
let mut total_ilt_entries: usize = 0;
for desc in &descriptors_sorted {
let dll_entries = desc
.functions
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("ILT entry count overflow"))?;
total_ilt_entries = total_ilt_entries
.checked_add(dll_entries)
.ok_or_else(|| malformed_error!("ILT entry count overflow"))?;
}
let ilt_size = total_ilt_entries
.checked_mul(entry_size)
.ok_or_else(|| malformed_error!("ILT byte size overflow"))?;
let mut total_string_size: usize = 0;
for desc in &descriptors_sorted {
let dll_size = desc
.dll_name
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("DLL name size overflow"))?;
total_string_size = total_string_size
.checked_add(dll_size)
.ok_or_else(|| malformed_error!("String table size overflow"))?;
for func in &desc.functions {
if let Some(ref name) = func.name {
let name_size = name
.len()
.checked_add(3)
.ok_or_else(|| malformed_error!("Function name size overflow"))?;
total_string_size = total_string_size
.checked_add(name_size)
.ok_or_else(|| malformed_error!("String table size overflow"))?;
}
}
}
let total_size = descriptor_table_size
.checked_add(ilt_size)
.and_then(|s| s.checked_add(total_string_size))
.and_then(|s| s.checked_add(16))
.ok_or_else(|| malformed_error!("Import table total size overflow"))?;
let mut data = vec![0u8; total_size];
let mut offset = 0;
let ilt_start_rva = table_rva
.checked_add(to_u32(descriptor_table_size)?)
.ok_or_else(|| malformed_error!("ILT start RVA overflow"))?;
let strings_start_rva = ilt_start_rva
.checked_add(to_u32(ilt_size)?)
.ok_or_else(|| malformed_error!("Strings start RVA overflow"))?;
let mut ilt_rva = ilt_start_rva;
let mut iat_offset: u32 = 0;
let mut dll_name_rvas = Vec::with_capacity(descriptors_sorted.len());
let mut current_dll_name_rva = strings_start_rva;
for desc in &descriptors_sorted {
dll_name_rvas.push(current_dll_name_rva);
let dll_size = desc
.dll_name
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("DLL name size overflow"))?;
current_dll_name_rva = current_dll_name_rva
.checked_add(to_u32(dll_size)?)
.ok_or_else(|| malformed_error!("DLL name RVA overflow"))?;
}
let mut current_func_name_rva = current_dll_name_rva;
let mut func_name_rvas: Vec<Vec<u64>> = Vec::with_capacity(descriptors_sorted.len());
for desc in &descriptors_sorted {
let mut rvas = Vec::with_capacity(desc.functions.len());
for func in &desc.functions {
if let Some(function_name) = func.name.as_ref() {
rvas.push(u64::from(current_func_name_rva));
let name_size = function_name
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("Function name size overflow"))?;
let advance = name_size
.checked_add(2)
.ok_or_else(|| malformed_error!("Hint/name advance overflow"))?;
current_func_name_rva = current_func_name_rva
.checked_add(to_u32(advance)?)
.ok_or_else(|| malformed_error!("Function name RVA overflow"))?;
} else {
rvas.push(0); }
}
func_name_rvas.push(rvas);
}
for (i, desc) in descriptors_sorted.iter().enumerate() {
let desc_ilt_rva = ilt_rva;
let desc_iat_rva = iat_rva
.checked_add(iat_offset)
.ok_or_else(|| malformed_error!("IAT RVA overflow"))?;
let dll_name_rva = *dll_name_rvas.get(i).ok_or(out_of_bounds_error!())?;
write_le_at::<u32>(&mut data, &mut offset, desc_ilt_rva)?;
write_le_at::<u32>(&mut data, &mut offset, 0)?;
write_le_at::<u32>(&mut data, &mut offset, 0)?;
write_le_at::<u32>(&mut data, &mut offset, dll_name_rva)?;
write_le_at::<u32>(&mut data, &mut offset, desc_iat_rva)?;
let entries_for_dll = desc
.functions
.len()
.checked_add(1)
.ok_or_else(|| malformed_error!("ILT entry count overflow"))?;
let dll_size = entries_for_dll
.checked_mul(entry_size)
.ok_or_else(|| malformed_error!("ILT DLL size overflow"))?;
let dll_size_u32 = to_u32(dll_size)?;
ilt_rva = ilt_rva
.checked_add(dll_size_u32)
.ok_or_else(|| malformed_error!("ILT RVA overflow"))?;
iat_offset = iat_offset
.checked_add(dll_size_u32)
.ok_or_else(|| malformed_error!("IAT offset overflow"))?;
}
for _ in 0..5 {
write_le_at::<u32>(&mut data, &mut offset, 0)?;
}
for (i, desc) in descriptors_sorted.iter().enumerate() {
let dll_func_rvas = func_name_rvas.get(i).ok_or(out_of_bounds_error!())?;
for (j, func) in desc.functions.iter().enumerate() {
let ilt_value = if func.name.is_none() {
if let Some(ordinal) = func.ordinal {
if is_pe32_plus {
0x8000_0000_0000_0000u64 | u64::from(ordinal)
} else {
0x8000_0000u64 | u64::from(ordinal)
}
} else {
0
}
} else {
*dll_func_rvas.get(j).ok_or(out_of_bounds_error!())?
};
if is_pe32_plus {
write_le_at::<u64>(&mut data, &mut offset, ilt_value)?;
} else {
#[allow(clippy::cast_possible_truncation)]
write_le_at::<u32>(&mut data, &mut offset, ilt_value as u32)?;
}
}
if is_pe32_plus {
write_le_at::<u64>(&mut data, &mut offset, 0)?;
} else {
write_le_at::<u32>(&mut data, &mut offset, 0)?;
}
}
for desc in &descriptors_sorted {
write_string_at(&mut data, &mut offset, &desc.dll_name)?;
}
for desc in &descriptors_sorted {
for func in &desc.functions {
if let Some(ref name) = func.name {
write_le_at::<u16>(&mut data, &mut offset, func.hint)?;
write_string_at(&mut data, &mut offset, name)?;
}
}
}
while offset % 4 != 0 {
if let Some(slot) = data.get_mut(offset) {
*slot = 0;
}
offset = offset
.checked_add(1)
.ok_or_else(|| malformed_error!("Alignment offset overflow"))?;
}
data.truncate(offset);
Ok(data)
}
}
impl Default for NativeImports {
fn default() -> Self {
Self::new()
}
}
fn adjust_rva(rva: u32, delta: i64) -> Result<u32> {
if delta >= 0 {
let abs_delta =
u32::try_from(delta).map_err(|_| malformed_error!("RVA delta exceeds u32 range"))?;
rva.checked_add(abs_delta)
.ok_or_else(|| malformed_error!("RVA delta would cause overflow"))
} else {
let abs_delta_i64 = delta
.checked_neg()
.ok_or_else(|| malformed_error!("RVA delta magnitude overflow"))?;
let abs_delta = u32::try_from(abs_delta_i64)
.map_err(|_| malformed_error!("RVA delta exceeds u32 range"))?;
rva.checked_sub(abs_delta)
.ok_or_else(|| malformed_error!("RVA delta would cause overflow"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_native_imports_is_empty() {
let imports = NativeImports::new();
assert!(imports.is_empty());
assert_eq!(imports.dll_count(), 0);
}
#[test]
fn add_dll_works() {
let mut imports = NativeImports::new();
imports.add_dll("kernel32.dll").unwrap();
assert!(!imports.is_empty());
assert_eq!(imports.dll_count(), 1);
assert!(imports.has_dll("kernel32.dll"));
imports.add_dll("kernel32.dll").unwrap();
assert_eq!(imports.dll_count(), 1);
}
}