use std::collections::HashMap;
use crate::{
file::pe::Export,
utils::{to_u32, write_le_at, write_string_at},
Result,
};
#[derive(Debug, Clone)]
pub struct NativeExports {
directory: ExportDirectory,
functions: HashMap<u16, ExportFunction>,
forwarders: HashMap<u16, ExportForwarder>,
name_to_ordinal: HashMap<String, u16>,
next_ordinal: u16,
export_table_base_rva: u32,
}
#[derive(Debug, Clone)]
pub struct ExportDirectory {
pub dll_name: String,
pub base_ordinal: u16,
pub function_count: u32,
pub name_count: u32,
pub timestamp: u32,
pub major_version: u16,
pub minor_version: u16,
}
#[derive(Debug, Clone)]
pub struct ExportFunction {
pub ordinal: u16,
pub name: Option<String>,
pub address: u32,
pub is_forwarder: bool,
}
#[derive(Debug, Clone)]
pub struct ExportForwarder {
pub ordinal: u16,
pub name: Option<String>,
pub target: String,
}
impl NativeExports {
#[must_use]
pub fn new(dll_name: &str) -> Self {
Self {
directory: ExportDirectory {
dll_name: dll_name.to_owned(),
base_ordinal: 1,
function_count: 0,
name_count: 0,
timestamp: 0,
major_version: 0,
minor_version: 0,
},
functions: HashMap::new(),
forwarders: HashMap::new(),
name_to_ordinal: HashMap::new(),
next_ordinal: 1,
export_table_base_rva: 0,
}
}
pub fn from_pe_exports(pe_exports: &[Export]) -> Result<Self> {
let mut exports = Self::new("");
for export in pe_exports {
let ordinal = u16::try_from(export.offset.unwrap_or(0))
.map_err(|_| malformed_error!("Export ordinal exceeds u16 range"))?;
if export.rva == 0 {
continue; }
if exports.directory.dll_name.is_empty() {
if let Some(ref name) = export.name {
exports.directory.dll_name.clone_from(name);
}
}
if let Some(ref name) = export.name {
exports.add_function(name, ordinal, export.rva)?;
} else {
exports.add_function_by_ordinal(ordinal, export.rva)?;
}
}
Ok(exports)
}
pub fn add_function(&mut self, name: &str, ordinal: u16, address: u32) -> Result<()> {
if name.is_empty() {
return Err(malformed_error!("Function name cannot be empty"));
}
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if self.functions.contains_key(&ordinal) || self.forwarders.contains_key(&ordinal) {
return Err(malformed_error!("Ordinal {ordinal} is already in use"));
}
if self.name_to_ordinal.contains_key(name) {
return Err(malformed_error!(
"Function name '{name}' is already exported"
));
}
let function = ExportFunction {
ordinal,
name: Some(name.to_owned()),
address,
is_forwarder: false,
};
self.functions.insert(ordinal, function);
self.name_to_ordinal.insert(name.to_owned(), ordinal);
self.directory.function_count = to_u32(self.functions.len())?;
self.directory.name_count = to_u32(self.name_to_ordinal.len())?;
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
Ok(())
}
pub fn add_function_by_ordinal(&mut self, ordinal: u16, address: u32) -> Result<()> {
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if self.functions.contains_key(&ordinal) || self.forwarders.contains_key(&ordinal) {
return Err(malformed_error!("Ordinal {ordinal} is already in use"));
}
let function = ExportFunction {
ordinal,
name: None,
address,
is_forwarder: false,
};
self.functions.insert(ordinal, function);
self.directory.function_count = to_u32(self.functions.len())?;
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
Ok(())
}
pub fn add_forwarder(&mut self, name: &str, ordinal: u16, target: &str) -> Result<()> {
if ordinal == 0 {
return Err(malformed_error!("Ordinal cannot be 0"));
}
if target.is_empty() {
return Err(malformed_error!("Forwarder target cannot be empty"));
}
if !target.contains('.') {
return Err(malformed_error!(
"Forwarder target '{}' must be in format 'DllName.FunctionName' or 'DllName.#Ordinal'",
target
));
}
if target.contains('\0') {
return Err(malformed_error!(
"Forwarder target cannot contain null bytes"
));
}
if self.functions.contains_key(&ordinal) || self.forwarders.contains_key(&ordinal) {
return Err(malformed_error!("Ordinal {ordinal} is already in use"));
}
if !name.is_empty() && self.name_to_ordinal.contains_key(name) {
return Err(malformed_error!(
"Function name '{name}' is already exported"
));
}
let forwarder = ExportForwarder {
ordinal,
name: if name.is_empty() {
None
} else {
Some(name.to_owned())
},
target: target.to_owned(),
};
self.forwarders.insert(ordinal, forwarder);
if !name.is_empty() {
self.name_to_ordinal.insert(name.to_owned(), ordinal);
}
self.directory.function_count = to_u32(self.functions.len() + self.forwarders.len())?;
self.directory.name_count = to_u32(self.name_to_ordinal.len())?;
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
Ok(())
}
#[must_use]
pub fn dll_name(&self) -> &str {
&self.directory.dll_name
}
#[must_use]
pub fn function_count(&self) -> usize {
self.functions.len() + self.forwarders.len()
}
#[must_use]
pub fn forwarder_count(&self) -> usize {
self.forwarders.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.functions.is_empty() && self.forwarders.is_empty()
}
#[must_use]
pub fn has_function(&self, name: &str) -> bool {
self.name_to_ordinal.contains_key(name)
}
#[must_use]
pub fn get_function_by_ordinal(&self, ordinal: u16) -> Option<&ExportFunction> {
self.functions.get(&ordinal)
}
#[must_use]
pub fn get_forwarder_by_ordinal(&self, ordinal: u16) -> Option<&ExportForwarder> {
self.forwarders.get(&ordinal)
}
#[must_use]
pub fn get_ordinal_by_name(&self, name: &str) -> Option<u16> {
self.name_to_ordinal.get(name).copied()
}
pub fn functions(&self) -> impl Iterator<Item = &ExportFunction> {
self.functions.values()
}
pub fn forwarders(&self) -> impl Iterator<Item = &ExportForwarder> {
self.forwarders.values()
}
#[must_use]
pub fn get_exported_function_names(&self) -> Vec<String> {
self.name_to_ordinal.keys().cloned().collect()
}
pub fn get_export_table_data(&self) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let base_rva = self.export_table_base_rva;
if base_rva == 0 {
return Err(malformed_error!("Export table base RVA not set"));
}
self.get_export_table_data_with_base_rva(base_rva)
}
pub fn get_export_table_data_with_base_rva(&self, base_rva: u32) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(Vec::new());
}
let export_dir_size = 40u32;
let mut min_ordinal = u16::MAX;
let mut max_ordinal = 0u16;
for &ordinal in self.functions.keys().chain(self.forwarders.keys()) {
if ordinal < min_ordinal {
min_ordinal = ordinal;
}
if ordinal > max_ordinal {
max_ordinal = ordinal;
}
}
let eat_entry_count = if max_ordinal >= self.directory.base_ordinal {
u32::from(max_ordinal - self.directory.base_ordinal + 1)
} else {
0
};
let eat_size = eat_entry_count * 4; let name_table_size = self.directory.name_count * 4; let ordinal_table_size = self.directory.name_count * 2;
let eat_rva = base_rva + export_dir_size;
let name_table_rva = eat_rva + eat_size;
let ordinal_table_rva = name_table_rva + name_table_size;
let strings_rva = ordinal_table_rva + ordinal_table_size;
let mut total_strings_size = self.directory.dll_name.len() + 1; for name in self.name_to_ordinal.keys() {
total_strings_size += name.len() + 1; }
for forwarder in self.forwarders.values() {
total_strings_size += forwarder.target.len() + 1; }
let total_size = export_dir_size
+ eat_size
+ name_table_size
+ ordinal_table_size
+ to_u32(total_strings_size)?;
let mut data = vec![0u8; total_size as usize];
let mut offset = 0;
write_le_at(&mut data, &mut offset, 0u32)?; write_le_at(&mut data, &mut offset, self.directory.timestamp)?; write_le_at(&mut data, &mut offset, self.directory.major_version)?; write_le_at(&mut data, &mut offset, self.directory.minor_version)?; write_le_at(&mut data, &mut offset, strings_rva)?; write_le_at(
&mut data,
&mut offset,
u32::from(self.directory.base_ordinal),
)?; write_le_at(&mut data, &mut offset, eat_entry_count)?; write_le_at(&mut data, &mut offset, self.directory.name_count)?; write_le_at(&mut data, &mut offset, eat_rva)?; write_le_at(&mut data, &mut offset, name_table_rva)?; write_le_at(&mut data, &mut offset, ordinal_table_rva)?;
let mut named_exports: Vec<(&String, u16)> = self
.name_to_ordinal
.iter()
.map(|(name, &ordinal)| (name, ordinal))
.collect();
named_exports.sort_by_key(|(name, _)| name.as_str());
let mut forwarder_string_offsets = HashMap::new();
let mut current_forwarder_offset = self.directory.dll_name.len() + 1; for (name, _) in &named_exports {
current_forwarder_offset += name.len() + 1; }
for forwarder in self.forwarders.values() {
forwarder_string_offsets.insert(forwarder.ordinal, current_forwarder_offset);
current_forwarder_offset += forwarder.target.len() + 1;
}
let eat_start_offset = offset;
for _ in 0..eat_entry_count {
write_le_at(&mut data, &mut offset, 0u32)?;
}
let mut temp_offset = eat_start_offset;
for ordinal_index in 0..eat_entry_count {
#[allow(clippy::cast_possible_truncation)]
let ordinal = self.directory.base_ordinal + (ordinal_index as u16);
if let Some(function) = self.functions.get(&ordinal) {
data[temp_offset..temp_offset + 4].copy_from_slice(&function.address.to_le_bytes());
} else if let Some(_forwarder) = self.forwarders.get(&ordinal) {
if let Some(&string_offset) = forwarder_string_offsets.get(&ordinal) {
let forwarder_rva = strings_rva + to_u32(string_offset)?;
data[temp_offset..temp_offset + 4]
.copy_from_slice(&forwarder_rva.to_le_bytes());
}
}
temp_offset += 4;
}
let mut name_string_offset = self.directory.dll_name.len() + 1; for (name, _) in &named_exports {
let name_rva = strings_rva + to_u32(name_string_offset)?;
write_le_at(&mut data, &mut offset, name_rva)?;
name_string_offset += name.len() + 1; }
for (_, ordinal) in &named_exports {
let adjusted_ordinal = ordinal - self.directory.base_ordinal;
write_le_at(&mut data, &mut offset, adjusted_ordinal)?;
}
write_string_at(&mut data, &mut offset, &self.directory.dll_name)?;
for (name, _ordinal) in &named_exports {
write_string_at(&mut data, &mut offset, name)?;
}
for forwarder in self.forwarders.values() {
write_string_at(&mut data, &mut offset, &forwarder.target)?;
}
Ok(data)
}
pub fn set_export_table_base_rva(&mut self, base_rva: u32) {
self.export_table_base_rva = base_rva;
}
pub fn set_dll_name(&mut self, dll_name: &str) -> Result<()> {
if dll_name.is_empty() {
return Err(malformed_error!("DLL name cannot be empty"));
}
if dll_name.contains('\0') {
return Err(malformed_error!("DLL name cannot contain null bytes"));
}
dll_name.clone_into(&mut self.directory.dll_name);
Ok(())
}
#[must_use]
pub fn directory(&self) -> &ExportDirectory {
&self.directory
}
}
impl ExportFunction {
#[must_use]
pub fn is_forwarder(&self) -> bool {
self.is_forwarder
}
}
impl Default for NativeExports {
fn default() -> Self {
Self::new("Unknown.dll")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_native_exports_is_empty() {
let exports = NativeExports::new("Test.dll");
assert!(exports.is_empty());
assert_eq!(exports.function_count(), 0);
assert_eq!(exports.forwarder_count(), 0);
assert_eq!(exports.dll_name(), "Test.dll");
}
#[test]
fn add_function_works() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("MyFunction", 1, 0x1000).unwrap();
assert!(!exports.is_empty());
assert_eq!(exports.function_count(), 1);
assert!(exports.has_function("MyFunction"));
let function = exports.get_function_by_ordinal(1).unwrap();
assert_eq!(function.name, Some("MyFunction".to_string()));
assert_eq!(function.address, 0x1000);
assert!(!function.is_forwarder());
}
#[test]
fn add_function_with_empty_name_fails() {
let mut exports = NativeExports::new("Test.dll");
let result = exports.add_function("", 1, 0x1000);
assert!(result.is_err());
}
#[test]
fn add_function_with_zero_ordinal_fails() {
let mut exports = NativeExports::new("Test.dll");
let result = exports.add_function("MyFunction", 0, 0x1000);
assert!(result.is_err());
}
#[test]
fn add_duplicate_function_name_fails() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("MyFunction", 1, 0x1000).unwrap();
let result = exports.add_function("MyFunction", 2, 0x2000);
assert!(result.is_err());
}
#[test]
fn add_duplicate_ordinal_fails() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("Function1", 1, 0x1000).unwrap();
let result = exports.add_function("Function2", 1, 0x2000);
assert!(result.is_err());
}
#[test]
fn add_function_by_ordinal_works() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function_by_ordinal(1, 0x1000).unwrap();
assert_eq!(exports.function_count(), 1);
let function = exports.get_function_by_ordinal(1).unwrap();
assert_eq!(function.name, None);
assert_eq!(function.address, 0x1000);
}
#[test]
fn add_forwarder_works() {
let mut exports = NativeExports::new("Test.dll");
exports
.add_forwarder("ForwardedFunc", 1, "kernel32.dll.GetCurrentProcessId")
.unwrap();
assert_eq!(exports.function_count(), 1);
assert_eq!(exports.forwarder_count(), 1);
assert!(exports.has_function("ForwardedFunc"));
let forwarder = exports.get_forwarder_by_ordinal(1).unwrap();
assert_eq!(forwarder.name, Some("ForwardedFunc".to_string()));
assert_eq!(forwarder.target, "kernel32.dll.GetCurrentProcessId");
}
#[test]
fn add_forwarder_with_empty_target_fails() {
let mut exports = NativeExports::new("Test.dll");
let result = exports.add_forwarder("ForwardedFunc", 1, "");
assert!(result.is_err());
}
#[test]
fn get_ordinal_by_name_works() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("Function1", 5, 0x1000).unwrap();
exports
.add_forwarder("Function2", 10, "kernel32.dll.SomeFunc")
.unwrap();
assert_eq!(exports.get_ordinal_by_name("Function1"), Some(5));
assert_eq!(exports.get_ordinal_by_name("Function2"), Some(10));
assert_eq!(exports.get_ordinal_by_name("MissingFunction"), None);
}
#[test]
fn get_exported_function_names_works() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("Function1", 1, 0x1000).unwrap();
exports.add_function("Function2", 2, 0x2000).unwrap();
exports.add_function_by_ordinal(3, 0x3000).unwrap();
let names = exports.get_exported_function_names();
assert_eq!(names.len(), 2);
assert!(names.contains(&"Function1".to_string()));
assert!(names.contains(&"Function2".to_string()));
}
#[test]
fn get_export_table_data_empty_returns_empty() {
let exports = NativeExports::new("Test.dll");
let data = exports.get_export_table_data().unwrap();
assert!(data.is_empty());
}
#[test]
fn get_export_table_data_without_base_rva_fails() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("MyFunction", 1, 0x1000).unwrap();
let result = exports.get_export_table_data();
assert!(result.is_err());
}
#[test]
fn get_export_table_data_with_exports_returns_data() {
let mut exports = NativeExports::new("Test.dll");
exports.set_export_table_base_rva(0x3000);
exports.add_function("MyFunction", 1, 0x1000).unwrap();
let data = exports.get_export_table_data().unwrap();
assert!(!data.is_empty());
assert!(data.len() >= 40); }
#[test]
fn function_iteration_works() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("Function1", 1, 0x1000).unwrap();
exports.add_function("Function2", 2, 0x2000).unwrap();
let functions: Vec<&ExportFunction> = exports.functions().collect();
assert_eq!(functions.len(), 2);
}
#[test]
fn forwarder_iteration_works() {
let mut exports = NativeExports::new("Test.dll");
exports
.add_forwarder("Forwarder1", 1, "kernel32.dll.Func1")
.unwrap();
exports
.add_forwarder("Forwarder2", 2, "user32.dll.Func2")
.unwrap();
let forwarders: Vec<&ExportForwarder> = exports.forwarders().collect();
assert_eq!(forwarders.len(), 2);
}
#[test]
fn mixed_functions_and_forwarders() {
let mut exports = NativeExports::new("Test.dll");
exports.add_function("RegularFunc", 1, 0x1000).unwrap();
exports
.add_forwarder("ForwardedFunc", 2, "kernel32.dll.GetTick")
.unwrap();
exports.add_function_by_ordinal(3, 0x3000).unwrap();
assert_eq!(exports.function_count(), 3); assert_eq!(exports.forwarders().count(), 1); assert_eq!(exports.functions().count(), 2);
let names = exports.get_exported_function_names();
assert_eq!(names.len(), 2); }
#[test]
fn set_dll_name_works() {
let mut exports = NativeExports::new("OldName.dll");
assert_eq!(exports.dll_name(), "OldName.dll");
exports.set_dll_name("NewName.dll").unwrap();
assert_eq!(exports.dll_name(), "NewName.dll");
assert_eq!(exports.directory().dll_name, "NewName.dll");
}
#[test]
fn set_dll_name_empty_fails() {
let mut exports = NativeExports::new("Original.dll");
let result = exports.set_dll_name("");
assert!(result.is_err());
assert_eq!(exports.dll_name(), "Original.dll");
}
#[test]
fn set_dll_name_with_null_byte_fails() {
let mut exports = NativeExports::new("Original.dll");
let result = exports.set_dll_name("Invalid\0Name.dll");
assert!(result.is_err());
assert_eq!(exports.dll_name(), "Original.dll");
}
#[test]
fn add_forwarder_invalid_format_fails() {
let mut exports = NativeExports::new("Test.dll");
let result = exports.add_forwarder("BadForward", 1, "kernel32GetTick");
assert!(result.is_err());
assert_eq!(exports.forwarder_count(), 0);
}
#[test]
fn add_forwarder_with_null_byte_fails() {
let mut exports = NativeExports::new("Test.dll");
let result = exports.add_forwarder("BadForward", 1, "kernel32\0.dll.GetTick");
assert!(result.is_err());
assert_eq!(exports.forwarder_count(), 0);
}
#[test]
fn add_forwarder_valid_formats() {
let mut exports = NativeExports::new("Test.dll");
exports
.add_forwarder("Forward1", 1, "kernel32.dll.GetCurrentProcessId")
.unwrap();
exports
.add_forwarder("Forward2", 2, "user32.dll.#120")
.unwrap();
exports.add_forwarder("Forward3", 3, "api.Func").unwrap();
assert_eq!(exports.forwarder_count(), 3);
}
}