use dashmap::{mapref::entry::Entry, DashMap};
use std::{
collections::HashSet,
sync::atomic::{AtomicBool, Ordering},
};
use crate::{
metadata::{
imports::{
native::NativeImports, Import, ImportRc, ImportSourceId, ImportType,
Imports as CilImports,
},
token::Token,
},
Result,
};
pub struct UnifiedImportContainer {
cil: CilImports,
native: NativeImports,
unified_name_cache: DashMap<String, Vec<ImportEntry>>,
unified_dll_cache: DashMap<String, DllSource>,
cache_dirty: AtomicBool,
}
#[derive(Clone)]
pub enum ImportEntry {
Cil(ImportRc),
Native(NativeImportRef),
}
#[derive(Clone, Debug)]
pub struct NativeImportRef {
pub dll_name: String,
pub function_name: Option<String>,
pub ordinal: Option<u16>,
pub iat_rva: u32,
}
#[derive(Clone, Debug)]
pub enum DllSource {
Cil(Vec<Token>),
Native,
Both(Vec<Token>),
}
#[derive(Clone, Debug)]
pub struct DllDependency {
pub name: String,
pub source: DllSource,
pub functions: Vec<String>,
}
impl Clone for UnifiedImportContainer {
fn clone(&self) -> Self {
Self {
cil: self.cil.clone(),
native: self.native.clone(),
unified_name_cache: DashMap::new(), unified_dll_cache: DashMap::new(), cache_dirty: AtomicBool::new(true), }
}
}
impl UnifiedImportContainer {
#[must_use]
pub fn new() -> Self {
Self {
cil: CilImports::new(),
native: NativeImports::new(),
unified_name_cache: DashMap::new(),
unified_dll_cache: DashMap::new(),
cache_dirty: AtomicBool::new(true),
}
}
pub fn cil(&self) -> &CilImports {
&self.cil
}
pub fn native(&self) -> &NativeImports {
&self.native
}
pub fn native_mut(&mut self) -> &mut NativeImports {
self.invalidate_cache();
&mut self.native
}
pub fn find_by_name(&self, name: &str) -> Vec<ImportEntry> {
self.ensure_cache_fresh();
if let Some(entries) = self.unified_name_cache.get(name) {
entries.value().clone()
} else {
Vec::new()
}
}
pub fn get_all_dll_dependencies(&self) -> Vec<DllDependency> {
self.ensure_cache_fresh();
self.unified_dll_cache
.iter()
.map(|entry| {
let dll_name = entry.key();
DllDependency {
name: dll_name.clone(),
source: entry.value().clone(),
functions: self.get_functions_for_dll(dll_name),
}
})
.collect()
}
pub fn get_all_dll_names(&self) -> Vec<String> {
self.ensure_cache_fresh();
self.unified_dll_cache
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn is_empty(&self) -> bool {
self.cil.is_empty() && self.native.is_empty()
}
pub fn total_count(&self) -> usize {
self.cil.len() + self.native.total_function_count()
}
pub fn add_native_function(&mut self, dll_name: &str, function_name: &str) -> Result<()> {
self.native.add_dll(dll_name)?;
self.native.add_function(dll_name, function_name)?;
self.invalidate_cache();
Ok(())
}
pub fn add_native_function_by_ordinal(&mut self, dll_name: &str, ordinal: u16) -> Result<()> {
self.native.add_dll(dll_name)?;
self.native.add_function_by_ordinal(dll_name, ordinal)?;
self.invalidate_cache();
Ok(())
}
pub fn update_iat_rvas(&mut self, rva_delta: i64) -> Result<()> {
self.native.update_iat_rvas(rva_delta)?;
Ok(())
}
fn ensure_cache_fresh(&self) {
if self.cache_dirty.load(Ordering::Relaxed) {
self.rebuild_unified_caches();
self.cache_dirty.store(false, Ordering::Relaxed);
}
}
fn invalidate_cache(&self) {
self.cache_dirty.store(true, Ordering::Relaxed);
}
fn rebuild_unified_caches(&self) {
self.unified_name_cache.clear();
self.unified_dll_cache.clear();
let mut native_import_set: HashSet<(String, String)> = HashSet::new();
for descriptor in self.native.descriptors() {
let dll_lower = descriptor.dll_name.to_ascii_lowercase();
for function in &descriptor.functions {
if let Some(ref func_name) = function.name {
native_import_set.insert((dll_lower.clone(), func_name.to_ascii_lowercase()));
}
}
}
for import_entry in &self.cil {
let import = import_entry.value();
let token = *import_entry.key();
let is_duplicate = if matches!(import.import, ImportType::Method(_)) {
if let Some(dll_name) = Self::extract_dll_from_pinvoke_import(import, &self.cil) {
let key = (
dll_name.to_ascii_lowercase(),
import.name.to_ascii_lowercase(),
);
native_import_set.contains(&key)
} else {
false
}
} else {
false
};
if !is_duplicate {
self.unified_name_cache
.entry(import.name.clone())
.or_default()
.push(ImportEntry::Cil(import.clone()));
}
if matches!(import.import, ImportType::Method(_)) {
if let Some(dll_name) = Self::extract_dll_from_pinvoke_import(import, &self.cil) {
match self.unified_dll_cache.entry(dll_name) {
Entry::Occupied(mut entry) => match entry.get_mut() {
DllSource::Cil(tokens) | DllSource::Both(tokens) => tokens.push(token),
DllSource::Native => {
let tokens = vec![token];
*entry.get_mut() = DllSource::Both(tokens);
}
},
Entry::Vacant(entry) => {
entry.insert(DllSource::Cil(vec![token]));
}
}
}
}
}
for descriptor in self.native.descriptors() {
let dll_name = &descriptor.dll_name;
for function in &descriptor.functions {
if let Some(ref func_name) = function.name {
self.unified_name_cache
.entry(func_name.clone())
.or_default()
.push(ImportEntry::Native(NativeImportRef {
dll_name: dll_name.clone(),
function_name: Some(func_name.clone()),
ordinal: function.ordinal,
iat_rva: function.rva,
}));
}
match self.unified_dll_cache.entry(dll_name.clone()) {
Entry::Occupied(mut entry) => {
match entry.get() {
DllSource::Cil(tokens) => {
let tokens = tokens.clone();
*entry.get_mut() = DllSource::Both(tokens);
}
DllSource::Native | DllSource::Both(_) => {
}
}
}
Entry::Vacant(entry) => {
entry.insert(DllSource::Native);
}
}
}
}
}
fn extract_dll_from_pinvoke_import(
import: &Import,
cil_imports: &CilImports,
) -> Option<String> {
if !matches!(import.import, ImportType::Method(_)) {
return None;
}
if let ImportSourceId::ModuleRef(token) = import.source_id {
if let Some(module_ref) = cil_imports.get_module_ref(token) {
return Some(module_ref.name.clone());
}
}
None
}
fn get_functions_for_dll(&self, dll_name: &str) -> Vec<String> {
let mut functions = HashSet::new();
if let Some(descriptor) = self.native.get_descriptor(dll_name) {
for function in &descriptor.functions {
if let Some(ref name) = function.name {
functions.insert(name.clone());
} else if let Some(ordinal) = function.ordinal {
functions.insert(format!("#{ordinal}"));
}
}
}
for import_entry in &self.cil {
let import = import_entry.value();
if let Some(import_dll) = Self::extract_dll_from_pinvoke_import(import, &self.cil) {
if import_dll.eq_ignore_ascii_case(dll_name) {
functions.insert(import.name.clone());
}
}
}
functions.into_iter().collect()
}
}
impl Default for UnifiedImportContainer {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for UnifiedImportContainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnifiedImportContainer")
.field("cil_count", &self.cil.len())
.field("native_dll_count", &self.native.dll_count())
.field("native_function_count", &self.native.total_function_count())
.field("is_cache_dirty", &self.cache_dirty.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{create_method, create_module_ref};
#[test]
fn test_unified_import_container_new() {
let container = UnifiedImportContainer::new();
assert!(container.is_empty());
assert_eq!(container.total_count(), 0);
}
#[test]
fn test_unified_import_container_default() {
let container = UnifiedImportContainer::default();
assert!(container.is_empty());
}
#[test]
fn test_add_native_function() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetCurrentProcessId")
.unwrap();
assert!(!container.is_empty());
assert!(container.total_count() >= 1);
}
#[test]
fn test_add_multiple_native_functions_same_dll() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetCurrentProcessId")
.unwrap();
container
.add_native_function("kernel32.dll", "GetCurrentThreadId")
.unwrap();
container
.add_native_function("kernel32.dll", "GetLastError")
.unwrap();
assert!(!container.is_empty());
assert!(container.total_count() >= 3);
}
#[test]
fn test_add_native_functions_multiple_dlls() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetCurrentProcessId")
.unwrap();
container
.add_native_function("user32.dll", "MessageBoxW")
.unwrap();
assert!(!container.is_empty());
assert!(container.total_count() >= 2);
}
#[test]
fn test_add_native_function_by_ordinal() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function_by_ordinal("user32.dll", 100)
.unwrap();
assert!(!container.is_empty());
}
#[test]
fn test_add_native_function_by_ordinal_invalid() {
let mut container = UnifiedImportContainer::new();
let result = container.add_native_function_by_ordinal("user32.dll", 0);
assert!(result.is_err());
}
#[test]
fn test_find_by_name_native() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "TestImport")
.unwrap();
let results = container.find_by_name("TestImport");
assert_eq!(results.len(), 1);
if let ImportEntry::Native(native_ref) = &results[0] {
assert_eq!(native_ref.dll_name, "kernel32.dll");
assert_eq!(native_ref.function_name, Some("TestImport".to_string()));
} else {
panic!("Expected Native import entry");
}
}
#[test]
fn test_find_by_name_not_found() {
let container = UnifiedImportContainer::new();
let results = container.find_by_name("NonExistent");
assert!(results.is_empty());
}
#[test]
fn test_get_all_dll_names() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "Func1")
.unwrap();
container
.add_native_function("user32.dll", "Func2")
.unwrap();
container
.add_native_function("advapi32.dll", "Func3")
.unwrap();
let dll_names = container.get_all_dll_names();
assert_eq!(dll_names.len(), 3);
assert!(dll_names.contains(&"kernel32.dll".to_string()));
assert!(dll_names.contains(&"user32.dll".to_string()));
assert!(dll_names.contains(&"advapi32.dll".to_string()));
}
#[test]
fn test_get_all_dll_dependencies() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetCurrentProcessId")
.unwrap();
container
.add_native_function("kernel32.dll", "GetLastError")
.unwrap();
let dependencies = container.get_all_dll_dependencies();
assert!(!dependencies.is_empty());
let kernel32_dep = dependencies.iter().find(|d| d.name == "kernel32.dll");
assert!(kernel32_dep.is_some());
let dep = kernel32_dep.unwrap();
assert!(dep.functions.len() >= 2);
assert!(dep.functions.contains(&"GetCurrentProcessId".to_string()));
assert!(dep.functions.contains(&"GetLastError".to_string()));
}
#[test]
fn test_cil_accessor() {
let container = UnifiedImportContainer::new();
let cil = container.cil();
assert!(cil.is_empty());
}
#[test]
fn test_native_accessor() {
let container = UnifiedImportContainer::new();
let native = container.native();
assert!(native.is_empty());
}
#[test]
fn test_native_mut_invalidates_cache() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("test.dll", "TestFunc")
.unwrap();
let _ = container.find_by_name("TestFunc");
let _ = container.native_mut();
assert!(container.cache_dirty.load(Ordering::Relaxed));
}
#[test]
fn test_clone_resets_cache() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("test.dll", "TestFunc")
.unwrap();
let _ = container.find_by_name("TestFunc");
let cloned = container.clone();
assert!(cloned.cache_dirty.load(Ordering::Relaxed));
assert!(!cloned.is_empty());
}
#[test]
fn test_debug_output() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("test.dll", "TestFunc")
.unwrap();
let debug_output = format!("{:?}", container);
assert!(debug_output.contains("UnifiedImportContainer"));
assert!(debug_output.contains("native_dll_count"));
assert!(debug_output.contains("native_function_count"));
}
#[test]
fn test_native_import_ref_structure() {
let import_ref = NativeImportRef {
dll_name: "kernel32.dll".to_string(),
function_name: Some("GetCurrentProcessId".to_string()),
ordinal: None,
iat_rva: 0x1000,
};
assert_eq!(import_ref.dll_name, "kernel32.dll");
assert_eq!(
import_ref.function_name,
Some("GetCurrentProcessId".to_string())
);
assert!(import_ref.ordinal.is_none());
assert_eq!(import_ref.iat_rva, 0x1000);
}
#[test]
fn test_native_import_ref_ordinal() {
let import_ref = NativeImportRef {
dll_name: "user32.dll".to_string(),
function_name: None,
ordinal: Some(120),
iat_rva: 0x2000,
};
assert!(import_ref.function_name.is_none());
assert_eq!(import_ref.ordinal, Some(120));
}
#[test]
fn test_native_import_ref_clone() {
let import_ref = NativeImportRef {
dll_name: "test.dll".to_string(),
function_name: Some("TestFunc".to_string()),
ordinal: None,
iat_rva: 0x3000,
};
let cloned = import_ref.clone();
assert_eq!(cloned.dll_name, "test.dll");
assert_eq!(cloned.function_name, Some("TestFunc".to_string()));
}
#[test]
fn test_dll_source_cil() {
let token = Token::new(0x06000001);
let source = DllSource::Cil(vec![token]);
if let DllSource::Cil(tokens) = source {
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0], token);
} else {
panic!("Expected Cil variant");
}
}
#[test]
fn test_dll_source_native() {
let source = DllSource::Native;
assert!(matches!(source, DllSource::Native));
}
#[test]
fn test_dll_source_both() {
let token = Token::new(0x06000001);
let source = DllSource::Both(vec![token]);
if let DllSource::Both(tokens) = source {
assert_eq!(tokens.len(), 1);
} else {
panic!("Expected Both variant");
}
}
#[test]
fn test_dll_dependency_structure() {
let dep = DllDependency {
name: "kernel32.dll".to_string(),
source: DllSource::Native,
functions: vec![
"GetCurrentProcessId".to_string(),
"GetLastError".to_string(),
],
};
assert_eq!(dep.name, "kernel32.dll");
assert!(matches!(dep.source, DllSource::Native));
assert_eq!(dep.functions.len(), 2);
}
#[test]
fn test_update_iat_rvas_positive() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("test.dll", "TestFunc")
.unwrap();
let result = container.update_iat_rvas(0x1000);
assert!(result.is_ok());
}
#[test]
fn test_update_iat_rvas_negative() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("test.dll", "TestFunc")
.unwrap();
let result = container.update_iat_rvas(-0x100);
assert!(result.is_ok());
}
#[test]
fn test_cil_pinvoke_dll_extraction() {
let container = UnifiedImportContainer::new();
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GetProcessId");
let token = Token::new(0x0A000001);
container
.cil
.add_method("GetProcessId".to_string(), &token, method, &module_ref)
.expect("Failed to add method import");
let dll_names = container.get_all_dll_names();
assert!(
dll_names.iter().any(|n| n == "kernel32.dll"),
"kernel32.dll should appear in DLL dependencies. Found: {:?}",
dll_names
);
}
#[test]
fn test_cil_pinvoke_functions_for_dll() {
let container = UnifiedImportContainer::new();
let module_ref = create_module_ref(1, "kernel32.dll");
let method1 = create_method("GetProcessId");
let method2 = create_method("GetCurrentProcess");
let method3 = create_method("ExitProcess");
container
.cil
.add_method(
"GetProcessId".to_string(),
&Token::new(0x0A000001),
method1,
&module_ref,
)
.expect("Failed to add method import");
container
.cil
.add_method(
"GetCurrentProcess".to_string(),
&Token::new(0x0A000002),
method2,
&module_ref,
)
.expect("Failed to add method import");
container
.cil
.add_method(
"ExitProcess".to_string(),
&Token::new(0x0A000003),
method3,
&module_ref,
)
.expect("Failed to add method import");
let dependencies = container.get_all_dll_dependencies();
let kernel32_dep = dependencies
.iter()
.find(|d| d.name == "kernel32.dll")
.expect("kernel32.dll should be in dependencies");
assert!(
kernel32_dep.functions.contains(&"GetProcessId".to_string()),
"GetProcessId should be in functions. Found: {:?}",
kernel32_dep.functions
);
assert!(
kernel32_dep
.functions
.contains(&"GetCurrentProcess".to_string()),
"GetCurrentProcess should be in functions"
);
assert!(
kernel32_dep.functions.contains(&"ExitProcess".to_string()),
"ExitProcess should be in functions"
);
}
#[test]
fn test_cil_pinvoke_find_by_name() {
let container = UnifiedImportContainer::new();
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("TestPInvokeMethod");
let token = Token::new(0x0A000001);
container
.cil
.add_method("TestPInvokeMethod".to_string(), &token, method, &module_ref)
.expect("Failed to add method import");
let results = container.find_by_name("TestPInvokeMethod");
assert_eq!(results.len(), 1, "Should find exactly one import");
if let ImportEntry::Cil(cil_import) = &results[0] {
assert_eq!(cil_import.name, "TestPInvokeMethod");
assert_eq!(cil_import.token, token);
} else {
panic!("Expected CIL import entry, got Native");
}
}
#[test]
fn test_mixed_cil_and_native_same_dll() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetLastError")
.expect("Failed to add native function");
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GetProcessId");
container
.cil
.add_method(
"GetProcessId".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let dependencies = container.get_all_dll_dependencies();
let kernel32_dep = dependencies
.iter()
.find(|d| d.name == "kernel32.dll")
.expect("kernel32.dll should be in dependencies");
assert!(
kernel32_dep.functions.contains(&"GetLastError".to_string()),
"GetLastError should be in functions (native)"
);
assert!(
kernel32_dep.functions.contains(&"GetProcessId".to_string()),
"GetProcessId should be in functions (CIL P/Invoke)"
);
assert!(
matches!(kernel32_dep.source, DllSource::Both(_)),
"Source should be Both since both CIL and native use kernel32.dll. Got: {:?}",
kernel32_dep.source
);
}
#[test]
fn test_cil_pinvoke_case_insensitive_dll_lookup() {
let container = UnifiedImportContainer::new();
let module_ref = create_module_ref(1, "KERNEL32.DLL");
let method = create_method("TestFunc");
container
.cil
.add_method(
"TestFunc".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let functions = container.get_functions_for_dll("kernel32.dll");
assert!(
functions.contains(&"TestFunc".to_string()),
"Should find function with case-insensitive DLL name lookup"
);
}
#[test]
fn test_deduplication_cil_and_native_same_function() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetLastError")
.expect("Failed to add native function");
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GetLastError");
container
.cil
.add_method(
"GetLastError".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let results = container.find_by_name("GetLastError");
assert_eq!(
results.len(),
1,
"Should deduplicate and return only one entry. Found: {}",
results.len()
);
assert!(
matches!(&results[0], ImportEntry::Native(_)),
"The deduplicated entry should be the Native import (has IAT info)"
);
}
#[test]
fn test_deduplication_case_insensitive() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("KERNEL32.DLL", "GetLastError")
.expect("Failed to add native function");
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GETLASTERROR");
container
.cil
.add_method(
"GETLASTERROR".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let results_cil = container.find_by_name("GETLASTERROR");
assert_eq!(
results_cil.len(),
0,
"CIL import with same function (case-insensitive) should be deduplicated"
);
let results_native = container.find_by_name("GetLastError");
assert_eq!(results_native.len(), 1, "Native import should be present");
}
#[test]
fn test_deduplication_preserves_non_duplicate_cil() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetLastError")
.expect("Failed to add native function");
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GetProcessId");
container
.cil
.add_method(
"GetProcessId".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let results = container.find_by_name("GetProcessId");
assert_eq!(
results.len(),
1,
"Non-duplicate CIL import should still be present"
);
assert!(
matches!(&results[0], ImportEntry::Cil(_)),
"Should be a CIL import entry"
);
let results_native = container.find_by_name("GetLastError");
assert_eq!(results_native.len(), 1);
}
#[test]
fn test_deduplication_dll_source_still_both() {
let mut container = UnifiedImportContainer::new();
container
.add_native_function("kernel32.dll", "GetLastError")
.expect("Failed to add native function");
let module_ref = create_module_ref(1, "kernel32.dll");
let method = create_method("GetLastError");
container
.cil
.add_method(
"GetLastError".to_string(),
&Token::new(0x0A000001),
method,
&module_ref,
)
.expect("Failed to add method import");
let dependencies = container.get_all_dll_dependencies();
let kernel32_dep = dependencies
.iter()
.find(|d| d.name == "kernel32.dll")
.expect("kernel32.dll should be in dependencies");
assert!(
matches!(kernel32_dep.source, DllSource::Both(_)),
"DLL source should be Both even when name cache is deduplicated. Got: {:?}",
kernel32_dep.source
);
}
}