use std::collections::HashMap;
use crate::{
file::pe::Import,
utils::{write_le_at, write_string_at},
Result,
};
#[derive(Debug, Clone)]
pub struct NativeImports {
descriptors: HashMap<String, ImportDescriptor>,
iat_entries: HashMap<u32, ImportAddressEntry>,
next_iat_rva: u32,
is_pe32_plus: bool,
}
#[derive(Debug, Clone)]
pub struct ImportDescriptor {
pub dll_name: String,
pub original_first_thunk: u32,
pub first_thunk: u32,
pub functions: Vec<Import>,
pub timestamp: u32,
pub forwarder_chain: u32,
}
#[derive(Debug, Clone)]
pub struct ImportAddressEntry {
pub rva: u32,
pub dll_name: String,
pub function_identifier: String,
pub original_value: u64,
}
impl NativeImports {
#[must_use]
pub fn new() -> Self {
Self {
descriptors: HashMap::new(),
iat_entries: HashMap::new(),
next_iat_rva: 0x1000, is_pe32_plus: false, }
}
pub fn set_pe_format(&mut self, is_pe32_plus: bool) {
self.is_pe32_plus = is_pe32_plus;
}
#[must_use]
pub fn is_pe32_plus(&self) -> bool {
self.is_pe32_plus
}
#[must_use]
pub fn iat_entry_size(&self) -> u32 {
if self.is_pe32_plus {
8
} else {
4
}
}
pub fn from_pe_imports(pe_imports: &[Import], is_pe32_plus: bool) -> Result<Self> {
let mut native = Self::new();
native.is_pe32_plus = is_pe32_plus;
let mut imports_by_dll: HashMap<&str, Vec<&Import>> = HashMap::new();
for import in pe_imports {
imports_by_dll.entry(&import.dll).or_default().push(import);
}
for (dll_name, dll_imports) in imports_by_dll {
let dll_name_owned = dll_name.to_owned();
let mut descriptor = ImportDescriptor {
dll_name: dll_name_owned.clone(),
original_first_thunk: 0,
first_thunk: 0,
functions: Vec::with_capacity(dll_imports.len()),
timestamp: 0,
forwarder_chain: 0,
};
for pe_import in dll_imports {
let function_identifier = Self::build_function_identifier(pe_import);
native.iat_entries.insert(
pe_import.rva,
ImportAddressEntry {
rva: pe_import.rva,
dll_name: dll_name_owned.clone(),
function_identifier,
original_value: 0,
},
);
descriptor.functions.push(pe_import.clone());
}
native.descriptors.insert(dll_name_owned, descriptor);
}
Ok(native)
}
fn build_function_identifier(import: &Import) -> String {
if let Some(ref name) = import.name {
if !name.is_empty() {
return name.clone();
}
}
import
.ordinal
.map_or_else(|| "unknown".to_string(), |ord| format!("#{ord}"))
}
pub fn add_dll(&mut self, dll_name: &str) -> Result<()> {
if dll_name.is_empty() {
return Err(malformed_error!("DLL name cannot be empty"));
}
if !self.descriptors.contains_key(dll_name) {
let descriptor = ImportDescriptor {
dll_name: dll_name.to_owned(),
original_first_thunk: 0, first_thunk: 0, functions: Vec::new(),
timestamp: 0,
forwarder_chain: 0,
};
self.descriptors.insert(dll_name.to_owned(), descriptor);
}
Ok(())
}
pub fn add_function(&mut self, dll_name: &str, function_name: &str) -> Result<()> {
if function_name.is_empty() {
return Err(malformed_error!("Function name cannot be empty"));
}
if let Some(descriptor) = self.descriptors.get(dll_name) {
if descriptor
.functions
.iter()
.any(|f| f.name.as_deref() == Some(function_name))
{
return Err(malformed_error!(
"Function '{function_name}' already imported from '{dll_name}'"
));
}
} else {
return Err(malformed_error!(
"DLL '{dll_name}' not found in import table"
));
}
let iat_rva = self.allocate_iat_rva();
let function = Import {
dll: dll_name.to_owned(),
name: Some(function_name.to_owned()),
ordinal: None,
rva: iat_rva,
hint: 0,
ilt_value: 0,
};
let iat_entry = ImportAddressEntry {
rva: iat_rva,
dll_name: dll_name.to_owned(),
function_identifier: function_name.to_owned(),
original_value: 0,
};
let descriptor = self
.descriptors
.get_mut(dll_name)
.ok_or_else(|| malformed_error!("DLL '{dll_name}' disappeared from import table"))?;
descriptor.functions.push(function);
self.iat_entries.insert(iat_rva, iat_entry);
Ok(())
}
pub fn add_function_by_ordinal(&mut self, dll_name: &str, ordinal: u16) -> Result<()> {
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if let Some(descriptor) = self.descriptors.get(dll_name) {
if descriptor
.functions
.iter()
.any(|f| f.ordinal == Some(ordinal))
{
return Err(malformed_error!(
"Ordinal {ordinal} already imported from '{dll_name}'"
));
}
} else {
return Err(malformed_error!(
"DLL '{dll_name}' not found in import table"
));
}
let iat_rva = self.allocate_iat_rva();
let descriptor = self
.descriptors
.get_mut(dll_name)
.ok_or_else(|| malformed_error!("DLL '{dll_name}' disappeared from import table"))?;
let function = Import {
dll: dll_name.to_owned(),
name: None,
ordinal: Some(ordinal),
rva: iat_rva,
hint: 0,
ilt_value: 0x8000_0000_0000_0000u64 | u64::from(ordinal),
};
let iat_entry = ImportAddressEntry {
rva: iat_rva,
dll_name: dll_name.to_owned(),
function_identifier: format!("#{ordinal}"),
original_value: function.ilt_value,
};
descriptor.functions.push(function);
self.iat_entries.insert(iat_rva, iat_entry);
Ok(())
}
#[must_use]
pub fn get_descriptor(&self, dll_name: &str) -> Option<&ImportDescriptor> {
self.descriptors.get(dll_name)
}
pub fn descriptors(&self) -> impl Iterator<Item = &ImportDescriptor> {
self.descriptors.values()
}
#[must_use]
pub fn has_dll(&self, dll_name: &str) -> bool {
self.descriptors.contains_key(dll_name)
}
#[must_use]
pub fn dll_count(&self) -> usize {
self.descriptors.len()
}
#[must_use]
pub fn total_function_count(&self) -> usize {
self.descriptors
.values()
.map(|descriptor| descriptor.functions.len())
.sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.descriptors.is_empty()
}
#[must_use]
pub fn get_dll_names(&self) -> Vec<String> {
self.descriptors.keys().cloned().collect()
}
pub fn update_iat_rvas(&mut self, rva_delta: i64) -> Result<()> {
let mut updated_entries = HashMap::new();
for (old_rva, mut entry) in self.iat_entries.drain() {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let new_rva = if rva_delta >= 0 {
old_rva.checked_add(rva_delta as u32)
} else {
old_rva.checked_sub((-rva_delta) as u32)
};
match new_rva {
Some(rva) => {
entry.rva = rva;
updated_entries.insert(rva, entry);
}
None => {
return Err(malformed_error!("RVA delta would cause overflow"));
}
}
}
self.iat_entries = updated_entries;
for descriptor in self.descriptors.values_mut() {
for function in &mut descriptor.functions {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let new_rva = if rva_delta >= 0 {
function.rva.checked_add(rva_delta as u32)
} else {
function.rva.checked_sub((-rva_delta) as u32)
};
match new_rva {
Some(rva) => function.rva = rva,
None => {
return Err(malformed_error!("RVA delta would cause overflow"));
}
}
}
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let new_next_rva = if rva_delta >= 0 {
self.next_iat_rva.checked_add(rva_delta as u32)
} else {
self.next_iat_rva.checked_sub((-rva_delta) as u32)
};
match new_next_rva {
Some(rva) => self.next_iat_rva = rva,
None => {
return Err(malformed_error!("RVA delta would cause overflow"));
}
}
Ok(())
}
fn allocate_iat_rva(&mut self) -> u32 {
let rva = self.next_iat_rva;
self.next_iat_rva += self.iat_entry_size();
rva
}
#[must_use]
pub fn iat_byte_size(&self, is_pe32_plus: bool) -> usize {
let entry_size = if is_pe32_plus { 8 } else { 4 };
let mut total_entries = 0;
for descriptor in self.descriptors.values() {
total_entries += descriptor.functions.len() + 1;
}
total_entries * entry_size
}
pub fn build_iat_bytes(&self, is_pe32_plus: bool, import_table_rva: u32) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let entry_size = if is_pe32_plus { 8 } else { 4 };
let mut iat_bytes = Vec::with_capacity(self.iat_byte_size(is_pe32_plus));
let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect();
descriptors_sorted
.sort_by(|a, b| a.dll_name.to_lowercase().cmp(&b.dll_name.to_lowercase()));
let descriptor_size = (self.descriptors.len() + 1) * 20;
let mut total_ilt_entries = 0;
for desc in &descriptors_sorted {
total_ilt_entries += desc.functions.len() + 1; }
let ilt_size = total_ilt_entries * entry_size;
#[allow(clippy::cast_possible_truncation)]
let strings_start_rva = import_table_rva + (descriptor_size + ilt_size) as u32;
let mut current_string_rva = strings_start_rva;
let mut dll_name_end_rva = current_string_rva;
#[allow(clippy::cast_possible_truncation)]
for desc in &descriptors_sorted {
dll_name_end_rva += (desc.dll_name.len() + 1) as u32; }
current_string_rva = dll_name_end_rva;
for descriptor in &descriptors_sorted {
for function in &descriptor.functions {
let thunk_value = if let Some(ordinal) = function.ordinal {
if function.name.is_none() {
if is_pe32_plus {
0x8000_0000_0000_0000u64 | u64::from(ordinal)
} else {
0x8000_0000u64 | u64::from(ordinal)
}
} else {
u64::from(current_string_rva)
}
} else {
u64::from(current_string_rva)
};
if is_pe32_plus {
iat_bytes.extend_from_slice(&thunk_value.to_le_bytes());
} else {
#[allow(clippy::cast_possible_truncation)]
iat_bytes.extend_from_slice(&(thunk_value as u32).to_le_bytes());
}
#[allow(clippy::cast_possible_truncation)]
if let Some(function_name) = function.name.as_ref() {
current_string_rva += 2; current_string_rva += (function_name.len() + 1) as u32;
}
}
if is_pe32_plus {
iat_bytes.extend_from_slice(&0u64.to_le_bytes());
} else {
iat_bytes.extend_from_slice(&0u32.to_le_bytes());
}
}
Ok(iat_bytes)
}
pub fn build_import_table(
&self,
is_pe32_plus: bool,
iat_rva: u32,
table_rva: u32,
) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let entry_size = if is_pe32_plus { 8 } else { 4 };
let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect();
descriptors_sorted
.sort_by(|a, b| a.dll_name.to_lowercase().cmp(&b.dll_name.to_lowercase()));
let descriptor_table_size = (descriptors_sorted.len() + 1) * 20;
let mut total_ilt_entries = 0;
for desc in &descriptors_sorted {
total_ilt_entries += desc.functions.len() + 1; }
let ilt_size = total_ilt_entries * entry_size;
let mut total_string_size = 0;
for desc in &descriptors_sorted {
total_string_size += desc.dll_name.len() + 1; for func in &desc.functions {
if let Some(ref name) = func.name {
total_string_size += 2 + name.len() + 1; }
}
}
let total_size = descriptor_table_size + ilt_size + total_string_size + 16; let mut data = vec![0u8; total_size];
let mut offset = 0;
#[allow(clippy::cast_possible_truncation)]
let ilt_start_rva = table_rva + descriptor_table_size as u32;
#[allow(clippy::cast_possible_truncation)]
let strings_start_rva = ilt_start_rva + ilt_size as u32;
let mut ilt_rva = ilt_start_rva;
let mut iat_offset = 0u32;
let mut dll_name_rvas = Vec::with_capacity(descriptors_sorted.len());
let mut current_dll_name_rva = strings_start_rva;
#[allow(clippy::cast_possible_truncation)]
for desc in &descriptors_sorted {
dll_name_rvas.push(current_dll_name_rva);
current_dll_name_rva += (desc.dll_name.len() + 1) as u32;
}
let mut current_func_name_rva = current_dll_name_rva;
let mut func_name_rvas: Vec<Vec<u64>> = Vec::with_capacity(descriptors_sorted.len());
#[allow(clippy::cast_possible_truncation)]
for desc in &descriptors_sorted {
let mut rvas = Vec::with_capacity(desc.functions.len());
for func in &desc.functions {
if let Some(function_name) = func.name.as_ref() {
rvas.push(u64::from(current_func_name_rva));
current_func_name_rva += 2; current_func_name_rva += (function_name.len() + 1) as u32;
} else {
rvas.push(0); }
}
func_name_rvas.push(rvas);
}
#[allow(clippy::cast_possible_truncation)]
for (i, desc) in descriptors_sorted.iter().enumerate() {
let desc_ilt_rva = ilt_rva;
let desc_iat_rva = iat_rva + iat_offset;
write_le_at::<u32>(&mut data, &mut offset, desc_ilt_rva)?;
write_le_at::<u32>(&mut data, &mut offset, 0)?;
write_le_at::<u32>(&mut data, &mut offset, 0)?;
write_le_at::<u32>(&mut data, &mut offset, dll_name_rvas[i])?;
write_le_at::<u32>(&mut data, &mut offset, desc_iat_rva)?;
let entries_for_dll = desc.functions.len() + 1; ilt_rva += (entries_for_dll * entry_size) as u32;
iat_offset += (entries_for_dll * entry_size) as u32;
}
for _ in 0..5 {
write_le_at::<u32>(&mut data, &mut offset, 0)?;
}
for (i, desc) in descriptors_sorted.iter().enumerate() {
for (j, func) in desc.functions.iter().enumerate() {
let ilt_value = if func.name.is_none() {
if let Some(ordinal) = func.ordinal {
if is_pe32_plus {
0x8000_0000_0000_0000u64 | u64::from(ordinal)
} else {
0x8000_0000u64 | u64::from(ordinal)
}
} else {
0
}
} else {
func_name_rvas[i][j]
};
if is_pe32_plus {
write_le_at::<u64>(&mut data, &mut offset, ilt_value)?;
} else {
#[allow(clippy::cast_possible_truncation)]
write_le_at::<u32>(&mut data, &mut offset, ilt_value as u32)?;
}
}
if is_pe32_plus {
write_le_at::<u64>(&mut data, &mut offset, 0)?;
} else {
write_le_at::<u32>(&mut data, &mut offset, 0)?;
}
}
for desc in &descriptors_sorted {
write_string_at(&mut data, &mut offset, &desc.dll_name)?;
}
for desc in &descriptors_sorted {
for func in &desc.functions {
if let Some(ref name) = func.name {
write_le_at::<u16>(&mut data, &mut offset, func.hint)?;
write_string_at(&mut data, &mut offset, name)?;
}
}
}
while offset % 4 != 0 {
if offset < data.len() {
data[offset] = 0;
}
offset += 1;
}
data.truncate(offset);
Ok(data)
}
}
impl Default for NativeImports {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_native_imports_is_empty() {
let imports = NativeImports::new();
assert!(imports.is_empty());
assert_eq!(imports.dll_count(), 0);
}
#[test]
fn add_dll_works() {
let mut imports = NativeImports::new();
imports.add_dll("kernel32.dll").unwrap();
assert!(!imports.is_empty());
assert_eq!(imports.dll_count(), 1);
assert!(imports.has_dll("kernel32.dll"));
imports.add_dll("kernel32.dll").unwrap();
assert_eq!(imports.dll_count(), 1);
}
}