use crate::{cilassembly::CilAssembly, Result};
#[derive(Debug, Clone)]
pub struct NativeExportsBuilder {
dll_name: String,
functions: Vec<(String, u16, u32)>,
ordinal_functions: Vec<(u16, u32)>,
forwarders: Vec<(String, u16, String)>,
next_ordinal: u16,
}
impl NativeExportsBuilder {
pub fn new(dll_name: impl Into<String>) -> Self {
Self {
dll_name: dll_name.into(),
functions: Vec::new(),
ordinal_functions: Vec::new(),
forwarders: Vec::new(),
next_ordinal: 1,
}
}
#[must_use]
pub fn add_function(mut self, name: impl Into<String>, ordinal: u16, address: u32) -> Self {
self.functions.push((name.into(), ordinal, address));
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
self
}
#[must_use]
pub fn add_function_auto(mut self, name: impl Into<String>, address: u32) -> Self {
let ordinal = self.next_ordinal;
self.functions.push((name.into(), ordinal, address));
self.next_ordinal += 1;
self
}
#[must_use]
pub fn add_function_by_ordinal(mut self, ordinal: u16, address: u32) -> Self {
self.ordinal_functions.push((ordinal, address));
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
self
}
#[must_use]
pub fn add_function_by_ordinal_auto(mut self, address: u32) -> Self {
let ordinal = self.next_ordinal;
self.ordinal_functions.push((ordinal, address));
self.next_ordinal += 1;
self
}
#[must_use]
pub fn add_forwarder(
mut self,
name: impl Into<String>,
ordinal: u16,
target: impl Into<String>,
) -> Self {
self.forwarders.push((name.into(), ordinal, target.into()));
if ordinal >= self.next_ordinal {
self.next_ordinal = ordinal + 1;
}
self
}
#[must_use]
pub fn add_forwarder_auto(
mut self,
name: impl Into<String>,
target: impl Into<String>,
) -> Self {
let ordinal = self.next_ordinal;
self.forwarders.push((name.into(), ordinal, target.into()));
self.next_ordinal += 1;
self
}
#[must_use]
pub fn dll_name(mut self, dll_name: impl Into<String>) -> Self {
self.dll_name = dll_name.into();
self
}
pub fn build(self, assembly: &mut CilAssembly) -> Result<()> {
for (name, ordinal, address) in &self.functions {
assembly.add_native_export_function(name, *ordinal, *address)?;
}
for (ordinal, address) in &self.ordinal_functions {
assembly.add_native_export_function_by_ordinal(*ordinal, *address)?;
}
for (name, ordinal, target) in &self.forwarders {
assembly.add_native_export_forwarder(name, *ordinal, target)?;
}
Ok(())
}
}
impl Default for NativeExportsBuilder {
fn default() -> Self {
Self::new("Unknown.dll")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cilassembly::CilAssembly;
use std::path::PathBuf;
#[test]
fn test_native_exports_builder_basic() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeExportsBuilder::new("TestLibrary.dll")
.add_function("MyFunction", 1, 0x1000)
.add_function("AnotherFunction", 2, 0x2000)
.build(&mut assembly);
assert!(result.is_ok());
}
}
#[test]
fn test_native_exports_builder_with_ordinals() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeExportsBuilder::new("TestLibrary.dll")
.add_function_by_ordinal(100, 0x1000)
.add_function("NamedFunction", 101, 0x2000)
.build(&mut assembly);
assert!(result.is_ok());
}
}
#[test]
fn test_native_exports_builder_with_forwarders() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeExportsBuilder::new("TestLibrary.dll")
.add_function("RegularFunction", 1, 0x1000)
.add_forwarder("ForwardedFunc", 2, "kernel32.dll.GetCurrentProcessId")
.add_forwarder("OrdinalForward", 3, "user32.dll.#120")
.build(&mut assembly);
assert!(result.is_ok());
}
}
#[test]
fn test_native_exports_builder_auto_ordinals() {
let builder = NativeExportsBuilder::new("TestLibrary.dll")
.add_function_auto("Function1", 0x1000)
.add_function_auto("Function2", 0x2000)
.add_function_by_ordinal_auto(0x3000)
.add_forwarder_auto("Forwarder1", "kernel32.dll.GetTick");
assert_eq!(builder.functions.len(), 2);
assert_eq!(builder.ordinal_functions.len(), 1);
assert_eq!(builder.forwarders.len(), 1);
assert_eq!(builder.functions[0].1, 1); assert_eq!(builder.functions[1].1, 2); assert_eq!(builder.ordinal_functions[0].0, 3); assert_eq!(builder.forwarders[0].1, 4);
assert_eq!(builder.next_ordinal, 5);
}
#[test]
fn test_native_exports_builder_mixed_ordinals() {
let builder = NativeExportsBuilder::new("TestLibrary.dll")
.add_function("Function1", 10, 0x1000) .add_function_auto("Function2", 0x2000) .add_function("Function3", 5, 0x3000) .add_function_auto("Function4", 0x4000);
assert_eq!(builder.functions[0].1, 10); assert_eq!(builder.functions[1].1, 11); assert_eq!(builder.functions[2].1, 5); assert_eq!(builder.functions[3].1, 12);
assert_eq!(builder.next_ordinal, 13);
}
#[test]
fn test_native_exports_builder_dll_name_change() {
let builder = NativeExportsBuilder::new("Original.dll")
.dll_name("Changed.dll")
.add_function("MyFunction", 1, 0x1000);
assert_eq!(builder.dll_name, "Changed.dll");
}
#[test]
fn test_native_exports_builder_empty() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(mut assembly) = CilAssembly::from_path(&path) {
let result = NativeExportsBuilder::new("EmptyLibrary.dll").build(&mut assembly);
assert!(result.is_ok());
}
}
#[test]
fn test_native_exports_builder_fluent_api() {
let builder = NativeExportsBuilder::new("TestLibrary.dll")
.add_function("Function1", 1, 0x1000)
.add_function_auto("Function2", 0x2000)
.add_function_by_ordinal(10, 0x3000)
.add_function_by_ordinal_auto(0x4000)
.add_forwarder("Forwarder1", 20, "kernel32.dll.GetCurrentProcessId")
.add_forwarder_auto("Forwarder2", "user32.dll.MessageBoxW")
.dll_name("FinalName.dll");
assert_eq!(builder.dll_name, "FinalName.dll");
assert_eq!(builder.functions.len(), 2);
assert_eq!(builder.ordinal_functions.len(), 2);
assert_eq!(builder.forwarders.len(), 2);
assert!(builder
.functions
.iter()
.any(|(name, ord, _)| name == "Function1" && *ord == 1));
assert!(builder
.functions
.iter()
.any(|(name, ord, _)| name == "Function2" && *ord == 2));
assert!(builder.ordinal_functions.iter().any(|(ord, _)| *ord == 10));
assert!(builder
.forwarders
.iter()
.any(|(name, ord, target)| name == "Forwarder1"
&& *ord == 20
&& target == "kernel32.dll.GetCurrentProcessId"));
assert!(builder.next_ordinal > 20);
}
}