use dashmap::{mapref::entry::Entry, DashMap};
use std::sync::atomic::{AtomicBool, Ordering};
use crate::{
metadata::{
exports::{native::NativeExports, Exports as CilExports},
tables::ExportedTypeRc,
token::Token,
},
Result,
};
pub struct UnifiedExportContainer {
cil: CilExports,
native: NativeExports,
unified_name_cache: DashMap<String, Vec<ExportEntry>>,
unified_function_cache: DashMap<String, ExportSource>,
cache_dirty: AtomicBool,
}
#[derive(Clone)]
pub enum ExportEntry {
Cil(ExportedTypeRc),
Native(NativeExportRef),
}
impl std::fmt::Debug for ExportEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExportEntry::Cil(cil_export) => f
.debug_struct("Cil")
.field("name", &cil_export.name)
.field("namespace", &cil_export.namespace)
.finish(),
ExportEntry::Native(native_ref) => f
.debug_struct("Native")
.field("ordinal", &native_ref.ordinal)
.field("name", &native_ref.name)
.finish(),
}
}
}
#[derive(Clone, Debug)]
pub struct NativeExportRef {
pub ordinal: u16,
pub name: Option<String>,
pub address_or_forwarder: ExportTarget,
}
#[derive(Clone, Debug)]
pub enum ExportTarget {
Address(u32),
Forwarder(String),
}
#[derive(Clone, Debug)]
pub enum ExportSource {
Cil(Token),
Native(u16), Both(Token, u16),
}
#[derive(Clone, Debug)]
pub struct ExportedFunction {
pub name: String,
pub source: ExportSource,
pub is_forwarder: bool,
pub forwarder_target: Option<String>,
}
impl Clone for UnifiedExportContainer {
fn clone(&self) -> Self {
Self {
cil: self.cil.clone(),
native: self.native.clone(),
unified_name_cache: DashMap::new(), unified_function_cache: DashMap::new(), cache_dirty: AtomicBool::new(true), }
}
}
impl UnifiedExportContainer {
#[must_use]
pub fn new() -> Self {
Self {
cil: CilExports::new(),
native: NativeExports::new(""), unified_name_cache: DashMap::new(),
unified_function_cache: DashMap::new(),
cache_dirty: AtomicBool::new(true),
}
}
#[must_use]
pub fn with_dll_name(dll_name: &str) -> Self {
Self {
cil: CilExports::new(),
native: NativeExports::new(dll_name),
unified_name_cache: DashMap::new(),
unified_function_cache: DashMap::new(),
cache_dirty: AtomicBool::new(true),
}
}
pub fn cil(&self) -> &CilExports {
&self.cil
}
pub fn native(&self) -> &NativeExports {
&self.native
}
pub fn native_mut(&mut self) -> &mut NativeExports {
self.invalidate_cache();
&mut self.native
}
pub fn find_by_name(&self, name: &str) -> Vec<ExportEntry> {
self.ensure_cache_fresh();
if let Some(entries) = self.unified_name_cache.get(name) {
entries.value().clone()
} else {
Vec::new()
}
}
pub fn get_all_exported_functions(&self) -> Vec<ExportedFunction> {
self.ensure_cache_fresh();
self.unified_function_cache
.iter()
.map(|entry| {
let name = entry.key().clone();
let source = entry.value().clone();
let (is_forwarder, forwarder_target) = match &source {
ExportSource::Native(ordinal) => {
if let Some(forwarder) = self.native.get_forwarder_by_ordinal(*ordinal) {
(true, Some(forwarder.target.clone()))
} else {
(false, None)
}
}
_ => (false, None),
};
ExportedFunction {
name,
source,
is_forwarder,
forwarder_target,
}
})
.collect()
}
pub fn get_native_function_names(&self) -> Vec<String> {
self.native.get_exported_function_names()
}
pub fn is_empty(&self) -> bool {
self.cil.is_empty() && self.native.is_empty()
}
pub fn total_count(&self) -> usize {
self.cil.len() + self.native.function_count() + self.native.forwarder_count()
}
pub fn add_native_function(
&mut self,
function_name: &str,
ordinal: u16,
address: u32,
) -> Result<()> {
self.native.add_function(function_name, ordinal, address)?;
self.invalidate_cache();
Ok(())
}
pub fn add_native_function_by_ordinal(&mut self, ordinal: u16, address: u32) -> Result<()> {
self.native.add_function_by_ordinal(ordinal, address)?;
self.invalidate_cache();
Ok(())
}
pub fn add_native_forwarder(
&mut self,
function_name: &str,
ordinal: u16,
forwarder_target: &str,
) -> Result<()> {
self.native
.add_forwarder(function_name, ordinal, forwarder_target)?;
self.invalidate_cache();
Ok(())
}
pub fn get_export_table_data(&self) -> Result<Option<Vec<u8>>> {
if self.native.is_empty() {
Ok(None)
} else {
Ok(Some(self.native.get_export_table_data()?))
}
}
pub fn set_dll_name(&mut self, dll_name: &str) -> Result<()> {
self.native.set_dll_name(dll_name)
}
fn ensure_cache_fresh(&self) {
if self
.cache_dirty
.compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
self.rebuild_unified_caches();
}
}
fn invalidate_cache(&self) {
self.cache_dirty.store(true, Ordering::Relaxed);
}
fn rebuild_unified_caches(&self) {
self.unified_name_cache.clear();
self.unified_function_cache.clear();
for export_entry in &self.cil {
let export_type = export_entry.value();
let token = *export_entry.key();
self.unified_name_cache
.entry(export_type.name.clone())
.or_default()
.push(ExportEntry::Cil(export_type.clone()));
match self.unified_function_cache.entry(export_type.name.clone()) {
Entry::Occupied(mut entry) => {
match entry.get() {
ExportSource::Native(ordinal) => {
*entry.get_mut() = ExportSource::Both(token, *ordinal);
}
ExportSource::Cil(_) | ExportSource::Both(_, _) => {
}
}
}
Entry::Vacant(entry) => {
entry.insert(ExportSource::Cil(token));
}
}
}
for function in self.native.functions() {
if let Some(ref name) = function.name {
self.unified_name_cache
.entry(name.clone())
.or_default()
.push(ExportEntry::Native(NativeExportRef {
ordinal: function.ordinal,
name: Some(name.clone()),
address_or_forwarder: ExportTarget::Address(function.address),
}));
match self.unified_function_cache.entry(name.clone()) {
Entry::Occupied(mut entry) => {
match entry.get() {
ExportSource::Cil(token) => {
*entry.get_mut() = ExportSource::Both(*token, function.ordinal);
}
ExportSource::Native(_) | ExportSource::Both(_, _) => {
}
}
}
Entry::Vacant(entry) => {
entry.insert(ExportSource::Native(function.ordinal));
}
}
}
}
for forwarder in self.native.forwarders() {
if let Some(ref name) = forwarder.name {
self.unified_name_cache
.entry(name.clone())
.or_default()
.push(ExportEntry::Native(NativeExportRef {
ordinal: forwarder.ordinal,
name: Some(name.clone()),
address_or_forwarder: ExportTarget::Forwarder(forwarder.target.clone()),
}));
self.unified_function_cache
.entry(name.clone())
.or_insert_with(|| ExportSource::Native(forwarder.ordinal));
}
}
}
}
impl Default for UnifiedExportContainer {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for UnifiedExportContainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnifiedExportContainer")
.field("cil_count", &self.cil.len())
.field("native_function_count", &self.native.function_count())
.field("native_forwarder_count", &self.native.forwarder_count())
.field("is_cache_dirty", &self.cache_dirty.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unified_export_container_new() {
let container = UnifiedExportContainer::new();
assert!(container.is_empty());
assert_eq!(container.total_count(), 0);
}
#[test]
fn test_unified_export_container_with_dll_name() {
let container = UnifiedExportContainer::with_dll_name("MyLibrary.dll");
assert!(container.is_empty());
}
#[test]
fn test_unified_export_container_default() {
let container = UnifiedExportContainer::default();
assert!(container.is_empty());
}
#[test]
fn test_add_native_function() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("TestFunction", 1, 0x1000)
.unwrap();
assert!(!container.is_empty());
assert_eq!(container.total_count(), 1);
}
#[test]
fn test_add_multiple_native_functions() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("Function1", 1, 0x1000)
.unwrap();
container
.add_native_function("Function2", 2, 0x2000)
.unwrap();
container
.add_native_function("Function3", 3, 0x3000)
.unwrap();
assert_eq!(container.total_count(), 3);
}
#[test]
fn test_add_native_function_by_ordinal() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function_by_ordinal(100, 0x5000)
.unwrap();
assert!(!container.is_empty());
assert_eq!(container.total_count(), 1);
}
#[test]
fn test_add_native_function_invalid_ordinal() {
let mut container = UnifiedExportContainer::new();
let result = container.add_native_function("Test", 0, 0x1000);
assert!(result.is_err());
}
#[test]
fn test_add_native_function_duplicate_ordinal() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("Function1", 1, 0x1000)
.unwrap();
let result = container.add_native_function("Function2", 1, 0x2000);
assert!(result.is_err());
}
#[test]
fn test_add_native_forwarder() {
let mut container = UnifiedExportContainer::new();
container
.add_native_forwarder("ForwardedFunction", 1, "kernel32.dll.GetCurrentProcessId")
.unwrap();
assert!(!container.is_empty());
assert!(container.total_count() >= 1);
assert!(container.native().forwarder_count() >= 1);
}
#[test]
fn test_find_by_name_native() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("MyExport", 1, 0x1000)
.unwrap();
let results = container.find_by_name("MyExport");
assert_eq!(results.len(), 1);
let ExportEntry::Native(native_ref) = &results[0] else {
panic!("Expected Native export entry, got {:?}", &results[0]);
};
assert_eq!(native_ref.ordinal, 1);
assert_eq!(native_ref.name, Some("MyExport".to_string()));
}
#[test]
fn test_find_by_name_not_found() {
let container = UnifiedExportContainer::new();
let results = container.find_by_name("NonExistent");
assert!(results.is_empty());
}
#[test]
fn test_get_native_function_names() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Alpha", 1, 0x1000).unwrap();
container.add_native_function("Beta", 2, 0x2000).unwrap();
container.add_native_function("Gamma", 3, 0x3000).unwrap();
let names = container.get_native_function_names();
assert_eq!(names.len(), 3);
assert!(names.contains(&"Alpha".to_string()));
assert!(names.contains(&"Beta".to_string()));
assert!(names.contains(&"Gamma".to_string()));
}
#[test]
fn test_get_all_exported_functions() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("NativeFunc", 1, 0x1000)
.unwrap();
let functions = container.get_all_exported_functions();
assert_eq!(functions.len(), 1);
assert_eq!(functions[0].name, "NativeFunc");
let ExportSource::Native(ordinal) = functions[0].source else {
panic!(
"Expected Native export source, got {:?}",
functions[0].source
);
};
assert_eq!(ordinal, 1);
}
#[test]
fn test_get_all_exported_functions_with_forwarder() {
let mut container = UnifiedExportContainer::new();
container
.add_native_forwarder("ForwardedFunc", 1, "other.dll.RealFunc")
.unwrap();
let functions = container.get_all_exported_functions();
assert_eq!(functions.len(), 1);
assert!(functions[0].is_forwarder);
assert_eq!(
functions[0].forwarder_target,
Some("other.dll.RealFunc".to_string())
);
}
#[test]
fn test_cil_accessor() {
let container = UnifiedExportContainer::new();
let cil = container.cil();
assert!(cil.is_empty());
}
#[test]
fn test_native_accessor() {
let container = UnifiedExportContainer::new();
let native = container.native();
assert!(native.is_empty());
}
#[test]
fn test_native_mut_invalidates_cache() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Test", 1, 0x1000).unwrap();
let _ = container.find_by_name("Test");
let _ = container.native_mut();
assert!(container.cache_dirty.load(Ordering::Relaxed));
}
#[test]
fn test_clone_resets_cache() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Test", 1, 0x1000).unwrap();
let _ = container.find_by_name("Test");
let cloned = container.clone();
assert!(cloned.cache_dirty.load(Ordering::Relaxed));
assert_eq!(cloned.total_count(), 1);
}
#[test]
fn test_debug_output() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Test", 1, 0x1000).unwrap();
let debug_output = format!("{:?}", container);
assert!(debug_output.contains("UnifiedExportContainer"));
assert!(debug_output.contains("native_function_count"));
}
#[test]
fn test_export_target_address() {
let target = ExportTarget::Address(0x1234);
let ExportTarget::Address(addr) = target else {
panic!("Expected Address variant, got {:?}", target);
};
assert_eq!(addr, 0x1234);
}
#[test]
fn test_export_target_forwarder() {
let target = ExportTarget::Forwarder("kernel32.dll.Func".to_string());
let ExportTarget::Forwarder(ref fwd) = target else {
panic!("Expected Forwarder variant, got {:?}", target);
};
assert_eq!(fwd, "kernel32.dll.Func");
}
#[test]
fn test_export_source_variants() {
let token = Token::new(0x02000001);
let cil_source = ExportSource::Cil(token);
if let ExportSource::Cil(t) = cil_source {
assert_eq!(t, token);
}
let native_source = ExportSource::Native(42);
if let ExportSource::Native(ord) = native_source {
assert_eq!(ord, 42);
}
let both_source = ExportSource::Both(token, 42);
if let ExportSource::Both(t, ord) = both_source {
assert_eq!(t, token);
assert_eq!(ord, 42);
}
}
#[test]
fn test_native_export_ref_clone() {
let export_ref = NativeExportRef {
ordinal: 1,
name: Some("TestFunc".to_string()),
address_or_forwarder: ExportTarget::Address(0x1000),
};
let cloned = export_ref.clone();
assert_eq!(cloned.ordinal, 1);
assert_eq!(cloned.name, Some("TestFunc".to_string()));
}
#[test]
fn test_exported_function_structure() {
let func = ExportedFunction {
name: "TestFunction".to_string(),
source: ExportSource::Native(1),
is_forwarder: false,
forwarder_target: None,
};
assert_eq!(func.name, "TestFunction");
assert!(!func.is_forwarder);
assert!(func.forwarder_target.is_none());
}
#[test]
fn test_exported_function_forwarder() {
let func = ExportedFunction {
name: "ForwardedFunc".to_string(),
source: ExportSource::Native(1),
is_forwarder: true,
forwarder_target: Some("target.dll.RealFunc".to_string()),
};
assert!(func.is_forwarder);
assert_eq!(
func.forwarder_target,
Some("target.dll.RealFunc".to_string())
);
}
#[test]
fn test_set_dll_name() {
let mut container = UnifiedExportContainer::new();
assert_eq!(container.native().dll_name(), "");
container.set_dll_name("MyLibrary.dll").unwrap();
assert_eq!(container.native().dll_name(), "MyLibrary.dll");
container.set_dll_name("AnotherName.dll").unwrap();
assert_eq!(container.native().dll_name(), "AnotherName.dll");
}
#[test]
fn test_set_dll_name_validation() {
let mut container = UnifiedExportContainer::new();
let result = container.set_dll_name("");
assert!(result.is_err());
let result = container.set_dll_name("test\0.dll");
assert!(result.is_err());
let result = container.set_dll_name("Valid.dll");
assert!(result.is_ok());
}
#[test]
fn test_cache_freshness_compare_and_swap() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("TestFunc", 1, 0x1000)
.unwrap();
assert!(container.cache_dirty.load(Ordering::Relaxed));
container.ensure_cache_fresh();
assert!(!container.cache_dirty.load(Ordering::Relaxed));
container.ensure_cache_fresh();
assert!(!container.cache_dirty.load(Ordering::Relaxed));
container.invalidate_cache();
assert!(container.cache_dirty.load(Ordering::Relaxed));
container.ensure_cache_fresh();
assert!(!container.cache_dirty.load(Ordering::Relaxed));
}
#[test]
fn test_cache_invalidation_on_mutation() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Func1", 1, 0x1000).unwrap();
let _ = container.find_by_name("Func1");
assert!(!container.cache_dirty.load(Ordering::Relaxed));
container.add_native_function("Func2", 2, 0x2000).unwrap();
assert!(container.cache_dirty.load(Ordering::Relaxed));
let results = container.find_by_name("Func2");
assert_eq!(results.len(), 1);
assert!(!container.cache_dirty.load(Ordering::Relaxed));
}
#[test]
fn test_mixed_native_function_types() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("NamedFunc", 1, 0x1000)
.unwrap();
container.add_native_function_by_ordinal(2, 0x2000).unwrap();
container
.add_native_forwarder("ForwardedFunc", 3, "kernel32.dll.GetCurrentProcessId")
.unwrap();
assert_eq!(container.native().function_count(), 3);
assert_eq!(container.native().forwarder_count(), 1);
let named_results = container.find_by_name("NamedFunc");
assert_eq!(named_results.len(), 1);
let forwarder_results = container.find_by_name("ForwardedFunc");
assert_eq!(forwarder_results.len(), 1);
let all_functions = container.get_all_exported_functions();
assert_eq!(all_functions.len(), 2);
let forwarder_func = all_functions
.iter()
.find(|f| f.name == "ForwardedFunc")
.unwrap();
assert!(forwarder_func.is_forwarder);
assert_eq!(
forwarder_func.forwarder_target,
Some("kernel32.dll.GetCurrentProcessId".to_string())
);
}
#[test]
fn test_export_table_generation() {
let mut container = UnifiedExportContainer::with_dll_name("TestLib.dll");
container
.add_native_function("Function1", 1, 0x1000)
.unwrap();
container
.add_native_function("Function2", 2, 0x2000)
.unwrap();
container.native_mut().set_export_table_base_rva(0x3000);
let data = container.get_export_table_data().unwrap();
assert!(data.is_some());
let table_data = data.unwrap();
assert!(table_data.len() >= 40);
}
#[test]
fn test_empty_container_export_table() {
let container = UnifiedExportContainer::new();
let data = container.get_export_table_data().unwrap();
assert!(data.is_none());
}
#[test]
fn test_unified_find_by_name_multiple_sources() {
let mut container = UnifiedExportContainer::new();
container
.add_native_function("ProcessData", 1, 0x1000)
.unwrap();
container
.add_native_function("ProcessFile", 2, 0x2000)
.unwrap();
container
.add_native_forwarder("ProcessMessage", 3, "other.dll.HandleMessage")
.unwrap();
assert_eq!(container.find_by_name("ProcessData").len(), 1);
assert_eq!(container.find_by_name("ProcessFile").len(), 1);
assert_eq!(container.find_by_name("ProcessMessage").len(), 1);
assert!(container.find_by_name("NonExistent").is_empty());
}
#[test]
fn test_native_function_names_list() {
let mut container = UnifiedExportContainer::new();
container.add_native_function("Alpha", 1, 0x1000).unwrap();
container.add_native_function("Beta", 2, 0x2000).unwrap();
container.add_native_function_by_ordinal(3, 0x3000).unwrap(); container
.add_native_forwarder("Gamma", 4, "lib.dll.Func")
.unwrap();
let names = container.get_native_function_names();
assert_eq!(names.len(), 3);
assert!(names.contains(&"Alpha".to_string()));
assert!(names.contains(&"Beta".to_string()));
assert!(names.contains(&"Gamma".to_string()));
}
}